|
31 | 31 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
32 | 32 | #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
33 | 33 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
| 34 | +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
34 | 35 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
35 | 36 | #include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
36 | 37 | #include "mlir/Dialect/GPU/Transforms/Passes.h"
|
@@ -195,6 +196,75 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
195 | 196 | }
|
196 | 197 | };
|
197 | 198 |
|
| 199 | +/// Based on |
| 200 | +/// mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp#AssertOpToAssertfailLowering |
| 201 | +/// Lowering of cf.assert into a conditional llvm.intr.trap plus gpu.printf with |
| 202 | +/// the metadata (filename, fileline, assert msg). |
| 203 | +struct AssertOpToBuiltinTrapLowering |
| 204 | + : public ConvertOpToLLVMPattern<cf::AssertOp> { |
| 205 | + using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern; |
| 206 | + |
| 207 | + LogicalResult |
| 208 | + matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor, |
| 209 | + ConversionPatternRewriter &rewriter) const override { |
| 210 | + Location loc = assertOp.getLoc(); |
| 211 | + |
| 212 | + // Split blocks and insert conditional branch. |
| 213 | + // ^before: |
| 214 | + // ... |
| 215 | + // cf.cond_br %condition, ^after, ^assert |
| 216 | + // ^assert: |
| 217 | + // cf.assert |
| 218 | + // cf.br ^after |
| 219 | + // ^after: |
| 220 | + // ... |
| 221 | + Block *beforeBlock = assertOp->getBlock(); |
| 222 | + Block *assertBlock = |
| 223 | + rewriter.splitBlock(beforeBlock, assertOp->getIterator()); |
| 224 | + Block *afterBlock = |
| 225 | + rewriter.splitBlock(assertBlock, ++assertOp->getIterator()); |
| 226 | + rewriter.setInsertionPointToEnd(beforeBlock); |
| 227 | + rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock, |
| 228 | + assertBlock); |
| 229 | + rewriter.setInsertionPointToEnd(assertBlock); |
| 230 | + rewriter.create<cf::BranchOp>(loc, afterBlock); |
| 231 | + |
| 232 | + // Continue cf.assert lowering. |
| 233 | + rewriter.setInsertionPoint(assertOp); |
| 234 | + |
| 235 | + // Populate file name, file number and function name from the location of |
| 236 | + // the AssertOp. |
| 237 | + StringRef fileName = "(unknown)"; |
| 238 | + StringRef funcName = "(unknown)"; |
| 239 | + int32_t fileLine = 0; |
| 240 | + if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) { |
| 241 | + fileName = fileLineColLoc.getFilename().strref(); |
| 242 | + fileLine = fileLineColLoc.getStartLine(); |
| 243 | + } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) { |
| 244 | + funcName = nameLoc.getName().strref(); |
| 245 | + if (auto fileLineColLoc = |
| 246 | + dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) { |
| 247 | + fileName = fileLineColLoc.getFilename().strref(); |
| 248 | + fileLine = fileLineColLoc.getStartLine(); |
| 249 | + } |
| 250 | + } |
| 251 | + |
| 252 | + Value assertLine = |
| 253 | + rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), fileLine); |
| 254 | + // interpolate the fmt str AOT because current gpu.printf lowering doesn't |
| 255 | + // handle %s |
| 256 | + llvm::Twine fmtStr = fileName + ":%u: " + funcName + |
| 257 | + " Device-side assertion `" + assertOp.getMsg() + |
| 258 | + "' failed.\n"; |
| 259 | + rewriter.create<gpu::PrintfOp>(assertOp.getLoc(), |
| 260 | + rewriter.getStringAttr(fmtStr), |
| 261 | + ValueRange{assertLine}); |
| 262 | + rewriter.replaceOpWithNewOp<LLVM::Trap>(assertOp); |
| 263 | + |
| 264 | + return success(); |
| 265 | + } |
| 266 | +}; |
| 267 | + |
198 | 268 | /// Import the GPU Ops to ROCDL Patterns.
|
199 | 269 | #include "GPUToROCDL.cpp.inc"
|
200 | 270 |
|
@@ -297,7 +367,7 @@ struct LowerGpuOpsToROCDLOpsPass
|
297 | 367 | populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
|
298 | 368 | populateMathToLLVMConversionPatterns(converter, llvmPatterns);
|
299 | 369 | cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
|
300 |
| - cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns); |
| 370 | + llvmPatterns.add<AssertOpToBuiltinTrapLowering>(converter); |
301 | 371 | populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
|
302 | 372 | populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
|
303 | 373 | populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
|
|
0 commit comments