@@ -152,15 +152,25 @@ def get_uci_config(dataset, model, use_baseline):
152152 elif model == "nsf" :
153153 if dataset in ["power" , "gas" ]:
154154 config = {
155- "num_density_layers" : 10 ,
155+ "num_u_channels" : 0 if use_baseline else 2 ,
156+ "num_density_layers" : 10 if use_baseline else 7 ,
156157 "num_hidden_layers" : 2 ,
157158 "num_hidden_channels" : 256 ,
158159 "num_bins" : 8 ,
159- "dropout_probability" : 0. if dataset == "power" else 0.1
160+ "dropout_probability" : 0. if dataset == "power" else 0.1 ,
161+
162+ "st_nets" : [128 ] * 2 ,
163+ "p_nets" : [200 ] * 2 ,
164+ "q_nets" : [200 ] * 2 ,
165+
166+ "lr" : 1e-3 ,
167+ "train_batch_size" : 5000
160168 }
161169
162170 elif dataset == "hepmass" :
171+ assert use_baseline
163172 config = {
173+ "num_u_channels" : 0 ,
164174 "num_density_layers" : 20 ,
165175 "num_hidden_layers" : 1 ,
166176 "num_hidden_channels" : 128 ,
@@ -170,20 +180,24 @@ def get_uci_config(dataset, model, use_baseline):
170180
171181 elif dataset == "miniboone" :
172182 config = {
173- "num_density_layers" : 10 ,
183+ "num_u_channels" : 0 if use_baseline else 10 ,
184+
185+ "num_density_layers" : 10 if use_baseline else 4 ,
174186 "num_hidden_layers" : 1 ,
175187 "num_hidden_channels" : 32 ,
176188 "num_bins" : 4 ,
177189 "dropout_probability" : 0.2 ,
178190
179- "lr" : 3e-4 ,
180- "train_batch_size" : 128
191+ "st_nets" : [32 ] * 2 ,
192+ "p_nets" : [64 ] * 2 ,
193+ "q_nets" : [64 ] * 2 ,
194+
195+ "lr" : 1e-3 ,
196+ "train_batch_size" : 1000
181197 }
182198
183- assert use_baseline
184199 config = {
185200 ** config ,
186- "num_u_channels" : 0 ,
187201 "tail_bound" : 3 ,
188202 "autoregressive" : False ,
189203 "batch_norm" : False
0 commit comments