-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Labels
inference modeEverything related to InferenceMode guardEverything related to InferenceMode guardmodule: dispatchDispatchStub, Type, void pointer table, c10 dispatchDispatchStub, Type, void pointer table, c10 dispatchoncall: quantizationQuantization support in PyTorchQuantization support in PyTorchtensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Error:
RuntimeError: Cannot set version_counter for inference tensor
Minimal repro:
import torch
from torchao.quantization.quantize_.workflows import Float8Tensor
x = Float8Tensor.from_hp(torch.randn(3, 4))
# This is fine
x[0:1]
# This fails
with torch.inference_mode():
x[0:1]
Float8Tensor
slice dispatches to this handler, which runs to completion without problems. The error seems to be triggered after we slice the tensor. This also happens to a few other tensor subclasses I've tried, and doesn't happen with non-subclassed torch.Tensor
.
Full CPP stack trace: https://gist.github.com/andrewor14/062b753e72d419d2c1e2a9d4e142b1fa
My question is why are we trying to increment the version counter for an op that's not in-place? Is this expected to work?
cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @ezyang @albanD
Metadata
Metadata
Assignees
Labels
inference modeEverything related to InferenceMode guardEverything related to InferenceMode guardmodule: dispatchDispatchStub, Type, void pointer table, c10 dispatchDispatchStub, Type, void pointer table, c10 dispatchoncall: quantizationQuantization support in PyTorchQuantization support in PyTorchtensor subclassRelated to tensor subclassesRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module