Skip to content

Commit 7317678

Browse files
committed
Redo subgroup id
1 parent 42036bf commit 7317678

File tree

4 files changed

+48
-26
lines changed

4 files changed

+48
-26
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,7 @@ def ROCDL_ReadlaneOp : ROCDL_IntrOp<"readlane", [], [0], [AllTypesMatch<["res",
204204
}];
205205
}
206206

207-
// The LLVM intrinsic function name is rather mouthful,
208-
// so here we opt to use a shorter rocdl name.
209-
def ROCDL_WaveIdOp : LLVM_IntrOpBase<ROCDL_Dialect, "wave_id",
210-
"amdgcn_s_get_waveid_in_workgroup", [], [], [], 1>,
207+
def ROCDL_WaveIdOp : ROCDL_IntrOp<"s.get.waveid.in.workgroup", [], [], [Pure], 1>,
211208
Arguments<(ins)> {
212209
let results = (outs LLVM_Type:$res);
213210
let assemblyFormat = "attr-dict `:` type($res)";

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,16 +214,32 @@ struct GPUSubgroupIdOpToROCDL final
214214
LogicalResult
215215
matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
216216
ConversionPatternRewriter &rewriter) const override {
217-
if (chipset.majorVersion < 10) {
218-
return rewriter.notifyMatchFailure(
219-
op, "SubgroupIdOp is not yet supported on this architecture");
220-
}
221-
222217
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
223-
Value waveIdOp = rewriter.create<ROCDL::WaveIdOp>(op.getLoc(), int32Type);
224-
waveIdOp = truncOrExtToLLVMType(rewriter, op.getLoc(), waveIdOp,
225-
*getTypeConverter());
226-
rewriter.replaceOp(op, {waveIdOp});
218+
auto loc = op.getLoc();
219+
LLVM::IntegerOverflowFlags flags =
220+
LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
221+
// w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) / subgroup_size
222+
Value workitemIdX = rewriter.create<ROCDL::ThreadIdXOp>(loc, int32Type);
223+
Value workitemIdY = rewriter.create<ROCDL::ThreadIdYOp>(loc, int32Type);
224+
Value workitemIdZ = rewriter.create<ROCDL::ThreadIdZOp>(loc, int32Type);
225+
Value workitemDimX = rewriter.create<ROCDL::BlockDimXOp>(loc, int32Type);
226+
Value workitemDimY = rewriter.create<ROCDL::BlockDimYOp>(loc, int32Type);
227+
Value dimYxIdZ = rewriter.create<LLVM::MulOp>(loc, int32Type, workitemDimY,
228+
workitemIdZ, flags);
229+
Value dimYxIdZPlusIdY = rewriter.create<LLVM::AddOp>(
230+
loc, int32Type, dimYxIdZ, workitemIdY, flags);
231+
Value dimYxIdZPlusIdYTimesDimX = rewriter.create<LLVM::MulOp>(
232+
loc, int32Type, workitemDimX, dimYxIdZPlusIdY, flags);
233+
Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
234+
rewriter.create<LLVM::AddOp>(loc, int32Type, workitemIdX,
235+
dimYxIdZPlusIdYTimesDimX, flags);
236+
Value subgroupSize = rewriter.create<LLVM::ConstantOp>(
237+
loc, IntegerType::get(rewriter.getContext(), 32), 64);
238+
Value waveIdOp = rewriter.create<LLVM::SDivOp>(
239+
loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
240+
241+
rewriter.replaceOp(op, {truncOrExtToLLVMType(rewriter, loc, waveIdOp,
242+
*getTypeConverter())});
227243
return success();
228244
}
229245
};

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-chipset.mlir

Lines changed: 0 additions & 13 deletions
This file was deleted.

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,25 @@ gpu.module @test_module {
740740
gpu.module @test_custom_data_layout attributes {llvm.data_layout = "e"} {
741741

742742
}
743+
744+
// -----
745+
746+
gpu.module @test_module {
747+
// CHECK-LABEL: func @gpu_subgroup_id()
748+
func.func @gpu_subgroup_id() -> (index) {
749+
// CHECK: %[[widx:.*]] = rocdl.workitem.id.x : i32
750+
// CHECK: %[[widy:.*]] = rocdl.workitem.id.y : i32
751+
// CHECK: %[[widz:.*]] = rocdl.workitem.id.z : i32
752+
// CHECK: %[[dimx:.*]] = rocdl.workgroup.dim.x : i32
753+
// CHECK: %[[dimy:.*]] = rocdl.workgroup.dim.y : i32
754+
// CHECK: %[[int5:.*]] = llvm.mul %[[dimy]], %[[widz]] overflow<nsw, nuw> : i32
755+
// CHECK: %[[int6:.*]] = llvm.add %[[int5]], %[[widy]] overflow<nsw, nuw> : i32
756+
// CHECK: %[[int7:.*]] = llvm.mul %[[dimx]], %[[int6]] overflow<nsw, nuw> : i32
757+
// CHECK: %[[int8:.*]] = llvm.add %[[widx]], %[[int7]] overflow<nsw, nuw> : i32
758+
// CHECK: %[[ssize:.*]] = llvm.mlir.constant(64 : i32) : i32
759+
// CHECK: = llvm.sdiv %[[int8]], %[[ssize]] : i32
760+
// CHECK: = llvm.sext %10 : i32 to i64
761+
%subgroupId = gpu.subgroup_id : index
762+
func.return %subgroupId : index
763+
}
764+
}

0 commit comments

Comments
 (0)