Skip to content

Commit 0f449b3

Browse files
author
Rob Cornish
committed
Added Sum-of-Squares Polynomial Flow bijection
Our implementation simply wraps the existing flow inside Pyro.
1 parent 8b29ad7 commit 0f449b3

File tree

12 files changed

+375
-249
lines changed

12 files changed

+375
-249
lines changed

Pipfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ matplotlib = "*"
1414
tensorboardx = {editable = true,git = "git://github.com/lanpa/tensorboardX.git"}
1515
pandas = "*"
1616
pytorch-ignite = "*"
17+
pyro-ppl = "*"
1718

1819
[requires]
1920
python_version = "3.7"

Pipfile.lock

Lines changed: 135 additions & 100 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
import copy
23

34

@@ -10,7 +11,7 @@ def get_config(dataset, model, use_baseline):
1011

1112

1213
def get_config_base(dataset, model, use_baseline):
13-
if dataset in ["2uniforms", "8gaussians", "checkerboard", "2spirals"]:
14+
if dataset in ["2uniforms", "8gaussians", "checkerboard", "2spirals", "rings"]:
1415
return get_2d_config(dataset, model, use_baseline)
1516

1617
elif dataset in ["power", "gas", "hepmass", "miniboone"]:
@@ -24,7 +25,7 @@ def get_config_base(dataset, model, use_baseline):
2425

2526

2627
def get_2d_config(dataset, model, use_baseline):
27-
assert model in ["flat-realnvp", "maf"], f"Invalid model {model} for dataset {dataset}"
28+
assert model in ["flat-realnvp", "maf", "sos"], f"Invalid model {model} for dataset {dataset}"
2829

2930
if dataset == "2uniforms":
3031
if use_baseline:
@@ -90,6 +91,16 @@ def get_2d_config(dataset, model, use_baseline):
9091
"num_test_elbo_samples": 100
9192
}
9293

94+
if model == "sos":
95+
warnings.warn("Overriding `num_density_layers`")
96+
config["num_density_layers"] = 3 if use_baseline else 2
97+
config["num_polynomials_per_layer"] = 2
98+
config["polynomial_degree"] = 4
99+
100+
config["st_nets"] = [10] * 2
101+
config["p_nets"] = [30] * 4
102+
config["q_nets"] = [30] * 4
103+
93104
return config
94105

95106

lgf/models/components/bijections/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@
3535
BruteForceInvertible1x1ConvBijection,
3636
LUInvertible1x1ConvBijection
3737
)
38+
39+
from .sos import SumOfSquaresPolynomialBijection

lgf/models/components/bijections/made.py

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,11 @@
22

33
import torch
44
import torch.nn as nn
5-
import torch.nn.functional as F
65

76
from .bijection import Bijection
87

98
from ..couplers import SharedCoupler
10-
11-
12-
class MaskedLinear(nn.Module):
13-
def __init__(self, input_degrees, output_degrees):
14-
super().__init__()
15-
16-
assert len(input_degrees.shape) == len(output_degrees.shape) == 1
17-
18-
num_input_channels = input_degrees.shape[0]
19-
num_output_channels = output_degrees.shape[0]
20-
21-
self.linear = nn.Linear(num_input_channels, num_output_channels)
22-
23-
mask = output_degrees.view(-1, 1) >= input_degrees
24-
self.register_buffer("mask", mask.to(self.linear.weight.dtype))
25-
26-
def forward(self, inputs):
27-
return F.linear(inputs, self.mask*self.linear.weight, self.linear.bias)
9+
from ..networks import get_ar_mlp
2810

2911

3012
class MADEBijection(Bijection):
@@ -36,7 +18,7 @@ def __init__(
3618
):
3719
super().__init__(x_shape=(num_input_channels,), z_shape=(num_input_channels,))
3820

39-
self.ar_map = self._get_ar_map(
21+
self.ar_coupler = self._get_ar_coupler(
4022
num_input_channels=num_input_channels,
4123
hidden_channels=hidden_channels,
4224
activation=activation
@@ -48,7 +30,7 @@ def _z_to_x(self, z, **kwargs):
4830
x = torch.zeros_like(z)
4931

5032
for dim in range(z.size(1)):
51-
result = self.ar_map(x)
33+
result = self.ar_coupler(x)
5234
means = result["shift"]
5335
log_stds = result["log-scale"]
5436

@@ -57,7 +39,7 @@ def _z_to_x(self, z, **kwargs):
5739
return {"x": x, "log-jac": self._log_jac_z_to_x(log_stds)}
5840

5941
def _x_to_z(self, x, **kwargs):
60-
result = self.ar_map(x)
42+
result = self.ar_coupler(x)
6143
means = result["shift"]
6244
log_stds = result["log-scale"]
6345

@@ -71,43 +53,17 @@ def _log_jac_x_to_z(self, log_stds):
7153
def _log_jac_z_to_x(self, log_stds):
7254
return -self._log_jac_x_to_z(log_stds)
7355

74-
def _get_ar_map(
56+
def _get_ar_coupler(
7557
self,
7658
num_input_channels,
7759
hidden_channels,
7860
activation
7961
):
8062
return SharedCoupler(
81-
shift_log_scale_net=self._get_ar_mlp(
63+
shift_log_scale_net=get_ar_mlp(
8264
num_input_channels=num_input_channels,
8365
hidden_channels=hidden_channels,
8466
num_outputs_per_input=2,
8567
activation=activation
8668
)
8769
)
88-
89-
def _get_ar_mlp(
90-
self,
91-
num_input_channels,
92-
hidden_channels,
93-
num_outputs_per_input,
94-
activation
95-
):
96-
assert num_input_channels >= 2
97-
assert all([num_input_channels <= d for d in hidden_channels]), "Random initialisation not yet implemented"
98-
99-
prev_degrees = torch.arange(1, num_input_channels + 1, dtype=torch.int64)
100-
layers = []
101-
102-
for hidden_channels in hidden_channels:
103-
degrees = torch.arange(hidden_channels, dtype=torch.int64) % (num_input_channels - 1) + 1
104-
105-
layers.append(MaskedLinear(prev_degrees, degrees))
106-
layers.append(activation())
107-
108-
prev_degrees = degrees
109-
110-
degrees = torch.arange(num_input_channels, dtype=torch.int64).repeat(num_outputs_per_input)
111-
layers.append(MaskedLinear(prev_degrees, degrees))
112-
113-
return nn.Sequential(*layers)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from pyro.distributions.transforms.polynomial import PolynomialFlow
7+
from pyro.nn import AutoRegressiveNN
8+
9+
from .bijection import Bijection
10+
11+
12+
class SumOfSquaresPolynomialBijection(Bijection):
13+
def __init__(
14+
self,
15+
num_input_channels,
16+
hidden_channels,
17+
activation,
18+
num_polynomials,
19+
polynomial_degree,
20+
):
21+
super().__init__(x_shape=(num_input_channels,), z_shape=(num_input_channels,))
22+
23+
arn = AutoRegressiveNN(
24+
input_dim=int(num_input_channels),
25+
hidden_dims=hidden_channels,
26+
param_dims=[(polynomial_degree + 1)*num_polynomials]
27+
)
28+
29+
self.flow = PolynomialFlow(
30+
autoregressive_nn=arn,
31+
input_dim=int(num_input_channels),
32+
count_degree=polynomial_degree,
33+
count_sum=num_polynomials
34+
)
35+
36+
def _x_to_z(self, x):
37+
z = self.flow._call(x)
38+
log_jac = self.flow.log_abs_det_jacobian(None, None).view(x.shape[0], 1)
39+
return {
40+
"z": z,
41+
"log-jac": log_jac
42+
}

lgf/models/components/networks.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34

45

56
class ConstantNetwork(nn.Module):
@@ -96,27 +97,6 @@ def get_resnet(
9697
)
9798

9899

99-
def get_mlp(
100-
num_input_channels,
101-
hidden_channels,
102-
num_output_channels,
103-
activation,
104-
log_softmax_outputs=False
105-
):
106-
layers = []
107-
prev_num_hidden_channels = num_input_channels
108-
for num_hidden_channels in hidden_channels:
109-
layers.append(nn.Linear(prev_num_hidden_channels, num_hidden_channels))
110-
layers.append(activation())
111-
prev_num_hidden_channels = num_hidden_channels
112-
layers.append(nn.Linear(prev_num_hidden_channels, num_output_channels))
113-
114-
if log_softmax_outputs:
115-
layers.append(nn.LogSoftmax(dim=1))
116-
117-
return nn.Sequential(*layers)
118-
119-
120100
def get_glow_cnn(num_input_channels, num_hidden_channels, num_output_channels):
121101
conv1 = nn.Conv2d(
122102
in_channels=num_input_channels,
@@ -149,3 +129,68 @@ def get_glow_cnn(num_input_channels, num_hidden_channels, num_output_channels):
149129
relu = nn.ReLU()
150130

151131
return nn.Sequential(conv1, bn1, relu, conv2, bn2, relu, conv3)
132+
133+
134+
def get_mlp(
135+
num_input_channels,
136+
hidden_channels,
137+
num_output_channels,
138+
activation,
139+
log_softmax_outputs=False
140+
):
141+
layers = []
142+
prev_num_hidden_channels = num_input_channels
143+
for num_hidden_channels in hidden_channels:
144+
layers.append(nn.Linear(prev_num_hidden_channels, num_hidden_channels))
145+
layers.append(activation())
146+
prev_num_hidden_channels = num_hidden_channels
147+
layers.append(nn.Linear(prev_num_hidden_channels, num_output_channels))
148+
149+
if log_softmax_outputs:
150+
layers.append(nn.LogSoftmax(dim=1))
151+
152+
return nn.Sequential(*layers)
153+
154+
155+
class MaskedLinear(nn.Module):
156+
def __init__(self, input_degrees, output_degrees):
157+
super().__init__()
158+
159+
assert len(input_degrees.shape) == len(output_degrees.shape) == 1
160+
161+
num_input_channels = input_degrees.shape[0]
162+
num_output_channels = output_degrees.shape[0]
163+
164+
self.linear = nn.Linear(num_input_channels, num_output_channels)
165+
166+
mask = output_degrees.view(-1, 1) >= input_degrees
167+
self.register_buffer("mask", mask.to(self.linear.weight.dtype))
168+
169+
def forward(self, inputs):
170+
return F.linear(inputs, self.mask*self.linear.weight, self.linear.bias)
171+
172+
173+
def get_ar_mlp(
174+
num_input_channels,
175+
hidden_channels,
176+
num_outputs_per_input,
177+
activation
178+
):
179+
assert num_input_channels >= 2
180+
assert all([num_input_channels <= d for d in hidden_channels]), "Random initialisation not yet implemented"
181+
182+
prev_degrees = torch.arange(1, num_input_channels + 1, dtype=torch.int64)
183+
layers = []
184+
185+
for hidden_channels in hidden_channels:
186+
degrees = torch.arange(hidden_channels, dtype=torch.int64) % (num_input_channels - 1) + 1
187+
188+
layers.append(MaskedLinear(prev_degrees, degrees))
189+
layers.append(activation())
190+
191+
prev_degrees = degrees
192+
193+
degrees = torch.arange(num_input_channels, dtype=torch.int64).repeat(num_outputs_per_input)
194+
layers.append(MaskedLinear(prev_degrees, degrees))
195+
196+
return nn.Sequential(*layers)

lgf/models/factory.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
ViewBijection,
1919
ConditionalAffineBijection,
2020
BruteForceInvertible1x1ConvBijection,
21-
LUInvertible1x1ConvBijection
21+
LUInvertible1x1ConvBijection,
22+
SumOfSquaresPolynomialBijection
2223
)
2324
from .components.densities import (
2425
DiagonalGaussianDensity,
@@ -157,6 +158,16 @@ def get_bijection(
157158
else:
158159
return BruteForceInvertible1x1ConvBijection(x_shape=x_shape)
159160

161+
elif layer_config["type"] == "sos":
162+
assert len(x_shape) == 1
163+
return SumOfSquaresPolynomialBijection(
164+
num_input_channels=x_shape[0],
165+
hidden_channels=layer_config["hidden_channels"],
166+
activation=get_activation(layer_config["activation"]),
167+
num_polynomials=layer_config["num_polynomials"],
168+
polynomial_degree=layer_config["polynomial_degree"],
169+
)
170+
160171
else:
161172
assert False, f"Invalid layer type {layer_config['type']}"
162173

lgf/models/schemas.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
def get_schema(config):
22
model = config["model"]
3-
if model in ["glow", "multiscale-realnvp", "flat-realnvp", "maf"]:
3+
if model in ["glow", "multiscale-realnvp", "flat-realnvp", "maf", "sos"]:
44
return get_schema_from_base(config)
55

66
elif model == "pure-cond-affine-mlp":
@@ -85,6 +85,14 @@ def get_base_schema(config):
8585
hidden_channels=config["g_hidden_channels"]
8686
)
8787

88+
elif model == "sos":
89+
return get_sos_schema(
90+
num_density_layers=config["num_density_layers"],
91+
hidden_channels=config["g_hidden_channels"],
92+
num_polynomials_per_layer=config["num_polynomials_per_layer"],
93+
polynomial_degree=config["polynomial_degree"],
94+
)
95+
8896
elif model == "glow":
8997
return get_glow_schema(
9098
num_scales=config["num_scales"],
@@ -163,7 +171,7 @@ def get_coupler_net_config(net_spec, model):
163171
"hidden_channels": net_spec
164172
}
165173

166-
elif model in ["pure-cond-affine-mlp", "maf", "flat-realnvp"]:
174+
elif model in ["pure-cond-affine-mlp", "maf", "flat-realnvp", "sos"]:
167175
return {
168176
"type": "mlp",
169177
"activation": "tanh",
@@ -330,3 +338,26 @@ def get_maf_schema(
330338
]
331339

332340
return result
341+
342+
343+
# TODO: Batch norm?
344+
# TODO: Flip after each layer?
345+
def get_sos_schema(
346+
num_density_layers,
347+
hidden_channels,
348+
num_polynomials_per_layer,
349+
polynomial_degree
350+
):
351+
return [{"type": "flatten"}] + [
352+
{
353+
"type": "sos",
354+
"hidden_channels": hidden_channels,
355+
"activation": "tanh",
356+
"num_polynomials": num_polynomials_per_layer,
357+
"polynomial_degree": polynomial_degree
358+
},
359+
{
360+
"type": "batch-norm",
361+
"per_channel": False # Irrelevant here since we flatten anyway
362+
}
363+
] * num_density_layers

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
parser = argparse.ArgumentParser()
13-
parser.add_argument("--model", choices=["maf", "flat-realnvp", "multiscale-realnvp", "glow"])
13+
parser.add_argument("--model", choices=["sos", "maf", "flat-realnvp", "multiscale-realnvp", "glow"])
1414
parser.add_argument("--dataset", choices=[
1515
"2uniforms", "8gaussians", "checkerboard", "2spirals",
1616
"power", "gas", "hepmass", "miniboone",

0 commit comments

Comments
 (0)