Skip to content

AOTAutogradCache doesn't support view_replay #141974

@jamesjwu

Description

@jamesjwu

🐛 Describe the bug

When we first implemented AOTAutogradCache, we decided to punt this, but now that we're trying to release this to OSS, I think it's good to get more eyes on it, since it looks like a simple fix.

Consider this simple test:

import torch
import torch._functorch.config

torch._functorch.config.view_replay_for_aliased_outputs = True
torch._functorch.config.strict_autograd_cache = True

def f(a):
    tmp = a.detach()
    a.mul_(2)
    return a, tmp

with torch.autograd._force_original_view_tracking(True):
    fn = torch.compile(f)
    out = fn(torch.rand(2,3))

print(out)

This fails when AOTAutogradCache attempts to serialize the ViewAndMutationMeta associated with the function. This is because view_replay logic stores a FunctionalTensor in it that represents an alias, with no underlying storage. When we try to serialize it, we try to access the underlying storage, but it doesn't exist, so we fail. This leads to more cache misses than necessary when running AOTAutogradCache, making it so that if there's ever a alias mutation on the output with view_replay turned on, AOTAutogradCache cannot serialize/save the cached object.

Stack trace:

W1203 08:57:32.604000 1374434 torch/_functorch/_aot_autograd/autograd_cache.py:837] [0/0] AOTAutograd cache unable to serialize compiled graph: Attempted to access the data pointer on an invalid python storage.
Traceback (most recent call last):
  File "/home/jjwu/test.py", line 14, in <module>
    out = fn(torch.rand(2,3))
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/eval_frame.py", line 573, in _fn
    return fn(*args, **kwargs)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 1379, in __call__
    return self._torchdynamo_orig_callable(
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 1163, in __call__
    result = self._inner_convert(
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/data/users/jjwu/a/pytorch/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 2864, in run
    super().run()
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 1053, in run
    while self.step():
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 963, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 3044, in RETURN_VALUE
    self._return(inst)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/symbolic_convert.py", line 3029, in _return
    self.output.compile_subgraph(
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1118, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1359, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1409, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1460, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/output_graph.py", line 1439, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/data/users/jjwu/a/pytorch/torch/__init__.py", line 2308, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/data/users/jjwu/a/pytorch/torch/_inductor/compile_fx.py", line 1812, in compile_fx
    return aot_autograd(
  File "/data/users/jjwu/a/pytorch/torch/_dynamo/backends/common.py", line 73, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/data/users/jjwu/a/pytorch/torch/_functorch/aot_autograd.py", line 1092, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
  File "/data/users/jjwu/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 744, in load
    compiled_fn = dispatch_and_compile()
  File "/data/users/jjwu/a/pytorch/torch/_functorch/aot_autograd.py", line 1078, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/data/users/jjwu/a/pytorch/torch/_functorch/aot_autograd.py", line 527, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/data/users/jjwu/a/pytorch/torch/_functorch/aot_autograd.py", line 777, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/data/users/jjwu/a/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 238, in aot_dispatch_base
    AOTAutogradCache.save(
  File "/data/users/jjwu/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 839, in save
    raise e
  File "/data/users/jjwu/a/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 835, in save
    content = pickle.dumps(entry)
  File "/data/users/jjwu/a/pytorch/torch/storage.py", line 1237, in __reduce__
    torch.save(self, b, _use_new_zipfile_serialization=False)
  File "/data/users/jjwu/a/pytorch/torch/serialization.py", line 944, in save
    _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  File "/data/users/jjwu/a/pytorch/torch/serialization.py", line 1093, in _legacy_save
    pickler.dump(obj)
  File "/data/users/jjwu/a/pytorch/torch/serialization.py", line 1090, in persistent_id
    return persistent_id(obj)
  File "/data/users/jjwu/a/pytorch/torch/serialization.py", line 1008, in persistent_id
    if storage.data_ptr() != 0:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Attempted to access the data pointer on an invalid python storage.

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

This is the object we store:

class FunctionalTensorMetadataEq:
def __init__(self, tensor: torch.Tensor) -> None:
assert torch._is_functional_tensor(tensor)
self.tensor = tensor
def __eq__(self, other: object) -> bool:
# If other is None, then it probably means that we weren't able to recreate
# the FunctionalTensorMetadataEq. One of this cases is when we update the
# view metadata by calling: create_synthetic_base_metadata.
if other is None:
return True
# Comparison agains any other type is not implemented.
if not isinstance(other, FunctionalTensorMetadataEq):
return NotImplemented
return has_same_metadata(self.tensor, other.tensor)

Unfortunately, even though we only need the metadata, we store the whole functional tensor. Since the tensor has no data_ptr() as it represents an alias, we fail to serialize it.

There are a few simple enough fixes to this problem:

  • We could adjust the serialization to handle FunctionalTensors with no underlying storage. This seems good overall, since there may be other cases that need to be able to save functional tensors.
  • We could also just change how view_replay uses functional tensors. It doesn't actually specifically need the functional tensor to begin with, it only uses the view_metas field:
    return impl->apply_view_metas(base);

So we could just store the view metas and apply them, instead of taking in a whole tensor.

Versions

nightly

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bdhirsh @oulgen @masnesral @yf225

Metadata

Metadata

Assignees

Labels

actionablecompile-cachehigh prioritymodule: aotdispatchumbrella label for AOTAutograd issuesmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate modulevllm-compile

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions