Skip to content

Commit 87fe8d2

Browse files
Added NSF-LGF config for Miniboone and Gas
1 parent 270f582 commit 87fe8d2

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

config.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,25 @@ def get_uci_config(dataset, model, use_baseline):
152152
elif model == "nsf":
153153
if dataset in ["power", "gas"]:
154154
config = {
155-
"num_density_layers": 10,
155+
"num_u_channels": 0 if use_baseline else 2,
156+
"num_density_layers": 10 if use_baseline else 7,
156157
"num_hidden_layers": 2,
157158
"num_hidden_channels": 256,
158159
"num_bins": 8,
159-
"dropout_probability": 0. if dataset == "power" else 0.1
160+
"dropout_probability": 0. if dataset == "power" else 0.1,
161+
162+
"st_nets": [128] * 2,
163+
"p_nets": [200] * 2,
164+
"q_nets": [200] * 2,
165+
166+
"lr": 1e-3,
167+
"train_batch_size": 5000
160168
}
161169

162170
elif dataset == "hepmass":
171+
assert use_baseline
163172
config = {
173+
"num_u_channels": 0,
164174
"num_density_layers": 20,
165175
"num_hidden_layers": 1,
166176
"num_hidden_channels": 128,
@@ -170,20 +180,24 @@ def get_uci_config(dataset, model, use_baseline):
170180

171181
elif dataset == "miniboone":
172182
config = {
173-
"num_density_layers": 10,
183+
"num_u_channels": 0 if use_baseline else 10,
184+
185+
"num_density_layers": 10 if use_baseline else 4,
174186
"num_hidden_layers": 1,
175187
"num_hidden_channels": 32,
176188
"num_bins": 4,
177189
"dropout_probability": 0.2,
178190

179-
"lr": 3e-4,
180-
"train_batch_size": 128
191+
"st_nets": [32] * 2,
192+
"p_nets": [64] * 2,
193+
"q_nets": [64] * 2,
194+
195+
"lr": 1e-3,
196+
"train_batch_size": 1000
181197
}
182198

183-
assert use_baseline
184199
config = {
185200
**config,
186-
"num_u_channels": 0,
187201
"tail_bound": 3,
188202
"autoregressive": False,
189203
"batch_norm": False

0 commit comments

Comments
 (0)