Skip to content

[mlir][memref][NFC] Simplify constifyIndexValues #135940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 48 additions & 101 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
// Utility functions for propagating static information
//===----------------------------------------------------------------------===//

/// Helper function that infers the constant values from a list of \p values,
/// a \p memRefTy, and another helper function \p getAttributes.
/// The inferred constant values replace the related `OpFoldResult` in
/// \p values.
/// Helper function that sets values[i] to constValues[i] if the latter is a
/// static value, as indicated by ShapedType::kDynamic.
///
/// \note This function shouldn't be used directly, instead, use the
/// `getConstifiedMixedXXX` methods from the related operations.
///
/// \p getAttributes retuns a list of potentially constant values, as determined
/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
/// many elements as \p values or be empty.
///
/// E.g., consider the following example:
/// ```
/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
/// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
/// ```
/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
/// Now using this helper function with:
/// - `values == [2, %dyn_stride]`,
/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
/// `getStridesAndOffset`), and
/// - `isDynamic == ShapedType::isDynamic`
/// Will yield: `values == [2, 1]`
static void constifyIndexValues(
SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
MLIRContext *ctxt,
llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
llvm::function_ref<bool(int64_t)> isDynamic) {
SmallVector<int64_t> constValues = getAttributes(memRefTy);
Builder builder(ctxt);
for (const auto &it : llvm::enumerate(constValues)) {
int64_t constValue = it.value();
if (!isDynamic(constValue))
values[it.index()] = builder.getIndexAttr(constValue);
}
for (OpFoldResult &ofr : values) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
// FIXME: We shouldn't need to do that, but right now, the static indices
// are created with the wrong type: `i64` instead of `index`.
// As a result, if we were to keep the attribute as is, we may fail to see
// that two attributes are equal because one would have the i64 type and
// the other the index type.
// The alternative would be to create constant indices with getI64Attr in
// this and the previous loop, but it doesn't logically make sense (we are
// dealing with indices here) and would only strenghten the inconsistency
// around how static indices are created (some places use getI64Attr,
// others use getIndexAttr).
// The workaround here is to stick to the IndexAttr type for all the
// values, hence we recreate the attribute even when it is already static
// to make sure the type is consistent.
ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
/// If constValues[i] is dynamic, tries to extract a constant value from
/// value[i] to allow for additional folding opportunities. Also convertes all
/// existing attributes to index attributes. (They may be i64 attributes.)
static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
ArrayRef<int64_t> constValues) {
assert(constValues.size() == values.size() &&
"incorrect number of const values");
for (auto [i, cstVal] : llvm::enumerate(constValues)) {
Builder builder(values[i].getContext());
if (!ShapedType::isDynamic(cstVal)) {
// Constant value is known, use it directly.
values[i] = builder.getIndexAttr(cstVal);
continue;
}
std::optional<int64_t> maybeConstant =
getConstantIntValue(cast<Value>(ofr));
if (maybeConstant)
ofr = builder.getIndexAttr(*maybeConstant);
if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
// Try to extract a constant or convert an existing to index.
values[i] = builder.getIndexAttr(*cst);
}
}
}

/// Wrapper around `getShape` that conforms to the function signature
/// expected for `getAttributes` in `constifyIndexValues`.
static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
ArrayRef<int64_t> sizes = memRefTy.getShape();
return SmallVector<int64_t>(sizes);
}

/// Wrapper around `getStridesAndOffset` that returns only the offset and
/// conforms to the function signature expected for `getAttributes` in
/// `constifyIndexValues`.
static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
SmallVector<int64_t> strides;
int64_t offset;
LogicalResult hasStaticInformation =
memrefType.getStridesAndOffset(strides, offset);
if (failed(hasStaticInformation))
return SmallVector<int64_t>();
return SmallVector<int64_t>(1, offset);
}

/// Wrapper around `getStridesAndOffset` that returns only the strides and
/// conforms to the function signature expected for `getAttributes` in
/// `constifyIndexValues`.
static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
SmallVector<int64_t> strides;
int64_t offset;
LogicalResult hasStaticInformation =
memrefType.getStridesAndOffset(strides, offset);
if (failed(hasStaticInformation))
return SmallVector<int64_t>();
return strides;
}

//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,

SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
constifyIndexValues(values, getSource().getType(), getContext(),
getConstantSizes, ShapedType::isDynamic);
constifyIndexValues(values, getSource().getType().getShape());
return values;
}

SmallVector<OpFoldResult>
ExtractStridedMetadataOp::getConstifiedMixedStrides() {
SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
constifyIndexValues(values, getSource().getType(), getContext(),
getConstantStrides, ShapedType::isDynamic);
SmallVector<int64_t> staticValues;
int64_t unused;
LogicalResult status =
getSource().getType().getStridesAndOffset(staticValues, unused);
(void)status;
assert(succeeded(status) && "could not get strides from type");
constifyIndexValues(values, staticValues);
return values;
}

OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
SmallVector<OpFoldResult> values(1, offsetOfr);
constifyIndexValues(values, getSource().getType(), getContext(),
getConstantOffset, ShapedType::isDynamic);
SmallVector<int64_t> staticValues, unused;
int64_t offset;
LogicalResult status =
getSource().getType().getStridesAndOffset(unused, offset);
(void)status;
assert(succeeded(status) && "could not get offset from type");
staticValues.push_back(offset);
constifyIndexValues(values, staticValues);
return values[0];
}

Expand Down Expand Up @@ -1975,24 +1914,32 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {

SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
SmallVector<OpFoldResult> values = getMixedSizes();
constifyIndexValues(values, getType(), getContext(), getConstantSizes,
ShapedType::isDynamic);
constifyIndexValues(values, getType().getShape());
return values;
}

SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
SmallVector<OpFoldResult> values = getMixedStrides();
constifyIndexValues(values, getType(), getContext(), getConstantStrides,
ShapedType::isDynamic);
SmallVector<int64_t> staticValues;
int64_t unused;
LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
(void)status;
assert(succeeded(status) && "could not get strides from type");
constifyIndexValues(values, staticValues);
return values;
}

OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
SmallVector<OpFoldResult> values = getMixedOffsets();
assert(values.size() == 1 &&
"reinterpret_cast must have one and only one offset");
constifyIndexValues(values, getType(), getContext(), getConstantOffset,
ShapedType::isDynamic);
SmallVector<int64_t> staticValues, unused;
int64_t offset;
LogicalResult status = getType().getStridesAndOffset(unused, offset);
(void)status;
assert(succeeded(status) && "could not get offset from type");
staticValues.push_back(offset);
constifyIndexValues(values, staticValues);
return values[0];
}

Expand Down
Loading