|
90 | 90 | #include "xla/python/pjrt_ifrt/pjrt_topology.h"
|
91 | 91 | #include "xla/python/pjrt_ifrt/pjrt_tuple.h"
|
92 | 92 |
|
93 |
| -#include "triton/Dialect/Triton/IR/Dialect.h" |
94 | 93 | #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
| 94 | +#include "triton/Dialect/Triton/IR/Dialect.h" |
95 | 95 |
|
96 | 96 | using namespace mlir;
|
97 | 97 | using namespace llvm;
|
@@ -325,6 +325,40 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client,
|
325 | 325 | client->LookupAddressableDevice(PjRtLocalDeviceId(device_id)));
|
326 | 326 | }
|
327 | 327 |
|
| 328 | +// To keep in sync with JLAllocatorStats in src/XLA.jl |
| 329 | +struct JLAllocatorStats { |
| 330 | + int64_t num_allocs; |
| 331 | + int64_t bytes_in_use; |
| 332 | + int64_t peak_bytes_in_use; |
| 333 | + int64_t largest_alloc_size; |
| 334 | + int64_t bytes_limit; |
| 335 | + int64_t bytes_reserved; |
| 336 | + int64_t peak_bytes_reserved; |
| 337 | + int64_t bytes_reservable_limit; |
| 338 | + int64_t largest_free_block_bytes; |
| 339 | + int64_t pool_bytes; |
| 340 | + int64_t peak_pool_bytes; |
| 341 | +}; |
| 342 | + |
| 343 | +extern "C" void PjRtDeviceGetAllocatorStats(PjRtDevice *device, |
| 344 | + JLAllocatorStats *jlstats) { |
| 345 | + auto stats = MyValueOrThrow(device->GetAllocatorStats()); |
| 346 | + int64_t optnull = std::numeric_limits<int64_t>::min(); |
| 347 | + |
| 348 | + jlstats->num_allocs = stats.num_allocs; |
| 349 | + jlstats->bytes_in_use = stats.bytes_in_use; |
| 350 | + jlstats->peak_bytes_in_use = stats.peak_bytes_in_use; |
| 351 | + jlstats->largest_alloc_size = stats.largest_alloc_size; |
| 352 | + jlstats->bytes_limit = stats.bytes_limit.value_or(optnull); |
| 353 | + jlstats->bytes_reserved = stats.bytes_reserved; |
| 354 | + jlstats->peak_bytes_reserved = stats.peak_bytes_reserved; |
| 355 | + jlstats->bytes_reservable_limit = |
| 356 | + stats.bytes_reservable_limit.value_or(optnull); |
| 357 | + jlstats->largest_free_block_bytes = stats.largest_free_block_bytes; |
| 358 | + jlstats->pool_bytes = stats.pool_bytes.value_or(optnull); |
| 359 | + jlstats->peak_pool_bytes = stats.peak_pool_bytes.value_or(optnull); |
| 360 | +} |
| 361 | + |
328 | 362 | extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; }
|
329 | 363 |
|
330 | 364 | extern "C" PjRtDevice *BufferToDevice(PjRtBuffer *Buffer) {
|
@@ -443,7 +477,7 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
|
443 | 477 | if (ReactantThrowError) {
|
444 | 478 | llvm::errs() << lmod << "\n";
|
445 | 479 | ReactantThrowError(err_str.c_str());
|
446 |
| - return wrap((mlir::ModuleOp)nullptr); |
| 480 | + return wrap((mlir::ModuleOp) nullptr); |
447 | 481 | }
|
448 | 482 | }
|
449 | 483 | mlir::MLIRContext &context = *unwrap(cctx);
|
@@ -642,8 +676,8 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
|
642 | 676 |
|
643 | 677 | if (auto func = dyn_cast<FunctionOpInterface>(op.getOperation())) {
|
644 | 678 | if (func.isExternal()) {
|
645 |
| - shouldRemove = true; |
646 |
| - return success(); |
| 679 | + shouldRemove = true; |
| 680 | + return success(); |
647 | 681 | }
|
648 | 682 | }
|
649 | 683 |
|
@@ -678,13 +712,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
|
678 | 712 | }
|
679 | 713 |
|
680 | 714 | bool shouldRemove = false;
|
681 |
| - if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) { |
| 715 | + if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, |
| 716 | + shouldRemove))) { |
682 | 717 | assert(0 && "failed to update all uses");
|
683 | 718 | }
|
684 | 719 | if (shouldRemove)
|
685 |
| - op.erase(); |
| 720 | + op.erase(); |
686 | 721 | else
|
687 |
| - SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); |
| 722 | + SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); |
688 | 723 | }
|
689 | 724 | prevMod.getBody()->getOperations().splice(
|
690 | 725 | prevMod.getBody()->getOperations().end(),
|
|
0 commit comments