Skip to content

aten::squeeze exported to ONNX as an If node #109292

@fxmarty

Description

@fxmarty

🐛 Describe the bug

Hi,

I am exporting to ONNX llama from transformers implementation, and don't understand why If nodes are inserted for the export of this squeeze operator, where the first dims are always of shape 1 and 1: https://github.com/huggingface/transformers/blob/7c63e6fc8c34dcf8b0121eaee776f41ccf3b1137/src/transformers/models/llama/modeling_llama.py#L182

It appears we are going into this path:

if (dim < 0 and input_rank is None) or dim_size is None:

because the symbolic helper _get_tensor_sizes give us here x_type.varyingSizes() = [None, None, None, None] .

Is there a way to hint that the shapes are constant hard-coded to 1? The intermediate captured graph is (the aten::sub is not in the original code base - added for debugging):

  %4407 : Float(1, 1, 17, 4, strides=[68, 68, 4, 1], requires_grad=0, device=cpu) = aten::sub(%cos.1, %4405, %4406), scope: transformers.models.llama.modeling_llama.LlamaForCausalLM::/transformers.models.llama.modeling_llama.LlamaModel::model/transformers.models.llama.modeling_llama.LlamaDecoderLayer::layers.0/transformers.models.llama.modeling_llama.LlamaAttention::self_attn # /home/fxmarty/hf_internship/transformers/src/transformers/models/llama/modeling_llama.py:183:0
  %4408 : int = prim::Constant[value=1](), scope: transformers.models.llama.modeling_llama.LlamaForCausalLM::/transformers.models.llama.modeling_llama.LlamaModel::model/transformers.models.llama.modeling_llama.LlamaDecoderLayer::layers.0/transformers.models.llama.modeling_llama.LlamaAttention::self_attn # /home/fxmarty/hf_internship/transformers/src/transformers/models/llama/modeling_llama.py:185:0
  %4409 : Float(1, 17, 4, strides=[68, 4, 1], requires_grad=0, device=cpu) = aten::squeeze(%4407, %4408), scope: transformers.models.llama.modeling_llama.LlamaForCausalLM::/transformers.models.llama.modeling_llama.LlamaModel::model/transformers.models.llama.modeling_llama.LlamaDecoderLayer::layers.0/transformers.models.llama.modeling_llama.LlamaAttention::self_attn # /home/fxmarty/hf_internship/transformers/src/transformers/models/llama/modeling_llama.py:185:0

A workaround is to use cos = cos[0, 0] - wondering if there is anything better.

Thank you!

cc @justinchuby

Versions

torch 2.0.1, opset_version = 12

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Reopened

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions