Skip to content

[MLIR][LLVM] Handle floats in Mem2Reg of memset intrinsics #131621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2025

Conversation

definelicht
Copy link
Contributor

This was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed.

Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns (memset of zero is a case that can appear in the wild).

@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Johannes de Fine Licht (definelicht)

Changes

This was lacking a bitcast from the shifted integer type into a float. Other non-struct types than integers and floats will still not be Mem2Reg'ed.

Also adds special handling for constants to be emitted as a constant directly rather than relying on followup canonicalization patterns (memset of zero is a case that can appear in the wild).


Full diff: https://github.com/llvm/llvm-project/pull/131621.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+48-26)
  • (modified) mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir (+51-18)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 655316cc5d66d..16109b5c59f7e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1051,30 +1051,53 @@ static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
 template <class MemsetIntr>
 static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
                              OpBuilder &builder) {
-  // TODO: Support non-integer types.
-  return TypeSwitch<Type, Value>(slot.elemType)
-      .Case([&](IntegerType intType) -> Value {
-        if (intType.getWidth() == 8)
-          return op.getVal();
-
-        assert(intType.getWidth() % 8 == 0);
-
-        // Build the memset integer by repeatedly shifting the value and
-        // or-ing it with the previous value.
-        uint64_t coveredBits = 8;
-        Value currentValue =
-            builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
-        while (coveredBits < intType.getWidth()) {
-          Value shiftBy = builder.create<LLVM::ConstantOp>(op.getLoc(), intType,
-                                                           coveredBits);
-          Value shifted =
-              builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
-          currentValue =
-              builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
-          coveredBits *= 2;
-        }
+  /// Returns an integer value that is `width` bits wide representing the value
+  /// assigned to the slot by memset.
+  auto buildMemsetValue = [&](unsigned width) -> Value {
+    if (width == 8)
+      return op.getVal();
+
+    assert(width % 8 == 0);
+
+    auto intType = IntegerType::get(op.getContext(), width);
+
+    // If we know the pattern at compile time, we can compute and assign a
+    // constant directly.
+    IntegerAttr constantPattern;
+    if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
+      APInt memsetVal(/*numBits=*/width, /*val=*/0);
+      unsigned patternWidth = op.getVal().getType().getWidth();
+      for (unsigned loBit = 0; loBit + patternWidth <= width;
+           loBit += patternWidth)
+        memsetVal.insertBits(constantPattern.getValue(), loBit);
+      return builder.create<LLVM::ConstantOp>(
+          op.getLoc(), IntegerAttr::get(intType, memsetVal));
+    }
+
+    // Otherwise build the memset integer at runtime by repeatedly shifting the
+    // value and or-ing it with the previous value.
+    uint64_t coveredBits = 8;
+    Value currentValue =
+        builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
+    while (coveredBits < width) {
+      Value shiftBy =
+          builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits);
+      Value shifted =
+          builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
+      currentValue =
+          builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
+      coveredBits *= 2;
+    }
 
-        return currentValue;
+    return currentValue;
+  };
+  return TypeSwitch<Type, Value>(slot.elemType)
+      .Case([&](IntegerType type) -> Value {
+        return buildMemsetValue(type.getWidth());
+      })
+      .Case([&](FloatType type) -> Value {
+        Value intVal = buildMemsetValue(type.getWidth());
+        return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal);
       })
       .Default([](Type) -> Value {
         llvm_unreachable(
@@ -1088,11 +1111,10 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
                        const SmallPtrSetImpl<OpOperand *> &blockingUses,
                        SmallVectorImpl<OpOperand *> &newBlockingUses,
                        const DataLayout &dataLayout) {
-  // TODO: Support non-integer types.
   bool canConvertType =
       TypeSwitch<Type, bool>(slot.elemType)
-          .Case([](IntegerType intType) {
-            return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
+          .Case<IntegerType, FloatType>([](auto type) {
+            return type.getWidth() % 8 == 0 && type.getWidth() > 0;
           })
           .Default([](Type) { return false; });
   if (!canConvertType)
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
index 646667505a373..f3dca45265082 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir
@@ -23,6 +23,30 @@ llvm.func @basic_memset(%memset_value: i8) -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_float(%memset_value: i8) -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %memset_len = llvm.mlir.constant(4 : i32) : i32
+  "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+  // CHECK-NOT: "llvm.intr.memset"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+  // CHECK-NOT: "llvm.intr.memset"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+  // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @basic_memset_inline
 // CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
 llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
@@ -45,6 +69,29 @@ llvm.func @basic_memset_inline(%memset_value: i8) -> i32 {
 
 // -----
 
+// CHECK-LABEL: llvm.func @memset_inline_float
+// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
+llvm.func @memset_inline_float(%memset_value: i8) -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4 : i32}> : (!llvm.ptr, i8) -> ()
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  // CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
+  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
+  // CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
+  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
+  // CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
+  // CHECK: %[[VALUE_FLOAT:.+]] = llvm.bitcast %[[VALUE_32]] : i32 to f32
+  // CHECK-NOT: "llvm.intr.memset.inline"
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+  // CHECK: llvm.return %[[VALUE_FLOAT]] : f32
+  llvm.return %2 : f32
+}
+
+// -----
+
 // CHECK-LABEL: llvm.func @basic_memset_constant
 llvm.func @basic_memset_constant() -> i32 {
   %0 = llvm.mlir.constant(1 : i32) : i32
@@ -53,15 +100,8 @@ llvm.func @basic_memset_constant() -> i32 {
   %memset_len = llvm.mlir.constant(4 : i32) : i32
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
-  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
-  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
-  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
-  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
-  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
-  // CHECK: llvm.return %[[RES]] : i32
+  // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
   llvm.return %2 : i32
 }
 
@@ -74,15 +114,8 @@ llvm.func @basic_memset_inline_constant() -> i32 {
   %memset_value = llvm.mlir.constant(42 : i8) : i8
   "llvm.intr.memset.inline"(%1, %memset_value) <{isVolatile = false, len = 4}> : (!llvm.ptr, i8) -> ()
   %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
-  // CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]]  : i32
-  // CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]]  : i32
-  // CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
-  // CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]]  : i32
-  // CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]]  : i32
-  // CHECK: llvm.return %[[RES]] : i32
+  // CHECK: %[[CONSTANT_VAL:..*]] = llvm.mlir.constant(707406378 : i32) : i32
+  // CHECK: llvm.return %[[CONSTANT_VAL]] : i32
   llvm.return %2 : i32
 }
 

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the improvement. I only have a question about the assert, the rest LGTM!

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

I would probably assert patternWidth == 8 or even use 8 directly when computing the constant. That way it is more obvious there cannot be an overflow (which triggers an assert IUC).

@definelicht definelicht force-pushed the mem2reg-memset-float branch from e126ed8 to 63777d8 Compare March 17, 2025 16:21
This was lacking a bitcast from the shifted integer type into a float.
Other non-struct types than integers and floats will still not be
Mem2Reg'ed.

Also adds special handling for constants to be emitted as a constant
directly rather than relying on followup canonicalization patterns
(`memset` of zero is a case that can appear in the wild).
@definelicht definelicht force-pushed the mem2reg-memset-float branch from 63777d8 to ce9af74 Compare March 17, 2025 16:26
@definelicht definelicht requested a review from gysit March 17, 2025 16:34
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@definelicht definelicht merged commit c3f7502 into llvm:main Mar 17, 2025
11 checks passed
@definelicht definelicht deleted the mem2reg-memset-float branch March 17, 2025 21:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants