-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Describe the bug
When we first implemented AOTAutogradCache, we decided to punt this, but now that we're trying to release this to OSS, I think it's good to get more eyes on it, since it looks like a simple fix.
Consider this simple test:
import torch
import torch._functorch.config
torch._functorch.config.view_replay_for_aliased_outputs = True
torch._functorch.config.strict_autograd_cache = True
def f(a):
tmp = a.detach()
a.mul_(2)
return a, tmp
with torch.autograd._force_original_view_tracking(True):
fn = torch.compile(f)
out = fn(torch.rand(2,3))
print(out)
This fails when AOTAutogradCache attempts to serialize the ViewAndMutationMeta associated with the function. This is because view_replay logic stores a FunctionalTensor in it that represents an alias, with no underlying storage. When we try to serialize it, we try to access the underlying storage, but it doesn't exist, so we fail. This leads to more cache misses than necessary when running AOTAutogradCache, making it so that if there's ever a alias mutation on the output with view_replay turned on, AOTAutogradCache cannot serialize/save the cached object.
Stack trace:
W1203 08:57:32.604000 1374434 torch/_functorch/_aot_autograd/autograd_cache.py:837] [0/0] AOTAutograd cache unable to serialize compiled graph: Attempted to access the data pointer on an invalid python storage.
Traceback (most recent call last):
File "/home/jjwu/test.py", line 14, in <module>
out = fn(torch.rand(2,3))
File "/data/users/jjwu/a/pytorch/torch/_dynamo/eval_frame.py", line 573, in _fn
return fn(*args, **kwargs)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 1379, in __call__
return self._torchdynamo_orig_callable(
File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 1163, in __call__
result = self._inner_convert(
File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/data/users/jjwu/a/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 2864, in run
super().run()
File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 1053, in run
while self.step():
File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 963, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 3044, in RETURN_VALUE
self._return(inst)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 3029, in _return
self.output.compile_subgraph(
File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1118, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1359, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1409, in call_user_compiler
return self._call_user_compiler(gm)
File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1460, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1439, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/data/users/jjwu/a/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/data/users/jjwu/a/pytorch/torch/__init__.py", line 2308, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/data/users/jjwu/a/pytorch/torch/_inductor/compile_fx.py", line 1812, in compile_fx
return aot_autograd(
File "/data/users/jjwu/a/pytorch/torch/_dynamo/backends/common.py", line 73, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/data/users/jjwu/a/pytorch/torch/_functorch/aot_autograd.py", line 1092, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
File "/data/users/jjwu/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 744, in load
compiled_fn = dispatch_and_compile()
File "/data/users/jjwu/a/pytorch/torch/_functorch/aot_autograd.py", line 1078, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/data/users/jjwu/a/pytorch/torch/_functorch/aot_autograd.py", line 527, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/data/users/jjwu/a/pytorch/torch/_functorch/aot_autograd.py", line 777, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/data/users/jjwu/a/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 238, in aot_dispatch_base
AOTAutogradCache.save(
File "/data/users/jjwu/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 839, in save
raise e
File "/data/users/jjwu/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 835, in save
content = pickle.dumps(entry)
File "/data/users/jjwu/a/pytorch/torch/storage.py", line 1237, in __reduce__
torch.save(self, b, _use_new_zipfile_serialization=False)
File "/data/users/jjwu/a/pytorch/torch/serialization.py", line 944, in save
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
File "/data/users/jjwu/a/pytorch/torch/serialization.py", line 1093, in _legacy_save
pickler.dump(obj)
File "/data/users/jjwu/a/pytorch/torch/serialization.py", line 1090, in persistent_id
return persistent_id(obj)
File "/data/users/jjwu/a/pytorch/torch/serialization.py", line 1008, in persistent_id
if storage.data_ptr() != 0:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Attempted to access the data pointer on an invalid python storage.
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
This is the object we store:
pytorch/torch/_functorch/_aot_autograd/functional_utils.py
Lines 365 to 381 in ed4831b
class FunctionalTensorMetadataEq: | |
def __init__(self, tensor: torch.Tensor) -> None: | |
assert torch._is_functional_tensor(tensor) | |
self.tensor = tensor | |
def __eq__(self, other: object) -> bool: | |
# If other is None, then it probably means that we weren't able to recreate | |
# the FunctionalTensorMetadataEq. One of this cases is when we update the | |
# view metadata by calling: create_synthetic_base_metadata. | |
if other is None: | |
return True | |
# Comparison agains any other type is not implemented. | |
if not isinstance(other, FunctionalTensorMetadataEq): | |
return NotImplemented | |
return has_same_metadata(self.tensor, other.tensor) |
Unfortunately, even though we only need the metadata, we store the whole functional tensor. Since the tensor has no data_ptr() as it represents an alias, we fail to serialize it.
There are a few simple enough fixes to this problem:
- We could adjust the serialization to handle FunctionalTensors with no underlying storage. This seems good overall, since there may be other cases that need to be able to save functional tensors.
- We could also just change how
view_replay
uses functional tensors. It doesn't actually specifically need the functional tensor to begin with, it only uses theview_metas
field:
return impl->apply_view_metas(base);
So we could just store the view metas and apply them, instead of taking in a whole tensor.
Versions
nightly
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bdhirsh @oulgen @masnesral @yf225