-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Describe the bug
Hi,
Investigating why a model implementation using SDPA vs no SDPA was not yielding the exact same output using fp16 with the math backend, I pinned it down to a different behavior of torch.softmax(inp, dtype=torch.float32).to(torch.float16)
vs torch.softmax(inp)
for float16 inputs.
I am sharing here a reproduction where the difference is small. However, it appears to accumulate over layers, and can result in absolute error up to ~0.1-0.2 in my case after 20-30 layers in a real-size model.
import torch
import math
query = torch.load("query_states_sdpa_3.pt")
key = torch.load("key_states_sdpa_3.pt")
value = torch.load("value_states_sdpa_3.pt")
print("query", query.device, query.dtype, query.is_contiguous(), query.shape)
print("key", key.device, key.dtype, key.is_contiguous(), key.shape)
print("value", value.device, value.dtype, value.is_contiguous(), value.shape)
torch.set_printoptions(threshold=1000000, precision=6)
def scaled_dot_product_attention(query, key, value, is_causal: bool, custom_cast: bool):
scale_factor = math.sqrt(1 / math.sqrt(query.size(-1)))
assert not is_causal
softmax_inp = ((query * scale_factor) @ ((key * scale_factor).transpose(-2, -1)))
if custom_cast:
attn_weight = torch.softmax(
softmax_inp,
dim=-1,
dtype=torch.float32
).to(query.dtype)
else:
attn_weight = torch.softmax(softmax_inp, dim=-1)
return attn_weight @ value, softmax_inp
is_causal = False
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
res_sdpa = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=is_causal, attn_mask=None)
res_eager_cast, softmax_inp = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=True)
res_eager_no_cast, _ = scaled_dot_product_attention(query, key, value, is_causal=is_causal, custom_cast=False)
res_softmax_0 = torch.softmax(softmax_inp, dim=-1)
res_softmax_1 = torch.softmax(softmax_inp, dim=-1, dtype=torch.float32).to(torch.float16)
print("-----")
absdiff = (res_softmax_0 - res_softmax_1).abs()
print("max absdiff softmax", absdiff.max())
print("median absdiff softmax", absdiff.median())
print("-----")
# These cast do not seem to matter.
res_sdpa = res_sdpa.to(torch.float32)
res_eager_cast = res_eager_cast.to(torch.float32)
res_eager_no_cast = res_eager_no_cast.to(torch.float32)
absdiff_nocast = (res_sdpa - res_eager_no_cast).abs()
absdiff_cast = (res_sdpa - res_eager_cast).abs()
print("SDPA max absdiff (no cast):", absdiff_nocast.max())
print("SDPA max absdiff (with cast):", absdiff_cast.max())
print("argwhere absdiff no cast > 0.0001", torch.argwhere(absdiff_nocast > 1e-4))
print("argwhere absdiff with cast > 0.0001", torch.argwhere(absdiff_cast > 1e-4))
Resulting in:
query cuda:0 torch.float16 False torch.Size([1, 32, 1238, 128])
key cuda:0 torch.float16 False torch.Size([1, 32, 1238, 128])
value cuda:0 torch.float16 False torch.Size([1, 32, 1238, 128])
-----
max absdiff softmax tensor(0.000488, device='cuda:0', dtype=torch.float16)
median absdiff softmax tensor(0., device='cuda:0', dtype=torch.float16)
-----
SDPA max absdiff (no cast): tensor(0., device='cuda:0')
SDPA max absdiff (with cast): tensor(0.000122, device='cuda:0')
argwhere absdiff no cast > 0.0001 tensor([], device='cuda:0', size=(0, 4), dtype=torch.int64)
argwhere absdiff with cast > 0.0001 tensor([[ 0, 3, 72, 68],
[ 0, 11, 82, 0],
[ 0, 11, 82, 14],
[ 0, 11, 82, 100],
[ 0, 11, 82, 113],
[ 0, 21, 33, 109],
[ 0, 21, 1084, 31],
[ 0, 30, 44, 16],
[ 0, 31, 10, 40]], device='cuda:0')
This difference is suprising to me as I was under the impression following #103167 (comment) that torch.softmax(inp, dtype=torch.float32).to(torch.float16)
and torch.softmax(inp)
were numerically equivalent, cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @drisspg @cpuhrsch @gante
https://pytorch.org/docs/2.2/generated/torch.nn.functional.softmax.html#torch.nn.functional.softmax explains that the input is casted to fp32 in case dtype=torch.float32
is passed. Not sure I am ready well
pytorch/aten/src/ATen/native/cuda/SoftMax.cu
Lines 723 to 742 in cea899c
if (!half_to_float) { | |
if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { | |
auto output_ptr = output.mutable_data_ptr<scalar_t>(); | |
auto input_ptr = input.const_data_ptr<scalar_t>(); | |
int64_t remaining = outer_size; | |
int64_t chunk_size = (1L << 30L) / dim_size; | |
while(remaining > 0) { | |
dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false>( | |
output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */); | |
input_ptr += chunk_size * dim_size; | |
output_ptr += chunk_size * dim_size; | |
remaining -= chunk_size; | |
} | |
} else { | |
constexpr int ILP = sizeof(float4) / sizeof(scalar_t); | |
dim3 block = SoftMax_getBlockSize(ILP, dim_size); | |
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue> | |
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>( | |
output.mutable_data_ptr<scalar_t>(), input.const_data_ptr<scalar_t>(), dim_size); | |
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
Thank you!
Versions
PyTorch version: 2.2.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.24.3
Libc version: glibc-2.31
Python version: 3.9.13 (main, Oct 13 2022, 21:15:33) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.3.52
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA DGX Display
GPU 4: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.129.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7742 64-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 1847.732
CPU max MHz: 2250,0000
CPU min MHz: 1500,0000
BogoMIPS: 4491.21
Virtualization: AMD-V
L1d cache: 2 MiB
L1i cache: 2 MiB
L2 cache: 32 MiB
L3 cache: 256 MiB
NUMA node0 CPU(s): 0-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca sme sev sev_es
Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.4
[pip3] onnx==1.15.0
[pip3] onnx-graphsurgeon==0.3.27
[pip3] onnx-tool==0.9.0
[pip3] onnxruntime==1.17.1
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.2.2
[pip3] torch-tb-profiler==0.4.0
[pip3] torchaudio==2.2.2
[pip3] torchvision==0.17.2
[pip3] triton==2.2.0
[conda] numpy 1.23.4 pypi_0 pypi
[conda] pytorch-triton 3.0.0+989adb9a29 pypi_0 pypi
[conda] torch 2.2.2 pypi_0 pypi
[conda] torch-tb-profiler 0.4.0 pypi_0 pypi
[conda] torchaudio 2.2.2 pypi_0 pypi
[conda] torchvision 0.17.2 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi