eformer (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a comprehensive collection of tools for distributed computing, custom data structures, numerical optimization, and high-performance operations. Eformer aims to make it easier to build, scale, and optimize models efficiently while leveraging JAX's capabilities for high-performance computing.
The library is organized into several core modules:
aparser
: Advanced argument parsing utilities with dataclass integrationcommon_types
: Shared type definitions and sharding constantsescale
: Distributed sharding and parallelism utilitiesexecutor
: Execution management and hardware-specific optimizationsjaximus
: Custom PyTree implementations and structured array utilitiesmpric
: Mixed precision training and dynamic scaling infrastructureoptimizers
: Flexible optimizer configuration and factory patternspytree
: Enhanced tree manipulation and transformation utilities
Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling, enabling faster training and reduced memory footprint.
Tools for efficient sharding and distributed computation in JAX, allowing you to scale your models across multiple devices with various sharding strategies:
- Data Parallelism (
DP
) - Fully Sharded Data Parallel (
FSDP
) - Tensor Parallelism (
TP
) - Expert Parallelism (
EP
) - Sequence Parallelism (
SP
)
Enhanced utilities for creating custom PyTrees and ArrayValue
objects, updated from Equinox, providing flexible data structures for your models.
A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp, making it easy to experiment with different optimization strategies.
For detailed API references and usage examples, see:
- Argument Parser (
aparser
) - Sharding Utilities (
escale
) - Execution Management (
executor
) - Mixed Precision Infrastructure (
mpric
)
You can install eformer
via pip:
pip install eformer
from eformer.mpric import PrecisionHandler
# Create a handler with float8 compute precision
handler = PrecisionHandler(
policy="p=f32,c=f8_e4m3,o=f32", # params in f32, compute in float8, output in f32
use_dynamic_scale=True
)
import jax
from eformer.jaximus import ArrayValue, implicit
class Array8B(ArrayValue):
scale: jax.Array
weight: jax.Array
def __init__(self, array: jax.Array):
self.weight, self.scale = quantize_row_q8_0(array)
def materialize(self):
return dequantize_row_q8_0(self.weight, self.scale)
array = jax.random.normal(jax.random.key(0), (256, 64), "f2")
qarray = Array8B(array)
We welcome contributions! Please read our Contributing Guidelines to get started.
This project is licensed under the Apache License 2.0. See the LICENSE file for details.