Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ def parse_args_onnx(parser):
default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"],
help="For Segment Anything. It corresponds to the number of points per segmentation masks.",
)
input_group.add_argument(
"--visual_seq_length",
type=int,
default=DEFAULT_DUMMY_SHAPES["visual_seq_length"],
help="Visual sequence length",
)

# deprecated argument
parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS)
Expand Down
1 change: 1 addition & 0 deletions optimum/commands/export/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def parse_args_tflite(parser: "ArgumentParser"):
default=None,
help=f"Audio tasks only. Audio sequence length {doc_input}",
)
input_group.add_argument("--visual_seq_length", type=int, default=None, help="Visual sequence length")

quantization_group = parser.add_argument_group("Quantization")
quantization_group.add_argument(
Expand Down
29 changes: 29 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def wrapper(*args, **kwargs):
"num_channels": 3,
"point_batch_size": 3,
"nb_points_per_image": 2,
"visual_seq_length": 16,
# audio
"feature_size": 80,
"nb_max_frames": 3000,
Expand Down Expand Up @@ -806,6 +807,9 @@ class DummyVisionInputGenerator(DummyInputGenerator):
"pixel_mask",
"sample",
"latent_sample",
"visual_embeds",
"visual_token_type_ids",
"visual_attention_mask",
)

def __init__(
Expand All @@ -816,6 +820,7 @@ def __init__(
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
visual_seq_length: int = DEFAULT_DUMMY_SHAPES["visual_seq_length"],
**kwargs,
):
self.task = task
Expand All @@ -839,6 +844,8 @@ def __init__(
self.image_size = (self.image_size, self.image_size)
self.batch_size = batch_size
self.height, self.width = self.image_size
self.visual_seq_length = visual_seq_length
self.visual_embedding_dim = getattr(normalized_config, "visual_embedding_dim", 512)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "pixel_mask":
Expand All @@ -848,6 +855,28 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
framework=framework,
dtype=int_dtype,
)
elif input_name in "visual_attention_mask":
return self.random_mask_tensor(
shape=[self.batch_size, self.visual_seq_length],
padding_side="right",
framework=framework,
dtype=int_dtype,
)

elif input_name == "visual_token_type_ids":
return self.random_int_tensor(
shape=[self.batch_size, self.visual_seq_length],
max_value=1,
framework=framework,
dtype=int_dtype,
)

elif input_name == "visual_embeds":
return self.random_float_tensor(
shape=[self.batch_size, self.visual_seq_length, self.visual_embedding_dim],
framework=framework,
dtype=float_dtype,
)
else:
return self.random_float_tensor(
shape=[self.batch_size, self.num_channels, self.height, self.width],
Expand Down