Skip to content

[MLIR][Mem2Reg][LLVM] Enhance partial load support #89094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
}],
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
"::mlir::RewriterBase &":$rewriter)
"::mlir::RewriterBase &":$rewriter,
"const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
Checks that this operation can be promoted to no longer use the provided
Expand Down Expand Up @@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
(ins "const ::mlir::MemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
"::mlir::RewriterBase &":$rewriter,
"::mlir::Value":$reachingDefinition)
"::mlir::Value":$reachingDefinition,
"const ::mlir::DataLayout &":$dataLayout)
>,
];
}
Expand Down
173 changes: 134 additions & 39 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {

bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }

Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}

Expand All @@ -122,37 +123,121 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}

/// Checks that two types are the same or can be cast into one another.
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
/// Checks if `type` can be used in any kind of conversion sequences.
static bool isSupportedTypeForConversion(Type type) {
// Aggregate types are not bitcastable.
if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
return false;

// LLVM vector types are only used for either pointers or target specific
// types. These types cannot be casted in the general case, thus the memory
// optimizations do not support them.
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
return false;

// Scalable types are not supported.
if (auto vectorType = dyn_cast<VectorType>(type))
return !vectorType.isScalable();
return true;
}

/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
/// truncations.
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
Type srcType) {
if (targetType == srcType)
return true;

if (!isSupportedTypeForConversion(targetType) ||
!isSupportedTypeForConversion(srcType))
return false;

// Pointer casts will only be sane when the bitsize of both pointer types is
// the same.
if (isa<LLVM::LLVMPointerType>(targetType) &&
isa<LLVM::LLVMPointerType>(srcType))
return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);

return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
}

/// Checks if `dataLayout` describes a little endian layout.
static bool isBigEndian(const DataLayout &dataLayout) {
auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
return endiannessStr && endiannessStr == "big";
}

/// The size of a byte in bits.
constexpr const static uint64_t kBitsInByte = 8;

/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
Value inputValue, Type targetType) {
if (inputValue.getType() == targetType)
return inputValue;

if (!isa<LLVM::LLVMPointerType>(targetType) &&
!isa<LLVM::LLVMPointerType>(inputValue.getType()))
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
Value srcValue, Type targetType,
const DataLayout &dataLayout) {
// Get the types of the source and target values.
Type srcType = srcValue.getType();
assert(areConversionCompatible(dataLayout, targetType, srcType) &&
"expected that the compatibility was checked before");

uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);

// Nothing has to be done if the types are already the same.
if (srcType == targetType)
return srcValue;

// In the special case of casting one pointer to another, we want to generate
// an address space cast. Bitcasts of pointers are not allowed and using
// pointer to integer conversions are not equivalent due to the loss of
// provenance.
if (isa<LLVM::LLVMPointerType>(targetType) &&
isa<LLVM::LLVMPointerType>(srcType))
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
srcValue);

IntegerType valueSizeInteger =
rewriter.getIntegerType(srcTypeSize * kBitsInByte);
Value replacement = srcValue;

// First, cast the value to a same-sized integer type.
if (isa<LLVM::LLVMPointerType>(srcType))
replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
replacement);
else if (replacement.getType() != valueSizeInteger)
replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
replacement);

// Truncate the integer if the size of the target is less than the value.
if (targetTypeSize != srcTypeSize) {
if (isBigEndian(dataLayout)) {
uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
replacement =
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
}

if (!isa<LLVM::LLVMPointerType>(targetType))
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
replacement = rewriter.create<LLVM::TruncOp>(
loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
replacement);
}

if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
// Now cast the integer to the actual target type if required.
if (isa<LLVM::LLVMPointerType>(targetType))
replacement =
rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
else if (replacement.getType() != targetType)
replacement =
rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);

return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
inputValue);
return replacement;
}

Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
return createConversionSequence(rewriter, getLoc(), getValue(),
slot.elemType);
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
const DataLayout &dataLayout) {
return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
dataLayout);
}

bool LLVM::LoadOp::canUsesBeRemoved(
Expand All @@ -167,17 +252,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
areConversionCompatible(dataLayout, getResult().getType(),
slot.elemType) &&
!getVolatile_();
}

DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
RewriterBase &rewriter, Value reachingDefinition,
const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
Value newResult = createConversionSequence(
rewriter, getLoc(), reachingDefinition, getResult().getType());
Value newResult =
createConversionSequence(rewriter, getLoc(), reachingDefinition,
getResult().getType(), dataLayout);
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
Expand All @@ -194,13 +282,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
areConversionCompatible(dataLayout, slot.elemType,
getValue().getType()) &&
!getVolatile_();
}

DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
RewriterBase &rewriter, Value reachingDefinition,
const DataLayout &dataLayout) {
return DeletionKind::Delete;
}

Expand Down Expand Up @@ -747,8 +837,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}

Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
Expand Down Expand Up @@ -802,7 +892,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(

DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
RewriterBase &rewriter, Value reachingDefinition,
const DataLayout &dataLayout) {
return DeletionKind::Delete;
}

Expand Down Expand Up @@ -1059,8 +1150,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}

Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}

Expand All @@ -1074,7 +1165,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(

DeletionKind LLVM::MemcpyOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
RewriterBase &rewriter, Value reachingDefinition,
const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
Expand Down Expand Up @@ -1109,7 +1201,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}

Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
RewriterBase &rewriter,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}

Expand All @@ -1123,7 +1216,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(

DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
RewriterBase &rewriter, Value reachingDefinition,
const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
Expand Down Expand Up @@ -1159,8 +1253,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}

Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}

Expand All @@ -1174,7 +1268,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(

DeletionKind LLVM::MemmoveOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
RewriterBase &rewriter, Value reachingDefinition,
const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
Expand Down
14 changes: 8 additions & 6 deletions mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {

bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }

Value memref::LoadOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}

Expand All @@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(

DeletionKind memref::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
RewriterBase &rewriter, Value reachingDefinition,
const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
Expand Down Expand Up @@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
return getMemRef() == slot.ptr;
}

Value memref::StoreOp::getStored(const MemorySlot &slot,
RewriterBase &rewriter) {
Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
const DataLayout &dataLayout) {
return getValue();
}

Expand All @@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(

DeletionKind memref::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
RewriterBase &rewriter, Value reachingDefinition,
const DataLayout &dataLayout) {
return DeletionKind::Delete;
}

Expand Down
Loading