Skip to content

Conversation

eqy
Copy link
Collaborator

@eqy eqy commented Aug 29, 2025

For #161822

Basically #125888 appears to introduce measurable CPU overhead due to TF32 precision setting checks. Unfortunately one of the checks is in getCurrentCUDABlasHandle, which is called preceding every cuBLAS matmul (and in a few other places such as workspace setup). We don't need to do this for non-float32 matmuls so this PR is to alleviate the performance hit where it hurts the most (smaller dtypes that have faster matmuls) at the cost of some copypasta.

Some microbenchmark runs with this PR (microseconds):

7.812199214640714 
8.07804053692962 
7.865882366786536
7.898942214978888
8.018492849259928

and without

8.222943563396257
8.129948014357069
8.184711361991504
8.28010104214627
8.266569921033806

script:

import torch
import time

warmup = 128
iters = 16384

a = torch.zeros(512, 512, device='cuda', dtype=torch.bfloat16)
for _ in range(warmup):
    torch.matmul(a, a)

torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
    torch.matmul(a, a)
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"{1e6 * (t1 - t0)/iters}")

Longer term we'd prefer making float32Precision faster (better data structures, less validation, etc.?)

cc @ptrblck @msaroufim @jerryzh168 @csarofeen @xwang233 @zasdfgbnm

@eqy eqy requested a review from syed-ahmed as a code owner August 29, 2025 21:24
@eqy eqy added module: cuda Related to torch.cuda, and CUDA support in general module: cublas Problem related to cublas support open source module: tf32 Related to tf32 data format topic: not user facing topic category labels Aug 29, 2025
Copy link

pytorch-bot bot commented Aug 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161823

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit b0b4c4b with merge base 93c5112 (image):

NEW FAILURE - The following job has failed:

  • pull / linux-jammy-rocm-py3.10 / build (gh)
    /var/lib/jenkins/workspace/aten/src/ATen/hip/CublasHandlePool.cpp:354:55: error: ‘CUBLAS_DEFAULT_MATH’ was not declared in this scope; did you mean ‘HIPBLAS_DEFAULT_MATH’?

This comment was automatically generated by Dr. CI and updates every 15 minutes.

// See Note [Writing Nondeterministic Operations]
globalContext().alertCuBLASConfigNotDeterministic();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
if (!NoTF32Guard::should_disable_tf32() &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just delegate this to an inline function instead of copy paste

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels mildly wrong... I.e. replacing one check with 30+ in different places are bound to cause errors. Where overhead is coming from? String comparison? Than it should be replaced by enum

@eqy
Copy link
Collaborator Author

eqy commented Aug 29, 2025

This feels mildly wrong... I.e. replacing one check with 30+ in different places are bound to cause errors. Where overhead is coming from? String comparison? Than it should be replaced by enum

I brought this up in the original PR, there's multiple reasons string comparison is used rather than enum, but it's not that simple either as there are multiple levels of string comparison e.g., for backend and for precision #125888
along with excessive validation checks and reference chasing

I think with the changes to inline it would be cleaner.

@ngimel
Copy link
Collaborator

ngimel commented Aug 31, 2025

Should we revert #125888 until we come up with better design? It seems like it created a lot of problems without solving any

@malfet
Copy link
Contributor

malfet commented Sep 1, 2025

Should we revert #125888 until we come up with better design? It seems like it created a lot of problems without solving any

I think reverting is tricky, as we've technically released 2.8 with it, but I agree, may be it's better to just revert.
@albanD : what do you think?

@ngimel
Copy link
Collaborator

ngimel commented Sep 1, 2025

Also I've looked at discussion on #125888 and it seems there were no serious arguments for strings - they want strings at python level, fine, but it's no reason to have string comparison for every cublas call

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cublas Problem related to cublas support module: cuda Related to torch.cuda, and CUDA support in general module: tf32 Related to tf32 data format open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants