Skip to content

torch.softmax(inp, dtype=torch.float32).to(torch.float16) is not equivalent to torch.softmax(inp) for fp16 input #123911

@fxmarty

Description

@fxmarty

🐛 Describe the bug

Hi,

Investigating why a model implementation using SDPA vs no SDPA was not yielding the exact same output using fp16 with the math backend, I pinned it down to a different behavior of torch.softmax(inp, dtype=torch.float32).to(torch.float16) vs torch.softmax(inp) for float16 inputs.

I am sharing here a reproduction where the difference is small. However, it appears to accumulate over layers, and can result in absolute error up to ~0.1-0.2 in my case after 20-30 layers in a real-size model.

import torch
import math

query = torch.load("query_states_sdpa_3.pt")
key = torch.load("key_states_sdpa_3.pt")
value = torch.load("value_states_sdpa_3.pt")

print("query", query.device, query.dtype, query.is_contiguous(), query.shape)
print("key", key.device, key.dtype, key.is_contiguous(), key.shape)
print("value", value.device, value.dtype, value.is_contiguous(), value.shape)

torch.set_printoptions(threshold=1000000, precision=6)

def scaled_dot_product_attention(query, key, value, is_causal: bool, custom_cast: bool):
    scale_factor = math.sqrt(1 / math.sqrt(query.size(-1)))
    assert not is_causal

    softmax_inp = ((query * scale_factor) @ ((key * scale_factor).transpose(-2, -1)))

    if custom_cast:
        attn_weight = torch.softmax(
            softmax_inp,
            dim=-1,
            dtype=torch.float32
        ).to(query.dtype)
    else:
        attn_weight = torch.softmax(softmax_inp, dim=-1)

    return attn_weight @ value, softmax_inp

is_causal = False

with torch.no_grad():
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
        res_sdpa = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=None)

        res_eager_cast, softmax_inp = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=True)

        res_eager_no_cast, _ = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=False)

res_softmax_0 = torch.softmax(softmax_inp, dim=-1)
res_softmax_1 = torch.softmax(softmax_inp, dim=-1, dtype=torch.float32).to(torch.float16)

print("-----")

absdiff = (res_softmax_0 - res_softmax_1).abs()
print("max absdiff softmax", absdiff.max())
print("median absdiff softmax", absdiff.median())

print("-----")

# These cast do not seem to matter.
res_sdpa = res_sdpa.to(torch.float32)
res_eager_cast = res_eager_cast.to(torch.float32)
res_eager_no_cast = res_eager_no_cast.to(torch.float32)

absdiff_nocast = (res_sdpa - res_eager_no_cast).abs()
absdiff_cast = (res_sdpa - res_eager_cast).abs()
print("SDPA max absdiff (no cast):", absdiff_nocast.max())
print("SDPA max absdiff (with cast):", absdiff_cast.max())
print("argwhere absdiff no cast > 0.0001", torch.argwhere(absdiff_nocast > 1e-4))
print("argwhere absdiff with cast > 0.0001", torch.argwhere(absdiff_cast > 1e-4))

Resulting in:

query cuda:0 torch.float16 False torch.Size([1, 32, 1238, 128])
key cuda:0 torch.float16 False torch.Size([1, 32, 1238, 128])
value cuda:0 torch.float16 False torch.Size([1, 32, 1238, 128])
-----
max absdiff softmax tensor(0.000488, device='cuda:0', dtype=torch.float16)
median absdiff softmax tensor(0., device='cuda:0', dtype=torch.float16)
-----
SDPA max absdiff (no cast): tensor(0., device='cuda:0')
SDPA max absdiff (with cast): tensor(0.000122, device='cuda:0')
argwhere absdiff no cast > 0.0001 tensor([], device='cuda:0', size=(0, 4), dtype=torch.int64)
argwhere absdiff with cast > 0.0001 tensor([[   0,    3,   72,   68],
        [   0,   11,   82,    0],
        [   0,   11,   82,   14],
        [   0,   11,   82,  100],
        [   0,   11,   82,  113],
        [   0,   21,   33,  109],
        [   0,   21, 1084,   31],
        [   0,   30,   44,   16],
        [   0,   31,   10,   40]], device='cuda:0')

This difference is suprising to me as I was under the impression following #103167 (comment) that torch.softmax(inp, dtype=torch.float32).to(torch.float16) and torch.softmax(inp) were numerically equivalent, cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @drisspg @cpuhrsch @gante

https://pytorch.org/docs/2.2/generated/torch.nn.functional.softmax.html#torch.nn.functional.softmax explains that the input is casted to fp32 in case dtype=torch.float32 is passed. Not sure I am ready well

if (!half_to_float) {
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) {
auto output_ptr = output.mutable_data_ptr<scalar_t>();
auto input_ptr = input.const_data_ptr<scalar_t>();
int64_t remaining = outer_size;
int64_t chunk_size = (1L << 30L) / dim_size;
while(remaining > 0) {
dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false>(
output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */);
input_ptr += chunk_size * dim_size;
output_ptr += chunk_size * dim_size;
remaining -= chunk_size;
}
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
output.mutable_data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), dim_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();

Thank you!

Versions

PyTorch version: 2.2.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.24.3
Libc version: glibc-2.31

Python version: 3.9.13 (main, Oct 13 2022, 21:15:33)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.3.52
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.129.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      43 bits physical, 48 bits virtual
CPU(s):                             128
On-line CPU(s) list:                0-127
Thread(s) per core:                 2
Core(s) per socket:                 64
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          AuthenticAMD
CPU family:                         23
Model:                              49
Model name:                         AMD EPYC 7742 64-Core Processor
Stepping:                           0
Frequency boost:                    enabled
CPU MHz:                            1847.732
CPU max MHz:                        2250,0000
CPU min MHz:                        1500,0000
BogoMIPS:                           4491.21
Virtualization:                     AMD-V
L1d cache:                          2 MiB
L1i cache:                          2 MiB
L2 cache:                           32 MiB
L3 cache:                           256 MiB
NUMA node0 CPU(s):                  0-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca sme sev sev_es

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.4
[pip3] onnx==1.15.0
[pip3] onnx-graphsurgeon==0.3.27
[pip3] onnx-tool==0.9.0
[pip3] onnxruntime==1.17.1
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.2.2
[pip3] torch-tb-profiler==0.4.0
[pip3] torchaudio==2.2.2
[pip3] torchvision==0.17.2
[pip3] triton==2.2.0
[conda] numpy                     1.23.4                   pypi_0    pypi
[conda] pytorch-triton            3.0.0+989adb9a29          pypi_0    pypi
[conda] torch                     2.2.2                    pypi_0    pypi
[conda] torch-tb-profiler         0.4.0                    pypi_0    pypi
[conda] torchaudio                2.2.2                    pypi_0    pypi
[conda] torchvision               0.17.2                   pypi_0    pypi
[conda] triton                    2.2.0                    pypi_0    pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: halfRelated to float16 half-precision floatsmodule: nnRelated to torch.nnmodule: numerical-stabilityProblems related to numerical stability of operationsmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions