Skip to content

Commit 9e3af90

Browse files
committed
[MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion
1 parent eae30a2 commit 9e3af90

File tree

2 files changed

+91
-11
lines changed

2 files changed

+91
-11
lines changed

mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
262262
.Default([](auto) { return std::nullopt; });
263263
}
264264

265-
static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
266-
StringRef baseName = getBaseName(op.getMode());
267-
std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
265+
static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
266+
Type type) {
267+
StringRef baseName = getBaseName(mode);
268+
std::optional<StringRef> typeMangling = getTypeMangling(type);
268269
if (!typeMangling)
269270
return std::nullopt;
270271
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
271272
typeMangling.value());
272273
}
273274

275+
static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
276+
return getFuncName(op.getMode(), op.getType(0));
277+
}
278+
274279
/// Get the subgroup size from the target or return a default.
275280
static std::optional<int> getSubgroupSize(Operation *op) {
276281
auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,30 +291,94 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
286291
val == getSubgroupSize(op);
287292
}
288293

294+
static bool needsBitCastOrExt(gpu::ShuffleOp op) {
295+
Type type = op.getType(0);
296+
return isa<BFloat16Type>(type) || type.isInteger(1);
297+
}
298+
299+
static Type getBitCastOrExtTy(Type oldTy,
300+
ConversionPatternRewriter &rewriter) {
301+
return TypeSwitch<Type, Type>(oldTy)
302+
.Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
303+
.Case<IntegerType>([&](auto intTy) -> Type {
304+
if (intTy.getWidth() == 1)
305+
return rewriter.getIntegerType(8);
306+
return Type{};
307+
})
308+
.Default([](auto) { return Type{}; });
309+
}
310+
311+
static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
312+
ConversionPatternRewriter &rewriter) {
313+
return TypeSwitch<Type, Value>(oldVal.getType())
314+
.Case<BFloat16Type>([&](auto) {
315+
return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
316+
})
317+
.Case<IntegerType>([&](auto intTy) -> Value {
318+
if (intTy.getWidth() == 1)
319+
return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
320+
return Value{};
321+
})
322+
.Default([](auto) { return Value{}; });
323+
}
324+
325+
static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
326+
ConversionPatternRewriter &rewriter) {
327+
return TypeSwitch<Type, Value>(newTy)
328+
.Case<BFloat16Type>([&](auto) {
329+
return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
330+
})
331+
.Case<IntegerType>([&](auto intTy) -> Value {
332+
if (intTy.getWidth() == 1)
333+
return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
334+
return Value{};
335+
})
336+
.Default([](auto) { return Value{}; });
337+
}
338+
289339
LogicalResult
290340
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
291341
ConversionPatternRewriter &rewriter) const final {
292342
if (!hasValidWidth(op))
293343
return rewriter.notifyMatchFailure(
294344
op, "shuffle width and subgroup size mismatch");
295345

296-
std::optional<std::string> funcName = getFuncName(op);
346+
Location loc = op->getLoc();
347+
Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
348+
std::optional<std::string> funcName;
349+
Value inValue;
350+
if (bitcastOrExtDestTy) {
351+
Value newVal =
352+
doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
353+
assert(newVal && "Unhandled op type in bitcastorext");
354+
funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
355+
inValue = newVal;
356+
} else {
357+
funcName = getFuncName(op);
358+
inValue = adaptor.getValue();
359+
}
297360
if (!funcName)
298361
return rewriter.notifyMatchFailure(op, "unsupported value type");
299362

300363
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
301364
assert(moduleOp && "Expecting module");
302-
Type valueType = adaptor.getValue().getType();
365+
Type valueType = inValue.getType();
303366
Type offsetType = adaptor.getOffset().getType();
304367
Type resultType = valueType;
305368
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
306369
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
307370
/*isMemNone=*/false, /*isConvergent=*/true);
308371

309-
Location loc = op->getLoc();
310-
std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
372+
std::array<Value, 2> args{inValue, adaptor.getOffset()};
311373
Value result =
312374
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
375+
if (bitcastOrExtDestTy) {
376+
Value newVal =
377+
doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
378+
assert(newVal && "Unhandled op type in bitcastortrunc");
379+
result = newVal;
380+
}
381+
313382
Value trueVal =
314383
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
315384
rewriter.replaceOp(op, {result, trueVal});

mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,17 @@ gpu.module @shuffles {
277277
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
278278
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
279279
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
280-
// CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32)
280+
// CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
281+
// CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
281282
llvm.func @gpu_shuffles(%i8_val: i8,
282283
%i16_val: i16,
283284
%i32_val: i32,
284285
%i64_val: i64,
285286
%f16_val: f16,
286287
%f32_val: f32,
287288
%f64_val: f64,
289+
%bf16_val: bf16,
290+
%i1_val: i1,
288291
%offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
289292
%width = arith.constant 16 : i32
290293
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,13 +304,23 @@ gpu.module @shuffles {
301304
// CHECK: llvm.mlir.constant(true) : i1
302305
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
303306
// CHECK: llvm.mlir.constant(true) : i1
307+
// CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
308+
// CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
309+
// CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16
310+
// CHECK: llvm.mlir.constant(true) : i1
311+
// CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
312+
// CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
313+
// CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1
314+
// CHECK: llvm.mlir.constant(true) : i1
304315
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
305316
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
306317
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
307318
%shuffleResult3, %valid3 = gpu.shuffle xor %i64_val, %offset, %width : i64
308319
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
309320
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
310321
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
322+
%shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
323+
%shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
311324
llvm.return
312325
}
313326
}
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
342355
// Cannot convert due to value type not being supported by the conversion
343356

344357
gpu.module @not_supported_lowering {
345-
llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
358+
llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
346359
%width = arith.constant 32 : i32
347-
// expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
348-
%shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
349360
llvm.return
350361
}
351362
}

0 commit comments

Comments
 (0)