Skip to content

Commit 9832e1a

Browse files
[mlir][Analysis] Add alignAffineMapWithValues
This function aligns an affine map (and operands) with given dims and syms SSA values. This is useful in conjunction with `FlatAffineConstraints::addLowerOrUpperBound`, which requires the `boundMap` to be aligned with the constraint set's dims and syms. Differential Revision: https://reviews.llvm.org/D107728
1 parent e7e3585 commit 9832e1a

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

mlir/include/mlir/Analysis/AffineStructures.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,31 @@ getFlattenedAffineExprs(IntegerSet set,
694694
std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
695695
FlatAffineConstraints *cst = nullptr);
696696

697+
/// Re-indexes the dimensions and symbols of an affine map with given `operands`
698+
/// values to align with `dims` and `syms` values.
699+
///
700+
/// Each dimension/symbol of the map, bound to an operand `o`, is replaced with
701+
/// dimension `i`, where `i` is the position of `o` within `dims`. If `o` is not
702+
/// in `dims`, replace it with symbol `i`, where `i` is the position of `o`
703+
/// within `syms`. If `o` is not in `syms` either, replace it with a new symbol.
704+
///
705+
/// Note: If a value appears multiple times as a dimension/symbol (or both), all
706+
/// corresponding dim/sym expressions are replaced with the first dimension
707+
/// bound to that value (or first symbol if no such dimension exists).
708+
///
709+
/// The resulting affine map has `dims.size()` many dimensions and at least
710+
/// `syms.size()` many symbols.
711+
///
712+
/// The SSA values of the symbols of the resulting map are optionally returned
713+
/// via `newSyms`. This is a concatenation of `syms` with the SSA values of the
714+
/// newly added symbols.
715+
///
716+
/// Note: As part of this re-indexing, dimensions may turn into symbols, or vice
717+
/// versa.
718+
AffineMap alignAffineMapWithValues(AffineMap map, ValueRange operands,
719+
ValueRange dims, ValueRange syms,
720+
SmallVector<Value> *newSyms = nullptr);
721+
697722
} // end namespace mlir.
698723

699724
#endif // MLIR_ANALYSIS_AFFINESTRUCTURES_H

mlir/lib/Analysis/AffineStructures.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3274,3 +3274,48 @@ void FlatAffineConstraints::removeIndependentConstraints(unsigned pos,
32743274
for (auto nbIndex : llvm::reverse(nbEqIndices))
32753275
removeEquality(nbIndex);
32763276
}
3277+
3278+
AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
3279+
ValueRange dims, ValueRange syms,
3280+
SmallVector<Value> *newSyms) {
3281+
assert(operands.size() == map.getNumInputs() &&
3282+
"expected same number of operands and map inputs");
3283+
MLIRContext *ctx = map.getContext();
3284+
Builder builder(ctx);
3285+
SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
3286+
unsigned numSymbols = syms.size();
3287+
SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
3288+
if (newSyms) {
3289+
newSyms->clear();
3290+
newSyms->append(syms.begin(), syms.end());
3291+
}
3292+
3293+
for (auto operand : llvm::enumerate(operands)) {
3294+
// Compute replacement dim/sym of operand.
3295+
AffineExpr replacement;
3296+
auto dimIt = std::find(dims.begin(), dims.end(), operand.value());
3297+
auto symIt = std::find(syms.begin(), syms.end(), operand.value());
3298+
if (dimIt != dims.end()) {
3299+
replacement =
3300+
builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
3301+
} else if (symIt != syms.end()) {
3302+
replacement =
3303+
builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
3304+
} else {
3305+
// This operand is neither a dimension nor a symbol. Add it as a new
3306+
// symbol.
3307+
replacement = builder.getAffineSymbolExpr(numSymbols++);
3308+
if (newSyms)
3309+
newSyms->push_back(operand.value());
3310+
}
3311+
// Add to corresponding replacements vector.
3312+
if (operand.index() < map.getNumDims()) {
3313+
dimReplacements[operand.index()] = replacement;
3314+
} else {
3315+
symReplacements[operand.index() - map.getNumDims()] = replacement;
3316+
}
3317+
}
3318+
3319+
return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
3320+
dims.size(), numSymbols);
3321+
}

0 commit comments

Comments
 (0)