@@ -114,10 +114,10 @@ def tile_processor(
114114 parallel_group : torch .distributed .ProcessGroup = None ,
115115 ) -> TileProcessor :
116116 """
117- Property representing the tile autoencoder .
117+ Property representing the tiled encoder or decoder .
118118
119119 Returns:
120- TileProcessor: The tile autoencoder .
120+ TileProcessor: The tiled encoder or decoder .
121121 """
122122 return TileProcessor (
123123 encode_fn = self .encode ,
@@ -152,9 +152,13 @@ def tiled_encode_3d(
152152
153153 Args:
154154 x (torch.Tensor shape:[N C T H W]): The input tensor to be encoded.
155- tile_sample_min_size (int, optional): The minimum size of each tile sample. Defaults to None.
156- tile_sample_min_length (int, optional): The minimum length of each tile sample. Defaults to None.
155+ tile_sample_min_height (int, optional): The minimum height of each tile sample. Defaults to 256.
156+ tile_sample_min_width (int, optional): The minimum width of each tile sample. Defaults to 256.
157+ tile_sample_min_length (int, optional): The minimum length of each tile sample. Defaults to 16.
158+ spatial_tile_overlap_factor (float, optional): Overlap factor for spatial tiles. Defaults to 0.25.
159+ temporal_tile_overlap_factor (float, optional): Overlap factor for temporal tiles. Defaults to 0.
157160 allow_spatial_tiling (bool, optional): Whether spatial tiling is allowed. Defaults to None.
161+ verbose (bool, optional): Whether to print verbose information. Defaults to False.
158162 parallel_group (torch.distributed.ProcessGroup, optional): Distributed encoding group. Defaults to None.
159163 Returns:
160164 torch.Tensor: The encoded tensor.
@@ -189,10 +193,14 @@ def tiled_decode_3d(
189193 Decodes the input tensor using the tile autoencoder.
190194
191195 Args:
192- x (Tensor): The input tensor to be decoded.
193- tile_sample_min_size (int, optional): The minimum size of the tile sample. Defaults to None.
194- tile_sample_min_length (int, optional): The minimum length of the tile sample. Defaults to None.
196+ x (torch.Tensor): The input tensor to be decoded.
197+ tile_sample_min_height (int, optional): The minimum height of each tile sample. Defaults to 256.
198+ tile_sample_min_width (int, optional): The minimum width of each tile sample. Defaults to 256.
199+ tile_sample_min_length (int, optional): The minimum length of each tile sample. Defaults to 16.
200+ spatial_tile_overlap_factor (float, optional): Overlap factor for spatial tiles. Defaults to 0.25.
201+ temporal_tile_overlap_factor (float, optional): Overlap factor for temporal tiles. Defaults to 0.
195202 allow_spatial_tiling (bool, optional): Whether spatial tiling is allowed. Defaults to None.
203+ verbose (bool, optional): Whether to print verbose information. Defaults to False.
196204 parallel_group (torch.distributed.ProcessGroup, optional): Distributed decoding group. Defaults to None.
197205 Returns:
198206 torch.Tensor shape:[N C T H W]: The decoded tensor.
@@ -253,13 +261,13 @@ def encode(self, x, sample_posterior=True):
253261 Encode the input video.
254262
255263 Args:
256- x (torch.Tensor): Input video tensor have shape N C T H W
264+ x (torch.Tensor): Input video tensor has shape N C T H W
257265
258266 Returns:
259267 tuple: Tuple containing the quantized tensor, embedding loss, and additional information.
260268 """
261269 N , C , T , H , W = x .shape
262- if T == 1 :
270+ if T == 1 and self . _temporal_downsample_factor > 1 :
263271 x = x .expand (- 1 , - 1 , 4 , - 1 , - 1 )
264272 x = self .encoder (x )
265273 posterior = DiagonalGaussianDistribution (x )
0 commit comments