1919
2020import * as utils from './utils' ;
2121
22+ type Entry = { value : number ; row : number ; col : number } ;
23+
2224/**
2325 * Internal 2-dimensional sparse matrix class
2426 */
2527export class SparseMatrix {
26- private rows : number [ ] ;
27- private cols : number [ ] ;
28- private values : number [ ] ;
29-
30- private entries = new Map < string , number > ( ) ;
28+ private entries = new Map < string , Entry > ( ) ;
3129
3230 readonly nRows : number = 0 ;
3331 readonly nCols : number = 0 ;
@@ -38,19 +36,20 @@ export class SparseMatrix {
3836 values : number [ ] ,
3937 dims : number [ ]
4038 ) {
41- // TODO: Assert that rows / cols / vals are the same length.
42- this . rows = [ ...rows ] ;
43- this . cols = [ ...cols ] ;
44- this . values = [ ...values ] ;
45-
46- for ( let i = 0 ; i < values . length ; i ++ ) {
47- const key = this . makeKey ( this . rows [ i ] , this . cols [ i ] ) ;
48- this . entries . set ( key , i ) ;
39+ if ( ( rows . length !== cols . length ) || ( rows . length !== values . length ) ) {
40+ throw new Error ( "rows, cols and values arrays must all have the same length" ) ;
4941 }
5042
5143 // TODO: Assert that dims are legit.
5244 this . nRows = dims [ 0 ] ;
5345 this . nCols = dims [ 1 ] ;
46+ for ( let i = 0 ; i < values . length ; i ++ ) {
47+ const row = rows [ i ] ;
48+ const col = cols [ i ] ;
49+ this . checkDims ( row , col ) ;
50+ const key = this . makeKey ( row , col ) ;
51+ this . entries . set ( key , { value : values [ i ] , row, col } ) ;
52+ }
5453 }
5554
5655 private makeKey ( row : number , col : number ) : string {
@@ -60,74 +59,84 @@ export class SparseMatrix {
6059 private checkDims ( row : number , col : number ) {
6160 const withinBounds = row < this . nRows && col < this . nCols ;
6261 if ( ! withinBounds ) {
63- throw new Error ( 'array index out of bounds ' ) ;
62+ throw new Error ( 'row and/or col specified outside of matrix dimensions ' ) ;
6463 }
6564 }
6665
6766 set ( row : number , col : number , value : number ) {
6867 this . checkDims ( row , col ) ;
6968 const key = this . makeKey ( row , col ) ;
7069 if ( ! this . entries . has ( key ) ) {
71- this . rows . push ( row ) ;
72- this . cols . push ( col ) ;
73- this . values . push ( value ) ;
74- this . entries . set ( key , this . values . length - 1 ) ;
70+ this . entries . set ( key , { value, row, col } ) ;
7571 } else {
76- const index = this . entries . get ( key ) ! ;
77- this . values [ index ] = value ;
72+ this . entries . get ( key ) ! . value = value ;
7873 }
7974 }
8075
8176 get ( row : number , col : number , defaultValue = 0 ) {
8277 this . checkDims ( row , col ) ;
8378 const key = this . makeKey ( row , col ) ;
8479 if ( this . entries . has ( key ) ) {
85- const index = this . entries . get ( key ) ! ;
86- return this . values [ index ] ;
80+ return this . entries . get ( key ) ! . value ;
8781 } else {
8882 return defaultValue ;
8983 }
9084 }
9185
86+ getAll ( ordered = true ) : { value : number ; row : number ; col : number } [ ] {
87+ const rowColValues : Entry [ ] = [ ] ;
88+ this . entries . forEach ( ( value ) => {
89+ rowColValues . push ( value ) ;
90+ } ) ;
91+ if ( ordered ) { // Ordering the result isn't required for processing but it does make it easier to write tests
92+ rowColValues . sort ( ( a , b ) => {
93+ if ( a . row === b . row ) {
94+ return a . col - b . col ;
95+ } else {
96+ return a . row - b . row ;
97+ }
98+ } ) ;
99+ }
100+ return rowColValues ;
101+ }
102+
92103 getDims ( ) : number [ ] {
93104 return [ this . nRows , this . nCols ] ;
94105 }
95106
96107 getRows ( ) : number [ ] {
97- return [ ... this . rows ] ;
108+ return Array . from ( this . entries , ( [ key , value ] ) => value . row ) ;
98109 }
99110
100111 getCols ( ) : number [ ] {
101- return [ ... this . cols ] ;
112+ return Array . from ( this . entries , ( [ key , value ] ) => value . col ) ;
102113 }
103114
104115 getValues ( ) : number [ ] {
105- return [ ... this . values ] ;
116+ return Array . from ( this . entries , ( [ key , value ] ) => value . value ) ;
106117 }
107118
108119 forEach ( fn : ( value : number , row : number , col : number ) => void ) : void {
109- for ( let i = 0 ; i < this . values . length ; i ++ ) {
110- fn ( this . values [ i ] , this . rows [ i ] , this . cols [ i ] ) ;
111- }
120+ this . entries . forEach ( ( value ) => fn ( value . value , value . row , value . col ) ) ;
112121 }
113122
114123 map ( fn : ( value : number , row : number , col : number ) => number ) : SparseMatrix {
115124 let vals : number [ ] = [ ] ;
116- for ( let i = 0 ; i < this . values . length ; i ++ ) {
117- vals . push ( fn ( this . values [ i ] , this . rows [ i ] , this . cols [ i ] ) ) ;
118- }
125+ this . entries . forEach ( ( value ) => {
126+ vals . push ( fn ( value . value , value . row , value . col ) ) ;
127+ } ) ;
119128 const dims = [ this . nRows , this . nCols ] ;
120- return new SparseMatrix ( this . rows , this . cols , vals , dims ) ;
129+ return new SparseMatrix ( this . getRows ( ) , this . getCols ( ) , vals , dims ) ;
121130 }
122131
123132 toArray ( ) {
124133 const rows : undefined [ ] = utils . empty ( this . nRows ) ;
125134 const output = rows . map ( ( ) => {
126135 return utils . zeros ( this . nCols ) ;
127136 } ) ;
128- for ( let i = 0 ; i < this . values . length ; i ++ ) {
129- output [ this . rows [ i ] ] [ this . cols [ i ] ] = this . values [ i ] ;
130- }
137+ this . entries . forEach ( ( value ) => {
138+ output [ value . row ] [ value . col ] = value . value ;
139+ } ) ;
131140 return output ;
132141 }
133142}
@@ -338,7 +347,6 @@ function elementWise(
338347 * search logic depends on this data format.
339348 */
340349export function getCSR ( x : SparseMatrix ) {
341- type Entry = { value : number ; row : number ; col : number } ;
342350 const entries : Entry [ ] = [ ] ;
343351
344352 x . forEach ( ( value , row , col ) => {
0 commit comments