Skip to content

Commit 9ebaff1

Browse files
committed
[mlir][Linalg] Refine how broadcast dims are treated
This PR fixes how broadcast dims (identified as "zero" results in permutation maps) corresponding to a reduction iterator are vectorised in the case of generic Ops. Here's an example: ```mlir #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)> func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) { %0 = tensor.empty() : tensor<1x12x197x1xf32> %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x12x197x197xf32>) outs(%0 : tensor<1x12x197x1xf32>) { ^bb0(%in: f32, %out: f32): %818 = arith.addf %in, %out : f32 linalg.yield %818 : f32 } -> tensor<1x12x197x1xf32> return %1 : tensor<1x12x197x1xf32> } ``` This is a perfectly valid Generic Op, but currently triggers two issues in the vectoriser. The root cause is this map: ```mlir #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)> ``` This map triggers an assert in `reindexIndexingMap` - this hook incorrectly assumes that every result in the input map is a `dim` expression and that there are no constants. That's not the case in this example. `reindexIndexingMap` is extended to allow maps like the one above. For now, only constant "zero" results are allowed. This can be extended in the future once a good motivating example is available. Separately, the permutation map highlighted above "breaks" mask calculation (ATM masks are always computed, even in the presence of static shapes). When applying the following permutation: ```mlir (d0, d1, d2, d3) -> (d0, d1, d2, 0) ``` to these canonical shapes (corresponding to the example above): ``` (1, 12, 197, 197) ``` we end up with the following error: ```bash error: vector types must have positive constant sizes but got 1, 12, 197, 0 ``` The error makes sense and indicates that we should update the permutation map above to: ``` (d0, d1, d2, d3) -> (d0, d1, d2) ``` This would correctly give the following vector type: ``` vector<1x12x197xi1> ``` Fixes #97247
1 parent 2e9cbb6 commit 9ebaff1

File tree

5 files changed

+128
-2
lines changed

5 files changed

+128
-2
lines changed

mlir/include/mlir/IR/AffineMap.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,10 @@ class AffineMap {
354354
/// returns the resulting values. `this` must be symbol-less.
355355
SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values) const;
356356

357+
size_t numOfZeroResults() const;
358+
359+
AffineMap dropZeros();
360+
357361
/// Returns true if the AffineMap represents a subset (i.e. a projection) of a
358362
/// symbol-less permutation map. `allowZeroInResults` allows projected
359363
/// permutation maps with constant zero result expressions.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ static AffineMap reindexIndexingMap(AffineMap map) {
476476
assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
477477
"expected projected permutation");
478478
auto res = compressUnusedDims(map);
479-
assert(res.getNumDims() == res.getNumResults() &&
479+
assert(res.getNumDims() == (res.getNumResults() - res.numOfZeroResults()) &&
480480
"expected reindexed map with same number of dims and results");
481481
return res;
482482
}
@@ -639,7 +639,21 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
639639
loc, value, outputOperand->get(), ValueRange{});
640640
}
641641

642-
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
642+
// The operand map may contain "zero" results, e.g.:
643+
// (d0, d1, d2, d3) -> (d0, d1, d2, 0)
644+
// When applied to canonical vector shapes like these:
645+
// (1, 16, 16, 4)
646+
// we would get:
647+
// (1, 16, 16, 0)
648+
// Instead, we should extract the following map:
649+
// (d0, d1, d2, d3) -> (d0, d1, d2)
650+
// This way, the corresponding vector/mask type will be:
651+
// vector<1x16x16xty>
652+
// rather than:
653+
// vector<1x16x16x0xty>
654+
auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
655+
write =
656+
state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
643657

644658
// If masked, set in-bounds to true. Masking guarantees that the access will
645659
// be in-bounds.

mlir/lib/IR/AffineMap.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,18 @@ AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {
553553
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
554554
}
555555

556+
AffineMap AffineMap::dropZeros() {
557+
auto exprs = llvm::to_vector<4>(getResults());
558+
SmallVector<AffineExpr, 8> newExprs;
559+
560+
for (auto expr : getResults()) {
561+
auto constExpr = dyn_cast<AffineConstantExpr>(expr);
562+
if (!constExpr)
563+
newExprs.push_back(expr);
564+
}
565+
return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext());
566+
}
567+
556568
AffineMap AffineMap::compose(AffineMap map) const {
557569
assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
558570
// Prepare `map` by concatenating the symbols and rewriting its exprs.
@@ -592,6 +604,17 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
592604
return res;
593605
}
594606

607+
size_t AffineMap::numOfZeroResults() const {
608+
size_t res = 0;
609+
for (auto expr : getResults()) {
610+
auto constExpr = dyn_cast<AffineConstantExpr>(expr);
611+
if (constExpr && constExpr.getValue() == 0)
612+
res++;
613+
}
614+
615+
return res;
616+
}
617+
595618
bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
596619
if (getNumSymbols() > 0)
597620
return false;

mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,3 +1964,43 @@ module attributes {transform.with_named_sequence} {
19641964
// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32>
19651965
// CHECK: vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32>
19661966
// CHECK: vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
1967+
1968+
// -----
1969+
1970+
// Extracted from: https://github.com/llvm/llvm-project/issues/97247
1971+
1972+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1973+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
1974+
1975+
func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) {
1976+
%0 = tensor.empty() : tensor<1x12x197x1xf32>
1977+
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x12x197x197xf32>) outs(%0 : tensor<1x12x197x1xf32>) {
1978+
^bb0(%in: f32, %out: f32):
1979+
%818 = arith.addf %in, %out : f32
1980+
linalg.yield %818 : f32
1981+
} -> tensor<1x12x197x1xf32>
1982+
return %1 : tensor<1x12x197x1xf32>
1983+
}
1984+
module attributes {transform.with_named_sequence} {
1985+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1986+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1987+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
1988+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
1989+
transform.yield
1990+
}
1991+
}
1992+
1993+
// CHECK: #[[$ATTR_32:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1994+
1995+
// CHECK-LABEL: func.func @generic_with_reduction_and_broadcast(
1996+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x197x197xf32>) -> tensor<1x12x197x1xf32> {
1997+
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
1998+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
1999+
// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<1x12x197x1xf32>
2000+
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : tensor<1x12x197x197xf32>, vector<1x12x197x197xf32>
2001+
// CHECK: %[[VAL_5:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_32]]} : tensor<1x12x197x1xf32>, vector<1x12x197xf32>
2002+
// CHECK: %[[VAL_6:.*]] = vector.multi_reduction <add>, %[[VAL_4]], %[[VAL_5]] [3] : vector<1x12x197x197xf32> to vector<1x12x197xf32>
2003+
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_6]] : vector<1x12x197xf32> to vector<1x1x12x197xf32>
2004+
// CHECK: %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 2, 3, 0] : vector<1x1x12x197xf32> to vector<1x12x197x1xf32>
2005+
// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {in_bounds = [true, true, true, true]} : vector<1x12x197x1xf32>, tensor<1x12x197x1xf32>
2006+
// CHECK: return %[[VAL_9]] : tensor<1x12x197x1xf32>

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,51 @@ module attributes {transform.with_named_sequence} {
147147

148148
// -----
149149

150+
#map = affine_map<(d0, d1) -> (d0, d1)>
151+
#map1 = affine_map<(d0, d1) -> (d0, 0)>
152+
153+
func.func @dynamic_generic_with_reduction_and_broadcast(%arg0: tensor<?x?xf32>, %init: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
154+
%0 = linalg.generic { indexing_maps = [#map, #map1],
155+
iterator_types = ["parallel", "reduction"]}
156+
ins(%arg0 : tensor<?x?xf32>)
157+
outs(%init : tensor<?x?xf32>) {
158+
^bb0(%in: f32, %out: f32):
159+
%1 = arith.addf %in, %out : f32
160+
linalg.yield %1 : f32
161+
} -> tensor<?x?xf32>
162+
return %0 : tensor<?x?xf32>
163+
}
164+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0)>
165+
166+
// CHECK-LABEL: func.func @dynamic_generic_with_reduction_and_broadcast(
167+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
168+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
169+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
170+
// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32>
171+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
172+
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32>
173+
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
174+
// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
175+
// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x4xi1>
176+
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_7]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x4xf32> } : vector<4x4xi1> -> vector<4x4xf32>
177+
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
178+
// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1>
179+
// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_10]] {in_bounds = [true], permutation_map = #[[$MAP]]} : tensor<?x?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
180+
// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction <add>, %[[VAL_9]], %[[VAL_12]] [1] : vector<4x4xf32> to vector<4xf32> } : vector<4x4xi1> -> vector<4xf32>
181+
// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
182+
// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {in_bounds = [true], permutation_map = #[[$MAP]]} : vector<4xf32>, tensor<?x?xf32> } : vector<4xi1> -> tensor<?x?xf32>
183+
// CHECK: return %[[VAL_15]] : tensor<?x?xf32>
184+
185+
module attributes {transform.with_named_sequence} {
186+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
187+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
188+
transform.structured.vectorize %0 vector_sizes [4, 4] : !transform.any_op
189+
transform.yield
190+
}
191+
}
192+
193+
// -----
194+
150195
func.func @vectorize_dynamic_2d_transpose(%arg0: tensor<?x?xf32>,
151196
%arg1: tensor<?x?xf32>,
152197
%arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {

0 commit comments

Comments
 (0)