Highlights
We are excited to announce the 0.15.0 release of torchao! This release adds:
- MXFP8 MoE training demonstrates 1.2x e2e training speedup with identical convergence versus bf16, training Llama4 Scout on a 64 node GB200 Crusoe cluster!
- MXFP8 MoE kernels shipped with torchao builds for CUDA 12.8+ (just pip install instead of building from source to use!)
- Safetensors enablement
- Quantization with parameter level targeting
MXFP8 MoE training demonstrates 1.2x e2e training speedup with identical convergence versus bf16, training Llama4 Scout on a 64 node GB200 Crusoe cluster
Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout demonstrated a 1.2x e2e training speedup with equivalent convergence to bfloat16 training baseline. In fact, after 3,000 steps it finishes with slightly lower loss than bfloat16! This is consistent with our scaling experiments with MXFP8 training for dense models.
| Number of GPUs | BF16 tokens/sec | MXFP8 tokens/sec | MXFP8 speedup vs BF16 |
|---|---|---|---|
| 512 | 6169 | 7401 | 1.20x |
See the TorchAO MXFP8 MoE training documentation for more details. You can also check out the TorchTitan MXFP8 documentation to run pretraining jobs with TorchAO MXFP8 by adding a single config.
Safetensors Enablement
You can now save and load TorchAO model checkpoints using safetensors! This feature is integrated with Hugging Face transformers starting from v5.0.0 and vLLM 0.13.0 for model inference/serving.
We currently support the following stable configs:
Float8DynamicActivationFloat8WeightConfig
Int4WeightOnlyConfig
IntxWeightOnlyConfig
Int8DynamicActivationIntxWeightConfig
Int8WeightOnlyConfig
Int8DynamicActivationInt8WeightConfig
and will continue to add support for configs as they become stable in the future.
Example:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from torchao.quantization import Float8WeightOnlyConfig
model_id = "facebook/opt-125m"
quant_config = Float8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
#### Push to hub
MODEL_NAME = model_id.split("/")[-1]
save_to = f"torchao-testing/{MODEL_NAME}-Float8WeightOnlyConfig-v2-0.15.0.dev-safetensors"
quantized_model.push_to_hub(save_to, safe_serialization=True)
tokenizer.push_to_hub(save_to)To serve a safetensors model checkpoint on vLLM, use the CLI command vllm serve torchao-testing/Qwen3-8B-INT4-0.15.0dev-safetensors.
int4 weight only quantization with preshuffled packing format is supported in vllm (#3245, #26066)
Now we support running int4 weight only with preshuffled packing format in vllm. We can ship int4 weight only quantized checkpoint with plain packing format (e.g. https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v2-0.14.0.dev), and the Tensor with plain format will be converted to preshuffled format automatically in SM90+ and when fbgemm_gpu_genai is available.
Support for quantizing any parameter by its fully qualified name (FQN) via FqnToConfig (#3083)
import torch
from torchao.quantization import quantize_, Float8WeightOnlyConfig, FqnToConfig
class Example(torch.nn.Module):
def __init__(self):
super().__init__()
self.custom_name = torch.nn.Parameter(torch.Tensor([1]))
model = Example()
config = FqnToConfig({"custom_name": Float8WeightOnlyConfig()})
# filter_fn is disabled when quantizing by fqn
quantize_(model, config, filter_fn=None)
print(model.custom_name)
#Float8Tensor(self.act_quant_kwargs=None, self.qdata=tensor([448.], dtype=torch.float8_e4m3fn), self.scale=tensor([0.0022]), self.block_size=[1], self.mm_config=None, self.kernel_preference=<KernelPreference.AUTO: 'auto'> self.shape=torch.Size([1]), self.device=device(type='cpu'), self.dtype=torch.float32)This is to better support MoE models (Llama4, Deepseek, gpt-oss) which store their weights under attributes like "down_proj", "gate_up_proj", etc. Previously we assumed that the weights were always stored under "weights" as we were primarily focused on quantizing nn.Linear layers.
Currently parameter quantization is only enabled for Int4WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig, and Float8WeightConfig.
BC Breaking
- Remove config functions like
int4_weight_only(#3145)
Before:
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
uintx_weight_only,
)
quantize_(model, float8_dynamic_activation_float8_weight())
quantize_(model, float8_static_activation_float8_weight(torch.randn(3)))
quantize_(model, float8_weight_only())
quantize_(model, fpx_weight_only(3, 2))
quantize_(model, gemlite_uintx_weight_only())
quantize_(model, int4_dynamic_activation_int4_weight())
quantize_(model, int4_weight_only())
quantize_(model, int8_dynamic_activation_int4_weight())
quantize_(model, int8_dynamic_activation_int8_weight())
quantize_(model, int8_weight_only())
quantize_(model, uintx_weight_only(torch.uint4))
After:
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
FPXWeightOnlyConfig,
GemliteUIntXWeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
UIntXWeightOnlyConfig,
)
quantize_(model, Float8DynamicActivationFloat8WeightConfig())
quantize_(model, Float8StaticActivationFloat8WeightConfig(torch.randn(3)))
quantize_(model, Float8WeightOnlyConfig())
quantize_(model, FPXWeightOnlyConfig(3, 2))
quantize_(model, GemliteUIntXWeightOnlyConfig())
quantize_(model, Int4DynamicActivationInt4WeightConfig())
quantize_(model, Int4WeightOnlyConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig())
quantize_(model, Int8DynamicActivationInt8WeightConfig())
quantize_(model, Int8WeightOnlyConfig())
quantize_(model, UIntXWeightOnlyConfig(torch.uint4))
- Add quantize_ nn.Parameter support (#3083)
Before:
model = torch.nn.Sequential(
torch.nn.Linear(128, 128),
torch.nn.Linear(128, 128),
torch.nn.Conv2d(128, 128, 3, 1, 1),
).cuda().to(torch.bfloat16)
config = ModuleFqnToConfig({
"0": Float8DynamicActivationFloat8WeightConfig(),
})
# these are equivalent
quantize_(model, config, filter_fn=_is_linear)
quantize_(model, config, filter_fn=None)
quantize_(model, config)
After:
# VALID: user must specify None
quantize_(model, config, filter_fn=None)
# INVALID: these now error!
quantize_(model, config, filter_fn=_is_linear)
quantize_(model, config)
- Remove old TORCH_VERSION variables (#3146)
Before:
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8
if TORCH_VERSION_AT_LEAST_2_8:
print("PyTorch version was 2.8+")
After: (no replacement)
- 4/x: mx cleanup: use kernel_preference instead of gemm_kernel_choice (#3385)
Deprecations
- Add deprecation warnings for various inference configs (#3294)
- 2/x mx cleanup: remove pack_fp6 (#3382)
New Features
- Add learnable_fake_quantize in pt2e (#3135)
- Add learnablefakequantize to pt2e flow (#3170)
- Add per tensor fp8 quantization support for conv3d (#3215)
- Add a_1_128_w_128_128 (DeepSeek) float8 scaling for inference (#3257)
- Support int8 output for scaled_embedding_bag (#3231)
- Add per tensor fp8 conv2d support (#3315)
- add Float8OpaqueTensor for dynamic float8 act float8 weight (#3075)
- Add NPU (Ascend) backend support for INT4 weight-only quantization workflow (#3172)
- Add conv2d support for IntxUnpackedToInt8Tensor (#3371)
- Introduce SINQ calibration-free quantization algorithm (#3156)
- Add 2.9.1 compatibility (#3390)
Improvement
- Build SmoothQuant release pipeline (#3010)
- Fix nvfp4 serialization (#3140)
- Add parq utility to create an optimizer (#3165)
- Drop old quantization flows (#3115)
- Fix TORCHAO_SKIP_LOADING_SO_FILES behavior (#3189)
- Mxtensor: make scale shape match qdata (#3198)
- Mxtensor: add pre-swizzle support (#3200)
- Fix mxfp8 matmul benchmark (#3221)
- Mxfp8 inference roofline: add fusion to observed (#3223)
- Enable custom MKN in inference roofline script (#3224)
- Mx: support inference_mode and rank 3+ (#3238)
- Nvfp4: support inference_mode and rank 3 (#3240)
- Move float8 blockwise kernels out of prototype (#3256)
- Add bias handling for a_1_128_w_128_128 float8 scaling (#3259)
- Updating flatten/unflatten functions (#3282)
- [mxfp8 moe training] compute prefix sum of group sizes inside kernel intead of precomputing (#3285)
- Make float8 a1x128_w128x128 granularity serializeable (#3279)
- Update Float8Tensor for GRPO training in unsloth (#3158)
- Skip quantization when channels_out / channels_in are not multiple of 16 (#3309)
- Enable
PerRow(axis)to support axes other than-1(#3303) - Add str to FqnToConfig to make printing more readable (#3323)
- Add support for e2e benchmark for conv2d/conv3d (#3329)
- Add support for
torch.chunkto float8tensor (#3334) - [CI] Update docker image for XPU test (#3336)
- [pt2e] Fix QAT annotations for special qspecs (#3337)
- Modify unflatten for vllm (#3297)
- Align memory_format for conv2d/3d in Float8Tensor with hp Tensor (#3352)
- Replace raise by warning to unblock unrecognized torchao version (#3338)
- Add PerBlock to safe globals (#3370)
- Unbreak BC, add back PerRow, PerTensor imports (#3396)
- Fix fbcode loading cpp extensions (#3372)
- Extend TestQAT module with xpu testcases (#3177)
- FP8 Blockwise Training:
triton_opfor dense model (#3402) - 5/x mx cleanup: rename mx inf config to MXDynamicActivationMXWeightConfig (#3386)
- 7/x mx cleanup: standardize nvfp4 config names (#3398)
- Introduce int8 quantization api (version 2) (#3391)
- Add mxfp8 and nvfp4 to Llama eval scripts (#3394)
- Flip mx inference scaling setting to RCEIL (#3428)
- Int8Tensor migration cleanup (#3407)
- Enable optim SR test (#3055)
Bug Fixes
- Mxfp4 and nvfp4: align fp4 packing to PyTorch Core definition (#3123)
- Fix broken test_affine_quantized_tensor_parallel test after DeviceMesh (#3151)
- Properly thread through use_triton_kernel (#3155)
- Fix AQT test (#3208)
- Require --no-build-isolation in torchao builds (#3209)
- Fix nf4 test that is failing in CI (#3216)
- Fix Wq size check (#3228)
- Makes fallback float8 1x128 by 128x128 gemm output bfloat16 (#3265)
- Float8 inference: fix bmm semantics (#3296)
- Add generic TorchAOTensor
extra_reprfor nn.Modules (#3328) - [mxfp8 moe training] fix bug introduced in #3385 (#3417)
- Revert "[xpu][test] Port 2 test/prototype/test_{parq, quantized_training} UT files to intel XPU" (#3432)
- Bump python version in tutorial ci workflow (#3439)
Performance
- Call named_modules once per model prepare (#3159)
Documentation
- Add mx and nv inference example to README.md (#3141)
- Fix setuptools version for docs build (#3150)
- Small updates to main torchao README.md (#3160)
- Remove LLaMa 2 from quantization README.md (#3161)
- [moe training] update readme (#3163)
- [moe training] update benchmarks using more reliable b200 gpu (#3174)
- Simplify quantization README.md (#3162)
- Mx_formats: Update README.md (#3210)
- Mx_formats - add compile to training example (#3211)
- [mxfp8] update readme with mxfp8 moe training prototype and mxfp8 training blog (#3207)
- Update TorchAO README inference section before PTC (#3206)
- Update QAT README before PTC (#3214)
- Update README.md (#3225)
- Add Unsloth + QAT blog to latest news (#3227)
- Docs: fix qat description in README.md (#3212)
- [moe training] update readme with links, cleanup (#3239)
- Update torchao + unsloth integration on README (#3267)
- [mxfp8 moe training][BE] add docs showing equivalent convergence to bf16 at scale (#3312)
- Update main README.md with recommended workflows (#3364)
- Make support table in README.md more concise (#3369)
- [mxfp8 moe training] update readme with rooflines and benchmarks (#3399)
- Update to new PT Theme (#2361)
- Add an example for quantizing LLaMa 4 Scout (#3408)
Developers
Not User Facing
- Add INT8-INT4-HQQ to model release script (#3127)
- Add outlier in AWQ test cases (#3106)
- [Inductor][float8] Support qlinear for float8 in inductor (#2565)
- Revert D82355346 (#3132)
- Include multi-modal eval in eval scripts (#3133)
- Update version to 0.15.0 (#3130)
- Revert "add learnable_fake_quantize in pt2e" (#3142)
- Fix rocm CI (#3136)
- [BE] Remove Float8Linear from quant_api.py (#3085)
- [Reland][CPU] Add ops for float8 linear (#3100)
- Fix pt2e test_qat_preserve_source_fn_stack (#3149)
- Update roofline benchmark with mxfp4 (#3152)
- [BE] [moe training] generic bench script for torchtitan models (#3124)
- [BE] [moe training] bench script for single device moe layer (#3126)
- Rename MXTensor's _scale_e8m0 to scale (#3164)
- Rename NVFP4Tensor's _scale_e4m3 field to scale (#3166)
- Rename NVFP4Tensor's _per_tensor_scale and _act_per_tensor_scale fields (#3168)
- Rename MXTensor and NVFP4Tensor's to_dtype to dequantize (#3169)
- [FP8 SDPA] Enable FP8 SDPA pattern match (#3076)
- Fixup missing callsites for _scale_e8m0 -> scale rename (#3173)
- [benchmarks] Add inference-only roofline for float8 (#3167)
- [mxfp8 moe training] add triton kernel for mxfp8 quantization along dim0 (#3128)
- Revert "[Reland][CPU] Add ops for float8 linear" (#3179)
- Fix SyntaxWarning during installation / import (#3184)
- Update compatibility matrix (#3178)
- Add nvfp4 cast benchmarks (#3188)
- [mxfp8 moe training] integrate mxfp8 dim0 cast triton kernel (#3186)
- Revert "Remove config functions like
int4_weight_only(#3145)" (#3192) - [CI] Use single-GPU runners for ROCm MI3xx CI (#3138)
- Enable SmoothQuant Test on Intel GPU (#3185)
- [mxfp8 moe training] update benchmarks to force load balancing (#3193)
- Extend mxfp8 roofline with more recipes (#3190)
- Extend inference roofline with real benchmarks (#3194)
- Add option to save profiling traces in inference roofline script (#3196)
- Add missing eval_mm_quality.sh (#3204)
- Fix typo in AQT test (#3205)
- Fix regression_test_aarch64 (#3217)
- Fix is metadata func (#3220)
- [moe training] change api _scaled_grouped_mm -> _quantize_then_scaled_grouped_mm (#3218)
- [mxfp8 moe training] make compile vs triton for dim0 cast configurable (#3219)
- Only convert to int4 preshuffled tensor in H100 (#3245)
- [mxfp8 moe training] add triton kernel for mxfp8 dequantization (#3195)
- [mxfp8 moe training] integrate triton quant/dequant kernels into mxfp8 all to all (#3197)
- [mxfp8 moe training] simplify e8m0 -> fp32 calc (#3201)
- [mxfp8 moe training] bench and profile mxfp8 a2a fwd and bwd separately (#3203)
- [mxfp8 moe training] initialize zero tensor differently to avoid d2h sync (#3253)
- Improve INT8 SDPA template (#3230)
- [mxfp8 moe training] update readme and tests (#3260)
- [mxfp8] fix test nan != nan issue (#3273)
- Set model parameters to requires_grad=False (#3272)
- [mxfp8 moe training] make scaling mode configurable and make rceil default (#3271)
- Use nn_module_stack instead (#3268)
- Fix unit test to use no grad (#3283)
- Fix setup.py to skip CPU kernels on Windows (#3187)
- [mxfp8 moe training] update benchmarks and tests; simplify per group blocked swizzle ref function (#3286)
- Revert "Revert "[Reland][CPU] Add ops for float8 linear" (#3179)" (#3254)
- Fix issue: No module named 'fbgemm_gpu.experimental' (#3292)
- Make
fqn_matches_fqn_configpublic and fix module device loading (#3302) - Fix module name extraction logic in quant_api.py (#3298)
- Pin pytest==8.4.2 (#3321)
- Update common used toy linear model (#3275)
- Use conda libgcc-ng 11.2 (#3327)
- Use conda libgcc-ng 11.2 for nightly tests (#3326)
- Add torch 2.9 in regression tests (#3311)
- Fix microbenchmarking run (#3346)
- Add
enable_fusion_modelingfor conv2d and conv3d (#3343) - [test] Port 2 test/quantization_{gptq, quant_primitive} UT files to intel XPU (#3350)
- Refactor Hadamard matrices loading (#3344)
- Move float8_opaque_tensor to prototype (#3365)
- Fixes accuracy error for mxfp8 linear (#3357)
- Re: #3290 FP8 Blockwise Training Tracker, quantization benchmarks (#3306)
- Gemm benchmark for #3290: replaced torch._scaled_mm with torch.nn.functional.scaled_mm (#3342)
- Update model name in dashboard to reflect api name and shape (#3358)
- Move Int4OpaqueTensor to prototype (#3378)
- [PT2E][X86] Add Inductor fusion passes of float8 qconv for X86Inductor backend (#3261)
- Fix style after #3261 (#3397)
- [mxfp8 moe training] add roofline script (#3388)
- 1/x mx cleanup: mxtensor: rename _block_size to block_size (#3381)
- 3/x mx cleanup: speed up local test suite (#3383)
- 6/x mx cleanup: make NVFP4Tensor use base implements (#3387)
- Revert "Fix style after #3261" (#3412)
- Revert "[PT2E][X86] Add Inductor fusion passes of float8 qconv for X86Inductor backend" (#3413)
- [Reland][PT2E][X86] Add Inductor fusion passes of float8 qconv for X8… (#3418)
- Add test for constant folding in pt2e quant (#3420)
- Revert "[Reland][PT2E][X86] Add Inductor fusion passes of float8 qconv for X8…" (#3427)
- Add CLAUDE.local.md to gitignore (#3437)
- Reland qconv fp8 fusion passes (#3433)
- [test] Port 2 test/quantization/pt2e/test_{quantize_pt2e, quantize_pt2e_qat} UT files to intel XPU (#3405)
- Skip certain mxfp8 tests for cuda < 12.8 (#3443)
New Contributors
- @Krishn1412 made their first contribution in #2866
- @Mingming-Ding made their first contribution in #3132
- @DiweiSun made their first contribution in #3027
- @KaiserLeave made their first contribution in #3338
- @agolajko made their first contribution in #3306
- @jiayisunx made their first contribution in #3261
Full Changelog: v0.14.0-rc3...v0.15.0-rc2