@@ -214,16 +214,32 @@ struct GPUSubgroupIdOpToROCDL final
214
214
LogicalResult
215
215
matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
216
216
ConversionPatternRewriter &rewriter) const override {
217
- if (chipset.majorVersion < 10 ) {
218
- return rewriter.notifyMatchFailure (
219
- op, " SubgroupIdOp is not yet supported on this architecture" );
220
- }
221
-
222
217
auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
223
- Value waveIdOp = rewriter.create <ROCDL::WaveIdOp>(op.getLoc (), int32Type);
224
- waveIdOp = truncOrExtToLLVMType (rewriter, op.getLoc (), waveIdOp,
225
- *getTypeConverter ());
226
- rewriter.replaceOp (op, {waveIdOp});
218
+ auto loc = op.getLoc ();
219
+ LLVM::IntegerOverflowFlags flags =
220
+ LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
221
+ // w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) / subgroup_size
222
+ Value workitemIdX = rewriter.create <ROCDL::ThreadIdXOp>(loc, int32Type);
223
+ Value workitemIdY = rewriter.create <ROCDL::ThreadIdYOp>(loc, int32Type);
224
+ Value workitemIdZ = rewriter.create <ROCDL::ThreadIdZOp>(loc, int32Type);
225
+ Value workitemDimX = rewriter.create <ROCDL::BlockDimXOp>(loc, int32Type);
226
+ Value workitemDimY = rewriter.create <ROCDL::BlockDimYOp>(loc, int32Type);
227
+ Value dimYxIdZ = rewriter.create <LLVM::MulOp>(loc, int32Type, workitemDimY,
228
+ workitemIdZ, flags);
229
+ Value dimYxIdZPlusIdY = rewriter.create <LLVM::AddOp>(
230
+ loc, int32Type, dimYxIdZ, workitemIdY, flags);
231
+ Value dimYxIdZPlusIdYTimesDimX = rewriter.create <LLVM::MulOp>(
232
+ loc, int32Type, workitemDimX, dimYxIdZPlusIdY, flags);
233
+ Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
234
+ rewriter.create <LLVM::AddOp>(loc, int32Type, workitemIdX,
235
+ dimYxIdZPlusIdYTimesDimX, flags);
236
+ Value subgroupSize = rewriter.create <LLVM::ConstantOp>(
237
+ loc, IntegerType::get (rewriter.getContext (), 32 ), 64 );
238
+ Value waveIdOp = rewriter.create <LLVM::SDivOp>(
239
+ loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
240
+
241
+ rewriter.replaceOp (op, {truncOrExtToLLVMType (rewriter, loc, waveIdOp,
242
+ *getTypeConverter ())});
227
243
return success ();
228
244
}
229
245
};
0 commit comments