-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Description
🐛 Describe the bug
Hi,
I am exporting to ONNX llama from transformers implementation, and don't understand why If
nodes are inserted for the export of this squeeze
operator, where the first dims are always of shape 1
and 1
: https://github.com/huggingface/transformers/blob/7c63e6fc8c34dcf8b0121eaee776f41ccf3b1137/src/transformers/models/llama/modeling_llama.py#L182
It appears we are going into this path:
pytorch/torch/onnx/symbolic_opset11.py
Line 937 in b6a1d3f
if (dim < 0 and input_rank is None) or dim_size is None: |
because the symbolic helper _get_tensor_sizes
give us here x_type.varyingSizes() = [None, None, None, None]
.
Is there a way to hint that the shapes are constant hard-coded to 1
? The intermediate captured graph is (the aten::sub is not in the original code base - added for debugging):
%4407 : Float(1, 1, 17, 4, strides=[68, 68, 4, 1], requires_grad=0, device=cpu) = aten::sub(%cos.1, %4405, %4406), scope: transformers.models.llama.modeling_llama.LlamaForCausalLM::/transformers.models.llama.modeling_llama.LlamaModel::model/transformers.models.llama.modeling_llama.LlamaDecoderLayer::layers.0/transformers.models.llama.modeling_llama.LlamaAttention::self_attn # /home/fxmarty/hf_internship/transformers/src/transformers/models/llama/modeling_llama.py:183:0
%4408 : int = prim::Constant[value=1](), scope: transformers.models.llama.modeling_llama.LlamaForCausalLM::/transformers.models.llama.modeling_llama.LlamaModel::model/transformers.models.llama.modeling_llama.LlamaDecoderLayer::layers.0/transformers.models.llama.modeling_llama.LlamaAttention::self_attn # /home/fxmarty/hf_internship/transformers/src/transformers/models/llama/modeling_llama.py:185:0
%4409 : Float(1, 17, 4, strides=[68, 4, 1], requires_grad=0, device=cpu) = aten::squeeze(%4407, %4408), scope: transformers.models.llama.modeling_llama.LlamaForCausalLM::/transformers.models.llama.modeling_llama.LlamaModel::model/transformers.models.llama.modeling_llama.LlamaDecoderLayer::layers.0/transformers.models.llama.modeling_llama.LlamaAttention::self_attn # /home/fxmarty/hf_internship/transformers/src/transformers/models/llama/modeling_llama.py:185:0
A workaround is to use cos = cos[0, 0]
- wondering if there is anything better.
Thank you!
cc @justinchuby
Versions
torch 2.0.1, opset_version = 12
Metadata
Metadata
Assignees
Labels
Type
Projects
Status