1919import { action , computed , observable , reaction } from 'mobx' ;
2020
2121import { BINARY_NEG_POS , ColorRange } from '../lib/colors' ;
22- import { BooleanLitType , CategoryLabel , GeneratedText , GeneratedTextCandidates , LitType , MulticlassPreds , RegressionScore , Scalar } from '../lib/lit_types' ;
22+ import { BooleanLitType , CategoryLabel , GeneratedText , GeneratedTextCandidates , LitType , MulticlassPreds , RegressionScore , Scalar , SparseMultilabelPreds } from '../lib/lit_types' ;
2323import { ClassificationResults , IndexedInput , RegressionResults } from '../lib/types' ;
2424import { createLitType , findSpecKeys , isLitSubtype , mapsContainSame } from '../lib/utils' ;
2525
@@ -68,6 +68,8 @@ export const GEN_TEXT_CANDS_SOURCE_PREFIX = 'GeneratedTextCandidates';
6868export const REGRESSION_SOURCE_PREFIX = 'Regression' ;
6969/** Column source prefix for columns from scalar model outputs. */
7070export const SCALAR_SOURCE_PREFIX = 'Scalar' ;
71+ /** Column source prefix for columns from multilabel model outputs. */
72+ export const MULTILABEL_SOURCE_PREFIX = 'Multilabel' ;
7173
7274/**
7375 * Data service singleton, responsible for maintaining columns of computed data
@@ -109,7 +111,7 @@ export class DataService extends LitService {
109111 }
110112 } , { fireImmediately : true } ) ;
111113
112- // Run other preiction interpreters when necessary.
114+ // Run other prediction interpreters when necessary.
113115 const getPredictionInputs =
114116 ( ) => [ this . appState . currentInputData , this . appState . currentModels ] ;
115117 reaction ( getPredictionInputs , ( ) => {
@@ -124,6 +126,7 @@ export class DataService extends LitService {
124126 this . runGeneratedTextPreds ( model , this . appState . currentInputData ) ;
125127 this . runRegression ( model , this . appState . currentInputData ) ;
126128 this . runScalarPreds ( model , this . appState . currentInputData ) ;
129+ this . runMultiLabelPreds ( model , this . appState . currentInputData ) ;
127130 }
128131 } , { fireImmediately : true } ) ;
129132
@@ -301,6 +304,36 @@ export class DataService extends LitService {
301304 }
302305 }
303306
307+ /**
308+ * Run multi label predictions and store results in data service.
309+ */
310+ private async runMultiLabelPreds ( model : string , data : IndexedInput [ ] ) {
311+ const { output} = this . appState . getModelSpec ( model ) ;
312+ if ( findSpecKeys ( output , SparseMultilabelPreds ) . length === 0 ) {
313+ return ;
314+ }
315+
316+ const multiLabelPredsPromise = this . apiService . getPreds (
317+ data , model , this . appState . currentDataset , [ SparseMultilabelPreds ] ) ;
318+ const preds = await multiLabelPredsPromise ;
319+
320+ // Add multi label prediction results as new column to the data service.
321+ if ( preds == null || preds . length === 0 ) {
322+ return ;
323+ }
324+ const multiLabelPredKeys = Object . keys ( preds [ 0 ] ) ;
325+ for ( const key of multiLabelPredKeys ) {
326+ const scoreFeatName = this . getColumnName ( model , key ) ;
327+ const scores = preds . map ( pred => pred [ key ] ) ;
328+ // TODO(b/303457849): maybe possible to directly use the data type from
329+ // the output spec rather than creating a new one.
330+ const dataType = createLitType ( SparseMultilabelPreds ) ;
331+ const source = `${ MULTILABEL_SOURCE_PREFIX } :${ model } ` ;
332+ this . addColumnFromList (
333+ scores , data , key , scoreFeatName , dataType , source ) ;
334+ }
335+ }
336+
304337 @action
305338 async setValuesForNewDatapoints ( datapoints : IndexedInput [ ] ) {
306339 // When new datapoints are created, set their data values for each
0 commit comments