Skip to content

Commit 552d26e

Browse files
authored
[mlir][gpu] Add extra value types for gpu::ShuffleOp (#104605)
Expand the accepted types for gpu.shuffle to any integer, float or 1d vector of integers or floats. Also updated the gpu-to-llvm-spv pass to support those types.
1 parent fd4f952 commit 552d26e

File tree

4 files changed

+84
-32
lines changed

4 files changed

+84
-32
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,17 +1274,11 @@ def GPU_ShuffleMode : I32EnumAttr<"ShuffleMode",
12741274
def GPU_ShuffleModeAttr : EnumAttr<GPU_Dialect, GPU_ShuffleMode,
12751275
"shuffle_mode">;
12761276

1277-
def I32I64F32OrF64 : TypeConstraint<Or<[I32.predicate,
1278-
I64.predicate,
1279-
F32.predicate,
1280-
F64.predicate]>,
1281-
"i32, i64, f32 or f64">;
1282-
12831277
def GPU_ShuffleOp : GPU_Op<
12841278
"shuffle", [Pure, AllTypesMatch<["value", "shuffleResult"]>]>,
1285-
Arguments<(ins I32I64F32OrF64:$value, I32:$offset, I32:$width,
1279+
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width,
12861280
GPU_ShuffleModeAttr:$mode)>,
1287-
Results<(outs I32I64F32OrF64:$shuffleResult, I1:$valid)> {
1281+
Results<(outs AnyIntegerOrFloatOr1DVector:$shuffleResult, I1:$valid)> {
12881282
let summary = "Shuffles values within a subgroup.";
12891283
let description = [{
12901284
The "shuffle" op moves values to a across lanes (a.k.a., invocations,

mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,26 +241,34 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
241241
llvm_unreachable("Unhandled shuffle mode");
242242
}
243243

244-
static StringRef getTypeMangling(Type type) {
245-
return TypeSwitch<Type, StringRef>(type)
244+
static std::optional<StringRef> getTypeMangling(Type type) {
245+
return TypeSwitch<Type, std::optional<StringRef>>(type)
246+
.Case<Float16Type>([](auto) { return "Dhj"; })
246247
.Case<Float32Type>([](auto) { return "fj"; })
247248
.Case<Float64Type>([](auto) { return "dj"; })
248-
.Case<IntegerType>([](auto intTy) {
249+
.Case<IntegerType>([](auto intTy) -> std::optional<StringRef> {
249250
switch (intTy.getWidth()) {
251+
case 8:
252+
return "cj";
253+
case 16:
254+
return "sj";
250255
case 32:
251256
return "ij";
252257
case 64:
253258
return "lj";
254259
}
255-
llvm_unreachable("Invalid integer width");
256-
});
260+
return std::nullopt;
261+
})
262+
.Default([](auto) { return std::nullopt; });
257263
}
258264

259-
static std::string getFuncName(gpu::ShuffleOp op) {
265+
static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
260266
StringRef baseName = getBaseName(op.getMode());
261-
StringRef typeMangling = getTypeMangling(op.getType(0));
267+
std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
268+
if (!typeMangling)
269+
return std::nullopt;
262270
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
263-
typeMangling);
271+
typeMangling.value());
264272
}
265273

266274
/// Get the subgroup size from the target or return a default.
@@ -284,15 +292,17 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
284292
return rewriter.notifyMatchFailure(
285293
op, "shuffle width and subgroup size mismatch");
286294

287-
std::string funcName = getFuncName(op);
295+
std::optional<std::string> funcName = getFuncName(op);
296+
if (!funcName)
297+
return rewriter.notifyMatchFailure(op, "unsupported value type");
288298

289299
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
290300
assert(moduleOp && "Expecting module");
291301
Type valueType = adaptor.getValue().getType();
292302
Type offsetType = adaptor.getOffset().getType();
293303
Type resultType = valueType;
294304
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
295-
moduleOp, funcName, {valueType, offsetType}, resultType,
305+
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
296306
/*isMemNone=*/false, /*isConvergent=*/true);
297307

298308
Location loc = op->getLoc();

mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,12 @@ gpu.module @shuffles attributes {
317317
// CHECK-SAME-DAG: will_return
318318
// CHECK-NOT: memory_effects = #llvm.memory_effects
319319
// CHECK-SAME: }
320+
// CHECK: llvm.func spir_funccc @_Z20sub_group_shuffle_upDhj(f16, i32) -> f16 attributes {
321+
// CHECK-SAME-DAG: no_unwind
322+
// CHECK-SAME-DAG: convergent
323+
// CHECK-SAME-DAG: will_return
324+
// CHECK-NOT: memory_effects = #llvm.memory_effects
325+
// CHECK-SAME: }
320326
// CHECK: llvm.func spir_funccc @_Z21sub_group_shuffle_xorlj(i64, i32) -> i64 attributes {
321327
// CHECK-SAME-DAG: no_unwind
322328
// CHECK-SAME-DAG: convergent
@@ -329,26 +335,54 @@ gpu.module @shuffles attributes {
329335
// CHECK-SAME-DAG: will_return
330336
// CHECK-NOT: memory_effects = #llvm.memory_effects
331337
// CHECK-SAME: }
338+
// CHECK: llvm.func spir_funccc @_Z21sub_group_shuffle_xorsj(i16, i32) -> i16 attributes {
339+
// CHECK-SAME-DAG: no_unwind
340+
// CHECK-SAME-DAG: convergent
341+
// CHECK-SAME-DAG: will_return
342+
// CHECK-NOT: memory_effects = #llvm.memory_effects
343+
// CHECK-SAME: }
344+
// CHECK: llvm.func spir_funccc @_Z17sub_group_shufflecj(i8, i32) -> i8 attributes {
345+
// CHECK-SAME-DAG: no_unwind
346+
// CHECK-SAME-DAG: convergent
347+
// CHECK-SAME-DAG: will_return
348+
// CHECK-NOT: memory_effects = #llvm.memory_effects
349+
// CHECK-SAME: }
332350

333351
// CHECK-LABEL: gpu_shuffles
334-
// CHECK-SAME: (%[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: i32)
335-
func.func @gpu_shuffles(%val0: i32, %id: i32,
336-
%val1: i64, %mask: i32,
337-
%val2: f32, %delta_up: i32,
338-
%val3: f64, %delta_down: i32) {
352+
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
353+
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
354+
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
355+
// CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32) {
356+
func.func @gpu_shuffles(%i8_val: i8,
357+
%i16_val: i16,
358+
%i32_val: i32,
359+
%i64_val: i64,
360+
%f16_val: f16,
361+
%f32_val: f32,
362+
%f64_val: f64,
363+
%offset: i32) {
339364
%width = arith.constant 16 : i32
340-
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleij(%[[VAL_0]], %[[VAL_1]])
365+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
341366
// CHECK: llvm.mlir.constant(true) : i1
342-
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorlj(%[[VAL_2]], %[[VAL_3]])
367+
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorsj(%[[I16_VAL]], %[[OFFSET]])
343368
// CHECK: llvm.mlir.constant(true) : i1
344-
// CHECK: llvm.call spir_funccc @_Z20sub_group_shuffle_upfj(%[[VAL_4]], %[[VAL_5]])
369+
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleij(%[[I32_VAL]], %[[OFFSET]])
345370
// CHECK: llvm.mlir.constant(true) : i1
346-
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[VAL_6]], %[[VAL_7]])
371+
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorlj(%[[I64_VAL]], %[[OFFSET]])
347372
// CHECK: llvm.mlir.constant(true) : i1
348-
%shuffleResult0, %valid0 = gpu.shuffle idx %val0, %id, %width : i32
349-
%shuffleResult1, %valid1 = gpu.shuffle xor %val1, %mask, %width : i64
350-
%shuffleResult2, %valid2 = gpu.shuffle up %val2, %delta_up, %width : f32
351-
%shuffleResult3, %valid3 = gpu.shuffle down %val3, %delta_down, %width : f64
373+
// CHECK: llvm.call spir_funccc @_Z20sub_group_shuffle_upDhj(%[[F16_VAL]], %[[OFFSET]])
374+
// CHECK: llvm.mlir.constant(true) : i1
375+
// CHECK: llvm.call spir_funccc @_Z20sub_group_shuffle_upfj(%[[F32_VAL]], %[[OFFSET]])
376+
// CHECK: llvm.mlir.constant(true) : i1
377+
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
378+
// CHECK: llvm.mlir.constant(true) : i1
379+
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
380+
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
381+
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
382+
%shuffleResult3, %valid3 = gpu.shuffle xor %i64_val, %offset, %width : i64
383+
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
384+
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
385+
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
352386
return
353387
}
354388
}
@@ -378,6 +412,20 @@ gpu.module @shuffles_mismatch {
378412
}
379413
}
380414

415+
// -----
416+
417+
// Cannot convert due to value type not being supported by the conversion
418+
419+
gpu.module @not_supported_lowering {
420+
func.func @gpu_shuffles(%val: i1, %id: i32) {
421+
%width = arith.constant 32 : i32
422+
// expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
423+
%shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
424+
return
425+
}
426+
}
427+
428+
381429
// -----
382430

383431
gpu.module @kernels {

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ func.func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
446446
// -----
447447

448448
func.func @shuffle_unsupported_type(%arg0 : index, %arg1 : i32, %arg2 : i32) {
449-
// expected-error@+1 {{operand #0 must be i32, i64, f32 or f64}}
449+
// expected-error@+1 {{op operand #0 must be Integer or Float or vector of Integer or Float values of ranks 1, but got 'index'}}
450450
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : index
451451
return
452452
}

0 commit comments

Comments
 (0)