-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion #119675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Pietro Ghiglio (PietroGhg) ChangesThis PR adds support to the Full diff: https://github.com/llvm/llvm-project/pull/119675.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
.Default([](auto) { return std::nullopt; });
}
- static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
- StringRef baseName = getBaseName(op.getMode());
- std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+ static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+ Type type) {
+ StringRef baseName = getBaseName(mode);
+ std::optional<StringRef> typeMangling = getTypeMangling(type);
if (!typeMangling)
return std::nullopt;
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
typeMangling.value());
}
+ static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+ return getFuncName(op.getMode(), op.getType(0));
+ }
+
/// Get the subgroup size from the target or return a default.
static std::optional<int> getSubgroupSize(Operation *op) {
auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
val == getSubgroupSize(op);
}
+ static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+ Type type = op.getType(0);
+ return isa<BFloat16Type>(type) || type.isInteger(1);
+ }
+
+ static Type getBitCastOrExtTy(Type oldTy,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Type>(oldTy)
+ .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+ .Case<IntegerType>([&](auto intTy) -> Type {
+ if (intTy.getWidth() == 1)
+ return rewriter.getIntegerType(8);
+ return Type{};
+ })
+ .Default([](auto) { return Type{}; });
+ }
+
+ static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(oldVal.getType())
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
+ static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(newTy)
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
return rewriter.notifyMatchFailure(
op, "shuffle width and subgroup size mismatch");
- std::optional<std::string> funcName = getFuncName(op);
+ Location loc = op->getLoc();
+ Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+ std::optional<std::string> funcName;
+ Value inValue;
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastorext");
+ funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+ inValue = newVal;
+ } else {
+ funcName = getFuncName(op);
+ inValue = adaptor.getValue();
+ }
if (!funcName)
return rewriter.notifyMatchFailure(op, "unsupported value type");
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
- Type valueType = adaptor.getValue().getType();
+ Type valueType = inValue.getType();
Type offsetType = adaptor.getOffset().getType();
Type resultType = valueType;
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
/*isMemNone=*/false, /*isConvergent=*/true);
- Location loc = op->getLoc();
- std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+ std::array<Value, 2> args{inValue, adaptor.getOffset()};
Value result =
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastortrunc");
+ result = newVal;
+ }
+
Value trueVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
- // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32)
+ // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+ // CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
llvm.func @gpu_shuffles(%i8_val: i8,
%i16_val: i16,
%i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
%f16_val: f16,
%f32_val: f32,
%f64_val: f64,
+ %bf16_val: bf16,
+ %i1_val: i1,
%offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
%width = arith.constant 16 : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
// CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+ // CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+ // CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+ // CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+ // CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+ // CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+ // CHECK: llvm.mlir.constant(true) : i1
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+ %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+ %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
llvm.return
}
}
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
// Cannot convert due to value type not being supported by the conversion
gpu.module @not_supported_lowering {
- llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+ llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
%width = arith.constant 32 : i32
- // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
- %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
llvm.return
}
}
|
@llvm/pr-subscribers-mlir-gpu Author: Pietro Ghiglio (PietroGhg) ChangesThis PR adds support to the Full diff: https://github.com/llvm/llvm-project/pull/119675.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
.Default([](auto) { return std::nullopt; });
}
- static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
- StringRef baseName = getBaseName(op.getMode());
- std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+ static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+ Type type) {
+ StringRef baseName = getBaseName(mode);
+ std::optional<StringRef> typeMangling = getTypeMangling(type);
if (!typeMangling)
return std::nullopt;
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
typeMangling.value());
}
+ static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+ return getFuncName(op.getMode(), op.getType(0));
+ }
+
/// Get the subgroup size from the target or return a default.
static std::optional<int> getSubgroupSize(Operation *op) {
auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
val == getSubgroupSize(op);
}
+ static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+ Type type = op.getType(0);
+ return isa<BFloat16Type>(type) || type.isInteger(1);
+ }
+
+ static Type getBitCastOrExtTy(Type oldTy,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Type>(oldTy)
+ .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+ .Case<IntegerType>([&](auto intTy) -> Type {
+ if (intTy.getWidth() == 1)
+ return rewriter.getIntegerType(8);
+ return Type{};
+ })
+ .Default([](auto) { return Type{}; });
+ }
+
+ static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(oldVal.getType())
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
+ static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(newTy)
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
return rewriter.notifyMatchFailure(
op, "shuffle width and subgroup size mismatch");
- std::optional<std::string> funcName = getFuncName(op);
+ Location loc = op->getLoc();
+ Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+ std::optional<std::string> funcName;
+ Value inValue;
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastorext");
+ funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+ inValue = newVal;
+ } else {
+ funcName = getFuncName(op);
+ inValue = adaptor.getValue();
+ }
if (!funcName)
return rewriter.notifyMatchFailure(op, "unsupported value type");
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
- Type valueType = adaptor.getValue().getType();
+ Type valueType = inValue.getType();
Type offsetType = adaptor.getOffset().getType();
Type resultType = valueType;
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
/*isMemNone=*/false, /*isConvergent=*/true);
- Location loc = op->getLoc();
- std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+ std::array<Value, 2> args{inValue, adaptor.getOffset()};
Value result =
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastortrunc");
+ result = newVal;
+ }
+
Value trueVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
- // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32)
+ // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+ // CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
llvm.func @gpu_shuffles(%i8_val: i8,
%i16_val: i16,
%i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
%f16_val: f16,
%f32_val: f32,
%f64_val: f64,
+ %bf16_val: bf16,
+ %i1_val: i1,
%offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
%width = arith.constant 16 : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
// CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+ // CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+ // CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+ // CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+ // CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+ // CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+ // CHECK: llvm.mlir.constant(true) : i1
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+ %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+ %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
llvm.return
}
}
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
// Cannot convert due to value type not being supported by the conversion
gpu.module @not_supported_lowering {
- llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+ llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
%width = arith.constant 32 : i32
- // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
- %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
llvm.return
}
}
|
Location loc = op->getLoc(); | ||
Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter); | ||
std::optional<std::string> funcName; | ||
Value inValue; | ||
if (bitcastOrExtDestTy) { | ||
Value newVal = | ||
doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter); | ||
assert(newVal && "Unhandled op type in bitcastorext"); | ||
funcName = getFuncName(op.getMode(), bitcastOrExtDestTy); | ||
inValue = newVal; | ||
} else { | ||
funcName = getFuncName(op); | ||
inValue = adaptor.getValue(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Location loc = op->getLoc(); | |
Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter); | |
std::optional<std::string> funcName; | |
Value inValue; | |
if (bitcastOrExtDestTy) { | |
Value newVal = | |
doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter); | |
assert(newVal && "Unhandled op type in bitcastorext"); | |
funcName = getFuncName(op.getMode(), bitcastOrExtDestTy); | |
inValue = newVal; | |
} else { | |
funcName = getFuncName(op); | |
inValue = adaptor.getValue(); | |
} | |
Location loc = op->getLoc(); | |
Value inValue = bitcastOrExtForShuffle(loc, adaptor.getValue(), rewriter); | |
std::optional<std::string> funcName = getFuncName(op.getMode(), inValue.getType()); |
Would this work too with a TypeSwitch
in bitcastOrExtForShuffle
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thank you
if (bitcastOrExtDestTy) { | ||
Value newVal = | ||
doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter); | ||
assert(newVal && "Unhandled op type in bitcastortrunc"); | ||
result = newVal; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (bitcastOrExtDestTy) { | |
Value newVal = | |
doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter); | |
assert(newVal && "Unhandled op type in bitcastortrunc"); | |
result = newVal; | |
} | |
result = bitcastOrTruncAfterShuffle(loc, result, rewriter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thank you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some drive by nits
static std::optional<std::string> getFuncName(gpu::ShuffleMode mode, | ||
Type type) { | ||
StringRef baseName = getBaseName(mode); | ||
std::optional<StringRef> typeMangling = getTypeMangling(type); | ||
if (!typeMangling) | ||
return std::nullopt; | ||
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, you don't need these indices, "_Z{}{}{}"
should work too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've removed the explicit indices
.Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); }) | ||
.Case<IntegerType>([&](auto intTy) -> Type { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't see why these need to be generic over the argument type
.Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); }) | |
.Case<IntegerType>([&](auto intTy) -> Type { | |
.Case<BFloat16Type>([&](Type) { return rewriter.getIntegerType(16); }) | |
.Case<IntegerType>([&](IntegerType intTy) -> Type { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure that would compile.
.Case([&](BFloat16Type) { return rewriter.getIntegerType(16); })
.Case([&](IntegerType intTy) -> Type {
Is nicer IMO tho
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I went with @victor-eds 's suggestion
return rewriter.getIntegerType(8); | ||
return Type{}; | ||
}) | ||
.Default([](auto) { return Type{}; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.Default([](auto) { return Type{}; }); | |
.Default([](Type) { return Type{}; }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thank you
gpu.module @not_supported_lowering { | ||
llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} { | ||
llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} { | ||
%width = arith.constant 32 : i32 | ||
// expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}} | ||
%shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1 | ||
llvm.return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have removed the reason for this module/function to exist. Maybe replace the i1
value with another unsupported type like a large float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've added a test for f128
@@ -286,30 +291,94 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { | |||
val == getSubgroupSize(op); | |||
} | |||
|
|||
static bool needsBitCastOrExt(gpu::ShuffleOp op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it, thanks for spotting it
9e3af90
to
09e2222
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple NITs. LGTM!
09e2222
to
7ae5428
Compare
7ae5428
to
60ef4fd
Compare
This PR adds support to the
bf16
andi1
data types when convertinggpu::shuffle
to theLLVMSPV
dialect, by insertingbitcast
to/fromi16
(forbf16
) and extending/truncating toi8
(fori1
).