Skip to content

AOTI FXIR is flaky when using multiple threads #162607

@angelayi

Description

@angelayi

🐛 Describe the bug

Filing an issue to track what was done in #162472

Repro with pytest --flake-runs 10 --flake-finder -v test/inductor/test_fxir_backend.py -k test_aoti_fx_add --pdb

Stacktrace looks something like --

Traceback (most recent call last):
  File "/home/angelayi/.conda/envs/pytorch-312/lib/python3.12/unittest/case.py", line 58, in testPartExecutor
    yield
  File "/home/angelayi/.conda/envs/pytorch-312/lib/python3.12/unittest/case.py", line 634, in run
    self._callTestMethod(testMethod)
  File "/home/angelayi/.conda/envs/pytorch-312/lib/python3.12/unittest/case.py", line 589, in _callTestMethod
    if method() is not None:
       ^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper
    method(*args, **kwargs)
  File "/data/users/angelayi/pytorch/test/inductor/test_fxir_backend.py", line 712, in test_aoti_fx_add
    self.check(M(), inp)
  File "/data/users/angelayi/pytorch/test/inductor/test_fxir_backend.py", line 697, in check
    self.assertTrue(torch.allclose(model(*inp), gm(*inp)))
                                                ^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/data/users/angelayi/pytorch/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.159", line 6, in forward
    triton_kernel_wrapper_mutation = torch.ops.higher_order.triton_kernel_wrapper_mutation(kernel_idx = 9, constant_args_idx = 9, grid = [(1, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'in_ptr0': arg0_1, 'in_ptr1': arg1_1, 'out_ptr0': buf0, 'xnumel': 3, 'XBLOCK': 4});  arg0_1 = arg1_1 = triton_kernel_wrapper_mutation = None
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 980, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/_ops.py", line 536, in __call__
    return wrapper()
           ^^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/_ops.py", line 532, in wrapper
    return self.dispatch(
           ^^^^^^^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/_ops.py", line 381, in dispatch
    return kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/angelayi/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 1094, in triton_kernel_wrapper_mutation_dense
    kernel[grid_fn](*args, **kwargs, **constant_args)
  File "/home/angelayi/.conda/envs/pytorch-312/lib/python3.12/site-packages/triton/runtime/jit.py", line 390, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/angelayi/.conda/envs/pytorch-312/lib/python3.12/site-packages/triton/runtime/jit.py", line 594, in run
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/angelayi/.conda/envs/pytorch-312/lib/python3.12/site-packages/triton/compiler/compiler.py", line 339, in compile
    module = src.make_ir(options, codegen_fns, module_map, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/angelayi/.conda/envs/pytorch-312/lib/python3.12/site-packages/triton/compiler/compiler.py", line 83, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/angelayi/.conda/envs/pytorch-312/lib/python3.12/site-packages/triton/runtime/jit.py", line 935, in get_jit_fn_file_line
    file_name = base_fn.fn.__code__.co_filename
                ^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute '__code__'

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @suo @ydwu4 @desertfire @chenyang78 @yushangdi @benjaminglass1 @blaine-rister

Versions

main

Metadata

Metadata

Assignees

No one assigned

    Labels

    export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next stepmodule: aotinductoraot inductoroncall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions