Skip to content

Inconsistency between Input dimension of Model.forward and transport.training_losses #38

@QhelDIV

Description

@QhelDIV

In models.py the forward method of SiT class has expected input / output dim like this:
(N, C, H, W)
(N, out_channels, H, W)

However, in transport.py the training_losses method expects x0 / xt to be of shape
(N, *, C)
There is also a specific assertion in L135:
assert model_output.size() == (B, *xt.size()[1:-1], C)

Do I have a wrong understanding here?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions