Skip to content

Commit 0e8a17a

Browse files
authored
XLA Allocator stats (#517)
* allocator * allocator2 * throw error when unsupported * single GetAllocatorStats call * format * fixup
1 parent c664414 commit 0e8a17a

File tree

3 files changed

+97
-7
lines changed

3 files changed

+97
-7
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@
9090
#include "xla/python/pjrt_ifrt/pjrt_topology.h"
9191
#include "xla/python/pjrt_ifrt/pjrt_tuple.h"
9292

93-
#include "triton/Dialect/Triton/IR/Dialect.h"
9493
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
94+
#include "triton/Dialect/Triton/IR/Dialect.h"
9595

9696
using namespace mlir;
9797
using namespace llvm;
@@ -325,6 +325,40 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client,
325325
client->LookupAddressableDevice(PjRtLocalDeviceId(device_id)));
326326
}
327327

328+
// To keep in sync with JLAllocatorStats in src/XLA.jl
329+
struct JLAllocatorStats {
330+
int64_t num_allocs;
331+
int64_t bytes_in_use;
332+
int64_t peak_bytes_in_use;
333+
int64_t largest_alloc_size;
334+
int64_t bytes_limit;
335+
int64_t bytes_reserved;
336+
int64_t peak_bytes_reserved;
337+
int64_t bytes_reservable_limit;
338+
int64_t largest_free_block_bytes;
339+
int64_t pool_bytes;
340+
int64_t peak_pool_bytes;
341+
};
342+
343+
extern "C" void PjRtDeviceGetAllocatorStats(PjRtDevice *device,
344+
JLAllocatorStats *jlstats) {
345+
auto stats = MyValueOrThrow(device->GetAllocatorStats());
346+
int64_t optnull = std::numeric_limits<int64_t>::min();
347+
348+
jlstats->num_allocs = stats.num_allocs;
349+
jlstats->bytes_in_use = stats.bytes_in_use;
350+
jlstats->peak_bytes_in_use = stats.peak_bytes_in_use;
351+
jlstats->largest_alloc_size = stats.largest_alloc_size;
352+
jlstats->bytes_limit = stats.bytes_limit.value_or(optnull);
353+
jlstats->bytes_reserved = stats.bytes_reserved;
354+
jlstats->peak_bytes_reserved = stats.peak_bytes_reserved;
355+
jlstats->bytes_reservable_limit =
356+
stats.bytes_reservable_limit.value_or(optnull);
357+
jlstats->largest_free_block_bytes = stats.largest_free_block_bytes;
358+
jlstats->pool_bytes = stats.pool_bytes.value_or(optnull);
359+
jlstats->peak_pool_bytes = stats.peak_pool_bytes.value_or(optnull);
360+
}
361+
328362
extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; }
329363

330364
extern "C" PjRtDevice *BufferToDevice(PjRtBuffer *Buffer) {
@@ -443,7 +477,7 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
443477
if (ReactantThrowError) {
444478
llvm::errs() << lmod << "\n";
445479
ReactantThrowError(err_str.c_str());
446-
return wrap((mlir::ModuleOp)nullptr);
480+
return wrap((mlir::ModuleOp) nullptr);
447481
}
448482
}
449483
mlir::MLIRContext &context = *unwrap(cctx);
@@ -642,8 +676,8 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
642676

643677
if (auto func = dyn_cast<FunctionOpInterface>(op.getOperation())) {
644678
if (func.isExternal()) {
645-
shouldRemove = true;
646-
return success();
679+
shouldRemove = true;
680+
return success();
647681
}
648682
}
649683

@@ -678,13 +712,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
678712
}
679713

680714
bool shouldRemove = false;
681-
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) {
715+
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID,
716+
shouldRemove))) {
682717
assert(0 && "failed to update all uses");
683718
}
684719
if (shouldRemove)
685-
op.erase();
720+
op.erase();
686721
else
687-
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
722+
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
688723
}
689724
prevMod.getBody()->getOperations().splice(
690725
prevMod.getBody()->getOperations().end(),

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ cc_library(
411411
"-Wl,-exported_symbol,_ClientProcessIndex",
412412
"-Wl,-exported_symbol,_ClientGetDevice",
413413
"-Wl,-exported_symbol,_ClientGetAddressableDevice",
414+
"-Wl,-exported_symbol,_PjRtDeviceGetAllocatorStats",
414415
"-Wl,-exported_symbol,_ExecutableFree",
415416
"-Wl,-exported_symbol,_BufferToDevice",
416417
"-Wl,-exported_symbol,_BufferToClient",

src/XLA.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,60 @@ function client(device::Device)
231231
end
232232
end
233233

234+
# To keep in sync with JLAllocatorStats in ReactantExtra/API.cpp
235+
struct JLAllocatorStats
236+
num_allocs::Int64
237+
bytes_in_use::Int64
238+
peak_bytes_in_use::Int64
239+
largest_alloc_size::Int64
240+
bytes_limit::Int64
241+
bytes_reserved::Int64
242+
peak_bytes_reserved::Int64
243+
bytes_reservable_limit::Int64
244+
largest_free_block_bytes::Int64
245+
pool_bytes::Int64
246+
peak_pool_bytes::Int64
247+
end
248+
249+
struct AllocatorStats
250+
num_allocs::Int64
251+
bytes_in_use::Int64
252+
peak_bytes_in_use::Int64
253+
largest_alloc_size::Int64
254+
bytes_limit::Union{Nothing,Int64}
255+
bytes_reserved::Int64
256+
peak_bytes_reserved::Int64
257+
bytes_reservable_limit::Union{Nothing,Int64}
258+
largest_free_block_bytes::Int64
259+
pool_bytes::Union{Nothing,Int64}
260+
peak_pool_bytes::Union{Nothing,Int64}
261+
end
262+
263+
function allocatorstats(
264+
device::Device=ClientGetDevice(default_backend[], default_device_idx[])
265+
)
266+
ref = Ref{JLAllocatorStats}()
267+
@ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats(
268+
device.device::Ptr{Cvoid}, ref::Ptr{Cvoid}
269+
)::Cvoid
270+
stats = ref[]
271+
272+
nullopt = typemin(Int64)
273+
return AllocatorStats(
274+
stats.num_allocs,
275+
stats.bytes_in_use,
276+
stats.peak_bytes_in_use,
277+
stats.largest_alloc_size,
278+
stats.bytes_limit == nullopt ? nothing : stats.bytes_limit,
279+
stats.bytes_reserved,
280+
stats.peak_bytes_reserved,
281+
stats.bytes_reservable_limit == nullopt ? nothing : stats.bytes_reservable_limit,
282+
stats.largest_free_block_bytes,
283+
stats.pool_bytes == nullopt ? nothing : stats.pool_bytes,
284+
stats.peak_pool_bytes == nullopt ? nothing : stats.peak_pool_bytes,
285+
)
286+
end
287+
234288
# https://github.com/openxla/xla/blob/4bfb5c82a427151d6fe5acad8ebe12cee403036a/xla/xla_data.proto#L29
235289
@inline primitive_type(::Type{Bool}) = 1
236290

0 commit comments

Comments
 (0)