Skip to content

Commit 8a3f366

Browse files
bdu91LIT team
authored andcommitted
Make multi label prediction scores visible in the data table column.
PiperOrigin-RevId: 570758611
1 parent e7115ab commit 8a3f366

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

lit_nlp/client/services/data_service.ts

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import {action, computed, observable, reaction} from 'mobx';
2020

2121
import {BINARY_NEG_POS, ColorRange} from '../lib/colors';
22-
import {BooleanLitType, CategoryLabel, GeneratedText, GeneratedTextCandidates, LitType, MulticlassPreds, RegressionScore, Scalar} from '../lib/lit_types';
22+
import {BooleanLitType, CategoryLabel, GeneratedText, GeneratedTextCandidates, LitType, MulticlassPreds, RegressionScore, Scalar, SparseMultilabelPreds} from '../lib/lit_types';
2323
import {ClassificationResults, IndexedInput, RegressionResults} from '../lib/types';
2424
import {createLitType, findSpecKeys, isLitSubtype, mapsContainSame} from '../lib/utils';
2525

@@ -68,6 +68,8 @@ export const GEN_TEXT_CANDS_SOURCE_PREFIX = 'GeneratedTextCandidates';
6868
export const REGRESSION_SOURCE_PREFIX = 'Regression';
6969
/** Column source prefix for columns from scalar model outputs. */
7070
export const SCALAR_SOURCE_PREFIX = 'Scalar';
71+
/** Column source prefix for columns from multilabel model outputs. */
72+
export const MULTILABEL_SOURCE_PREFIX = 'Multilabel';
7173

7274
/**
7375
* Data service singleton, responsible for maintaining columns of computed data
@@ -109,7 +111,7 @@ export class DataService extends LitService {
109111
}
110112
}, {fireImmediately: true});
111113

112-
// Run other preiction interpreters when necessary.
114+
// Run other prediction interpreters when necessary.
113115
const getPredictionInputs =
114116
() => [this.appState.currentInputData, this.appState.currentModels];
115117
reaction(getPredictionInputs, () => {
@@ -124,6 +126,7 @@ export class DataService extends LitService {
124126
this.runGeneratedTextPreds(model, this.appState.currentInputData);
125127
this.runRegression(model, this.appState.currentInputData);
126128
this.runScalarPreds(model, this.appState.currentInputData);
129+
this.runMultiLabelPreds(model, this.appState.currentInputData);
127130
}
128131
}, {fireImmediately: true});
129132

@@ -301,6 +304,36 @@ export class DataService extends LitService {
301304
}
302305
}
303306

307+
/**
308+
* Run multi label predictions and store results in data service.
309+
*/
310+
private async runMultiLabelPreds(model: string, data: IndexedInput[]) {
311+
const {output} = this.appState.getModelSpec(model);
312+
if (findSpecKeys(output, SparseMultilabelPreds).length === 0) {
313+
return;
314+
}
315+
316+
const multiLabelPredsPromise = this.apiService.getPreds(
317+
data, model, this.appState.currentDataset, [SparseMultilabelPreds]);
318+
const preds = await multiLabelPredsPromise;
319+
320+
// Add multi label prediction results as new column to the data service.
321+
if (preds == null || preds.length === 0) {
322+
return;
323+
}
324+
const multiLabelPredKeys = Object.keys(preds[0]);
325+
for (const key of multiLabelPredKeys) {
326+
const scoreFeatName = this.getColumnName(model, key);
327+
const scores = preds.map(pred => pred[key]);
328+
// TODO(b/303457849): maybe possible to directly use the data type from
329+
// the output spec rather than creating a new one.
330+
const dataType = createLitType(SparseMultilabelPreds);
331+
const source = `${MULTILABEL_SOURCE_PREFIX}:${model}`;
332+
this.addColumnFromList(
333+
scores, data, key, scoreFeatName, dataType, source);
334+
}
335+
}
336+
304337
@action
305338
async setValuesForNewDatapoints(datapoints: IndexedInput[]) {
306339
// When new datapoints are created, set their data values for each

0 commit comments

Comments
 (0)