22
33import torch
44import torch .nn as nn
5- import torch .nn .functional as F
65
76from .bijection import Bijection
87
98from ..couplers import SharedCoupler
10-
11-
12- class MaskedLinear (nn .Module ):
13- def __init__ (self , input_degrees , output_degrees ):
14- super ().__init__ ()
15-
16- assert len (input_degrees .shape ) == len (output_degrees .shape ) == 1
17-
18- num_input_channels = input_degrees .shape [0 ]
19- num_output_channels = output_degrees .shape [0 ]
20-
21- self .linear = nn .Linear (num_input_channels , num_output_channels )
22-
23- mask = output_degrees .view (- 1 , 1 ) >= input_degrees
24- self .register_buffer ("mask" , mask .to (self .linear .weight .dtype ))
25-
26- def forward (self , inputs ):
27- return F .linear (inputs , self .mask * self .linear .weight , self .linear .bias )
9+ from ..networks import get_ar_mlp
2810
2911
3012class MADEBijection (Bijection ):
@@ -36,7 +18,7 @@ def __init__(
3618 ):
3719 super ().__init__ (x_shape = (num_input_channels ,), z_shape = (num_input_channels ,))
3820
39- self .ar_map = self ._get_ar_map (
21+ self .ar_coupler = self ._get_ar_coupler (
4022 num_input_channels = num_input_channels ,
4123 hidden_channels = hidden_channels ,
4224 activation = activation
@@ -48,7 +30,7 @@ def _z_to_x(self, z, **kwargs):
4830 x = torch .zeros_like (z )
4931
5032 for dim in range (z .size (1 )):
51- result = self .ar_map (x )
33+ result = self .ar_coupler (x )
5234 means = result ["shift" ]
5335 log_stds = result ["log-scale" ]
5436
@@ -57,7 +39,7 @@ def _z_to_x(self, z, **kwargs):
5739 return {"x" : x , "log-jac" : self ._log_jac_z_to_x (log_stds )}
5840
5941 def _x_to_z (self , x , ** kwargs ):
60- result = self .ar_map (x )
42+ result = self .ar_coupler (x )
6143 means = result ["shift" ]
6244 log_stds = result ["log-scale" ]
6345
@@ -71,43 +53,17 @@ def _log_jac_x_to_z(self, log_stds):
7153 def _log_jac_z_to_x (self , log_stds ):
7254 return - self ._log_jac_x_to_z (log_stds )
7355
74- def _get_ar_map (
56+ def _get_ar_coupler (
7557 self ,
7658 num_input_channels ,
7759 hidden_channels ,
7860 activation
7961 ):
8062 return SharedCoupler (
81- shift_log_scale_net = self . _get_ar_mlp (
63+ shift_log_scale_net = get_ar_mlp (
8264 num_input_channels = num_input_channels ,
8365 hidden_channels = hidden_channels ,
8466 num_outputs_per_input = 2 ,
8567 activation = activation
8668 )
8769 )
88-
89- def _get_ar_mlp (
90- self ,
91- num_input_channels ,
92- hidden_channels ,
93- num_outputs_per_input ,
94- activation
95- ):
96- assert num_input_channels >= 2
97- assert all ([num_input_channels <= d for d in hidden_channels ]), "Random initialisation not yet implemented"
98-
99- prev_degrees = torch .arange (1 , num_input_channels + 1 , dtype = torch .int64 )
100- layers = []
101-
102- for hidden_channels in hidden_channels :
103- degrees = torch .arange (hidden_channels , dtype = torch .int64 ) % (num_input_channels - 1 ) + 1
104-
105- layers .append (MaskedLinear (prev_degrees , degrees ))
106- layers .append (activation ())
107-
108- prev_degrees = degrees
109-
110- degrees = torch .arange (num_input_channels , dtype = torch .int64 ).repeat (num_outputs_per_input )
111- layers .append (MaskedLinear (prev_degrees , degrees ))
112-
113- return nn .Sequential (* layers )
0 commit comments