Skip to content

Commit 9792ced

Browse files
committed
[MLIR][NVVM] Add support for mapa MLIR Ops
Adds `mapa` and `mapa.shared.cluster` MLIR Ops to generate mapa instructions. `mapa` - Map the address of the shared variable in the target CTA. - `mapa` - source is a register containing generic address pointing to shared memory. - `mapa.shared.cluster` - source is a shared memory variable or a register containing a valid shared memory address.
1 parent 2a26292 commit 9792ced

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2512,6 +2512,28 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
25122512
}];
25132513
}
25142514

2515+
//===----------------------------------------------------------------------===//
2516+
// NVVM Mapa Op
2517+
//===----------------------------------------------------------------------===//
2518+
2519+
def NVVM_MapaOp: NVVM_Op<"mapa",
2520+
[TypesMatchWith<"`res` and `a` should have the same type",
2521+
"a", "res", "$_self">]> {
2522+
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
2523+
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
2524+
2525+
string llvmBuilder = [{
2526+
int addrSpace = llvm::cast<LLVMPointerType>(op.getA().getType()).getAddressSpace();
2527+
2528+
bool isSharedMemory = addrSpace == NVVM::NVVMMemorySpace::kSharedMemorySpace;
2529+
2530+
auto intId = isSharedMemory? llvm::Intrinsic::nvvm_mapa_shared_cluster : llvm::Intrinsic::nvvm_mapa;
2531+
$res = createIntrinsicCall(builder, intId, {$a, $b});
2532+
}];
2533+
2534+
let assemblyFormat = "$a`,` $b attr-dict `:` type($a) `->` type($res)";
2535+
}
2536+
25152537
def NVVM_Exit : NVVM_Op<"exit"> {
25162538
let summary = "Exit Op";
25172539
let description = [{

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,14 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
11891189

11901190
// -----
11911191

1192+
func.func @mapa(%a: !llvm.ptr, %b : i32) {
1193+
// expected-error @below {{`res` and `a` should have the same type}}
1194+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
1195+
return
1196+
}
1197+
1198+
// -----
1199+
11921200
func.func @gep_struct_variable(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32) {
11931201
// expected-error @below {{op expected index 1 indexing a struct to be constant}}
11941202
llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr, i32, i32) -> !llvm.ptr, !llvm.struct<(i32)>

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,15 @@ func.func @wgmma_wait_group_sync_aligned() {
509509
return
510510
}
511511

512+
// CHECK-LABEL: @mapa
513+
func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
514+
// CHECK: nvvm.mapa %{{.*}}
515+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
516+
// CHECK: nvvm.mapa %{{.*}}
517+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
518+
return
519+
}
520+
512521
// -----
513522

514523
// Just check these don't emit errors.

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,3 +757,13 @@ llvm.func @nvvm_wgmma_wait_group_aligned() {
757757
nvvm.wgmma.wait.group.sync.aligned 20
758758
llvm.return
759759
}
760+
761+
// -----
762+
// CHECK-LABEL: @nvvm_mapa
763+
llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
764+
// CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
765+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
766+
// CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
767+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
768+
llvm.return
769+
}

0 commit comments

Comments
 (0)