Skip to content

Commit ab9e447

Browse files
authored
[MLIR][NVVM] Add support for mapa MLIR Ops (#124514)
Adds `mapa` and `mapa.shared.cluster` MLIR Ops to generate mapa instructions. `mapa` - Map the address of the shared variable in the target CTA. - `mapa` - source is a register containing generic address pointing to shared memory. - `mapa.shared.cluster` - source is a shared memory variable or a register containing a valid shared memory address. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-mapa
1 parent 1f38d38 commit ab9e447

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,6 +2539,28 @@ def NVVM_GriddepcontrolLaunchDependentsOp
25392539
}];
25402540
}
25412541

2542+
//===----------------------------------------------------------------------===//
2543+
// NVVM Mapa Op
2544+
//===----------------------------------------------------------------------===//
2545+
2546+
def NVVM_MapaOp: NVVM_Op<"mapa",
2547+
[TypesMatchWith<"`res` and `a` should have the same type",
2548+
"a", "res", "$_self">]> {
2549+
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
2550+
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
2551+
2552+
string llvmBuilder = [{
2553+
int addrSpace = llvm::cast<LLVMPointerType>(op.getA().getType()).getAddressSpace();
2554+
2555+
bool isSharedMemory = addrSpace == NVVM::NVVMMemorySpace::kSharedMemorySpace;
2556+
2557+
auto intId = isSharedMemory? llvm::Intrinsic::nvvm_mapa_shared_cluster : llvm::Intrinsic::nvvm_mapa;
2558+
$res = createIntrinsicCall(builder, intId, {$a, $b});
2559+
}];
2560+
2561+
let assemblyFormat = "$a`,` $b attr-dict `:` type($a) `->` type($res)";
2562+
}
2563+
25422564
def NVVM_Exit : NVVM_Op<"exit"> {
25432565
let summary = "Exit Op";
25442566
let description = [{

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,14 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
11891189

11901190
// -----
11911191

1192+
func.func @mapa(%a: !llvm.ptr, %b : i32) {
1193+
// expected-error @below {{`res` and `a` should have the same type}}
1194+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
1195+
return
1196+
}
1197+
1198+
// -----
1199+
11921200
func.func @gep_struct_variable(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32) {
11931201
// expected-error @below {{op expected index 1 indexing a struct to be constant}}
11941202
llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr, i32, i32) -> !llvm.ptr, !llvm.struct<(i32)>

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,15 @@ func.func @griddepcontrol_launch_dependents()
522522
return
523523
}
524524

525+
// CHECK-LABEL: @mapa
526+
func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
527+
// CHECK: nvvm.mapa %{{.*}}
528+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
529+
// CHECK: nvvm.mapa %{{.*}}
530+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
531+
return
532+
}
533+
525534
// -----
526535

527536
// Just check these don't emit errors.

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,13 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
773773
nvvm.griddepcontrol.launch.dependents
774774
llvm.return
775775
}
776+
777+
// -----
778+
// CHECK-LABEL: @nvvm_mapa
779+
llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
780+
// CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
781+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
782+
// CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
783+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
784+
llvm.return
785+
}

0 commit comments

Comments
 (0)