In models.py the forward method of SiT class has expected input / output dim like this:
(N, C, H, W)
(N, out_channels, H, W)
However, in transport.py the training_losses method expects x0 / xt to be of shape
(N, *, C)
There is also a specific assertion in L135:
assert model_output.size() == (B, *xt.size()[1:-1], C)
Do I have a wrong understanding here?