-
Notifications
You must be signed in to change notification settings - Fork 596
Fix batched inference/generation, position_ids creation, falcon alibi, gpt_bigcode multi-query,.. #2326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} | ||
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"} | ||
self.add_past_key_values(common_inputs, direction="inputs") | ||
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was wrong
if self._normalized_config.multi_query: | ||
# No dim for `n_head` when using multi-query attention | ||
inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 1: decoder_sequence_name} | ||
else: | ||
inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 2: decoder_sequence_name} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
support for multi_query=True/False for gpt bigcode
# No-op bfloat16 casting to avoid issues with legacy ONNX export which cast to complex128 | ||
def noop_bfloat16_casting(self): | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for falcon with alibi and any method that calls bfloat16 on a tensor (not supported by onnx exporter)
@_onnx_symbolic("aten::__ior_") | ||
@symbolic_helper.parse_args("v", "v") | ||
def __ior_(g: jit_utils.GraphContext, self: torch._C.Value, other: torch._C.Value) -> torch._C.Value: | ||
return g.op("Or", self, other) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this fixes the missing in-place or op.
global IMAGE | ||
if IMAGE is None: | ||
# Load a sample image from the Hugging Face Hub | ||
IMAGE = load_image( | ||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png" | ||
) | ||
image = IMAGE.resize((width, height)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to avoid load_image failing because of multiple calls, the method itself is very error prone when many calls are made to the same url in parallel.
def test_all_models_requiring_postion_ids(self): | ||
for model_type in TasksManager.get_supported_model_type_for_task(task=self.TASK, exporter="onnx"): | ||
model_type_requires_position_ids = model_type in MODEL_TYPES_REQUIRING_POSITION_IDS | ||
onnx_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["onnx"][self.TASK].func | ||
onnx_config_class_with_position_ids = issubclass(onnx_config_class, TextDecoderWithPositionIdsOnnxConfig) | ||
|
||
if model_type_requires_position_ids ^ onnx_config_class_with_position_ids: | ||
raise ValueError( | ||
f"Model type {model_type} {'requires' if model_type_requires_position_ids else 'does not require'} position ids, " | ||
f"but the ONNX config class {onnx_config_class} {'is' if onnx_config_class_with_position_ids else 'is not'} " | ||
f"subclassed from TextDecoderWithPositionIdsOnnxConfig.\n" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so that they're always in sync (found a couple models that weren't added)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huuuge work, thank you so much @IlyasMoutawwakil 🔥🔥
…con alibi, gpt_bigcode multi-query,.. (#28) same as huggingface/optimum#2326 --------- Co-authored-by: Copilot <[email protected]>
What does this PR do?
The batched inference issue and numerical mismatch has persisted in the ort class since very long, and even with #1381, the position ids were only created in the generation tests instead of creating them inside the inference class, not to mention that for generation,
prepare_input_for_generation
takes care of generating position ids. So for simple forward pass, it was still an issue.This PR is the result of a rabbit hole I went into when I realized that my decoder testing refactorization removed batched the only batched generation check we had 🥲. So I enabled batched inference/generation by default and to my surprise, all models that require position ids were failing batched inference/generation. But simply installing transformers==4.52 and it was passing, so the problem is obviously in something that happened in the transformers 4.53 refactorization.
Starting from transformers 4.53 the modeling code uses boolean 4D masks, which are not "officially" supported by the torch onnx export (it's not really a "masked operation"), the boolean mask is simply converted to 0 and -inf filled tensor https://github.com/pytorch/pytorch/blob/f8fafdc7a6d260cea6c145643f4cf73631c81460/torch/onnx/symbolic_opset14.py#L187
This, in the case of padded batched inputs, results in the softmax returning nans, which pollutes the entire sequence logits (the entire padded sequences return nans as logits). This behavior used to be avoided by not calling
_unmask_unattended
in the case of onnx export. Instead of going that path again for the new masking methods (whch also results in small numerical mismatches), we fix this by patching the torch onnx exporter directly, and overloading graph it uses to replaceaten::scaled_dot_product_attention
.A lot of fixes and changes, but hey, we were able to un-bloat the model patcher, get matching logits even for the mask tokens and get better support and testing of batched inference/generation across transformers versions.
Ps: I tested every version of transformers from 4.36 to 4.53. and also tested these changes on torch 2.1.
Before submitting
Who can review?