Skip to content

[mlir][Linalg] Refine how broadcast dims are treated #99015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 26, 2024

Conversation

banach-space
Copy link
Contributor

This PR fixes how broadcast dims (identified as "zero" results in
permutation maps) corresponding to a reduction iterator are vectorised
in the case of generic Ops. Here's an example:

  #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
  #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>

  func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) {
    %0 = tensor.empty() : tensor<1x12x197x1xf32>

    %1 = linalg.generic {indexing_maps = [#map, #map1],
                        iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
      ins(%arg0 : tensor<1x12x197x197xf32>)
      outs(%0 : tensor<1x12x197x1xf32>) {

    ^bb0(%in: f32, %out: f32):
      %818 = arith.addf %in, %out : f32
      linalg.yield %818 : f32
    } -> tensor<1x12x197x1xf32>
    return %1 : tensor<1x12x197x1xf32>
  }

This is a perfectly valid Generic Op, but currently triggers two issues
in the vectoriser. The root cause is this map:

  #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>

This map triggers an assert in reindexIndexingMap - this hook
incorrectly assumes that every result in the input map is a dim
expression and that there are no constants. That's not the case in this
example. reindexIndexingMap is extended to allow maps like the one
above. For now, only constant "zero" results are allowed. This can be
extended in the future once a good motivating example is available.

Separately, the permutation map highlighted above "breaks" mask
calculation (ATM masks are always computed, even in the presence of
static shapes). When applying the following permutation:

  (d0, d1, d2, d3) -> (d0, d1, d2, 0)

to these canonical shapes (corresponding to the example above):

  (1, 12, 197, 197)

we end up with the following error:

error: vector types must have positive constant sizes but got 1, 12, 197, 0

The error makes sense and indicates that we should update the
permutation map above to:

  (d0, d1, d2, d3) -> (d0, d1, d2)

This would correctly give the following vector type:

  vector<1x12x197xi1>

Fixes #97247

@llvmbot
Copy link
Member

llvmbot commented Jul 16, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This PR fixes how broadcast dims (identified as "zero" results in
permutation maps) corresponding to a reduction iterator are vectorised
in the case of generic Ops. Here's an example:

  #map = affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d1, d2, d3)&gt;
  #map1 = affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d1, d2, 0)&gt;

  func.func @<!-- -->generic_with_reduction_and_broadcast(%arg0: tensor&lt;1x12x197x197xf32&gt;) -&gt; (tensor&lt;1x12x197x1xf32&gt;) {
    %0 = tensor.empty() : tensor&lt;1x12x197x1xf32&gt;

    %1 = linalg.generic {indexing_maps = [#map, #map1],
                        iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
      ins(%arg0 : tensor&lt;1x12x197x197xf32&gt;)
      outs(%0 : tensor&lt;1x12x197x1xf32&gt;) {

    ^bb0(%in: f32, %out: f32):
      %818 = arith.addf %in, %out : f32
      linalg.yield %818 : f32
    } -&gt; tensor&lt;1x12x197x1xf32&gt;
    return %1 : tensor&lt;1x12x197x1xf32&gt;
  }

This is a perfectly valid Generic Op, but currently triggers two issues
in the vectoriser. The root cause is this map:

  #map1 = affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d1, d2, 0)&gt;

This map triggers an assert in reindexIndexingMap - this hook
incorrectly assumes that every result in the input map is a dim
expression and that there are no constants. That's not the case in this
example. reindexIndexingMap is extended to allow maps like the one
above. For now, only constant "zero" results are allowed. This can be
extended in the future once a good motivating example is available.

Separately, the permutation map highlighted above "breaks" mask
calculation (ATM masks are always computed, even in the presence of
static shapes). When applying the following permutation:

  (d0, d1, d2, d3) -&gt; (d0, d1, d2, 0)

to these canonical shapes (corresponding to the example above):

  (1, 12, 197, 197)

we end up with the following error:

error: vector types must have positive constant sizes but got 1, 12, 197, 0

The error makes sense and indicates that we should update the
permutation map above to:

  (d0, d1, d2, d3) -&gt; (d0, d1, d2)

This would correctly give the following vector type:

  vector&lt;1x12x197xi1&gt;

Fixes #97247


Full diff: https://github.com/llvm/llvm-project/pull/99015.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/AffineMap.h (+4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-2)
  • (modified) mlir/lib/IR/AffineMap.cpp (+23)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+40)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+45)
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 264c1c8308e78..866fe01e53665 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -346,6 +346,10 @@ class AffineMap {
   /// returns the resulting values. `this` must be symbol-less.
   SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values) const;
 
+  size_t numOfZeroResults() const;
+
+  AffineMap dropZeros();
+
   /// Returns true if the AffineMap represents a subset (i.e. a projection) of a
   /// symbol-less permutation map. `allowZeroInResults` allows projected
   /// permutation maps with constant zero result expressions.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a4c0508d0d8fa..288a05559e0b8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -476,7 +476,7 @@ static AffineMap reindexIndexingMap(AffineMap map) {
   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
          "expected projected permutation");
   auto res = compressUnusedDims(map);
-  assert(res.getNumDims() == res.getNumResults() &&
+  assert(res.getNumDims() == (res.getNumResults() - res.numOfZeroResults()) &&
          "expected reindexed map with same number of dims and results");
   return res;
 }
@@ -629,7 +629,21 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
         loc, value, outputOperand->get(), ValueRange{});
   }
 
-  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
+  // The operand map may contain "zero" results, e.g.:
+  //    (d0, d1, d2, d3) -> (d0, d1, d2, 0)
+  // When applied to canonical vector shapes like these:
+  //    (1, 16, 16, 4)
+  // we would get:
+  //    (1, 16, 16, 0)
+  // Instead, we should extract the following map:
+  //    (d0, d1, d2, d3) -> (d0, d1, d2)
+  // This way, the corresponding vector/mask type will be:
+  //    vector<1x16x16xty>
+  // rather than:
+  //    vector<1x16x16x0xty>
+  auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
+  write =
+      state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
 
   // If masked, set in-bounds to true. Masking guarantees that the access will
   // be in-bounds.
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 62f595299afe2..0d93c3ad19b0f 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -540,6 +540,18 @@ AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {
   return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
 }
 
+AffineMap AffineMap::dropZeros() {
+  auto exprs = llvm::to_vector<4>(getResults());
+  SmallVector<AffineExpr, 8> newExprs;
+
+  for (auto expr : getResults()) {
+    auto constExpr = dyn_cast<AffineConstantExpr>(expr);
+    if (!constExpr)
+      newExprs.push_back(expr);
+  }
+  return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext());
+}
+
 AffineMap AffineMap::compose(AffineMap map) const {
   assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
   // Prepare `map` by concatenating the symbols and rewriting its exprs.
@@ -579,6 +591,17 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
   return res;
 }
 
+size_t AffineMap::numOfZeroResults() const {
+  size_t res = 0;
+  for (auto expr : getResults()) {
+    auto constExpr = dyn_cast<AffineConstantExpr>(expr);
+    if (constExpr && constExpr.getValue() == 0)
+      res++;
+  }
+
+  return res;
+}
+
 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
   if (getNumSymbols() > 0)
     return false;
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index d7ff1ded9d933..bf015ef409b81 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1899,3 +1899,43 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32>
 //       CHECK:     vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32>
 //       CHECK:     vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+
+// -----
+
+// Extracted from: https://github.com/llvm/llvm-project/issues/97247
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
+
+func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) {
+  %0 = tensor.empty() : tensor<1x12x197x1xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x12x197x197xf32>) outs(%0 : tensor<1x12x197x1xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %818 = arith.addf %in, %out : f32
+    linalg.yield %818 : f32
+  } -> tensor<1x12x197x1xf32>
+  return %1 : tensor<1x12x197x1xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: #[[$ATTR_32:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @generic_with_reduction_and_broadcast(
+// CHECK-SAME:                                                    %[[VAL_0:.*]]: tensor<1x12x197x197xf32>) -> tensor<1x12x197x1xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<1x12x197x1xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : tensor<1x12x197x197xf32>, vector<1x12x197x197xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_32]]} : tensor<1x12x197x1xf32>, vector<1x12x197xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.multi_reduction <add>, %[[VAL_4]], %[[VAL_5]] [3] : vector<1x12x197x197xf32> to vector<1x12x197xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.broadcast %[[VAL_6]] : vector<1x12x197xf32> to vector<1x1x12x197xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 2, 3, 0] : vector<1x1x12x197xf32> to vector<1x12x197x1xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {in_bounds = [true, true, true, true]} : vector<1x12x197x1xf32>, tensor<1x12x197x1xf32>
+// CHECK:           return %[[VAL_9]] : tensor<1x12x197x1xf32>
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbeccc7fecd68..2464759522c0f 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -147,6 +147,51 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0, 0)>
+
+func.func @dynamic_generic_with_reduction_and_broadcast(%arg0: tensor<?x?xf32>, %init: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %0 = linalg.generic { indexing_maps = [#map, #map1],
+                        iterator_types = ["parallel", "reduction"]}
+    ins(%arg0 : tensor<?x?xf32>)
+    outs(%init : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL:   func.func @dynamic_generic_with_reduction_and_broadcast(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x4xi1>
+// CHECK:           %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_7]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x4xf32> } : vector<4x4xi1> -> vector<4x4xf32>
+// CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1>
+// CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_10]] {in_bounds = [true], permutation_map = #[[$MAP]]} : tensor<?x?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction <add>, %[[VAL_9]], %[[VAL_12]] [1] : vector<4x4xf32> to vector<4xf32> } : vector<4x4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {in_bounds = [true], permutation_map = #[[$MAP]]} : vector<4xf32>, tensor<?x?xf32> } : vector<4xi1> -> tensor<?x?xf32>
+// CHECK:           return %[[VAL_15]] : tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 4] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @vectorize_dynamic_2d_transpose(%arg0: tensor<?x?xf32>,
                                           %arg1: tensor<?x?xf32>,
                                           %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {

@llvmbot
Copy link
Member

llvmbot commented Jul 16, 2024

@llvm/pr-subscribers-mlir-core

Author: Andrzej Warzyński (banach-space)

Changes

This PR fixes how broadcast dims (identified as "zero" results in
permutation maps) corresponding to a reduction iterator are vectorised
in the case of generic Ops. Here's an example:

  #map = affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d1, d2, d3)&gt;
  #map1 = affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d1, d2, 0)&gt;

  func.func @<!-- -->generic_with_reduction_and_broadcast(%arg0: tensor&lt;1x12x197x197xf32&gt;) -&gt; (tensor&lt;1x12x197x1xf32&gt;) {
    %0 = tensor.empty() : tensor&lt;1x12x197x1xf32&gt;

    %1 = linalg.generic {indexing_maps = [#map, #map1],
                        iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
      ins(%arg0 : tensor&lt;1x12x197x197xf32&gt;)
      outs(%0 : tensor&lt;1x12x197x1xf32&gt;) {

    ^bb0(%in: f32, %out: f32):
      %818 = arith.addf %in, %out : f32
      linalg.yield %818 : f32
    } -&gt; tensor&lt;1x12x197x1xf32&gt;
    return %1 : tensor&lt;1x12x197x1xf32&gt;
  }

This is a perfectly valid Generic Op, but currently triggers two issues
in the vectoriser. The root cause is this map:

  #map1 = affine_map&lt;(d0, d1, d2, d3) -&gt; (d0, d1, d2, 0)&gt;

This map triggers an assert in reindexIndexingMap - this hook
incorrectly assumes that every result in the input map is a dim
expression and that there are no constants. That's not the case in this
example. reindexIndexingMap is extended to allow maps like the one
above. For now, only constant "zero" results are allowed. This can be
extended in the future once a good motivating example is available.

Separately, the permutation map highlighted above "breaks" mask
calculation (ATM masks are always computed, even in the presence of
static shapes). When applying the following permutation:

  (d0, d1, d2, d3) -&gt; (d0, d1, d2, 0)

to these canonical shapes (corresponding to the example above):

  (1, 12, 197, 197)

we end up with the following error:

error: vector types must have positive constant sizes but got 1, 12, 197, 0

The error makes sense and indicates that we should update the
permutation map above to:

  (d0, d1, d2, d3) -&gt; (d0, d1, d2)

This would correctly give the following vector type:

  vector&lt;1x12x197xi1&gt;

Fixes #97247


Full diff: https://github.com/llvm/llvm-project/pull/99015.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/AffineMap.h (+4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-2)
  • (modified) mlir/lib/IR/AffineMap.cpp (+23)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+40)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+45)
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 264c1c8308e78..866fe01e53665 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -346,6 +346,10 @@ class AffineMap {
   /// returns the resulting values. `this` must be symbol-less.
   SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values) const;
 
+  size_t numOfZeroResults() const;
+
+  AffineMap dropZeros();
+
   /// Returns true if the AffineMap represents a subset (i.e. a projection) of a
   /// symbol-less permutation map. `allowZeroInResults` allows projected
   /// permutation maps with constant zero result expressions.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a4c0508d0d8fa..288a05559e0b8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -476,7 +476,7 @@ static AffineMap reindexIndexingMap(AffineMap map) {
   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
          "expected projected permutation");
   auto res = compressUnusedDims(map);
-  assert(res.getNumDims() == res.getNumResults() &&
+  assert(res.getNumDims() == (res.getNumResults() - res.numOfZeroResults()) &&
          "expected reindexed map with same number of dims and results");
   return res;
 }
@@ -629,7 +629,21 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
         loc, value, outputOperand->get(), ValueRange{});
   }
 
-  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
+  // The operand map may contain "zero" results, e.g.:
+  //    (d0, d1, d2, d3) -> (d0, d1, d2, 0)
+  // When applied to canonical vector shapes like these:
+  //    (1, 16, 16, 4)
+  // we would get:
+  //    (1, 16, 16, 0)
+  // Instead, we should extract the following map:
+  //    (d0, d1, d2, d3) -> (d0, d1, d2)
+  // This way, the corresponding vector/mask type will be:
+  //    vector<1x16x16xty>
+  // rather than:
+  //    vector<1x16x16x0xty>
+  auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
+  write =
+      state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
 
   // If masked, set in-bounds to true. Masking guarantees that the access will
   // be in-bounds.
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 62f595299afe2..0d93c3ad19b0f 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -540,6 +540,18 @@ AffineMap AffineMap::dropResults(const llvm::SmallBitVector &positions) const {
   return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
 }
 
+AffineMap AffineMap::dropZeros() {
+  auto exprs = llvm::to_vector<4>(getResults());
+  SmallVector<AffineExpr, 8> newExprs;
+
+  for (auto expr : getResults()) {
+    auto constExpr = dyn_cast<AffineConstantExpr>(expr);
+    if (!constExpr)
+      newExprs.push_back(expr);
+  }
+  return AffineMap::get(getNumDims(), getNumSymbols(), newExprs, getContext());
+}
+
 AffineMap AffineMap::compose(AffineMap map) const {
   assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
   // Prepare `map` by concatenating the symbols and rewriting its exprs.
@@ -579,6 +591,17 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
   return res;
 }
 
+size_t AffineMap::numOfZeroResults() const {
+  size_t res = 0;
+  for (auto expr : getResults()) {
+    auto constExpr = dyn_cast<AffineConstantExpr>(expr);
+    if (constExpr && constExpr.getValue() == 0)
+      res++;
+  }
+
+  return res;
+}
+
 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
   if (getNumSymbols() > 0)
     return false;
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index d7ff1ded9d933..bf015ef409b81 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1899,3 +1899,43 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 0] : vector<1x4xf32> to vector<4x1xf32>
 //       CHECK:     vector.transfer_write %[[VAL_8]], %{{.*}} {in_bounds = [true, true]} : vector<4x1xf32>, tensor<4x1xf32>
 //       CHECK:     vector.transfer_write %[[VAL_7]], %{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+
+// -----
+
+// Extracted from: https://github.com/llvm/llvm-project/issues/97247
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
+
+func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) {
+  %0 = tensor.empty() : tensor<1x12x197x1xf32>
+  %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x12x197x197xf32>) outs(%0 : tensor<1x12x197x1xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %818 = arith.addf %in, %out : f32
+    linalg.yield %818 : f32
+  } -> tensor<1x12x197x1xf32>
+  return %1 : tensor<1x12x197x1xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK: #[[$ATTR_32:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @generic_with_reduction_and_broadcast(
+// CHECK-SAME:                                                    %[[VAL_0:.*]]: tensor<1x12x197x197xf32>) -> tensor<1x12x197x1xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.empty() : tensor<1x12x197x1xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true, true]} : tensor<1x12x197x197xf32>, vector<1x12x197x197xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_32]]} : tensor<1x12x197x1xf32>, vector<1x12x197xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.multi_reduction <add>, %[[VAL_4]], %[[VAL_5]] [3] : vector<1x12x197x197xf32> to vector<1x12x197xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.broadcast %[[VAL_6]] : vector<1x12x197xf32> to vector<1x1x12x197xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transpose %[[VAL_7]], [1, 2, 3, 0] : vector<1x1x12x197xf32> to vector<1x12x197x1xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {in_bounds = [true, true, true, true]} : vector<1x12x197x1xf32>, tensor<1x12x197x1xf32>
+// CHECK:           return %[[VAL_9]] : tensor<1x12x197x1xf32>
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index bbeccc7fecd68..2464759522c0f 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -147,6 +147,51 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0, 0)>
+
+func.func @dynamic_generic_with_reduction_and_broadcast(%arg0: tensor<?x?xf32>, %init: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %0 = linalg.generic { indexing_maps = [#map, #map1],
+                        iterator_types = ["parallel", "reduction"]}
+    ins(%arg0 : tensor<?x?xf32>)
+    outs(%init : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0)>
+
+// CHECK-LABEL:   func.func @dynamic_generic_with_reduction_and_broadcast(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x4xi1>
+// CHECK:           %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_7]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x4xf32> } : vector<4x4xi1> -> vector<4x4xf32>
+// CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1>
+// CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]], %[[VAL_10]] {in_bounds = [true], permutation_map = #[[$MAP]]} : tensor<?x?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction <add>, %[[VAL_9]], %[[VAL_12]] [1] : vector<4x4xf32> to vector<4xf32> } : vector<4x4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {in_bounds = [true], permutation_map = #[[$MAP]]} : vector<4xf32>, tensor<?x?xf32> } : vector<4xi1> -> tensor<?x?xf32>
+// CHECK:           return %[[VAL_15]] : tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [4, 4] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @vectorize_dynamic_2d_transpose(%arg0: tensor<?x?xf32>,
                                           %arg1: tensor<?x?xf32>,
                                           %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {

@MaheshRavishankar
Copy link
Contributor

I have not looked into details of this PR, and there is value in having the lower levels of the stack being robust, but

 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>

is not the canonical representation of broadcast IMO. This is.

 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>

So just signalling that here.

Comment on lines 349 to 351
size_t numOfZeroResults() const;

AffineMap dropZeros();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing doc comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please check if we have a utility for this already? As I mentioned in another PR, we have so many AffineMap utilities that sometimes it's not easy to find what we need :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your comment on the other PR and that immediately made me think about this :)

I couldn't find anything that would help for these specific case. However, we do have:

So, there's scope for re-use.

Here's what I suggest:

  • Introduce SmallVector<unsigned> getZeroResults()
    • It would return the positions of zero/bcast results.
  • Simplify isMinorIdentityWithBroadcasting()
    • It would use the newly introduced getZeroResults().
  • Replace dropZeros() (proposed in this PR) with a call to dropResults(getZeroResults()).

I would still have to introduce getZeroResults() (i.e. yet another hook in this file), but would avoid introducing dropZeros() :) WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is nothing that provides what you need I would just introduce what you need and try to reuse/refactor at implementation level.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! Left a few comments.

I would also revisit the commit description as this seems to be a reduction problem, not a broadcast problem? Or maybe both broadcast and reduction?

Comment on lines 349 to 351
size_t numOfZeroResults() const;

AffineMap dropZeros();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please check if we have a utility for this already? As I mentioned in another PR, we have so many AffineMap utilities that sometimes it's not easy to find what we need :)

// vector<1x16x16x0xty>
auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
write =
state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with the xfer read counterpart? Should we move this logic into maskOperation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with the xfer read counterpart?

It turns out you've already implemented that :)

// Remove zeros from indexing map to use it as masking map.
SmallVector<int64_t> zeroPos;
auto results = indexingMap.getResults();
for (const auto &result : llvm::enumerate(results)) {
if (isa<AffineConstantExpr>(result.value())) {
zeroPos.push_back(result.index());
}
}
AffineMap maskingMap = indexingMap.dropResults(zeroPos);

Should we move this logic into maskOperation?

Yes! Sending an update shortly.

@banach-space banach-space force-pushed the andrzej/fix_97247 branch 2 times, most recently from c9c9bff to 8dbb4ce Compare July 25, 2024 20:14
@banach-space
Copy link
Contributor Author

I would also revisit the commit description as this seems to be a reduction problem, not a broadcast problem? Or maybe both broadcast and reduction?

Both :) We already test:

However, there's no single test that would cover both.

Copy link

github-actions bot commented Jul 25, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space
Copy link
Contributor Author

Ping :)

This PR fixes how broadcast dims (identified as "zero" results in
permutation maps) corresponding to a reduction iterator are vectorised
in the case of generic Ops. Here's an example:

```mlir
  #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
  #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>

  func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) {
    %0 = tensor.empty() : tensor<1x12x197x1xf32>

    %1 = linalg.generic {indexing_maps = [#map, #map1],
                        iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
      ins(%arg0 : tensor<1x12x197x197xf32>)
      outs(%0 : tensor<1x12x197x1xf32>) {

    ^bb0(%in: f32, %out: f32):
      %818 = arith.addf %in, %out : f32
      linalg.yield %818 : f32
    } -> tensor<1x12x197x1xf32>
    return %1 : tensor<1x12x197x1xf32>
  }
```

This is a perfectly valid Generic Op, but currently triggers two issues
in the vectoriser. The root cause is this map:

```mlir
  #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
```

This map triggers an assert in `reindexIndexingMap` -  this hook
incorrectly assumes that every result in the input map is a `dim`
expression and that there are no constants. That's not the case in this
example. `reindexIndexingMap` is extended to allow maps like the one
above. For now, only constant "zero" results are allowed. This can be
extended in the future once a good motivating example is available.

Separately, the permutation map highlighted above "breaks" mask
calculation (ATM masks are always computed, even in the presence of
static shapes). When applying the following permutation:
```mlir
  (d0, d1, d2, d3) -> (d0, d1, d2, 0)
```

to these canonical shapes (corresponding to the example above):
```
  (1, 12, 197, 197)
```
we end up with the following error:
```bash
error: vector types must have positive constant sizes but got 1, 12, 197, 0
```

The error makes sense and indicates that we should update the
permutation map above to:
```
  (d0, d1, d2, d3) -> (d0, d1, d2)
```

This would correctly give the following vector type:
```
  vector<1x12x197xi1>
```

Fixes llvm#97247
* Move the logic to remove zero from indexing maps to `maskOperation`
* Update the input mask name in `maskOperation` to `maybeIndexingMap` -
  the actual input is always an indexing map extracted from the
  corresponding linalg Op
* Remove the duplicated comment for `maskOperation`
@banach-space
Copy link
Contributor Author

@dcaballe , could you take another look?

Comment on lines +227 to +230
/// permuted using `maybeIndexingMap`.
Operation *
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
std::optional<AffineMap> maybeMaskingMap = std::nullopt);
std::optional<AffineMap> maybeIndexingMap = std::nullopt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think indexing map would be misleading here as it's a term very tight to linalg ops and here we may apply it to arbitrary ops. Both will evolve in different ways for more complex vectorization scenarios, like those with control flow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reminding me about this!

Right now we are actually passing the indexing map whenever this hook is used - this is just to documenting the state today. So let me leave it as is.

I will send a follow-on PR in which I'll try to clarify the difference between the two maps.

}
}
AffineMap maskingMap = indexingMap.dropResults(zeroPos);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, ok, we are just moving this into maskOperation because we have other entry points where this was not executed... Got it.

@banach-space banach-space merged commit 6d11494 into llvm:main Sep 26, 2024
8 checks passed
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
This PR fixes how broadcast dims (identified as "zero" results in
permutation maps) corresponding to a reduction iterator are vectorised
in the case of generic Ops. Here's an example:

```mlir
  #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
  #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>

  func.func @generic_with_reduction_and_broadcast(%arg0: tensor<1x12x197x197xf32>) -> (tensor<1x12x197x1xf32>) {
    %0 = tensor.empty() : tensor<1x12x197x1xf32>

    %1 = linalg.generic {indexing_maps = [#map, #map1],
                        iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
      ins(%arg0 : tensor<1x12x197x197xf32>)
      outs(%0 : tensor<1x12x197x1xf32>) {

    ^bb0(%in: f32, %out: f32):
      %818 = arith.addf %in, %out : f32
      linalg.yield %818 : f32
    } -> tensor<1x12x197x1xf32>
    return %1 : tensor<1x12x197x1xf32>
  }
```

This is a perfectly valid Generic Op, but currently triggers two issues
in the vectoriser. The root cause is this map:

```mlir
  #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
```

This map triggers an assert in `reindexIndexingMap` -  this hook
incorrectly assumes that every result in the input map is a `dim`
expression and that there are no constants. That's not the case in this
example. `reindexIndexingMap` is extended to allow maps like the one
above. For now, only constant "zero" results are allowed. This can be
extended in the future once a good motivating example is available.

Separately, the permutation map highlighted above "breaks" mask
calculation (ATM masks are always computed, even in the presence of
static shapes). When applying the following permutation:
```mlir
  (d0, d1, d2, d3) -> (d0, d1, d2, 0)
```

to these canonical shapes (corresponding to the example above):
```
  (1, 12, 197, 197)
```
we end up with the following error:
```bash
error: vector types must have positive constant sizes but got 1, 12, 197, 0
```

The error makes sense and indicates that we should update the
permutation map above to:
```
  (d0, d1, d2, d3) -> (d0, d1, d2)
```

This would correctly give the following vector type:
```
  vector<1x12x197xi1>
```

Fixes llvm#97247
@banach-space banach-space deleted the andrzej/fix_97247 branch September 28, 2024 09:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR][Linalg] Vectorization Fail with reduction
6 participants