Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
463b7ed
WIP bring in TransferManager
renerocksai Feb 21, 2025
ce99e22
WIP tests pass
renerocksai Feb 21, 2025
7c2c19b
WIP in-progress stuff
renerocksai Feb 26, 2025
0a72529
more wip
renerocksai Feb 26, 2025
8a853da
wip compiles
renerocksai Feb 26, 2025
3685491
runs on cpu
renerocksai Feb 26, 2025
10ae2fd
wip 1 buffer
renerocksai Feb 26, 2025
4c98a10
wip TransferManager
renerocksai Feb 26, 2025
c1a1363
wip fix
renerocksai Feb 26, 2025
04864f0
wip TransferManager.transferDataSlices()
renerocksai Feb 26, 2025
1ae9926
wip
renerocksai Feb 26, 2025
a94351a
platform.TransferManager.progress()
renerocksai Feb 27, 2025
7da6265
wip, nearly there
renerocksai Mar 2, 2025
6a31c2b
wip example with many largish buffers
renerocksai Mar 2, 2025
bc80be2
wip xferman example with mmapped file
renerocksai Mar 2, 2025
9c77c4f
wip checking for overlaps
renerocksai Mar 2, 2025
8f66015
wip investigating
renerocksai Mar 3, 2025
3d054ba
wip repro transfermanager bug
renerocksai Mar 4, 2025
e366134
rebase on master
renerocksai Mar 4, 2025
8c2dbc0
xferman/mmap* : auto-select memory kind based on platform
renerocksai Mar 4, 2025
b799acc
unfck cuda pjrt plugin build
renerocksai Mar 4, 2025
e1ed32d
switch xferman:mmap to bf16, add more logging
steeve Mar 4, 2025
ab704e1
wip: log device buffer sizes
renerocksai Mar 4, 2025
b77181b
wip log shape_specs
renerocksai Mar 4, 2025
8b7e375
wip: dump pjrt buffer dims, honor metadata offset of safetensors file…
renerocksai Mar 4, 2025
58fba19
wip fix dims ptr bug
renerocksai Mar 4, 2025
03b8b0c
wip removed extensive logging, fixed examples/loader
renerocksai Mar 5, 2025
543a595
fixed loader/xferman segfault: wait for load events before deinit buf…
renerocksai Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,379 changes: 67 additions & 3,312 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

3,367 changes: 61 additions & 3,306 deletions examples/MODULE.bazel.lock

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions examples/loader/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,26 @@ zig_cc_binary(
"@zml//stdx",
"@zml//zml",
],
args = [
"$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
],
data = [
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
],
)

zig_cc_binary(
name = "safetensors-xferman",
main = "xferman.zig",
deps = [
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
args = [
"$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
],
data = [
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
],
)
10 changes: 8 additions & 2 deletions examples/loader/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,19 @@ pub fn asyncMain() !void {
std.debug.print("\nStart to read {d} buffers from store..\n", .{buffer_store.buffers.count()});

while (it.next()) |entry| : (i += 1) {
const host_buffer = entry.value_ptr.*;
const buffer_entry = entry.value_ptr.*;
const host_buffer = zml.HostBuffer.fromBytes(buffer_entry.shape, buffer_entry.data);
total_bytes += host_buffer.data.len;
std.debug.print("Buffer: {s} ({any} / {any})\n", .{ entry.key_ptr.*, i + 1, buffer_store.buffers.count() });
buffers[i] = try zml.Buffer.from(platform, host_buffer);
}

const stop = timer.read();

// Now print after taking the timing
it = buffer_store.buffers.iterator();
while (it.next()) |entry| : (i += 1) {
std.debug.print("Buffer: {s} ({any} / {any})\n", .{ entry.key_ptr.*, i + 1, buffer_store.buffers.count() });
}
const time_in_s = stdx.math.divFloat(f64, stop, std.time.ns_per_s);
const mbs = stdx.math.divFloat(f64, total_bytes, 1024 * 1024);

Expand Down
122 changes: 122 additions & 0 deletions examples/loader/xferman.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml");
const asynk = @import("async");

const asyncc = asynk.asyncc;

pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}

fn checkSlicesForOverlaps(alloc: std.mem.Allocator, entire_buffer: []const u8, subslices: [][]const u8) !void {
std.log.info("Checking for overlaps...", .{});
var bytefield = try alloc.alloc(bool, entire_buffer.len);
defer alloc.free(bytefield);

for (0..bytefield.len) |i| {
bytefield[i] = false;
}

for (subslices, 0..) |sub, idx| {
const start: usize = @intFromPtr(sub.ptr) - @intFromPtr(entire_buffer.ptr);
if (start + sub.len > entire_buffer.len) {
std.log.err("Error: subslice {d} reaches outside of mmapped file: file(0..{d}), subslice({d}..{d})", .{
idx,
entire_buffer.len,
start,
start + sub.len,
});
return error.Overflow;
}

for (start..start + sub.len) |index| {
if (bytefield[index] == true) {
return error.Overlap;
}
bytefield[index] = true;
}
}
std.log.info("Checking for overlaps...done", .{});
}

pub fn asyncMain() !void {
// Short lived allocations
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();

var args = std.process.args();
// Skip executable path
_ = args.next().?;

const file = if (args.next()) |path| blk: {
std.debug.print("File path: {s}\n", .{path});
break :blk path;
} else {
std.debug.print("Missing file path argument\n", .{});
std.debug.print("Try: bazel run -c opt //loader:safetensors -- /path/to/mymodel.safetensors or /path/to/model.safetensors.index.json \n", .{});
std.process.exit(0);
};

var context = try zml.Context.init();
defer context.deinit();

const platform = context.autoPlatform(.{});
context.printAvailablePlatforms(platform);

var buffer_store = try zml.aio.safetensors.open(allocator, file);
defer buffer_store.deinit();

var total_bytes: usize = 0;
var timer = try std.time.Timer.start();

var bit = buffer_store.buffers.iterator();
var slice_list = std.ArrayList([]const u8).init(allocator);
defer slice_list.deinit();
while (bit.next()) |item| {
const buffer_entry = item.value_ptr;
try slice_list.append(buffer_entry.data);
}

// try checkSlicesForOverlaps(allocator, buffer_store.files[0].data, slice_list.items);

const memory_kind: zml.pjrt.Memory.Kind = switch (platform.target) {
.cpu => .unpinned_host,
else => .device,
};
const events = try buffer_store.starTransferToDevice(platform, memory_kind);
const DO_AWAIT_EVENTS = false;
const prefix = if (DO_AWAIT_EVENTS) "A" else "NOT a";
std.debug.print("{s} waiting {d} events\n", .{ prefix, events.len });
if (DO_AWAIT_EVENTS) {
for (events) |event| {
while (event.isReady(platform.pjrt_api) == false) {
// spin
}
}
}
const stop = timer.read();

var it = buffer_store.buffers.iterator();
var i: usize = 0;
std.debug.print("\nStart to read {d} buffers from store..\n", .{buffer_store.buffers.count()});

while (it.next()) |entry| : (i += 1) {
total_bytes += entry.value_ptr.*.data.len;
std.debug.print("Buffer: {s} ({any} / {any})\n", .{ entry.key_ptr.*, i + 1, buffer_store.buffers.count() });
}

const time_in_s = stdx.math.divFloat(f64, stop, std.time.ns_per_s);
const mbs = stdx.math.divFloat(f64, total_bytes, 1024 * 1024);

std.debug.print("\nLoading speed: {d:.2} MB/s\n\n", .{mbs / time_in_s});

if (!DO_AWAIT_EVENTS) {
for (events) |event| {
while (event.isReady(platform.pjrt_api) == false) {
// spin
}
}
}
}
16 changes: 16 additions & 0 deletions examples/mnist/mnist.zig
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ pub fn asyncMain() !void {
var context = try zml.Context.init();
defer context.deinit();

// log.info("Sleeping for 15 seconds (attach debugger)", .{});
// for (0..15) |i| {
// std.time.sleep(1 * std.time.ns_per_s);
// std.debug.print("{d} .. ", .{i});
// }
// std.debug.print("\n", .{});

// log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});

// // Auto-select platform
Expand All @@ -70,6 +77,15 @@ pub fn asyncMain() !void {
defer buffer_store.deinit();

const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);
const events = try buffer_store.starTransferToDevice(platform, .unpinned_host);

// just to make sure we have buffers
for (events) |event| {
log.info("Awaiting event {}", .{event});
while (event.isReady(platform.pjrt_api) == false) {
std.time.sleep(500 * std.time.ns_per_ms);
}
}
log.info("Reading model shapes from PyTorch file {s}...", .{pt_model});

// Start compiling
Expand Down
54 changes: 54 additions & 0 deletions examples/xferman/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@

load("@zml//bazel:zig.bzl", "zig_cc_binary")

zig_cc_binary(
name = "xferman",
main = "main.zig",
deps = [
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
)

zig_cc_binary(
name = "manymany",
main = "manymany.zig",
deps = [
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
)

zig_cc_binary(
name = "mmap",
main = "mmap.zig",
deps = [
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
args = [
"$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
],
data = [
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
],
)

zig_cc_binary(
name = "mmap2",
main = "mmap2.zig",
deps = [
"@zml//async",
"@zml//stdx",
"@zml//zml",
],
args = [
"$(location @Meta-Llama-3.2-1B-Instruct//:model.safetensors)",
],
data = [
"@Meta-Llama-3.2-1B-Instruct//:model.safetensors",
],
)
123 changes: 123 additions & 0 deletions examples/xferman/main.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
const asynk = @import("async");
const clap = @import("clap");
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml");

const log = std.log.scoped(.xferman);

pub fn main() !void {
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
}

pub fn asyncMain() !void {
log.info(" Compiled with {}", .{@import("builtin").mode});

const allocator = std.heap.c_allocator;

var context = try zml.Context.init();
defer context.deinit();

// Our weights and bias to use
const weights = [_]f16{4} ** (16 * 1024);
const bias = [_]f16{4} ** (16 * 1024);
const input_shape = zml.Shape.init(.{(&weights).len}, .f16);

const platform = context.autoPlatform(.{});
const api = platform.pjrt_api;

const shapes: []const zml.Shape = &.{ input_shape, input_shape };

log.debug("input_shapes = {any}", .{shapes});
var manager = try zml.platform.TransferManager.init(
allocator,
platform,
.unpinned_host,
shapes,
);
defer manager.deinit();
const buffer_count = try manager.pjrt_transfer_manager.bufferCount(platform.pjrt_api);
log.debug("transfer manager has {d} buffers", .{buffer_count});

const weights_buffer = std.mem.sliceAsBytes(&weights);
const bias_buffer = std.mem.sliceAsBytes(&bias);

const start_time = std.time.nanoTimestamp();
var event_cycle_counter: usize = 0;

// transfer both slices in one call
if (true) {
const events = try manager.transferDataMany(&.{ weights_buffer, bias_buffer }, .{});
for (events) |event| {
while (!event.isReady(api)) : (event_cycle_counter += 1) {
// this is faster than event.awaitt()
}
}
}

// transfer both buffers individually, but using transferDataMany to check
// continuaton via opts: start_buffer_index, last_data_is_last_transfer
if (false) {
// first
{
// const event = try manager.transferDataSingle(0, weights_buffer, 0, true);
const events = try manager.transferDataMany(&.{weights_buffer}, .{
.last_data_is_last_transfer = false,
});

for (events) |event| {
while (!event.isReady(api)) : (event_cycle_counter += 1) {
// this is faster than event.awaitt()
}
}
}
// second
{
const events = try manager.transferDataMany(&.{bias_buffer}, .{
.start_buffer_index = 1,
.last_data_is_last_transfer = true, // true is default but we are explicit here
});

for (events) |event| {
while (!event.isReady(api)) : (event_cycle_counter += 1) {
// this is faster than event.awaitt()
}
}
}
}

// transfer all buffers as slices of one big buffer (as would be the case
// with an mmapped file)
if (false) {
var big_buf = try allocator.alloc(u8, weights_buffer.len + bias_buffer.len);
@memcpy(big_buf[0..weights_buffer.len], weights_buffer);
@memcpy(big_buf[weights_buffer.len..], bias_buffer);

const slice_specs: []const zml.platform.TransferManager.TransferDataSlicesSpec =
&.{
.{ .offset = 0, .len = weights_buffer.len },
.{ .offset = weights_buffer.len, .len = bias_buffer.len },
};
const events = try manager.transferDataSlices(big_buf, slice_specs);
_ = events; // we don't need them as we're going to query .progress()

var dt: i128 = undefined;
while (true) {
event_cycle_counter += 1;
dt = std.time.nanoTimestamp() - start_time;
const progress = try manager.progress();
log.debug("After {d} ns: {}", .{ dt, progress });
if (progress.transferred_buffers == progress.total_buffers) {
break;
}
}
}

const end_time = std.time.nanoTimestamp();
log.info("Transferred {d} buffers ({d} bytes) in {d} cycles = {d} ns", .{
buffer_count,
weights_buffer.len + bias_buffer.len,
event_cycle_counter,
end_time - start_time,
});
}
Loading
Loading