Skip to content

Triton Kernel Rejects NamedTupleVariable Arguments #148289

@cora-codes

Description

@cora-codes

🚀 The feature, motivation and pitch

PyTorch's TorchDynamo fails when passing NamedTupleVariable to Triton kernels, raising "Unexpected argument type for a Triton kernel". It would be nice to support named tuple arguments since it makes writing Triton kernels far cleaner.

import torch
import typing
import triton
from torch.profiler import profile, record_function, ProfilerActivity

class T1(typing.NamedTuple):
    foo: None = None
    bar: None = None
class T2(typing.NamedTuple):
    foo: T1 = T1()
    bar: T1 = T1()
class T3(typing.NamedTuple):
    foo: T2 = T2()
    bar: T2 = T2()
class T4(typing.NamedTuple):
    foo: T3 = T3()
    bar: T3 = T3()
class T5(typing.NamedTuple):
    foo: T4 = T4()
    bar: T4 = T4()

@triton.jit
def test(t5: T5):
    pass

if __name__ == "__main__":
    t5 = T5()

    @torch.compile(mode="max-autotune-no-cudagraphs", fullgraph=True)
    def main():
        for i in range(100):
            test[(1,)](t5)
    main()

Alternatives

No response

Additional context

No response

cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @chauhang @penguinwu @voznesenskym @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @amjames @oulgen @aakhundov @davidberard98

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions