Skip to content

Inexact results of VMap operation due to optimization in linalg.solve #151440

@Flamefire

Description

@Flamefire

🐛 Describe the bug

I've investigated #151113 and #114868 and traced the issue to _linalg_solve_ex_out.

It only happens on AMD CPUs but not on Intel CPUs in the scale that fails the test. It happens with both OpenBLAS and MKL although the differences are slightly different.

There is an "optimization" using a transposed input in some cases.

TLDR: Disabling the optimization and the other side of it resolved both issues.

The test cases run the linalg.(tensor_)solve function twice. First directly and then with the same input duplicated as a batch of 2 with vmap.

  • linalg_solve_ex_out is called with those in both cases and the same inputs (except for the batched duplication in the vmap case)
  • this calls linalg_lu_factor_ex_out which first calls linalg_lu_factor_ex_out and then linalg_lu_solve_out
  • The result is supposed to be the same but there are slight differences, e.g. (regular vs vmap):
- -15.8471, -12.4022, -17.0307, -12.6871, 29.1342, -13.0953, -6.9707,  -14.4058, 24.0526, 5.87875, 2.9288,  -7.22714, 
+ -15.8453, -12.4006, -17.0288, -12.6856, 29.1309, -13.0939, -6.96982, -14.4041, 24.0499, 5.87819, 2.92857, -7.22624, 

This then later causes larger differences, e.g. the largest absolute difference is in an element 492.4144 != 492.3525 which then fails the test allowing at most 1e-4

I think the optimization can be safely removed as it is seemingly outdated.

Possible optimization: Compute the LU factorization of A^T if A is contiguous
Then we solve A^T X = B with adjoint=True
This saves a copy as A doesn't need to be copied into an F-contig matrix in lu_factor

But in linalg_lu_factor_ex_out the only copy is done when !LU.is_same(A) but LU is a new Tensor (at least in this codepath) and even if it is not I don't think A.mT() can be the same as LU, can it?

There is another potential copy being done in linalg_lu_solve_out conditioned on LU.mT().is_contiguous(). But in all tests cases of this test with and without the optimization LU.mT() is always contiguous.
If this is the case in general or at least "usually" that "optimization" can be removed to ensure better results.

Versions

Pretty much all recent-ish PyTorch versions independent of other versions, but only on AMD CPUs

CPU:
Architektur: x86_64
CPU Operationsmodus: 32-bit, 64-bit
Adressgrößen: 43 bits physical, 48 bits virtual
Byte-Reihenfolge: Little Endian
CPU(s): 256
Liste der Online-CPU(s): 0-255
Anbieterkennung: AuthenticAMD
Modellname: AMD EPYC 7702 64-Core Processor
Prozessorfamilie: 23
Modell: 49
Thread(s) pro Kern: 2
Kern(e) pro Sockel: 64
Sockel: 2
Stepping: 0
Übertaktung: aktiviert
Skalierung der CPU(s): 69%
Maximale Taktfrequenz der CPU: 2183,5930
Minimale Taktfrequenz der CPU: 1500,0000
BogoMIPS: 4000,22
Markierungen: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Virtualisierung: AMD-V

cc @jianyuh @nikitaved @mruberry @walterddr @xwang233 @lezcano @zou3519 @Chillee @samdow @kshitij12345

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: functorchPertaining to torch.func or pytorch/functorchmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: numerical-stabilityProblems related to numerical stability of operationsmodule: vmaptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions