Skip to content

Commit 63777d8

Browse files
committed
[MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics.
This was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed. Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns (`memset` of zero is a case that can appear in the wild).
1 parent 8bc0f87 commit 63777d8

File tree

2 files changed

+91
-44
lines changed

2 files changed

+91
-44
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,30 +1051,53 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
10511051
template <class MemsetIntr>
10521052
static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
10531053
OpBuilder &builder) {
1054-
// TODO: Support non-integer types.
1055-
return TypeSwitch<Type, Value>(slot.elemType)
1056-
.Case([&](IntegerType intType) -> Value {
1057-
if (intType.getWidth() == 8)
1058-
return op.getVal();
1059-
1060-
assert(intType.getWidth() % 8 == 0);
1061-
1062-
// Build the memset integer by repeatedly shifting the value and
1063-
// or-ing it with the previous value.
1064-
uint64_t coveredBits = 8;
1065-
Value currentValue =
1066-
builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
1067-
while (coveredBits < intType.getWidth()) {
1068-
Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
1069-
coveredBits);
1070-
Value shifted =
1071-
builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
1072-
currentValue =
1073-
builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
1074-
coveredBits *= 2;
1075-
}
1054+
/// Returns an integer value that is `width` bits wide representing the value
1055+
/// assigned to the slot by memset.
1056+
auto buildMemsetValue = [&](unsigned width) -> Value {
1057+
assert(width % 8 == 0);
1058+
auto intType = IntegerType::get(op.getContext(), width);
1059+
1060+
// If we know the pattern at compile time, we can compute and assign a
1061+
// constant directly.
1062+
IntegerAttr constantPattern;
1063+
if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
1064+
// The pattern must fit in a byte.
1065+
assert(constantPattern.getValue().getActiveBits() <= 8);
1066+
APInt memsetVal(/*numBits=*/width, /*val=*/0);
1067+
for (unsigned loBit = 0; loBit < width; loBit += 8)
1068+
memsetVal.insertBits(constantPattern.getValue(), loBit);
1069+
return builder.create<LLVM::ConstantOp>(
1070+
op.getLoc(), IntegerAttr::get(intType, memsetVal));
1071+
}
1072+
1073+
// If the output is a single byte, we can return the pattern directly.
1074+
if (width == 8)
1075+
return op.getVal();
1076+
1077+
// Otherwise build the memset integer at runtime by repeatedly shifting the
1078+
// value and or-ing it with the previous value.
1079+
uint64_t coveredBits = 8;
1080+
Value currentValue =
1081+
builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
1082+
while (coveredBits < width) {
1083+
Value shiftBy =
1084+
builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits);
1085+
Value shifted =
1086+
builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
1087+
currentValue =
1088+
builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
1089+
coveredBits *= 2;
1090+
}
10761091

1077-
return currentValue;
1092+
return currentValue;
1093+
};
1094+
return TypeSwitch<Type, Value>(slot.elemType)
1095+
.Case([&](IntegerType type) -> Value {
1096+
return buildMemsetValue(type.getWidth());
1097+
})
1098+
.Case([&](FloatType type) -> Value {
1099+
Value intVal = buildMemsetValue(type.getWidth());
1100+
return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal);
10781101
})
10791102
.Default([](Type) -> Value {
10801103
llvm_unreachable(
@@ -1088,11 +1111,10 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
10881111
const SmallPtrSetImpl<OpOperand *> &blockingUses,
10891112
SmallVectorImpl<OpOperand *> &newBlockingUses,
10901113
const DataLayout &dataLayout) {
1091-
// TODO: Support non-integer types.
10921114
bool canConvertType =
10931115
TypeSwitch<Type, bool>(slot.elemType)
1094-
.Case([](IntegerType intType) {
1095-
return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
1116+
.Case<IntegerType, FloatType>([](auto type) {
1117+
return type.getWidth() % 8 == 0 && type.getWidth() > 0;
10961118
})
10971119
.Default([](Type) { return false; });
10981120
if (!canConvertType)

mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,30 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
2323

2424
// -----
2525

26+
// CHECK-LABEL: llvm.func @memset_float
27+
// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
28+
llvm.func @memset_float(%memset_value: i8) -> f32 {
29+
%one = llvm.mlir.constant(1 : i32) : i32
30+
%alloca = llvm.alloca %one x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
31+
%memset_len = llvm.mlir.constant(4 : i32) : i32
32+
"llvm.intr.memset"(%alloca, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
33+
// CHECK-NOT: "llvm.intr.memset"
34+
// CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
35+
// CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
36+
// CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
37+
// CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
38+
// CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
39+
// CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
40+
// CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
41+
// CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
42+
// CHECK-NOT: "llvm.intr.memset"
43+
%load = llvm.load %alloca {alignment = 4 : i64} : !llvm.ptr -> f32
44+
// CHECK: llvm.return %[[VALUE_FLOAT]] : f32
45+
llvm.return %load : f32
46+
}
47+
48+
// -----
49+
2650
// CHECK-LABEL: llvm.func @basic_memset_inline
2751
// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
2852
llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
@@ -53,36 +77,37 @@ llvm.func @basic_memset_constant() -> i32 {
5377
%memset_len = llvm.mlir.constant(4 : i32) : i32
5478
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
5579
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
56-
// CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
57-
// CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
58-
// CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
59-
// CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
60-
// CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
61-
// CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
62-
// CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
63-
// CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
64-
// CHECK: llvm.return %[[RES]] : i32
80+
// CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
81+
// CHECK: llvm.return %[[CONSTANT_VAL]] : i32
6582
llvm.return %2 : i32
6683
}
6784

6885
// -----
6986

87+
// CHECK-LABEL: llvm.func @memset_one_byte_constant
88+
llvm.func @memset_one_byte_constant() -> i8 {
89+
%one = llvm.mlir.constant(1 : i32) : i32
90+
%alloca = llvm.alloca %one x i8 : (i32) -> !llvm.ptr
91+
// CHECK: %{{.+}} = llvm.mlir.constant(42 : i8) : i8
92+
%value = llvm.mlir.constant(42 : i8) : i8
93+
"llvm.intr.memset"(%alloca, %value, %one) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
94+
%load = llvm.load %alloca : !llvm.ptr -> i8
95+
// CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(42 : i8) : i8
96+
// CHECK: llvm.return %[[CONSTANT_VAL]] : i8
97+
llvm.return %load : i8
98+
}
99+
100+
// -----
101+
70102
// CHECK-LABEL: llvm.func @basic_memset_inline_constant
71103
llvm.func @basic_memset_inline_constant() -> i32 {
72104
%0 = llvm.mlir.constant(1 : i32) : i32
73105
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
74106
%memset_value = llvm.mlir.constant(42 : i8) : i8
75107
"llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
76108
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
77-
// CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
78-
// CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
79-
// CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
80-
// CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
81-
// CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
82-
// CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
83-
// CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
84-
// CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
85-
// CHECK: llvm.return %[[RES]] : i32
109+
// CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
110+
// CHECK: llvm.return %[[CONSTANT_VAL]] : i32
86111
llvm.return %2 : i32
87112
}
88113

0 commit comments

Comments
 (0)