Skip to content

Commit cdd652e

Browse files
authored
[MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion (#119675)
This PR adds support to the `bf16` and `i1` data types when converting `gpu::shuffle` to the `LLVMSPV` dialect, by inserting `bitcast` to/from `i16` (for `bf16`) and extending/truncating to `i8` (for `i1`).
1 parent c05fc9b commit cdd652e

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,13 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
262262
.Default([](auto) { return std::nullopt; });
263263
}
264264

265-
static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
266-
StringRef baseName = getBaseName(op.getMode());
267-
std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
265+
static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
266+
Type type) {
267+
StringRef baseName = getBaseName(mode);
268+
std::optional<StringRef> typeMangling = getTypeMangling(type);
268269
if (!typeMangling)
269270
return std::nullopt;
270-
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
271+
return llvm::formatv("_Z{}{}{}", baseName.size(), baseName,
271272
typeMangling.value());
272273
}
273274

@@ -286,33 +287,70 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
286287
val == getSubgroupSize(op);
287288
}
288289

290+
static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc,
291+
ConversionPatternRewriter &rewriter) {
292+
return TypeSwitch<Type, Value>(oldVal.getType())
293+
.Case([&](BFloat16Type) {
294+
return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(),
295+
oldVal);
296+
})
297+
.Case([&](IntegerType intTy) -> Value {
298+
if (intTy.getWidth() == 1)
299+
return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(),
300+
oldVal);
301+
return oldVal;
302+
})
303+
.Default(oldVal);
304+
}
305+
306+
static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy,
307+
Location loc,
308+
ConversionPatternRewriter &rewriter) {
309+
return TypeSwitch<Type, Value>(newTy)
310+
.Case([&](BFloat16Type) {
311+
return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
312+
})
313+
.Case([&](IntegerType intTy) -> Value {
314+
if (intTy.getWidth() == 1)
315+
return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
316+
return oldVal;
317+
})
318+
.Default(oldVal);
319+
}
320+
289321
LogicalResult
290322
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
291323
ConversionPatternRewriter &rewriter) const final {
292324
if (!hasValidWidth(op))
293325
return rewriter.notifyMatchFailure(
294326
op, "shuffle width and subgroup size mismatch");
295327

296-
std::optional<std::string> funcName = getFuncName(op);
328+
Location loc = op->getLoc();
329+
Value inValue =
330+
bitcastOrExtBeforeShuffle(adaptor.getValue(), loc, rewriter);
331+
std::optional<std::string> funcName =
332+
getFuncName(op.getMode(), inValue.getType());
297333
if (!funcName)
298334
return rewriter.notifyMatchFailure(op, "unsupported value type");
299335

300336
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
301337
assert(moduleOp && "Expecting module");
302-
Type valueType = adaptor.getValue().getType();
338+
Type valueType = inValue.getType();
303339
Type offsetType = adaptor.getOffset().getType();
304340
Type resultType = valueType;
305341
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
306342
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
307343
/*isMemNone=*/false, /*isConvergent=*/true);
308344

309-
Location loc = op->getLoc();
310-
std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
345+
std::array<Value, 2> args{inValue, adaptor.getOffset()};
311346
Value result =
312347
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
348+
Value resultOrConversion =
349+
bitcastOrTruncAfterShuffle(result, op.getType(0), loc, rewriter);
350+
313351
Value trueVal =
314352
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
315-
rewriter.replaceOp(op, {result, trueVal});
353+
rewriter.replaceOp(op, {resultOrConversion, trueVal});
316354
return success();
317355
}
318356
};

mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,17 @@ gpu.module @shuffles {
279279
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
280280
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
281281
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
282-
// CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32)
282+
// CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
283+
// CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
283284
llvm.func @gpu_shuffles(%i8_val: i8,
284285
%i16_val: i16,
285286
%i32_val: i32,
286287
%i64_val: i64,
287288
%f16_val: f16,
288289
%f32_val: f32,
289290
%f64_val: f64,
291+
%bf16_val: bf16,
292+
%i1_val: i1,
290293
%offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
291294
%width = arith.constant 16 : i32
292295
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -303,13 +306,23 @@ gpu.module @shuffles {
303306
// CHECK: llvm.mlir.constant(true) : i1
304307
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
305308
// CHECK: llvm.mlir.constant(true) : i1
309+
// CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
310+
// CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
311+
// CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16
312+
// CHECK: llvm.mlir.constant(true) : i1
313+
// CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
314+
// CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
315+
// CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1
316+
// CHECK: llvm.mlir.constant(true) : i1
306317
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
307318
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
308319
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
309320
%shuffleResult3, %valid3 = gpu.shuffle xor %i64_val, %offset, %width : i64
310321
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
311322
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
312323
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
324+
%shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
325+
%shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
313326
llvm.return
314327
}
315328
}
@@ -344,10 +357,10 @@ gpu.module @shuffles_mismatch {
344357
// Cannot convert due to value type not being supported by the conversion
345358

346359
gpu.module @not_supported_lowering {
347-
llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
360+
llvm.func @gpu_shuffles(%val: f128, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
348361
%width = arith.constant 32 : i32
349362
// expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
350-
%shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
363+
%shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : f128
351364
llvm.return
352365
}
353366
}

0 commit comments

Comments
 (0)