Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,8 @@ def _launch(
c(profiler_start, index),
lowering_semantics,
)
if lowering_semantics == LoweringSemantics.Warpgroup:
prof_smem = dialect.with_transforms(prof_smem, ir.ArrayAttr.get([]))
prof = profiler.OnDeviceProfiler(
profiler_spec, prof_smem, maybe_prof_buffer
)
Expand Down
7 changes: 5 additions & 2 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,6 @@ def _mgpu_slice_smem_op_lowering_rule(
) -> Sequence[ir.Value]:
del ctx
sliced_ref = _slice_smem(op.result.type, op.offset)

memref_ty = ir.MemRefType(sliced_ref.type)
if (
memref_ty.element_type == ir.Type.parse("!mosaic_gpu.barrier")
Expand Down Expand Up @@ -1380,7 +1379,11 @@ def _memref_subview_op_lowering_rule(
del ctx

in_transforms = inference_utils.in_transforms(op)[0]
out_transforms = inference_utils.out_transforms(op)[0]
if inference_utils.is_transformable_smem_memref(op.result):
out_transforms = inference_utils.out_transforms(op)[0]
else:
# This can happen for e.g. memref of rank 0.
out_transforms = ir.ArrayAttr.get([])

if in_transforms != out_transforms:
raise NotImplementedError(
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/mosaic/gpu/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def is_transformable_smem_memref(v: ir.Value) -> bool:
# barriers have no business being transformed
and v.type.element_type != barrier_ty # pylint: disable=attribute-error
and utils.is_smem_ref(v)
and v.type.rank != 0
)


Expand Down
35 changes: 21 additions & 14 deletions jax/experimental/mosaic/gpu/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import gpu
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import scf
import numpy as np

from .utils import * # noqa: F403
Expand Down Expand Up @@ -315,31 +316,30 @@ def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Va
self.entries_per_wg,
),
)
self.smem_buffer_ptr = memref_ptr(self.smem_buffer, memory_space=3)
self.gmem_buffer = gmem_buffer
self.is_profiling_thread = arith.cmpi(
arith.CmpIPredicate.eq,
arith.remui(thread_idx(), c(WARPGROUP_SIZE, i32)),
c(0, i32),
)
# Hopefully mem2reg will remove the allocation.
self.offset = memref.alloca(ir.MemRefType.get((), i32), [], [])
memref.store(c(0, i32), self.offset, [])
self.offset = memref.alloca(ir.MemRefType.get((), index), [], [])
memref.store(c(0, index), self.offset, [])

@contextlib.contextmanager
def record(self, name: str):
i32 = ir.IntegerType.get_signless(32)
index = ir.IndexType.get()
name_id = self.spec.intern_name(name)
def store(modifier):
# smem_buffer[offset] = modifier | name_id
# smem_buffer[offset + 1] = %clock
# offset += 2
offset = memref.load(self.offset, [])
base_ref = memref_slice(self.smem_buffer, offset)
base_ptr = memref_ptr(base_ref, memory_space=3)
i64 = ir.IntegerType.get_signless(64)
base_addr = arith.addi(
llvm.ptrtoint(i64, self.smem_buffer_ptr),
arith.extui(i64, arith.muli(offset, c(4, i32))),
)
base_addr = llvm.ptrtoint(i64, base_ptr)
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[self.is_profiling_thread, base_addr, c(modifier | name_id, i32)],
Expand All @@ -349,7 +349,7 @@ def store(modifier):
"b,l,r",
has_side_effects=True,
)
new_offset = arith.addi(offset, c(2, i32))
new_offset = arith.addi(offset, c(2, index))
memref.store(new_offset, self.offset, [])
store(ProfilerSpec.ENTER)
yield
Expand Down Expand Up @@ -379,11 +379,18 @@ def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]):
with when(self.is_profiling_thread):
memref.store(self.start, wg_gmem_buffer, [c(0, index)])
memref.store(smid(), wg_gmem_buffer, [c(1, index)])
num_traces = memref.load(self.offset, [])
num_traces = arith.index_cast(i32, memref.load(self.offset, []))
memref.store(num_traces, wg_gmem_buffer, [c(2, index)])
traces = vector.load(
ir.VectorType.get((self.entries_per_wg - 3,), i32),
self.smem_buffer,
[c(0, index)],
for_op = scf.ForOp(
c(0, index),
c(self.entries_per_wg - 3, index),
c(1, index),
)
vector.store(traces, wg_gmem_buffer, [c(3, index)])
with ir.InsertionPoint(for_op.body):
x = memref.load(self.smem_buffer, [for_op.induction_variable])
memref.store(
x,
wg_gmem_buffer,
[arith.addi(for_op.induction_variable, c(3, index))],
)
scf.yield_([])
17 changes: 8 additions & 9 deletions jax/experimental/mosaic/gpu/transform_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from collections.abc import Callable
from functools import partial
import math
from typing import cast

from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -106,9 +105,11 @@ def _resolve_transforms(


def _transforms_from_uses(op: ir.OpView) -> ir.ArrayAttr | None:
transforms = None
if not inference_utils.is_transformable_smem_memref(op.result):
return None

for result_use in cast(ir.OpResult, op.result).uses:
transforms = None
for result_use in ir.OpResult(op.result).uses:
consumer = result_use.owner
op_user = consumer.operands[result_use.operand_number]
user_transforms = inference_utils.in_transforms_for_operand(
Expand Down Expand Up @@ -314,7 +315,7 @@ def _infer_memref_subview_transforms(
in_transforms = inference_utils.value_transforms(op.source)
transforms = _resolve_transforms(transforms, in_transforms)

if transforms is None:
if not transforms:
return None

# Here, we have some transforms to propagate one way or the other. For now,
Expand Down Expand Up @@ -407,14 +408,12 @@ def _infer_memref_transpose_transforms(
return [ir.ArrayAttr.get(in_transforms)], [out_transforms]


# `memref.load` is used to load barrier phases---the rule needn't do anything
# interesting, but we need to have it in order to avoid crashing on it.
@partial(_add_transform_inference_rule, memref.LoadOp)
def _infer_memref_load_transforms(op: memref.LoadOp) -> OptionalTransforms:
if not ir.MemRefType(op.memref.type).shape:
# memref.load returns a scalar, so there is nothing interesting to do here.
in_transforms = inference_utils.value_transforms(op.memref)
if in_transforms is None:
return None
raise NotImplementedError("Non-scalar memref.load transforms")
return [in_transforms], []


@partial(_add_transform_inference_rule, memref.CastOp)
Expand Down
30 changes: 30 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4566,6 +4566,36 @@ def create_kernel():
x[slicing].reshape(sub_shape),
)

def test_profiler(self):
def body(ctx, input, result, scratch):
del scratch
with ctx.named_region("load"):
reg = vector_load(input)
with ctx.named_region("store"):
vector_store(reg, result)

dtype = jnp.bfloat16
shape = (128, 128)
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
with tempfile.TemporaryDirectory() as tmpdir:
kernel = mgpu.as_gpu_kernel(
body,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape),
out_shape=jax_shape,
smem_scratch_shape=[],
prof_spec=profiler.ProfilerSpec(1024, dump_path=tmpdir),
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
)
param = self.prng.uniform(-1, 1, shape).astype(dtype)
self.assertArraysEqual(kernel(param), param)
[name] = os.listdir(tmpdir)
with open(os.path.join(tmpdir, name)) as f:
data = f.read()
self.assertEqual(data.count('"name": "load"'), 2)
self.assertEqual(data.count('"name": "store"'), 2)


class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase):

Expand Down
2 changes: 0 additions & 2 deletions tests/pallas/mosaic_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,8 +1719,6 @@ def kernel(o_ref):
np.testing.assert_array_equal(kernel(), x)

def test_profiler(self):
self.skip_if_wg_semantics() # Transform inference not implemented.

def kernel(x_ref, o_ref):
with jax.named_scope("add"):
with jax.named_scope("load"):
Expand Down
Loading