Skip to content

Conversation

IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Jul 24, 2025

What does this PR do?

The batched inference issue and numerical mismatch has persisted in the ort class since very long, and even with #1381, the position ids were only created in the generation tests instead of creating them inside the inference class, not to mention that for generation, prepare_input_for_generation takes care of generating position ids. So for simple forward pass, it was still an issue.

This PR is the result of a rabbit hole I went into when I realized that my decoder testing refactorization removed batched the only batched generation check we had 🥲. So I enabled batched inference/generation by default and to my surprise, all models that require position ids were failing batched inference/generation. But simply installing transformers==4.52 and it was passing, so the problem is obviously in something that happened in the transformers 4.53 refactorization.

Starting from transformers 4.53 the modeling code uses boolean 4D masks, which are not "officially" supported by the torch onnx export (it's not really a "masked operation"), the boolean mask is simply converted to 0 and -inf filled tensor https://github.com/pytorch/pytorch/blob/f8fafdc7a6d260cea6c145643f4cf73631c81460/torch/onnx/symbolic_opset14.py#L187
This, in the case of padded batched inputs, results in the softmax returning nans, which pollutes the entire sequence logits (the entire padded sequences return nans as logits). This behavior used to be avoided by not calling _unmask_unattended in the case of onnx export. Instead of going that path again for the new masking methods (whch also results in small numerical mismatches), we fix this by patching the torch onnx exporter directly, and overloading graph it uses to replace aten::scaled_dot_product_attention.

A lot of fixes and changes, but hey, we were able to un-bloat the model patcher, get matching logits even for the mask tokens and get better support and testing of batched inference/generation across transformers versions.

Ps: I tested every version of transformers from 4.36 to 4.53. and also tested these changes on torch 2.1.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil marked this pull request as ready for review July 25, 2025 12:36
@IlyasMoutawwakil IlyasMoutawwakil removed the request for review from echarlaix July 28, 2025 08:53
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was wrong

Comment on lines +561 to +565
if self._normalized_config.multi_query:
# No dim for `n_head` when using multi-query attention
inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 1: decoder_sequence_name}
else:
inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 2: decoder_sequence_name}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

support for multi_query=True/False for gpt bigcode

Comment on lines +257 to +259
# No-op bfloat16 casting to avoid issues with legacy ONNX export which cast to complex128
def noop_bfloat16_casting(self):
return self
Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for falcon with alibi and any method that calls bfloat16 on a tensor (not supported by onnx exporter)

Comment on lines 419 to 422
@_onnx_symbolic("aten::__ior_")
@symbolic_helper.parse_args("v", "v")
def __ior_(g: jit_utils.GraphContext, self: torch._C.Value, other: torch._C.Value) -> torch._C.Value:
return g.op("Or", self, other)
Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fixes the missing in-place or op.

Comment on lines +90 to +96
global IMAGE
if IMAGE is None:
# Load a sample image from the Hugging Face Hub
IMAGE = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png"
)
image = IMAGE.resize((width, height))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid load_image failing because of multiple calls, the method itself is very error prone when many calls are made to the same url in parallel.

Comment on lines +198 to +209
def test_all_models_requiring_postion_ids(self):
for model_type in TasksManager.get_supported_model_type_for_task(task=self.TASK, exporter="onnx"):
model_type_requires_position_ids = model_type in MODEL_TYPES_REQUIRING_POSITION_IDS
onnx_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["onnx"][self.TASK].func
onnx_config_class_with_position_ids = issubclass(onnx_config_class, TextDecoderWithPositionIdsOnnxConfig)

if model_type_requires_position_ids ^ onnx_config_class_with_position_ids:
raise ValueError(
f"Model type {model_type} {'requires' if model_type_requires_position_ids else 'does not require'} position ids, "
f"but the ONNX config class {onnx_config_class} {'is' if onnx_config_class_with_position_ids else 'is not'} "
f"subclassed from TextDecoderWithPositionIdsOnnxConfig.\n"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so that they're always in sync (found a couple models that weren't added)

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huuuge work, thank you so much @IlyasMoutawwakil 🔥🔥

@IlyasMoutawwakil IlyasMoutawwakil changed the title Fix ORTModelForCausalLM batched generation Fix batched inference/generation, position_ids creation, falcon alibi, gpt_bigcode multi-query,.. Jul 29, 2025
@IlyasMoutawwakil IlyasMoutawwakil merged commit 31d4ea9 into main Jul 30, 2025
57 of 61 checks passed
@IlyasMoutawwakil IlyasMoutawwakil deleted the fix-ort-batched-generation branch July 30, 2025 12:45
echarlaix pushed a commit to huggingface/optimum-onnx that referenced this pull request Aug 1, 2025
…con alibi, gpt_bigcode multi-query,.. (#28)

same as huggingface/optimum#2326

---------

Co-authored-by: Copilot <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants