Skip to content

Commit 79c96cb

Browse files
committed
[mlir][rocdl] Add AMDGPU-specific cf.assert lowering
1 parent 8865986 commit 79c96cb

File tree

2 files changed

+110
-1
lines changed

2 files changed

+110
-1
lines changed

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
3232
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
3333
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
34+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
3435
#include "mlir/Dialect/Func/IR/FuncOps.h"
3536
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3637
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -195,6 +196,75 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
195196
}
196197
};
197198

199+
/// Based on
200+
/// mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp#AssertOpToAssertfailLowering
201+
/// Lowering of cf.assert into a conditional llvm.intr.trap plus gpu.printf with
202+
/// the metadata (filename, fileline, assert msg).
203+
struct AssertOpToBuiltinTrapLowering
204+
: public ConvertOpToLLVMPattern<cf::AssertOp> {
205+
using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
206+
207+
LogicalResult
208+
matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
209+
ConversionPatternRewriter &rewriter) const override {
210+
Location loc = assertOp.getLoc();
211+
212+
// Split blocks and insert conditional branch.
213+
// ^before:
214+
// ...
215+
// cf.cond_br %condition, ^after, ^assert
216+
// ^assert:
217+
// cf.assert
218+
// cf.br ^after
219+
// ^after:
220+
// ...
221+
Block *beforeBlock = assertOp->getBlock();
222+
Block *assertBlock =
223+
rewriter.splitBlock(beforeBlock, assertOp->getIterator());
224+
Block *afterBlock =
225+
rewriter.splitBlock(assertBlock, ++assertOp->getIterator());
226+
rewriter.setInsertionPointToEnd(beforeBlock);
227+
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
228+
assertBlock);
229+
rewriter.setInsertionPointToEnd(assertBlock);
230+
rewriter.create<cf::BranchOp>(loc, afterBlock);
231+
232+
// Continue cf.assert lowering.
233+
rewriter.setInsertionPoint(assertOp);
234+
235+
// Populate file name, file number and function name from the location of
236+
// the AssertOp.
237+
StringRef fileName = "(unknown)";
238+
StringRef funcName = "(unknown)";
239+
int32_t fileLine = 0;
240+
if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
241+
fileName = fileLineColLoc.getFilename().strref();
242+
fileLine = fileLineColLoc.getStartLine();
243+
} else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
244+
funcName = nameLoc.getName().strref();
245+
if (auto fileLineColLoc =
246+
dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
247+
fileName = fileLineColLoc.getFilename().strref();
248+
fileLine = fileLineColLoc.getStartLine();
249+
}
250+
}
251+
252+
Value assertLine =
253+
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), fileLine);
254+
// Interpolate the fmt str AOT because current gpu.printf lowering doesn't
255+
// handle %s.
256+
llvm::Twine fmtStr = fileName + ":%u: " + funcName +
257+
" Device-side assertion `" + assertOp.getMsg() +
258+
"' failed.\n";
259+
rewriter.create<gpu::PrintfOp>(assertOp.getLoc(),
260+
rewriter.getStringAttr(fmtStr),
261+
ValueRange{assertLine});
262+
rewriter.replaceOpWithNewOp<LLVM::Trap>(assertOp);
263+
264+
return success();
265+
}
266+
};
267+
198268
/// Import the GPU Ops to ROCDL Patterns.
199269
#include "GPUToROCDL.cpp.inc"
200270

@@ -297,7 +367,7 @@ struct LowerGpuOpsToROCDLOpsPass
297367
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
298368
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
299369
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
300-
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
370+
llvmPatterns.add<AssertOpToBuiltinTrapLowering>(converter);
301371
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
302372
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
303373
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-rocdl{index-bitwidth=32 runtime=HIP}),rocdl-attach-target{chip=%chip})' \
3+
// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary \
4+
// RUN: | mlir-cpu-runner \
5+
// RUN: --shared-libs=%mlir_rocm_runtime \
6+
// RUN: --shared-libs=%mlir_runner_utils \
7+
// RUN: --entry-point-result=void 2>&1 \
8+
// RUN: | FileCheck %s
9+
10+
// CHECK-DAG: thread 0: print after passing assertion
11+
// CHECK-DAG: thread 1: print after passing assertion
12+
// CHECK-DAG: mlir/test/Integration/GPU/ROCM/assert.mlir:{{.*}}: (unknown) Device-side assertion `failing assertion' failed.
13+
// CHECK-DAG: mlir/test/Integration/GPU/ROCM/assert.mlir:{{.*}}: (unknown) Device-side assertion `failing assertion' failed.
14+
// CHECK-NOT: print after failing assertion
15+
16+
module attributes {gpu.container_module} {
17+
gpu.module @kernels {
18+
gpu.func @test_assert(%c0: i1, %c1: i1) kernel {
19+
%0 = gpu.thread_id x
20+
cf.assert %c1, "passing assertion"
21+
gpu.printf "thread %lld: print after passing assertion\n" %0 : index
22+
cf.assert %c0, "failing assertion"
23+
gpu.printf "thread %lld: print after failing assertion\n" %0 : index
24+
gpu.return
25+
}
26+
}
27+
28+
func.func @main() {
29+
%c2 = arith.constant 2 : index
30+
%c1 = arith.constant 1 : index
31+
%c0_i1 = arith.constant 0 : i1
32+
%c1_i1 = arith.constant 1 : i1
33+
gpu.launch_func @kernels::@test_assert
34+
blocks in (%c1, %c1, %c1)
35+
threads in (%c2, %c1, %c1)
36+
args(%c0_i1 : i1, %c1_i1 : i1)
37+
return
38+
}
39+
}

0 commit comments

Comments
 (0)