Skip to content

Tensor subclass slice during inference_mode fails #164872

@andrewor14

Description

@andrewor14

Error:

RuntimeError: Cannot set version_counter for inference tensor

Minimal repro:

import torch
from torchao.quantization.quantize_.workflows import Float8Tensor

x = Float8Tensor.from_hp(torch.randn(3, 4))

# This is fine
x[0:1]

# This fails
with torch.inference_mode():
    x[0:1]

Float8Tensor slice dispatches to this handler, which runs to completion without problems. The error seems to be triggered after we slice the tensor. This also happens to a few other tensor subclasses I've tried, and doesn't happen with non-subclassed torch.Tensor.

Full CPP stack trace: https://gist.github.com/andrewor14/062b753e72d419d2c1e2a9d4e142b1fa

My question is why are we trying to increment the version counter for an op that's not in-place? Is this expected to work?

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @ezyang @albanD

Metadata

Metadata

Assignees

Labels

inference modeEverything related to InferenceMode guardmodule: dispatchDispatchStub, Type, void pointer table, c10 dispatchoncall: quantizationQuantization support in PyTorchtensor subclassRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions