-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Open
Labels
dynamo-triage-jan2025module: dynamomodule: fxmodule: user tritonrelated to ability to directly torch.compile triton kernelsrelated to ability to directly torch.compile triton kernelsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
PyTorch's TorchDynamo fails when passing NamedTupleVariable to Triton kernels, raising "Unexpected argument type for a Triton kernel". It would be nice to support named tuple arguments since it makes writing Triton kernels far cleaner.
import torch
import typing
import triton
from torch.profiler import profile, record_function, ProfilerActivity
class T1(typing.NamedTuple):
foo: None = None
bar: None = None
class T2(typing.NamedTuple):
foo: T1 = T1()
bar: T1 = T1()
class T3(typing.NamedTuple):
foo: T2 = T2()
bar: T2 = T2()
class T4(typing.NamedTuple):
foo: T3 = T3()
bar: T3 = T3()
class T5(typing.NamedTuple):
foo: T4 = T4()
bar: T4 = T4()
@triton.jit
def test(t5: T5):
pass
if __name__ == "__main__":
t5 = T5()
@torch.compile(mode="max-autotune-no-cudagraphs", fullgraph=True)
def main():
for i in range(100):
test[(1,)](t5)
main()
Alternatives
No response
Additional context
No response
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @chauhang @penguinwu @voznesenskym @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @amjames @oulgen @aakhundov @davidberard98
Metadata
Metadata
Assignees
Labels
dynamo-triage-jan2025module: dynamomodule: fxmodule: user tritonrelated to ability to directly torch.compile triton kernelsrelated to ability to directly torch.compile triton kernelsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module