Skip to content

Commit 6d11494

Browse files
authored
[mlir][Linalg] Refine how broadcast dims are treated (#99015)
This PR fixes how broadcast dims (identified as "zero" results in permutation maps) corresponding to a reduction iterator are vectorised in the case of generic Ops. Here's an example: ```mlir #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)> func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) { %0 = tensor.empty() : tensor<1x12x197x1xf32> %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x12x197x197xf32>) outs(%0 : tensor<1x12x197x1xf32>) { ^bb0(%in: f32, %out: f32): %818 = arith.addf %in, %out : f32 linalg.yield %818 : f32 } -> tensor<1x12x197x1xf32> return %1 : tensor<1x12x197x1xf32> } ``` This is a perfectly valid Generic Op, but currently triggers two issues in the vectoriser. The root cause is this map: ```mlir #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)> ``` This map triggers an assert in `reindexIndexingMap` - this hook incorrectly assumes that every result in the input map is a `dim` expression and that there are no constants. That's not the case in this example. `reindexIndexingMap` is extended to allow maps like the one above. For now, only constant "zero" results are allowed. This can be extended in the future once a good motivating example is available. Separately, the permutation map highlighted above "breaks" mask calculation (ATM masks are always computed, even in the presence of static shapes). When applying the following permutation: ```mlir (d0, d1, d2, d3) -> (d0, d1, d2, 0) ``` to these canonical shapes (corresponding to the example above): ``` (1, 12, 197, 197) ``` we end up with the following error: ```bash error: vector types must have positive constant sizes but got 1, 12, 197, 0 ``` The error makes sense and indicates that we should update the permutation map above to: ``` (d0, d1, d2, d3) -> (d0, d1, d2) ``` This would correctly give the following vector type: ``` vector<1x12x197xi1> ``` Fixes #97247
1 parent d781df2 commit 6d11494

File tree

5 files changed

+148
-19
lines changed

5 files changed

+148
-19
lines changed

mlir/include/mlir/IR/AffineMap.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,24 @@ class AffineMap {
354354
/// returns the resulting values. `this` must be symbol-less.
355355
SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values) const;
356356

357+
/// Returns the number of "zero" results (constant values == 0) in this map.
358+
///
359+
/// Example:
360+
/// * For `(d0, d1) -> (d0, d1, 0)` returns 1
361+
/// * For `(d0, d1, d2) -> (d0, d1)` returns 0
362+
/// * For `(d0, d1, d2) -> (d0, 0, d1, 0, d2)` returns 2
363+
size_t getNumOfZeroResults() const;
364+
365+
/// Returns the AffineMap resulting from removing "zero" results (constant
366+
/// values == 0) from this map.
367+
///
368+
/// Example:
369+
/// * For `(d0, d1) -> (d0, d1, 0)` returns `(d0, d1) -> (d0, d1)`
370+
/// * For `(d0, d1, d2) -> (d0, d1)` returns `(d0, d1, d2) -> (d0, d1)`
371+
/// * For `(d0, d1, d2) -> (d0, 0, d1, 0, d2)` returns
372+
/// `(d0, d1, d2) -> (d0, d1, d2)`
373+
AffineMap dropZeroResults();
374+
357375
/// Returns true if the AffineMap represents a subset (i.e. a projection) of a
358376
/// symbol-less permutation map. `allowZeroInResults` allows projected
359377
/// permutation maps with constant zero result expressions.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,10 @@ struct VectorizationState {
224224
/// Masks an operation with the canonical vector mask if the operation needs
225225
/// masking. Returns the masked operation or the original operation if masking
226226
/// is not needed. If provided, the canonical mask for this operation is
227-
/// permuted using `maybeMaskingMap`.
227+
/// permuted using `maybeIndexingMap`.
228228
Operation *
229229
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
230-
std::optional<AffineMap> maybeMaskingMap = std::nullopt);
230+
std::optional<AffineMap> maybeIndexingMap = std::nullopt);
231231

232232
private:
233233
/// Initializes the iteration space static sizes using the Linalg op
@@ -422,16 +422,28 @@ Value VectorizationState::getOrCreateMaskFor(
422422
return mask;
423423
}
424424

425-
/// Masks an operation with the canonical vector mask if the operation needs
426-
/// masking. Returns the masked operation or the original operation if masking
427-
/// is not needed. If provided, the canonical mask for this operation is
428-
/// permuted using `maybeMaskingMap`.
429425
Operation *
430426
VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
431427
LinalgOp linalgOp,
432-
std::optional<AffineMap> maybeMaskingMap) {
428+
std::optional<AffineMap> maybeIndexingMap) {
433429
LDBG("Trying to mask: " << *opToMask << "\n");
434430

431+
std::optional<AffineMap> maybeMaskingMap = std::nullopt;
432+
// The Operand indexing map may contain "zero" results, e.g.:
433+
// (d0, d1, d2, d3) -> (d0, d1, d2, 0)
434+
// When applied to canonical vector shapes like these:
435+
// (1, 16, 16, 4)
436+
// we would get:
437+
// (1, 16, 16, 0)
438+
// Instead, we should extract the following map permutation map for masking:
439+
// (d0, d1, d2, d3) -> (d0, d1, d2)
440+
// This way, the corresponding vector/mask type will be:
441+
// vector<1x16x16xty>
442+
// rather than:
443+
// vector<1x16x16x0xty>
444+
if (maybeIndexingMap)
445+
maybeMaskingMap = maybeIndexingMap->dropZeroResults();
446+
435447
// Create or retrieve mask for this operation.
436448
Value mask =
437449
getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
@@ -476,7 +488,8 @@ static AffineMap reindexIndexingMap(AffineMap map) {
476488
assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
477489
"expected projected permutation");
478490
auto res = compressUnusedDims(map);
479-
assert(res.getNumDims() == res.getNumResults() &&
491+
assert(res.getNumDims() ==
492+
(res.getNumResults() - res.getNumOfZeroResults()) &&
480493
"expected reindexed map with same number of dims and results");
481494
return res;
482495
}
@@ -1349,16 +1362,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13491362
// permutation map and masking map.
13501363
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
13511364

1352-
// Remove zeros from indexing map to use it as masking map.
1353-
SmallVector<int64_t> zeroPos;
1354-
auto results = indexingMap.getResults();
1355-
for (const auto &result : llvm::enumerate(results)) {
1356-
if (isa<AffineConstantExpr>(result.value())) {
1357-
zeroPos.push_back(result.index());
1358-
}
1359-
}
1360-
AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1361-
13621365
AffineMap readMap;
13631366
VectorType readType;
13641367
Type elemType = getElementTypeOrSelf(opOperand->get());
@@ -1388,7 +1391,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
13881391
Operation *read = rewriter.create<vector::TransferReadOp>(
13891392
loc, readType, opOperand->get(), indices, readMap,
13901393
ArrayRef<bool>(inBounds));
1391-
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1394+
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
13921395
Value readValue = read->getResult(0);
13931396

13941397
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access

mlir/lib/IR/AffineMap.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,29 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
592592
return res;
593593
}
594594

595+
size_t AffineMap::getNumOfZeroResults() const {
596+
size_t res = 0;
597+
for (auto expr : getResults()) {
598+
auto constExpr = dyn_cast<AffineConstantExpr>(expr);
599+
if (constExpr && constExpr.getValue() == 0)
600+
res++;
601+
}
602+
603+
return res;
604+
}
605+
606+
AffineMap AffineMap::dropZeroResults() {
607+
auto exprs = llvm::to_vector(getResults());
608+
SmallVector<AffineExpr> newExprs;
609+
610+
for (auto expr : getResults()) {
611+
auto constExpr = dyn_cast<AffineConstantExpr>(expr);
612+
if (!constExpr || constExpr.getValue() != 0)
613+
newExprs.push_back(expr);
614+
}
615+
return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext());
616+
}
617+
595618
bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
596619
if (getNumSymbols() > 0)
597620
return false;

mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,3 +1964,43 @@ module attributes {transform.with_named_sequence} {
19641964
// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32>
19651965
// CHECK: vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32>
19661966
// CHECK: vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
1967+
1968+
// -----
1969+
1970+
// Extracted from: https://github.com/llvm/llvm-project/issues/97247
1971+
1972+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1973+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
1974+
1975+
func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) {
1976+
%0 = tensor.empty() : tensor<1x12x197x1xf32>
1977+
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x12x197x197xf32>) outs(%0 : tensor<1x12x197x1xf32>) {
1978+
^bb0(%in: f32, %out: f32):
1979+
%818 = arith.addf %in, %out : f32
1980+
linalg.yield %818 : f32
1981+
} -> tensor<1x12x197x1xf32>
1982+
return %1 : tensor<1x12x197x1xf32>
1983+
}
1984+
module attributes {transform.with_named_sequence} {
1985+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1986+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1987+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1988+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
1989+
transform.yield
1990+
}
1991+
}
1992+
1993+
// CHECK: #[[$ATTR_32:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1994+
1995+
// CHECK-LABEL: func.func @generic_with_reduction_and_broadcast(
1996+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x197x197xf32>) -> tensor<1x12x197x1xf32> {
1997+
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
1998+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
1999+
// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<1x12x197x1xf32>
2000+
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : tensor<1x12x197x197xf32>, vector<1x12x197x197xf32>
2001+
// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_32]]} : tensor<1x12x197x1xf32>, vector<1x12x197xf32>
2002+
// CHECK: %[[VAL_6:.*]] = vector.multi_reduction <add>, %[[VAL_4]], %[[VAL_5]] [3] : vector<1x12x197x197xf32> to vector<1x12x197xf32>
2003+
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_6]] : vector<1x12x197xf32> to vector<1x1x12x197xf32>
2004+
// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 2, 3, 0] : vector<1x1x12x197xf32> to vector<1x12x197x1xf32>
2005+
// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {in_bounds = [true, true, true, true]} : vector<1x12x197x1xf32>, tensor<1x12x197x1xf32>
2006+
// CHECK: return %[[VAL_9]] : tensor<1x12x197x1xf32>

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,51 @@ module attributes {transform.with_named_sequence} {
147147

148148
// -----
149149

150+
#map = affine_map<(d0, d1) -> (d0, d1)>
151+
#map1 = affine_map<(d0, d1) -> (d0, 0)>
152+
153+
func.func @dynamic_generic_with_reduction_and_broadcast(%arg0: tensor<?x?xf32>, %init: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
154+
%0 = linalg.generic { indexing_maps = [#map, #map1],
155+
iterator_types = ["parallel", "reduction"]}
156+
ins(%arg0 : tensor<?x?xf32>)
157+
outs(%init : tensor<?x?xf32>) {
158+
^bb0(%in: f32, %out: f32):
159+
%1 = arith.addf %in, %out : f32
160+
linalg.yield %1 : f32
161+
} -> tensor<?x?xf32>
162+
return %0 : tensor<?x?xf32>
163+
}
164+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0)>
165+
166+
// CHECK-LABEL: func.func @dynamic_generic_with_reduction_and_broadcast(
167+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
168+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
169+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
170+
// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32>
171+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
172+
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32>
173+
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
174+
// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
175+
// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x4xi1>
176+
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_7]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x4xf32> } : vector<4x4xi1> -> vector<4x4xf32>
177+
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
178+
// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1>
179+
// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_10]] {in_bounds = [true], permutation_map = #[[$MAP]]} : tensor<?x?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
180+
// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction <add>, %[[VAL_9]], %[[VAL_12]] [1] : vector<4x4xf32> to vector<4xf32> } : vector<4x4xi1> -> vector<4xf32>
181+
// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
182+
// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {in_bounds = [true], permutation_map = #[[$MAP]]} : vector<4xf32>, tensor<?x?xf32> } : vector<4xi1> -> tensor<?x?xf32>
183+
// CHECK: return %[[VAL_15]] : tensor<?x?xf32>
184+
185+
module attributes {transform.with_named_sequence} {
186+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
187+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
188+
transform.structured.vectorize %0 vector_sizes [4, 4] : !transform.any_op
189+
transform.yield
190+
}
191+
}
192+
193+
// -----
194+
150195
func.func @vectorize_dynamic_2d_transpose(%arg0: tensor<?x?xf32>,
151196
%arg1: tensor<?x?xf32>,
152197
%arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {

0 commit comments

Comments
 (0)