Skip to content

Commit 9c5c154

Browse files
committed
fixup! [mlir][Linalg] Refine how broadcast dims are treated
Addressing PR comments
1 parent 54676c3 commit 9c5c154

File tree

3 files changed

+32
-17
lines changed

3 files changed

+32
-17
lines changed

mlir/include/mlir/IR/AffineMap.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,23 @@ class AffineMap {
354354
/// returns the resulting values. `this` must be symbol-less.
355355
SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values) const;
356356

357-
size_t numOfZeroResults() const;
357+
/// Returns the number of "zero" results (constant values == 0) in this map.
358+
///
359+
/// Example:
360+
/// * For `(d0, d1) -> (d0, d1, 0)` returns 1
361+
/// * For `(d0, d1, d2) -> (d0, d1)` returns 0
362+
/// * For `(d0, d1, d2) -> (d0, 0, d1, 0, d2)` returns 2
363+
size_t getNumOfZeroResults() const;
358364

359-
AffineMap dropZeros();
365+
/// Returns the AffineMap resulting from removing "zero" results (constant
366+
/// values == 0) from this map.
367+
///
368+
/// Example:
369+
/// * For `(d0, d1) -> (d0, d1, 0)` returns `(d0, d1) -> (d0, d1)`
370+
/// * For `(d0, d1, d2) -> (d0, d1)` returns `(d0, d1, d2) -> (d0, d1)`
371+
/// * For `(d0, d1, d2) -> (d0, 0, d1, 0, d2)` returns
372+
/// `(d0, d1, d2) -> (d0, d1, d2)`
373+
AffineMap dropZeroResults();
360374

361375
/// Returns true if the AffineMap represents a subset (i.e. a projection) of a
362376
/// symbol-less permutation map. `allowZeroInResults` allows projected

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,8 @@ static AffineMap reindexIndexingMap(AffineMap map) {
476476
assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
477477
"expected projected permutation");
478478
auto res = compressUnusedDims(map);
479-
assert(res.getNumDims() == (res.getNumResults() - res.numOfZeroResults()) &&
479+
assert(res.getNumDims() ==
480+
(res.getNumResults() - res.getNumOfZeroResults()) &&
480481
"expected reindexed map with same number of dims and results");
481482
return res;
482483
}
@@ -641,7 +642,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
641642
// vector<1x16x16xty>
642643
// rather than:
643644
// vector<1x16x16x0xty>
644-
auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
645+
AffineMap opOperantMapWithoutZeros = opOperandMap.dropZeroResults();
645646
write =
646647
state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
647648

mlir/lib/IR/AffineMap.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -553,18 +553,6 @@ AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {
553553
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
554554
}
555555

556-
AffineMap AffineMap::dropZeros() {
557-
auto exprs = llvm::to_vector<4>(getResults());
558-
SmallVector<AffineExpr, 8> newExprs;
559-
560-
for (auto expr : getResults()) {
561-
auto constExpr = dyn_cast<AffineConstantExpr>(expr);
562-
if (!constExpr)
563-
newExprs.push_back(expr);
564-
}
565-
return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext());
566-
}
567-
568556
AffineMap AffineMap::compose(AffineMap map) const {
569557
assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
570558
// Prepare `map` by concatenating the symbols and rewriting its exprs.
@@ -604,7 +592,7 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
604592
return res;
605593
}
606594

607-
size_t AffineMap::numOfZeroResults() const {
595+
size_t AffineMap::getNumOfZeroResults() const {
608596
size_t res = 0;
609597
for (auto expr : getResults()) {
610598
auto constExpr = dyn_cast<AffineConstantExpr>(expr);
@@ -615,6 +603,18 @@ size_t AffineMap::numOfZeroResults() const {
615603
return res;
616604
}
617605

606+
AffineMap AffineMap::dropZeroResults() {
607+
auto exprs = llvm::to_vector(getResults());
608+
SmallVector<AffineExpr> newExprs;
609+
610+
for (auto expr : getResults()) {
611+
auto constExpr = dyn_cast<AffineConstantExpr>(expr);
612+
if (!constExpr || constExpr.getValue() != 0)
613+
newExprs.push_back(expr);
614+
}
615+
return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext());
616+
}
617+
618618
bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
619619
if (getNumSymbols() > 0)
620620
return false;

0 commit comments

Comments
 (0)