Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,11 @@ def PreCanonicalizationOptimizationPass : Pass<"pre-canonicalization-optimizatio
"::mlir::vector::VectorDialect",
"::mlir::tpu::TPUDialect",
];
let options = [
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"6", "">,
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
];
}

#endif // TPU_ATTRS
4 changes: 3 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{});

std::unique_ptr<OperationPass<func::FuncOp>>
createPreCanonicalizationOptimizationPass();
createPreCanonicalizationOptimizationPass(
int hardware_generation = -1,
std::array<int64_t, 2> target_shape = {8, 128});

std::unique_ptr<OperationPass<func::FuncOp>>
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
Expand Down
60 changes: 0 additions & 60 deletions jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1422,66 +1422,6 @@ FailureOr<Value> canonicalize_vector_transpose(const CanonicalizeContext &ctx,
return new_op;
}

// Finds the split point for a reshape between a multi-dimensional shape and a
// shape where a suffix has been collapsed into a single dimension.
//
// This function checks if `src_shape` and `tgt_shape` follow the pattern:
// src_shape: (P..., S_1, S_2, ..., S_N)
// tgt_shape: (P..., T_collapsed)
// where `P` is a common prefix and `product(S_1..S_N) == T_collapsed`.
//
// It handles a differing number of leading 1s in the prefix by stripping them
// from both shapes before comparison.
//
// This utility is used for two inverse patterns:
// 1. Collapse (e.g., `load` -> `reshape`): The function is called directly,
// where `src_shape` is the multi-dimensional pre-reshape vector shape.
// 2. Expand (e.g., `reshape` -> `store`): The function is called with swapped
// arguments, where `src_shape` is the multi-dimensional *post-reshape*
// vector shape.
//
// Returns:
// - A pair containing:
// 1. The index in `src_shape` where the collapsing suffix begins.
// 2. The product of the collapsed dimensions excluding the innermost one
// (i.e., product(S_1..S_{N-1})), used as the "sublane product".
// - `std::nullopt` if the shapes do not match the pattern.
std::optional<std::pair<int, int>> findSplitPoint(ArrayRef<int64_t> src_shape,
ArrayRef<int64_t> tgt_shape) {
int s = 0, t = 0;
// drop leading 1s
while (s < src_shape.size() && src_shape[s] == 1) {
++s;
}
while (t < tgt_shape.size() && tgt_shape[t] == 1) {
++t;
}

// Find the end of the common prefix between the shapes (ignoring leading 1s).
int s_prefix_end = s, t_prefix_end = t;
while (s_prefix_end < src_shape.size() && t_prefix_end < tgt_shape.size() &&
src_shape[s_prefix_end] == tgt_shape[t_prefix_end]) {
++s_prefix_end;
++t_prefix_end;
}

// After the common prefix, the rest of the target shape must consist of just
// one dimension (the collapsed one).
if (t_prefix_end != tgt_shape.size() - 1) {
return std::nullopt;
}
int64_t src_prod = 1;
for (int i = s_prefix_end; i < src_shape.size(); ++i) {
src_prod *= src_shape[i];
}

if (tgt_shape.back() != src_prod) {
return std::nullopt;
}
src_prod /= src_shape.back();
return std::make_pair(s_prefix_end, src_prod);
}

FailureOr<Value> canonicalize_shape_cast(const CanonicalizeContext& ctx,
Operation& raw_op) {
CanonicalBuilder builder(ctx, raw_op.getLoc(), &raw_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@ limitations under the License.
==============================================================================*/

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
Expand All @@ -29,9 +35,11 @@ limitations under the License.
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"

namespace mlir::tpu {

Expand All @@ -41,6 +49,184 @@ namespace mlir::tpu {

namespace {

void CanonicalizeStore(int hardware_generation,
std::array<int64_t, 2> target_shape, Operation& raw_op) {
// Fuses a vector.shape_cast (that expands dimensions) into a subsequent
// vector.store or dense tpu.vector_store. This is the inverse of the
// canonicalize_reshape func.
Value value_to_store;
TypedValue<MemRefType> base;
ValueRange indices;

Operation* store_op;

if (auto store = dyn_cast<vector::StoreOp>(raw_op)) {
store_op = store.getOperation();
value_to_store = store.getValueToStore();
base = store.getBase();
indices = store.getIndices();
} else if (auto store = dyn_cast<tpu::VectorStoreOp>(raw_op)) {
store_op = store.getOperation();
value_to_store = store.getValueToStore();
base = store.getBase();
indices = store.getIndices();
if (!store.getStrides().empty() || store.getMask() || store.getAdd()) {
return;
}
} else {
return;
}

// Look for vector::ShapeCastOp feeding the store
auto shape_cast_op =
dyn_cast_if_present<vector::ShapeCastOp>(value_to_store.getDefiningOp());
if (!shape_cast_op || !shape_cast_op.getResult().hasOneUse()) {
return;
}

auto src_ty = shape_cast_op.getSource().getType();
auto tgt_ty = shape_cast_op.getResult().getType();
auto memref_ty = base.getType();

if (tgt_ty.getShape() != memref_ty.getShape()) {
return;
}
if (!isContiguousMemref(base)) {
return;
}
if (src_ty.getRank() > tgt_ty.getRank()) {
return;
}
auto last_src_lanes = src_ty.getShape().back();
if (last_src_lanes % target_shape[1] != 0) {
return;
}
std::optional<std::pair<int64_t, int64_t>> split_opt =
findSplitPoint(tgt_ty.getShape(), src_ty.getShape());
if (!split_opt) {
return;
}
auto [split_point, sublane_prod] = *split_opt;

int64_t bitwidth = src_ty.getElementTypeBitWidth();
int64_t packing = 32 / bitwidth;
if (hardware_generation < 4 && packing > 1) {
return;
}
if (sublane_prod % packing != 0) {
return;
}

ImplicitLocOpBuilder b(store_op->getLoc(), store_op);
auto loc = store_op->getLoc();
auto i32_type = b.getI32Type();
int64_t num_i32_rows = sublane_prod / packing;

SmallVector<int64_t> mem_shape;
if (split_point == 0) {
mem_shape.push_back(sublane_prod);
} else {
mem_shape.assign(memref_ty.getShape().begin(),
memref_ty.getShape().begin() + split_point);
int64_t prev_dim = mem_shape.back();
int64_t new_dim = prev_dim * sublane_prod;
if (sublane_prod != 0 && new_dim / sublane_prod != prev_dim) {
return;
}
mem_shape.back() = new_dim;
}

auto lane_dim = memref_ty.getShape().back();
if (lane_dim != target_shape[1]) {
return;
}
mem_shape.push_back(lane_dim);
Value reshaped_ref = b.create<tpu::MemRefReshapeOp>(
MemRefType::get(mem_shape, memref_ty.getElementType()), base);

*(mem_shape.end() - 2) /= packing;
Value i32_view = b.create<tpu::MemRefBitcastOp>(
MemRefType::get(mem_shape, i32_type), reshaped_ref);

Value src_vec = shape_cast_op.getSource();
SmallVector<int64_t> slice_sizes(src_ty.getShape());
slice_sizes.back() = lane_dim;
SmallVector<int64_t> unit_strides(src_ty.getRank(), 1);

auto i32_view_shape = cast<MemRefType>(i32_view.getType()).getShape();

SmallVector<Value> store_indices;
Value split_base_idx;
int64_t stride_dim;

if (split_point == 0) {
// No common prefix - create indices for entire i32_view shape
split_base_idx = IdxConst(0, b, loc);
for (size_t i = 0; i < i32_view_shape.size(); ++i) {
store_indices.push_back(IdxConst(0, b, loc));
}
stride_dim = 0;
} else {
// Common prefix exists - use it
store_indices.assign(indices.begin(), indices.begin() + split_point);
split_base_idx = store_indices.back();
// Add remaining indices to match i32_view rank
while (store_indices.size() < i32_view_shape.size()) {
store_indices.push_back(IdxConst(0, b, loc));
}
stride_dim = split_point - 1;
}
SmallVector<int32_t> strides(i32_view_shape.size(), 1);
strides[stride_dim] = num_i32_rows;
for (int64_t i = 0; i < num_i32_rows; ++i) {
SmallVector<int64_t> offsets(src_ty.getRank(), 0);
offsets.back() = i * packing * lane_dim;
Value slice = b.create<vector::ExtractStridedSliceOp>(
src_vec, offsets, slice_sizes, unit_strides);

auto i_chunk_ty =
VectorType::get(cast<VectorType>(slice.getType()).getShape(),
b.getIntegerType(bitwidth));
auto i32_chunk_ty =
VectorType::get(cast<VectorType>(slice.getType()).getShape(), i32_type);
Value packed_chunk;
if (packing > 1) {
Value acc = b.create<arith::ExtUIOp>(
i32_chunk_ty, b.create<arith::BitcastOp>(i_chunk_ty, slice));
for (int64_t p = 1; p < packing; ++p) {
offsets.back() = (i * packing + p) * lane_dim;
slice = b.create<vector::ExtractStridedSliceOp>(
src_vec, offsets, slice_sizes, unit_strides);
Value sj_i32 = b.create<arith::ExtUIOp>(
i32_chunk_ty, b.create<arith::BitcastOp>(i_chunk_ty, slice));
Value sh = I32Const(p * bitwidth, i32_chunk_ty.getShape(), b, loc);
acc = b.create<arith::OrIOp>(acc, b.create<arith::ShLIOp>(sj_i32, sh));
}
packed_chunk = acc;
} else {
packed_chunk = b.create<arith::BitcastOp>(i32_chunk_ty, slice);
}

auto packed_shape = cast<VectorType>(packed_chunk.getType()).getShape();
Value chunk_to_store = packed_chunk;
if (i32_view_shape.size() > packed_shape.size()) {
SmallVector<int64_t> reshape_vec_shape(
i32_view_shape.size() - packed_shape.size(), 1);
reshape_vec_shape.append(packed_shape.begin(), packed_shape.end());
auto reshape_type = VectorType::get(reshape_vec_shape, i32_type);
chunk_to_store = b.create<tpu::ReshapeOp>(reshape_type, packed_chunk);
}
store_indices[stride_dim] =
b.create<arith::AddIOp>(split_base_idx, IdxConst(i, b, loc));

b.create<tpu::StridedStoreOp>(chunk_to_store, i32_view, store_indices,
strides);
}

store_op->erase();
shape_cast_op->erase();
}

struct RhsTraversalResult {
tpu::TransposeOp transpose_op = nullptr;
vector::ExtractStridedSliceOp slice_op = nullptr;
Expand Down Expand Up @@ -180,7 +366,13 @@ tryFuseRhsTranspose(tpu::MatmulOp op, ImplicitLocOpBuilder& builder) {
struct PreCanonicalizationOptimizationPass
: impl::PreCanonicalizationOptimizationPassBase<
PreCanonicalizationOptimizationPass> {
PreCanonicalizationOptimizationPass(int hardware_generation,
std::array<int64_t, 2> target_shape)
: hardware_generation_(hardware_generation),
target_shape_(target_shape) {}

void runOnOperation() override {
// Calculate target shape from pass parameters
getOperation().walk([&](tpu::MatmulOp op) {
// We only attempt this fusion if dimension numbers are present.
if (!op.getDimensionNumbers().has_value()) {
Expand All @@ -196,14 +388,31 @@ struct PreCanonicalizationOptimizationPass
op.setDimensionNumbersAttr(new_dnums);
}
});

// Apply store canonicalization
getOperation().walk([&](vector::StoreOp op) {
CanonicalizeStore(hardware_generation_, target_shape_,
*op.getOperation());
});

getOperation().walk([&](tpu::VectorStoreOp op) {
CanonicalizeStore(hardware_generation_, target_shape_,
*op.getOperation());
});
}

private:
int64_t hardware_generation_;
std::array<int64_t, 2> target_shape_;
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
createPreCanonicalizationOptimizationPass() {
return std::make_unique<PreCanonicalizationOptimizationPass>();
createPreCanonicalizationOptimizationPass(int hardware_generation,
std::array<int64_t, 2> target_shape) {
return std::make_unique<PreCanonicalizationOptimizationPass>(
hardware_generation, target_shape);
}

} // namespace mlir::tpu
} // namespace mlir::tpu
32 changes: 32 additions & 0 deletions jaxlib/mosaic/dialect/tpu/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,38 @@ FailureOr<SmallVector<int>> computeSqueezedDimsChecked(
return squeezed;
}

std::optional<std::pair<int64_t, int64_t>> findSplitPoint(
ArrayRef<int64_t> src_shape, ArrayRef<int64_t> tgt_shape) {
int64_t s = 0, t = 0;
while (s < src_shape.size() && src_shape[s] == 1) {
++s;
}
while (t < tgt_shape.size() && tgt_shape[t] == 1) {
++t;
}

int64_t s_prefix_end = s, t_prefix_end = t;
while (s_prefix_end < src_shape.size() && t_prefix_end < tgt_shape.size() &&
src_shape[s_prefix_end] == tgt_shape[t_prefix_end]) {
++s_prefix_end;
++t_prefix_end;
}

if (t_prefix_end != tgt_shape.size() - 1) {
return std::nullopt;
}
int64_t src_prod = 1;
for (int64_t i = s_prefix_end; i < src_shape.size(); ++i) {
src_prod *= src_shape[i];
}

if (tgt_shape.back() != src_prod) {
return std::nullopt;
}
src_prod /= src_shape.back();
return std::make_pair(s_prefix_end, src_prod);
}

std::optional<std::pair<bool, bool>> isTransposedMatmul(
DotDimensionNumbersAttr dim_numbers) {
auto lhs_contracting_dims = dim_numbers.getLhsContractingDims();
Expand Down
Loading
Loading