Skip to content

Commit a932c84

Browse files
authored
Refined parameter documentation of VAE and updated the input tensor processing logic in vae.encode(). (#71)
1 parent aabd1bc commit a932c84

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

inference/model/vae/vae_model.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ def tile_processor(
114114
parallel_group: torch.distributed.ProcessGroup = None,
115115
) -> TileProcessor:
116116
"""
117-
Property representing the tile autoencoder.
117+
Property representing the tiled encoder or decoder.
118118
119119
Returns:
120-
TileProcessor: The tile autoencoder.
120+
TileProcessor: The tiled encoder or decoder.
121121
"""
122122
return TileProcessor(
123123
encode_fn=self.encode,
@@ -152,9 +152,13 @@ def tiled_encode_3d(
152152
153153
Args:
154154
x (torch.Tensor shape:[N C T H W]): The input tensor to be encoded.
155-
tile_sample_min_size (int, optional): The minimum size of each tile sample. Defaults to None.
156-
tile_sample_min_length (int, optional): The minimum length of each tile sample. Defaults to None.
155+
tile_sample_min_height (int, optional): The minimum height of each tile sample. Defaults to 256.
156+
tile_sample_min_width (int, optional): The minimum width of each tile sample. Defaults to 256.
157+
tile_sample_min_length (int, optional): The minimum length of each tile sample. Defaults to 16.
158+
spatial_tile_overlap_factor (float, optional): Overlap factor for spatial tiles. Defaults to 0.25.
159+
temporal_tile_overlap_factor (float, optional): Overlap factor for temporal tiles. Defaults to 0.
157160
allow_spatial_tiling (bool, optional): Whether spatial tiling is allowed. Defaults to None.
161+
verbose (bool, optional): Whether to print verbose information. Defaults to False.
158162
parallel_group (torch.distributed.ProcessGroup, optional): Distributed encoding group. Defaults to None.
159163
Returns:
160164
torch.Tensor: The encoded tensor.
@@ -189,10 +193,14 @@ def tiled_decode_3d(
189193
Decodes the input tensor using the tile autoencoder.
190194
191195
Args:
192-
x (Tensor): The input tensor to be decoded.
193-
tile_sample_min_size (int, optional): The minimum size of the tile sample. Defaults to None.
194-
tile_sample_min_length (int, optional): The minimum length of the tile sample. Defaults to None.
196+
x (torch.Tensor): The input tensor to be decoded.
197+
tile_sample_min_height (int, optional): The minimum height of each tile sample. Defaults to 256.
198+
tile_sample_min_width (int, optional): The minimum width of each tile sample. Defaults to 256.
199+
tile_sample_min_length (int, optional): The minimum length of each tile sample. Defaults to 16.
200+
spatial_tile_overlap_factor (float, optional): Overlap factor for spatial tiles. Defaults to 0.25.
201+
temporal_tile_overlap_factor (float, optional): Overlap factor for temporal tiles. Defaults to 0.
195202
allow_spatial_tiling (bool, optional): Whether spatial tiling is allowed. Defaults to None.
203+
verbose (bool, optional): Whether to print verbose information. Defaults to False.
196204
parallel_group (torch.distributed.ProcessGroup, optional): Distributed decoding group. Defaults to None.
197205
Returns:
198206
torch.Tensor shape:[N C T H W]: The decoded tensor.
@@ -253,13 +261,13 @@ def encode(self, x, sample_posterior=True):
253261
Encode the input video.
254262
255263
Args:
256-
x (torch.Tensor): Input video tensor have shape N C T H W
264+
x (torch.Tensor): Input video tensor has shape N C T H W
257265
258266
Returns:
259267
tuple: Tuple containing the quantized tensor, embedding loss, and additional information.
260268
"""
261269
N, C, T, H, W = x.shape
262-
if T == 1:
270+
if T == 1 and self._temporal_downsample_factor > 1:
263271
x = x.expand(-1, -1, 4, -1, -1)
264272
x = self.encoder(x)
265273
posterior = DiagonalGaussianDistribution(x)

0 commit comments

Comments
 (0)