Skip to content

Commit 4b56e2e

Browse files
[mlir][Analysis][NFC] Remove code duplication around getFlattenedAffineExprs
Remove code duplication in `addLowerOrUpperBound` and `composeMatchingMap`. Differential Revision: https://reviews.llvm.org/D107814
1 parent 9e6e081 commit 4b56e2e

File tree

2 files changed

+44
-42
lines changed

2 files changed

+44
-42
lines changed

mlir/include/mlir/Analysis/AffineStructures.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,19 @@ class FlatAffineConstraints {
602602
/// must already have a corresponding dim/symbol in this constraint system.
603603
AffineMap computeAlignedMap(AffineMap map, ValueRange operands) const;
604604

605+
/// Given an affine map that is aligned with this constraint system:
606+
/// * Flatten the map.
607+
/// * Add newly introduced local columns at the beginning of this constraint
608+
/// system (local column pos 0).
609+
/// * Add equalities that define the new local columns to this constraint
610+
/// system.
611+
/// * Return the flattened expressions via `flattenedExprs`.
612+
///
613+
/// Note: This is a shared helper function of `addLowerOrUpperBound` and
614+
/// `composeMatchingMap`.
615+
LogicalResult flattenAlignedMapAndMergeLocals(
616+
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs);
617+
605618
// Eliminates a single identifier at 'position' from equality and inequality
606619
// constraints. Returns 'success' if the identifier was eliminated, and
607620
// 'failure' otherwise.

mlir/lib/Analysis/AffineStructures.cpp

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -400,54 +400,35 @@ LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
400400
assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
401401

402402
std::vector<SmallVector<int64_t, 8>> flatExprs;
403-
FlatAffineConstraints localCst;
404-
if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
405-
LLVM_DEBUG(llvm::dbgs()
406-
<< "composition unimplemented for semi-affine maps\n");
403+
if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
407404
return failure();
408-
}
409405
assert(flatExprs.size() == other.getNumResults());
410406

411-
// Add localCst information.
412-
if (localCst.getNumLocalIds() > 0) {
413-
unsigned numLocalIds = getNumLocalIds();
414-
// Insert local dims of localCst at the beginning.
415-
for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; ++l)
416-
addLocalId(0);
417-
// Insert local dims of `this` at the end of localCst.
418-
for (unsigned l = 0; l < numLocalIds; ++l)
419-
localCst.addLocalId(localCst.getNumLocalIds());
420-
// Dimensions of localCst and this constraint set match. Append localCst to
421-
// this constraint set.
422-
append(localCst);
423-
}
424-
425407
// Add dimensions corresponding to the map's results.
426408
for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
427409
addDimId(0);
428410
}
429411

430412
// We add one equality for each result connecting the result dim of the map to
431413
// the other identifiers.
432-
// For eg: if the expression is 16*i0 + i1, and this is the r^th
414+
// E.g.: if the expression is 16*i0 + i1, and this is the r^th
433415
// iteration/result of the value map, we are adding the equality:
434-
// d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
435-
// add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
416+
// d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
417+
// add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
436418
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
437419
const auto &flatExpr = flatExprs[r];
438420
assert(flatExpr.size() >= other.getNumInputs() + 1);
439421

440-
// eqToAdd is the equality corresponding to the flattened affine expression.
441422
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
442423
// Set the coefficient for this result to one.
443424
eqToAdd[r] = 1;
444425

445426
// Dims and symbols.
446427
for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
447-
// Negate 'eq[r]' since the newly added dimension will be set to this one.
428+
// Negate `eq[r]` since the newly added dimension will be set to this one.
448429
eqToAdd[e + i] = -flatExpr[i];
449430
}
450-
// Local vars common to eq and localCst are at the beginning.
431+
// Local columns of `eq` are at the beginning.
451432
unsigned j = getNumDimIds() + getNumSymbolIds();
452433
unsigned end = flatExpr.size() - 1;
453434
for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
@@ -1872,27 +1853,14 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
18721853
}
18731854
}
18741855

1875-
LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
1876-
AffineMap boundMap,
1877-
bool eq, bool lower) {
1878-
assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
1879-
assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
1880-
assert(pos < getNumDimAndSymbolIds() && "invalid position");
1881-
1882-
// Equality follows the logic of lower bound except that we add an equality
1883-
// instead of an inequality.
1884-
assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
1885-
if (eq)
1886-
lower = true;
1887-
1888-
std::vector<SmallVector<int64_t, 8>> flatExprs;
1856+
LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals(
1857+
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
18891858
FlatAffineConstraints localCst;
1890-
if (failed(getFlattenedAffineExprs(boundMap, &flatExprs, &localCst))) {
1859+
if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
18911860
LLVM_DEBUG(llvm::dbgs()
18921861
<< "composition unimplemented for semi-affine maps\n");
18931862
return failure();
18941863
}
1895-
assert(flatExprs.size() == boundMap.getNumResults());
18961864

18971865
// Add localCst information.
18981866
if (localCst.getNumLocalIds() > 0) {
@@ -1908,6 +1876,27 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
19081876
append(localCst);
19091877
}
19101878

1879+
return success();
1880+
}
1881+
1882+
LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
1883+
AffineMap boundMap,
1884+
bool eq, bool lower) {
1885+
assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
1886+
assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
1887+
assert(pos < getNumDimAndSymbolIds() && "invalid position");
1888+
1889+
// Equality follows the logic of lower bound except that we add an equality
1890+
// instead of an inequality.
1891+
assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
1892+
if (eq)
1893+
lower = true;
1894+
1895+
std::vector<SmallVector<int64_t, 8>> flatExprs;
1896+
if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
1897+
return failure();
1898+
assert(flatExprs.size() == boundMap.getNumResults());
1899+
19111900
// Add one (in)equality for each result.
19121901
for (const auto &flatExpr : flatExprs) {
19131902
SmallVector<int64_t> ineq(getNumCols(), 0);
@@ -1921,7 +1910,7 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
19211910
if (ineq[pos] != 0)
19221911
continue;
19231912
ineq[pos] = lower ? 1 : -1;
1924-
// Local vars common to eq and localCst are at the beginning.
1913+
// Local columns of `ineq` are at the beginning.
19251914
unsigned j = getNumDimIds() + getNumSymbolIds();
19261915
unsigned end = flatExpr.size() - 1;
19271916
for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {

0 commit comments

Comments
 (0)