-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Labels
compile-cacheoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Sample:
# Autograd cache stuff
remote = should_use_remote_autograd_cache()
local = should_use_local_autograd_cache()
if local or remote:
compiled_fn = AOTAutogradCache.load(
dispatch_and_compile,
mod,
fake_flat_args,
aot_config,
cudagraphs,
local,
remote,
)
else:
compiled_fn = dispatch_and_compile()
I can kind of see why it's written this way: you want to avoid exercising the AOTAutogradCache.load codepath at all which makes it less risky. But I think this is a false economy. Just do the test inside load(). This gives way better information hiding, since cache being enabled or not management is done entirely inside of AOTAutogradCache class, and that's two less functions I have to expose to external calls.
Inductor also suffers from this antipattern which I guess is where the pattern was copied from
if (
not config.force_disable_caches
and (config.fx_graph_cache or fx_graph_remote_cache)
and not aot_mode
):
for i, input in enumerate(example_inputs):
if (
isinstance(input, torch.Tensor)
and input.device.type == "cuda"
and i in static_input_idxs
):
input._is_inductor_static = True # type: ignore[attr-defined]
compiled_graph = FxGraphCache.load(
codegen_and_compile,
gm,
example_inputs,
graph_kwargs,
inputs_to_check,
local=config.fx_graph_cache,
remote=fx_graph_remote_cache,
cc @chauhang @penguinwu @jamesjwu @oulgen
Versions
main
Metadata
Metadata
Assignees
Labels
compile-cacheoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module