Skip to content

[export] Data dependent error on slices in autograd::applySlicing #163146

@elin-croft

Description

@elin-croft

🐛 Describe the bug

export onnx model error, the error somehow is related with dynamic slice and input data.
error traceback here

[torch.onnx] Obtain model graph for `Mlp([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `Mlp([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `Mlp([...]` with `torch.export.export(..., strict=True)`...
[torch.onnx] Obtain model graph for `Mlp([...]` with `torch.export.export(..., strict=True)`... ❌
[torch.onnx] Obtain model graph for `Mlp([...]` with `torch.export draft_export`...
W0917 11:24:08.128222 31595 site-packages/torch/fx/experimental/symbolic_shapes.py:7274] propagate_real_tensors evaluate_expr(u1) -> 200
W0917 11:24:08.313687 31595 site-packages/torch/export/_draft_export.py:500] 
W0917 11:24:08.313687 31595 site-packages/torch/export/_draft_export.py:500] ################################################################################################### 
W0917 11:24:08.313687 31595 site-packages/torch/export/_draft_export.py:500] WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
W0917 11:24:08.313687 31595 site-packages/torch/export/_draft_export.py:500] To view the report of failures in an html page, please run the command:
W0917 11:24:08.313687 31595 site-packages/torch/export/_draft_export.py:500]     `tlparse /var/folders/ld/qwk1tvt50nj_xnz08ntyqdlw0000gp/T/export/dedicated_log_torch_trace_3b7s6392.log --export`
W0917 11:24:08.313687 31595 site-packages/torch/export/_draft_export.py:500] Or, you can view the errors in python by inspecting `print(ep._report)`.
W0917 11:24:08.313687 31595 site-packages/torch/export/_draft_export.py:500] #################################################################################################
[torch.onnx] Draft Export report:

###################################################################################################
WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################

1. Data dependent error.
    When exporting, we were unable to evaluate the value of `u1`.
    This was encountered 1 times.
    This occurred at the following user stacktrace: 
        File /miniforge3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py, lineno 1773, in _wrapped_call_impl
        File /miniforge3/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py, lineno 1784, in _call_impl
            selected_item_embedding = item_embedding[:, :max_item_num, :]
        
        Locals:
            item_embedding: ['Tensor(shape: torch.Size([s10, s64, 64]), stride: (64*s64, 64, 1), storage_offset: 0)']
            max_item_num: ['Tensor(shape: torch.Size([]), stride: (), storage_offset: 0)']

        Symbols:
           s10: L['item_embedding'].size()[0]
           s64: L['item_embedding'].size()[1]

    And the following framework stacktrace: 
        File /miniforge3/envs/torch/lib/python3.9/site-packages/torch/_ops.py, lineno 1243, in __call__
            return self._op(*args, **kwargs)

    As a result, it was specialized to a constant (e.g. `200` in the 1st occurrence), and asserts were inserted into the graph.

    Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
    Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.

code here

import torch
import torch.nn as nn


def sequence_mask(lengths, max_len=None):
    """
    Creates a boolean mask for sequences based on their lengths.

    Args:
        lengths (torch.Tensor): A 1D tensor of shape (batch_size,)
                                containing the lengths of each sequence.
        max_len (int, optional): The maximum length of the sequences.
                                 If None, it will be inferred from the
                                 maximum value in 'lengths'. Defaults to None.

    Returns:
        torch.Tensor: A boolean mask tensor of shape (batch_size, max_len).
    """
    if max_len is None:
        max_len = lengths.max().item()

    # Create a tensor representing indices from 0 to max_len-1
    indices = torch.arange(max_len, device=lengths.device).unsqueeze(0)

    # Expand lengths to match the dimensions of indices for broadcasting
    lengths_expanded = lengths.unsqueeze(1)

    # Compare indices with lengths to create the mask
    mask = indices < lengths_expanded
    return mask


class Mlp(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Mlp, self).__init__()
        self.item_encoder = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
        )
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, user_embedding, item_embedding, selected_last_n):
        selected_last_n = torch.minimum(selected_last_n, torch.tensor(200))
        max_item_num = torch.max(selected_last_n)
        torch._check(
            (max_item_num <= item_embedding.shape[1]).item(),
            "selected_last_n must be smaller then item_embedding.shape[1]",
        )
        mask = sequence_mask(
            selected_last_n, max_item_num
        )  # same as tensorflow.sequence_mask
        selected_item_embedding = item_embedding[:, :max_item_num, :]

        item_fusion = torch.sum(
            self.item_encoder(selected_item_embedding) * mask.unsqueeze(-1), dim=1
        )

        embedding = torch.cat([user_embedding, item_fusion], dim=1)
        out = nn.functional.relu(self.fc1(embedding))
        out = nn.functional.relu(self.fc2(out)).squeeze(-1)
        return out

    def export(self):
        user_embedding = torch.randn((2, 64))
        item_embedding = torch.randn((2, 210, 64))
        selected_last_n = torch.tensor([188, 210])
        batch = torch.export.Dim.DYNAMIC
        item_num = torch.export.Dim.DYNAMIC
        torch.onnx.export(
            self,
            (user_embedding, item_embedding, selected_last_n),
            "model.onnx",
            input_names=[
                "user_embedding_input",
                "item_embedding_input",
                "selected_last_n_input",
            ],
            output_names=["output"],
            export_params=True,
            dynamo=True,
            dynamic_shapes={
                "user_embedding": {0: batch},
                "item_embedding": {0: batch, 1: item_num},
                "selected_last_n": {0: batch},
            },
        )


model = Mlp(128 + 64, 64, 1)
model.export()

Versions

PyTorch version: 2.8.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.1.1 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 4.0.2
Libc version: N/A

Python version: 3.9.23 | packaged by conda-forge | (main, Jun 4 2025, 18:02:02) [Clang 18.1.8 ] (64-bit runtime)
Python platform: macOS-15.1.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M4 Pro

Versions of relevant libraries:
[pip3] numpy==2.0.2
[pip3] onnx==1.18.0
[pip3] onnx-ir==0.1.9
[pip3] onnxruntime==1.19.2
[pip3] onnxscript==0.5.1
[pip3] torch==2.8.0
[pip3] torchvision==0.23.0
[conda] numpy 2.0.2 pypi_0 pypi
[conda] torch 2.8.0 pypi_0 pypi
[conda] torchvision 0.23.0 pypi_0 pypi

cc @justinchuby @titaiwangms @bdhirsh @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: onnxRelated to torch.onnxneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.oncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions