-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][LLVM][SROA] Support incorrectly typed memory accesses #85813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][LLVM][SROA] Support incorrectly typed memory accesses #85813
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Christian Ulmann (Dinistro) ChangesThis commit relaxes the assumption of type consistency for LLVM dialect load and store operations in SROA. Instead, there is now a check that loads and stores are in the bounds specified by the sub-slot they access. This commit additionally removes the corresponding patterns from the type consistency pass, as they are no longer necessary. Note: It will be necessary to extend Mem2Reg with the logic for differently sized accesses as well. This is non-the-less a strict upgrade for productive flows, as the type consistency pass can produce invalid IR for some odd cases. Patch is 33.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85813.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b523374f6c06b5..f8f9264b3889be 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -323,7 +323,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
}
def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ [DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
dag args = (ins LLVM_AnyPointer:$addr,
@@ -402,7 +403,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
}
def LLVM_StoreOp : LLVM_MemAccessOpBase<"store",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ [DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
dag args = (ins LLVM_LoadableType:$value,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
index b32ac56d7079c6..cacb241bfd7a10 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -29,18 +29,6 @@ namespace LLVM {
/// interpret pointee types as consistently as possible.
std::unique_ptr<Pass> createTypeConsistencyPass();
-/// Transforms uses of pointers to a whole struct to uses of pointers to the
-/// first element of a struct. This is achieved by inserting a GEP to the first
-/// element when possible.
-template <class User>
-class AddFieldGetterToStructDirectUse : public OpRewritePattern<User> {
-public:
- using OpRewritePattern<User>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(User user,
- PatternRewriter &rewriter) const override;
-};
-
/// Canonicalizes GEPs of which the base type and the pointer's type hint do not
/// match. This is done by replacing the original GEP into a GEP with the type
/// hint as a base type when an element of the hinted type aligns with the
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
index 56e5e96aecd13c..87db1aaf39dea2 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
@@ -26,8 +26,8 @@ struct MemorySlot {
/// Memory slot attached with information about its destructuring procedure.
struct DestructurableMemorySlot : public MemorySlot {
- /// Maps an index within the memory slot to the type of the pointer that
- /// will be generated to access the element directly.
+ /// Maps an index within the memory slot to the element type of the pointer
+ /// that will be generated to access the element directly.
DenseMap<Attribute, Type> elementPtrs;
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 00b4559658fd4d..f9662789025764 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -13,10 +13,8 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -71,12 +69,8 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
if (!destructuredType)
return {};
- DenseMap<Attribute, Type> allocaTypeMap;
- for (Attribute index : llvm::make_first_range(destructuredType.value()))
- allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())});
-
- return {
- DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}};
+ return {DestructurableMemorySlot{{getResult(), getElemType()},
+ *destructuredType}};
}
DenseMap<Attribute, MemorySlot>
@@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
return DeletionKind::Delete;
}
+/// Checks if `slot` can be accessed through the provided access type.
+static bool isValidAccessType(const MemorySlot &slot, Type accessType,
+ const DataLayout &dataLayout) {
+ return dataLayout.getTypeSize(accessType) <=
+ dataLayout.getTypeSize(slot.elemType);
+}
+
LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
- return success(getAddr() != slot.ptr || getType() == slot.elemType);
+ return success(getAddr() != slot.ptr ||
+ isValidAccessType(slot, getType(), dataLayout));
}
LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr ||
- getValue().getType() == slot.elemType);
+ isValidAccessType(slot, getValue().getType(), dataLayout));
+}
+
+/// Returns the subslot's type at the requested index.
+static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
+ Attribute index) {
+ auto subelementIndexMap =
+ slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
+ if (!subelementIndexMap)
+ return {};
+ assert(!subelementIndexMap->empty());
+
+ // Note: Returns a null-type when no entry was found.
+ return subelementIndexMap->lookup(index);
+}
+
+bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ if (getVolatile_())
+ return false;
+
+ // A load always accesses the first element of the destructured slot.
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ Type subslotType = getTypeAtIndex(slot, index);
+ if (!subslotType)
+ return false;
+
+ // The access can only be replaced when the subslot is read within its bounds.
+ if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
+ return false;
+
+ usedIndices.insert(index);
+ return true;
+}
+
+DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ auto it = subslots.find(index);
+ assert(it != subslots.end());
+
+ rewriter.modifyOpInPlace(
+ *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+ return DeletionKind::Keep;
+}
+
+bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
+ SmallPtrSetImpl<Attribute> &usedIndices,
+ SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
+ const DataLayout &dataLayout) {
+ if (getVolatile_())
+ return false;
+
+ // A load always accesses the first element of the destructured slot.
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ Type subslotType = getTypeAtIndex(slot, index);
+ if (!subslotType)
+ return false;
+
+ // The access can only be replaced when the subslot is read within its bounds.
+ if (dataLayout.getTypeSize(getValue().getType()) >
+ dataLayout.getTypeSize(subslotType))
+ return false;
+
+ usedIndices.insert(index);
+ return true;
+}
+
+DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
+ DenseMap<Attribute, MemorySlot> &subslots,
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
+ auto it = subslots.find(index);
+ assert(it != subslots.end());
+
+ rewriter.modifyOpInPlace(
+ *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
+ return DeletionKind::Keep;
}
//===----------------------------------------------------------------------===//
@@ -384,16 +468,17 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
// dynamic indices can never be properly rewired.
if (!getDynamicIndices().empty())
return false;
+ //// TODO: This is not necessary, I think.
+ // if (slot.elemType != getElemType())
+ // return false;
Type reachedType = getResultPtrElementType();
if (!reachedType || getIndices().size() < 2)
return false;
auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
if (!firstLevelIndex)
return false;
- assert(slot.elementPtrs.contains(firstLevelIndex));
- if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
- return false;
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
+ assert(slot.elementPtrs.contains(firstLevelIndex));
usedIndices.insert(firstLevelIndex);
return true;
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
index b25c831bc7172a..3d700fe94e3b9c 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -49,104 +49,6 @@ static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
}
-//===----------------------------------------------------------------------===//
-// AddFieldGetterToStructDirectUse
-//===----------------------------------------------------------------------===//
-
-/// Gets the type of the first subelement of `type` if `type` is destructurable,
-/// nullptr otherwise.
-static Type getFirstSubelementType(Type type) {
- auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
- if (!destructurable)
- return nullptr;
-
- Type subelementType = destructurable.getTypeAtIndex(
- IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0));
- if (subelementType)
- return subelementType;
-
- return nullptr;
-}
-
-/// Extracts a pointer to the first field of an `elemType` from the address
-/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer
-/// instead.
-template <class MemOp>
-static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
- Type elemType) {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- rewriter.setInsertionPointAfterValue(op.getAddr());
- SmallVector<GEPArg> firstTypeIndices{0, 0};
-
- Value properPtr = rewriter.create<GEPOp>(
- op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
- op.getAddr(), firstTypeIndices);
-
- rewriter.modifyOpInPlace(op,
- [&]() { op.getAddrMutable().assign(properPtr); });
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
- LoadOp load, PatternRewriter &rewriter) const {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- Type inconsistentElementType =
- isElementTypeInconsistent(load.getAddr(), load.getType());
- if (!inconsistentElementType)
- return failure();
- Type firstType = getFirstSubelementType(inconsistentElementType);
- if (!firstType)
- return failure();
- DataLayout layout = DataLayout::closest(load);
- if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
- return failure();
-
- insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
-
- // If the load does not use the first type but a type that can be casted from
- // it, add a bitcast and change the load type.
- if (firstType != load.getResult().getType()) {
- rewriter.setInsertionPointAfterValue(load.getResult());
- BitcastOp bitcast = rewriter.create<BitcastOp>(
- load->getLoc(), load.getResult().getType(), load.getResult());
- rewriter.modifyOpInPlace(load,
- [&]() { load.getResult().setType(firstType); });
- rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
- bitcast);
- }
-
- return success();
-}
-
-template <>
-LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
- StoreOp store, PatternRewriter &rewriter) const {
- PatternRewriter::InsertionGuard guard(rewriter);
-
- Type inconsistentElementType =
- isElementTypeInconsistent(store.getAddr(), store.getValue().getType());
- if (!inconsistentElementType)
- return failure();
- Type firstType = getFirstSubelementType(inconsistentElementType);
- if (!firstType)
- return failure();
-
- DataLayout layout = DataLayout::closest(store);
- // Check that the first field has the right type or can at least be bitcast
- // to the right type.
- if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
- return failure();
-
- insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
-
- rewriter.modifyOpInPlace(
- store, [&]() { store.getValueMutable().assign(store.getValue()); });
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// CanonicalizeAlignedGep
//===----------------------------------------------------------------------===//
@@ -684,9 +586,6 @@ struct LLVMTypeConsistencyPass
: public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
void runOnOperation() override {
RewritePatternSet rewritePatterns(&getContext());
- rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext());
- rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
- &getContext());
rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
rewritePatterns.add<BitcastStores>(&getContext());
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 7be4056fb2fc80..6c5250d527ade8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -120,11 +120,8 @@ memref::AllocaOp::getDestructurableSlots() {
if (!destructuredType)
return {};
- DenseMap<Attribute, Type> indexMap;
- for (auto const &[index, type] : *destructuredType)
- indexMap.insert({index, MemRefType::get({}, type)});
-
- return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
+ return {
+ DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
}
DenseMap<Attribute, MemorySlot>
diff --git a/mlir/test/Dialect/LLVMIR/sroa.mlir b/mlir/test/Dialect/LLVMIR/sroa.mlir
index 02d25f27f978a6..73666afaf66b27 100644
--- a/mlir/test/Dialect/LLVMIR/sroa.mlir
+++ b/mlir/test/Dialect/LLVMIR/sroa.mlir
@@ -215,3 +215,94 @@ llvm.func @no_nested_dynamic_indexing(%arg: i32) -> i32 {
// CHECK: llvm.return %[[RES]] : i32
llvm.return %3 : i32
}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_first_field
+llvm.func @store_first_field(%arg: i32) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+ // CHECK: llvm.store %{{.*}}, %[[ALLOCA]] : i32
+ llvm.store %arg, %1 : i32, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_first_field_different_type
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @store_first_field_different_type(%arg: f32) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+ // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+ llvm.store %arg, %1 : f32, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @store_sub_field
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+llvm.func @store_sub_field(%arg: f32) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32)> : (i32) -> !llvm.ptr
+ // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] : f32
+ llvm.store %arg, %1 : f32, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_first_field
+llvm.func @load_first_field() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> i32
+ %2 = llvm.load %1 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_first_field_different_type
+llvm.func @load_first_field_different_type() -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i32
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr -> f32
+ %2 = llvm.load %1 : !llvm.ptr -> f32
+ // CHECK: llvm.return %[[RES]] : f32
+ llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @load_sub_field
+llvm.func @load_sub_field() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x i64 : (i32) -> !llvm.ptr
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.struct<(i64, i32)> : (i32) -> !llvm.ptr
+ // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
+ %res = llvm.load %1 : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[RES]] : i32
+ llvm.return %res : i32
+}
+
+// -----
+
+// CHECK-LABEL: llvm.func @vector_store_type_mismatch
+// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
+llvm.func @vector_store_type_mismatch(%arg: vector<4xi32>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x vector<4xf32>
+ %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
+ // CHECK: llvm.store %[[ARG]], %[[ALLOCA]]
+ llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 021151b929d8e2..a6176142f17463 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -26,63 +26,6 @@ llvm.func @same_address_keep_inbounds(%arg: i32) {
// -----
-// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field
-llvm.func @struct_store_instead_of_first_field(%arg: i32) {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
- // CHECK: llvm.store %{{.*}}, %[[GEP]] : i32
- llvm.store %arg, %1 : i32, !llvm.ptr
- llvm.return
-}
-
-// -----
-
-// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field_same_size
-// CHECK-SAME: (%[[ARG:.*]]: f32)
-llvm.func @struct_store_instead_of_first_field_same_size(%arg: f32) {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)>
- %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr
- // CHECK-DAG: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)>
- // CHECK-DAG: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
- // CHECK: llvm.store %[[BITCAST]], %[[GEP]] : i32
- llvm.store %arg, %1 ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks LGTM!
This commit relaxes the assumption of type consistency for LLVM dialect load and store operations in SROA. Instead, there is now a check that loads and stores are in the bounds specified by the sub-slot they access. This commit additionally removes the corresponding patterns from the type consistency pass, as they are no longer necessary.
9751e1a
to
1d720cd
Compare
) This commit relaxes the assumption of type consistency for LLVM dialect load and store operations in SROA. Instead, there is now a check that loads and stores are in the bounds specified by the sub-slot they access. This commit additionally removes the corresponding patterns from the type consistency pass, as they are no longer necessary. Note: It will be necessary to extend Mem2Reg with the logic for differently sized accesses as well. This is non-the-less a strict upgrade for productive flows, as the type consistency pass can produce invalid IR for some odd cases.
This commit relaxes the assumption of type consistency for LLVM dialect load and store operations in SROA. Instead, there is now a check that loads and stores are in the bounds specified by the sub-slot they access.
This commit additionally removes the corresponding patterns from the type consistency pass, as they are no longer necessary.
Note: It will be necessary to extend Mem2Reg with the logic for differently sized accesses as well. This is non-the-less a strict upgrade for productive flows, as the type consistency pass can produce invalid IR for some odd cases.