Skip to content

Commit d8bd7f1

Browse files
authored
[mlir] Support ROCDL::ReadlaneOp (#116593)
Support ROCDL::ReadlaneOp to solve ROCm/triton-internal#411.
1 parent 3a63407 commit d8bd7f1

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,21 @@ def ROCDL_BallotOp :
197197
let assemblyFormat = "$pred attr-dict `:` type($res)";
198198
}
199199

200+
def ROCDL_ReadlaneOp : ROCDL_IntrOp<"readlane", [], [0], [AllTypesMatch<["res", "src0"]>], 1>,
201+
Arguments<(ins LLVM_Type:$src0,
202+
I32:$src1)> {
203+
let results = (outs LLVM_Type:$res);
204+
let summary = "Get the value in the specific lane.";
205+
206+
let description = [{
207+
Get the value in lane `src1` from input `src0`.
208+
}];
209+
210+
let assemblyFormat = [{
211+
$src0 `,` $src1 attr-dict `:` `(` type($src0) `,` type($src1) `)` `->` type($res)
212+
}];
213+
}
214+
200215
//===----------------------------------------------------------------------===//
201216
// Thread index and Block index
202217

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,17 @@ llvm.func @rocdl.s.wait.dscnt() {
388388

389389
// -----
390390

391+
llvm.func @rocdl.readlane(%src : f32) -> f32 {
392+
%cst0 = llvm.mlir.constant(0 : i32) : i32
393+
394+
// CHECK-LABEL: rocdl.readlane
395+
// CHECK: rocdl.readlane %{{.*}} %{{.*}}
396+
%ret = rocdl.readlane %src, %cst0 : (f32, i32) -> f32
397+
llvm.return %ret : f32
398+
}
399+
400+
// -----
401+
391402
// expected-error@below {{attribute attached to unexpected op}}
392403
func.func private @expected_llvm_func() attributes { rocdl.kernel }
393404

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,25 @@ llvm.func @rocdl.ballot64(%pred : i1) -> i64 {
118118
llvm.return %0 : i64
119119
}
120120

121+
llvm.func @rocdl.readlane(%src0 : f32, %src1: f64, %src2: i32, %src3: vector<2 x f32>) -> f32 {
122+
%idx = llvm.mlir.constant(0 : i32) : i32
123+
124+
// CHECK-LABEL: rocdl.readlane
125+
// CHECK: call float @llvm.amdgcn.readlane.f32(float %{{.*}}, i32 0)
126+
%0 = rocdl.readlane %src0, %idx : (f32, i32) -> f32
127+
128+
// CHECK: call double @llvm.amdgcn.readlane.f64(double %{{.*}}, i32 0)
129+
%1 = rocdl.readlane %src1, %idx : (f64, i32) -> f64
130+
131+
// CHECK: call i32 @llvm.amdgcn.readlane.i32(i32 %{{.*}}, i32 0)
132+
%2 = rocdl.readlane %src2, %idx : (i32, i32) -> i32
133+
134+
// CHECK: call <2 x float> @llvm.amdgcn.readlane.v2f32(<2 x float> %{{.*}}, i32 0)
135+
%3 = rocdl.readlane %src3, %idx : (vector<2 x f32>, i32) -> vector<2 x f32>
136+
137+
llvm.return %0 : f32
138+
}
139+
121140
llvm.func @rocdl.waitcnt() {
122141
// CHECK-LABEL: rocdl.waitcnt
123142
// CHECK-NEXT: call void @llvm.amdgcn.s.waitcnt(i32 0)

0 commit comments

Comments
 (0)