Skip to content

Commit fbde19a

Browse files
authored
[MLIR][LLVM] Change addressof builders to use opaque pointers (#69215)
This commit changes the builders of the `llvm.mlir.addressof` operations to no longer produce typed pointers. As a consequence, a GPU to NVVM pattern and the toy example LLVM lowerings had to be updated, as they still relied on typed pointers.
1 parent 8ddca6b commit fbde19a

File tree

5 files changed

+34
-37
lines changed

5 files changed

+34
-37
lines changed

mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
117117
/// * `i32 (i8*, ...)`
118118
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
119119
auto llvmI32Ty = IntegerType::get(context, 32);
120-
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
121-
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
120+
auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
121+
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
122122
/*isVarArg=*/true);
123123
return llvmFnType;
124124
}
@@ -162,9 +162,9 @@ class PrintOpLowering : public ConversionPattern {
162162
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
163163
builder.getIndexAttr(0));
164164
return builder.create<LLVM::GEPOp>(
165-
loc,
166-
LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
167-
globalPtr, ArrayRef<Value>({cst0, cst0}));
165+
loc, LLVM::LLVMPointerType::get(builder.getContext()),
166+
IntegerType::get(builder.getContext(), 8), globalPtr,
167+
ArrayRef<Value>({cst0, cst0}));
168168
}
169169
};
170170
} // namespace

mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ class PrintOpLowering : public ConversionPattern {
117117
/// * `i32 (i8*, ...)`
118118
static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) {
119119
auto llvmI32Ty = IntegerType::get(context, 32);
120-
auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
121-
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
120+
auto llvmPtrTy = LLVM::LLVMPointerType::get(context);
121+
auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy,
122122
/*isVarArg=*/true);
123123
return llvmFnType;
124124
}
@@ -162,9 +162,9 @@ class PrintOpLowering : public ConversionPattern {
162162
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
163163
builder.getIndexAttr(0));
164164
return builder.create<LLVM::GEPOp>(
165-
loc,
166-
LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
167-
globalPtr, ArrayRef<Value>({cst0, cst0}));
165+
loc, LLVM::LLVMPointerType::get(builder.getContext()),
166+
IntegerType::get(builder.getContext(), 8), globalPtr,
167+
ArrayRef<Value>({cst0, cst0}));
168168
}
169169
};
170170
} // namespace

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,15 +1071,15 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
10711071
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
10721072
[{
10731073
build($_builder, $_state,
1074-
LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()),
1074+
LLVM::LLVMPointerType::get($_builder.getContext(), global.getAddrSpace()),
10751075
global.getSymName());
10761076
$_state.addAttributes(attrs);
10771077
}]>,
10781078
OpBuilder<(ins "LLVMFuncOp":$func,
10791079
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
10801080
[{
10811081
build($_builder, $_state,
1082-
LLVM::LLVMPointerType::get(func.getFunctionType()), func.getName());
1082+
LLVM::LLVMPointerType::get($_builder.getContext()), func.getName());
10831083
$_state.addAttributes(attrs);
10841084
}]>
10851085
];

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -441,15 +441,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
441441
Location loc = gpuPrintfOp->getLoc();
442442

443443
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
444-
mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
444+
mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
445445

446446
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
447447
// This ensures that global constants and declarations are placed within
448448
// the device code, not the host code
449449
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
450450

451451
auto vprintfType =
452-
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
452+
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
453453
LLVM::LLVMFuncOp vprintfDecl =
454454
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
455455

@@ -473,7 +473,7 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
473473
// Get a pointer to the format string's first element
474474
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
475475
Value stringStart = rewriter.create<LLVM::GEPOp>(
476-
loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
476+
loc, ptrType, ptrType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
477477
SmallVector<Type> types;
478478
SmallVector<Value> args;
479479
// Promote and pack the arguments into a stack allocation.
@@ -490,18 +490,17 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
490490
}
491491
Type structType =
492492
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
493-
Type structPtrType = LLVM::LLVMPointerType::get(structType);
494493
Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
495494
rewriter.getIndexAttr(1));
496-
Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
497-
/*alignment=*/0);
495+
Value tempAlloc =
496+
rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one,
497+
/*alignment=*/0);
498498
for (auto [index, arg] : llvm::enumerate(args)) {
499-
Value ptr = rewriter.create<LLVM::GEPOp>(
500-
loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
501-
ArrayRef<LLVM::GEPArg>{0, index});
499+
Value ptr =
500+
rewriter.create<LLVM::GEPOp>(loc, ptrType, arg.getType(), tempAlloc,
501+
ArrayRef<LLVM::GEPArg>{0, index});
502502
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
503503
}
504-
tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
505504
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
506505

507506
rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -542,34 +542,32 @@ gpu.module @test_module_28 {
542542
gpu.module @test_module_29 {
543543
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
544544
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
545-
// CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
545+
// CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32
546546

547547
// CHECK-LABEL: func @test_const_printf
548548
gpu.func @test_const_printf() {
549-
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
550-
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
549+
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
550+
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr
551551
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
552-
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
553-
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
554-
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
552+
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
553+
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
555554
gpu.printf "Hello, world\n"
556555
gpu.return
557556
}
558557

559558
// CHECK-LABEL: func @test_printf
560559
// CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
561560
gpu.func @test_printf(%arg0: i32, %arg1: f32) {
562-
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
563-
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
561+
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr
562+
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr) -> !llvm.ptr
564563
// CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
565564
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
566-
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
567-
// CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
568-
// CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
569-
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
570-
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
571-
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
572-
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
565+
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr
566+
// CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr) -> !llvm.ptr
567+
// CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : i32, !llvm.ptr
568+
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr) -> !llvm.ptr
569+
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : f64, !llvm.ptr
570+
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
573571
gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
574572
gpu.return
575573
}

0 commit comments

Comments
 (0)