-
Notifications
You must be signed in to change notification settings - Fork 3.2k
[Mosaic GPU] Pass in TMA descriptors through kernel parameters #22175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45015cd
to
2eb5a48
Compare
As we've established (sigh) we can't pass in TMA descriptors through global memory. The current workaround was to use constant memory instead, but this raises a number of potential concurrency issues. So, instead, we use the freshly added support for grid_constant parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work fine and should in fact have lower overheads than both previous methods. PiperOrigin-RevId: 648744363
2eb5a48
to
265a54d
Compare
embg
added a commit
to triton-lang/triton
that referenced
this pull request
Aug 19, 2024
## Motivation Currently, Triton passes TMA descriptors by-ref through global memory. This has a number of problems: * Significant launch overhead (5-10us) for the host-to-device memcpy * Users must insert fences for TMA descriptor cache flush (see #4342). When users don't insert these fences correctly, they run into very strange bugs: #4332 * The memcpy makes it nearly impossible to use cudagraphs There are two possible solutions: * [Pass the descriptor by-value](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#using-tma-to-transfer-multi-dimensional-arrays) * [Create the descriptor on-device](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device) Because of the tricky memory model for TMA descriptors on H100, creating a descriptor on-device requires moving data back and forth from L2 cache. This is relatively expensive (100s of cycles at least) and requires the user or compiler to correctly insert release/acquire fences. In some cases, there is no way to avoid creating the descriptor on-device. But for many use-cases, it's perfectly fine to set up the descriptor on the host and pass by-value, avoiding both performance and correctness issues. This PR implements the by-value functionality. ## User-level API Whenever the user provides a kernel param which implements the method `tma_desc_cpu_ptr()`, Triton will lower that argument to a `__grid_constant__` by-value param. The existing helper methods `create_[1d/2d]_tma_descriptor` were modified to return such a type, so existing code does not need any changes to take advantage of the new feature. ## Implementation details When a kernel param with `tma_desc_cpu_ptr()` is detected, we attach an attribute to that param at the TTIR level. The attribute is passed through to TTGIR. When lowering TTGIR to LLIR, we use code ported from Mosaic (jax-ml/jax#22175) to set up the correct LLVM attributes. The runtime is also modified to pass by-value TMA descriptors properly. ## Limitations This feature is currently broken when compiling an `IRSource` directly (which is useful for editing IR and re-compiling). That would require updating some [regexes](https://github.com/triton-lang/triton/blob/edcc2bcb8dd2e9224c94b689df9cbb7d2986ebea/python/triton/compiler/compiler.py#L52-L53) which infer the function signature from the IR. `IRSource` compilation still works fine for kernels which do not use the new feature. Once the approach I'm taking here is reviewed, I plan to fix that limitation, either in this PR or in a follow-up PR.
bertmaher
pushed a commit
to bertmaher/triton
that referenced
this pull request
Dec 10, 2024
## Motivation Currently, Triton passes TMA descriptors by-ref through global memory. This has a number of problems: * Significant launch overhead (5-10us) for the host-to-device memcpy * Users must insert fences for TMA descriptor cache flush (see triton-lang#4342). When users don't insert these fences correctly, they run into very strange bugs: triton-lang#4332 * The memcpy makes it nearly impossible to use cudagraphs There are two possible solutions: * [Pass the descriptor by-value](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#using-tma-to-transfer-multi-dimensional-arrays) * [Create the descriptor on-device](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device) Because of the tricky memory model for TMA descriptors on H100, creating a descriptor on-device requires moving data back and forth from L2 cache. This is relatively expensive (100s of cycles at least) and requires the user or compiler to correctly insert release/acquire fences. In some cases, there is no way to avoid creating the descriptor on-device. But for many use-cases, it's perfectly fine to set up the descriptor on the host and pass by-value, avoiding both performance and correctness issues. This PR implements the by-value functionality. ## User-level API Whenever the user provides a kernel param which implements the method `tma_desc_cpu_ptr()`, Triton will lower that argument to a `__grid_constant__` by-value param. The existing helper methods `create_[1d/2d]_tma_descriptor` were modified to return such a type, so existing code does not need any changes to take advantage of the new feature. ## Implementation details When a kernel param with `tma_desc_cpu_ptr()` is detected, we attach an attribute to that param at the TTIR level. The attribute is passed through to TTGIR. When lowering TTGIR to LLIR, we use code ported from Mosaic (jax-ml/jax#22175) to set up the correct LLVM attributes. The runtime is also modified to pass by-value TMA descriptors properly. ## Limitations This feature is currently broken when compiling an `IRSource` directly (which is useful for editing IR and re-compiling). That would require updating some [regexes](https://github.com/triton-lang/triton/blob/edcc2bcb8dd2e9224c94b689df9cbb7d2986ebea/python/triton/compiler/compiler.py#L52-L53) which infer the function signature from the IR. `IRSource` compilation still works fine for kernels which do not use the new feature. Once the approach I'm taking here is reviewed, I plan to fix that limitation, either in this PR or in a follow-up PR.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[Mosaic GPU] Pass in TMA descriptors through kernel parameters
As we've established (sigh) we can't pass in TMA descriptors through global memory.
The current workaround was to use constant memory instead, but this raises a number of
potential concurrency issues. So, instead, we use the freshly added support for grid_constant
parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work
fine and should in fact have lower overheads than both previous methods.