Skip to content

[dynamo] Add bytecode source attribution to VariableTrackers #162857

@williamwen42

Description

@williamwen42

We can expand VariableTrackers track which bytecode instruction generated the VT. Then it would be possible to output the source code corresponding to each VT (because bytecode instructions contain source position information), which could be helpful for debugging.

For example,

import torch


@torch.compile(backend="eager")
def fn(x):
    y = x + 1
    z = x + y
    return z


fn(torch.ones(3))

Today, TORCH_LOGS=trace_bytecode gives

TRACE RESUME 0 []
TRACE LOAD_FAST x []
TRACE LOAD_CONST 1 [LazyVariableTracker()]
TRACE BINARY_OP 0 [LazyVariableTracker(), ConstantVariable(int: 1)]
TRACE STORE_FAST y [TensorVariable()]
TRACE LOAD_FAST x []
TRACE LOAD_FAST y [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE STORE_FAST z [TensorVariable()]
TRACE LOAD_FAST z []
TRACE RETURN_VALUE None [TensorVariable()]

The TensorVariable created by TRACE BINARY_OP 0 [LazyVariableTracker(), ConstantVariable(int: 1)] could have source attribution, for example:

  File "/data/users/williamwen/pytorch/playground3.py", line 6, in fn
    y = x + 1
        ~~^~~

Then we could output this information whenever this VT is involved in some error, e.g. graph break.

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @Lucaskabela

Metadata

Metadata

Assignees

No one assigned

    Labels

    dynamo-variable-trackerfeatureA request for a proper, new feature.module: compile uxUX issues related to torch.compilemodule: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions