Skip to content

Commit 46d1cb8

Browse files
authored
[mlir] GPUToROCDL: Add support for non-i32/f32 shuffle types (#136320)
Use recently added repacking utilities to support other datatypes. Also, tighten `gpu.shuffle` verification to reject scalable vectors
1 parent 7da385d commit 46d1cb8

File tree

5 files changed

+46
-32
lines changed

5 files changed

+46
-32
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def GPU_LaunchOp : GPU_Op<"launch", [
840840
- a variadic number of Private memory attributions.
841841

842842
The `kernelFunc` and `kernelModule` attributes are optional and specifies
843-
the kernel name and a module in which the kernel should be outlined.
843+
the kernel name and a module in which the kernel should be outlined.
844844

845845
Syntax:
846846

@@ -1201,7 +1201,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
12011201
}
12021202

12031203
def AnyIntegerOrFloatOr1DVector :
1204-
AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
1204+
AnyTypeOf<[AnyIntegerOrFloat, FixedVectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
12051205

12061206
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
12071207
let summary = "Reduce values among subgroup.";

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
6262
return canBeBare;
6363
}
6464

65-
Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
66-
const unsigned indexBitwidth) {
65+
static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
66+
const unsigned indexBitwidth) {
6767
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
6868
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
6969
Value minus1 = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
@@ -138,10 +138,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
138138
Location loc = op->getLoc();
139139
Value initShflValue = adaptor.getValue();
140140
Type shflType = initShflValue.getType();
141-
// TODO: Add support for non 32-bit shuffle values.
142-
if (!shflType.isIntOrFloat() || shflType.getIntOrFloatBitWidth() != 32)
143-
return rewriter.notifyMatchFailure(
144-
op, "only 32-bit int/float types are supported");
145141

146142
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
147143
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
@@ -179,15 +175,17 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
179175
Value two = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 2);
180176
Value dwordAlignedDstLane =
181177
rewriter.create<LLVM::ShlOp>(loc, int32Type, selectDstLane, two);
182-
if (shflType.isF32()) {
183-
initShflValue =
184-
rewriter.create<LLVM::BitcastOp>(loc, int32Type, initShflValue);
185-
}
186-
Value shflValue = rewriter.create<ROCDL::DsBpermuteOp>(
187-
loc, int32Type, dwordAlignedDstLane, initShflValue);
188-
if (shflType.isF32()) {
189-
shflValue = rewriter.create<LLVM::BitcastOp>(loc, shflType, shflValue);
178+
179+
SmallVector<Value> decomposed =
180+
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
181+
SmallVector<Value> swizzled;
182+
for (Value v : decomposed) {
183+
Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
184+
dwordAlignedDstLane, v);
185+
swizzled.emplace_back(res);
190186
}
187+
Value shflValue =
188+
LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
191189
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
192190
return success();
193191
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl-unsupported.mlir

Lines changed: 0 additions & 13 deletions
This file was deleted.

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,27 @@ gpu.module @test_module {
710710
%shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
711711
func.return %shfl, %shfli, %shfld : f32, f32, f32
712712
}
713+
714+
// CHECK-LABEL: func @gpu_shuffle_vec
715+
// CHECK-SAME: (%[[ARG:.*]]: vector<4xf16>, %{{.*}}: i32, %{{.*}}: i32)
716+
func.func @gpu_shuffle_vec(%arg0: vector<4xf16>, %arg1: i32, %arg2: i32) -> vector<4xf16> {
717+
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG]] : vector<4xf16> to vector<2xi32>
718+
// CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
719+
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %13[%[[IDX0]] : i32] : vector<2xi32>
720+
// CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
721+
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %13[%[[IDX1]] : i32] : vector<2xi32>
722+
// CHECK: %[[PERM0:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM0]] : (i32, i32) -> i32
723+
// CHECK: %[[PERM1:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM1]] : (i32, i32) -> i32
724+
// CHECK: %[[V0:.*]] = llvm.mlir.poison : vector<2xi32>
725+
// CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
726+
// CHECK: %[[V1:.*]] = llvm.insertelement %[[PERM0]], %[[V0]][%[[IDX0]] : i32] : vector<2xi32>
727+
// CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
728+
// CHECK: %[[V2:.*]] = llvm.insertelement %[[PERM1]], %[[V1]][%[[IDX1]] : i32] : vector<2xi32>
729+
// CHECK: %[[RES:.*]] = llvm.bitcast %[[V2]] : vector<2xi32> to vector<4xf16>
730+
// CHECK: llvm.return %[[RES]] : vector<4xf16>
731+
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<4xf16>
732+
func.return %shfl : vector<4xf16>
733+
}
713734
}
714735

715736
// -----

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,15 +367,15 @@ func.func @subgroup_reduce_cluster_stride_without_size(%arg0 : vector<4xf32>) {
367367
// -----
368368

369369
func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
370-
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
370+
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or fixed-length vector of}}
371371
%res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
372372
return
373373
}
374374

375375
// -----
376376

377377
func.func @subgroup_reduce_bad_type_scalable(%arg0 : vector<[2]xf32>) {
378-
// expected-error@+1 {{is not compatible with scalable vector types}}
378+
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or fixed-length vector of}}
379379
%res = gpu.subgroup_reduce add %arg0 : (vector<[2]xf32>) -> vector<[2]xf32>
380380
return
381381
}
@@ -463,13 +463,21 @@ func.func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
463463
// -----
464464

465465
func.func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
466-
// expected-error@+1 {{op operand #0 must be Integer or Float or vector of Integer or Float values of ranks 1, but got 'index'}}
466+
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
467467
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : index
468468
return
469469
}
470470

471471
// -----
472472

473+
func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %arg2 : i32) {
474+
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
475+
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<[4]xf32>
476+
return
477+
}
478+
479+
// -----
480+
473481
module {
474482
gpu.module @gpu_funcs {
475483
// expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}

0 commit comments

Comments
 (0)