-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics #131621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Johannes de Fine Licht (definelicht) ChangesThis was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed. Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns ( Full diff: https://github.com/llvm/llvm-project/pull/131621.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 655316cc5d66d..16109b5c59f7e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1051,30 +1051,53 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
template <class MemsetIntr>
static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
OpBuilder &builder) {
- // TODO: Support non-integer types.
- return TypeSwitch<Type, Value>(slot.elemType)
- .Case([&](IntegerType intType) -> Value {
- if (intType.getWidth() == 8)
- return op.getVal();
-
- assert(intType.getWidth() % 8 == 0);
-
- // Build the memset integer by repeatedly shifting the value and
- // or-ing it with the previous value.
- uint64_t coveredBits = 8;
- Value currentValue =
- builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
- while (coveredBits < intType.getWidth()) {
- Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
- coveredBits);
- Value shifted =
- builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
- currentValue =
- builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
- coveredBits *= 2;
- }
+ /// Returns an integer value that is `width` bits wide representing the value
+ /// assigned to the slot by memset.
+ auto buildMemsetValue = [&](unsigned width) -> Value {
+ if (width == 8)
+ return op.getVal();
+
+ assert(width % 8 == 0);
+
+ auto intType = IntegerType::get(op.getContext(), width);
+
+ // If we know the pattern at compile time, we can compute and assign a
+ // constant directly.
+ IntegerAttr constantPattern;
+ if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
+ APInt memsetVal(/*numBits=*/width, /*val=*/0);
+ unsigned patternWidth = op.getVal().getType().getWidth();
+ for (unsigned loBit = 0; loBit + patternWidth <= width;
+ loBit += patternWidth)
+ memsetVal.insertBits(constantPattern.getValue(), loBit);
+ return builder.create<LLVM::ConstantOp>(
+ op.getLoc(), IntegerAttr::get(intType, memsetVal));
+ }
+
+ // Otherwise build the memset integer at runtime by repeatedly shifting the
+ // value and or-ing it with the previous value.
+ uint64_t coveredBits = 8;
+ Value currentValue =
+ builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
+ while (coveredBits < width) {
+ Value shiftBy =
+ builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits);
+ Value shifted =
+ builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
+ currentValue =
+ builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
+ coveredBits *= 2;
+ }
- return currentValue;
+ return currentValue;
+ };
+ return TypeSwitch<Type, Value>(slot.elemType)
+ .Case([&](IntegerType type) -> Value {
+ return buildMemsetValue(type.getWidth());
+ })
+ .Case([&](FloatType type) -> Value {
+ Value intVal = buildMemsetValue(type.getWidth());
+ return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal);
})
.Default([](Type) -> Value {
llvm_unreachable(
@@ -1088,11 +1111,10 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
- // TODO: Support non-integer types.
bool canConvertType =
TypeSwitch<Type, bool>(slot.elemType)
- .Case([](IntegerType intType) {
- return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
+ .Case<IntegerType, FloatType>([](auto type) {
+ return type.getWidth() % 8 == 0 && type.getWidth() > 0;
})
.Default([](Type) { return false; });
if (!canConvertType)
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 646667505a373..f3dca45265082 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -23,6 +23,30 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
// -----
+// CHECK-LABEL: llvm.func @memset_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_float(%memset_value: i8) -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %memset_len = llvm.mlir.constant(4 : i32) : i32
+ "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ // CHECK-NOT: "llvm.intr.memset"
+ // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+ // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+ // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+ // CHECK-NOT: "llvm.intr.memset"
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+ // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+ llvm.return %2 : f32
+}
+
+// -----
+
// CHECK-LABEL: llvm.func @basic_memset_inline
// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
@@ -45,6 +69,29 @@ llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
// -----
+// CHECK-LABEL: llvm.func @memset_inline_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_inline_float(%memset_value: i8) -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4 : i32}> : (!llvm.ptr, i8) -> ()
+ // CHECK-NOT: "llvm.intr.memset.inline"
+ // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+ // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+ // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+ // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+ // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+ // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+ // CHECK-NOT: "llvm.intr.memset.inline"
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+ // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+ llvm.return %2 : f32
+}
+
+// -----
+
// CHECK-LABEL: llvm.func @basic_memset_constant
llvm.func @basic_memset_constant() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
@@ -53,15 +100,8 @@ llvm.func @basic_memset_constant() -> i32 {
%memset_len = llvm.mlir.constant(4 : i32) : i32
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
- // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
- // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
- // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
- // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
- // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
- // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
- // CHECK: llvm.return %[[RES]] : i32
+ // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+ // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
llvm.return %2 : i32
}
@@ -74,15 +114,8 @@ llvm.func @basic_memset_inline_constant() -> i32 {
%memset_value = llvm.mlir.constant(42 : i8) : i8
"llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
- // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
- // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
- // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
- // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
- // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
- // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
- // CHECK: llvm.return %[[RES]] : i32
+ // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+ // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
llvm.return %2 : i32
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the improvement. I only have a question about the assert, the rest LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I would probably assert patternWidth == 8 or even use 8 directly when computing the constant. That way it is more obvious there cannot be an overflow (which triggers an assert IUC).
e126ed8
to
63777d8
Compare
This was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed. Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns (`memset` of zero is a case that can appear in the wild).
63777d8
to
ce9af74
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed.
Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns (
memset
of zero is a case that can appear in the wild).