Skip to content

[mlir] Re-land Loosen restrictions on folding dynamic reshapes #142827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Jun 4, 2025

The original PR #137963 had a nvidia bot failure. This appears to be a flaky test because rerunning the build was successful.

This change needs #142663 to fix incorrect usage of getReassociationIndicesForCollapse

Reverts #142639

@IanWood1
Copy link
Contributor Author

IanWood1 commented Jun 4, 2025

@matthias-springer @joker-eph This looks like a flaky test. The original build with the nvidia bot failed but after rerunning the test it passed. Here are the runs:

In case the logs get deleted here's a link to the comment on the original PR reporting the failure #137963 (comment)

@IanWood1 IanWood1 marked this pull request as ready for review June 4, 2025 18:22
@llvmbot
Copy link
Member

llvmbot commented Jun 4, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Ian Wood (IanWood1)

Changes

The original PR #137963 had a nvidia bot failure. This appears to be a flaky test because rerunning the build was successful.

This change needs #142663 to fix incorrect usage of getReassociationIndicesForCollapse

Reverts llvm/llvm-project#142639


Patch is 32.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142827.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+319-53)
  • (modified) mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+35-4)
  • (modified) mlir/unittests/Dialect/Utils/CMakeLists.txt (+1)
  • (added) mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp (+203)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 1a04d702e0559..3b1fdb69e8ef1 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -10,6 +10,10 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 
 #include <numeric>
 #include <optional>
@@ -28,67 +32,329 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
   return std::nullopt;
 }
 
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
-                                         ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
-    return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+  /// FIXME: Signed type is used for consistency with ReassociationIndices.
+  /// We should consider refactoring all reassociation utilities to use unsigned
+  /// types.
+  int64_t leftIdx = 0, rightIdx = 0;
+
+  /// Util for manual checks of the range's validity
+  LogicalResult verify() const {
+    return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+  }
+
+  /// Checks range's containment within another range. Treats the edges
+  /// non-exclusively.
+  bool isInRange(const ReassociationIndexRange &outerRange) const {
+    return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+  }
+
+  unsigned size() const {
+    assert(succeeded(verify()));
+    return rightIdx - leftIdx + 1;
+  }
+  bool containsSingleIndex() const { return size() == 1; }
+
+  /// Collects indices that do not overlap between this and another range.
+  ReassociationIndices
+  getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
+    if (rightIdx < rhs.leftIdx) {
+      // The intervals do not overlap - concatenate the indices from both.
+      auto jointFullIndices = getFullIndices();
+      jointFullIndices.append(rhs.getFullIndices());
+      return jointFullIndices;
+    }
+    ReassociationIndices result;
+    // Handle the chunk left of the overlapping range.
+    int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
+    int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
+    llvm::append_range(result, llvm::seq(leftStart, leftEnd));
+    // Handle the chunk right of the overlapping range. Symmetrically, we should
+    // skip the edge of the overlap AND include the rightmost index.
+    int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
+    int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
+    if (rightStart < rightEnd)
+      llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
+    return result;
+  }
+
+  /// Converts the range into ReassociationIndices.
+  ReassociationIndices getFullIndices() const {
+    ReassociationIndices result;
+    for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
+      result.push_back(idx);
+    }
+    return result;
+  }
+};
+} // namespace
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence that can be collapsed into a dynamic dimension (at least one must
+/// be present in the source).
+/// By default, lazily returns once the first dynamic dimension has been found.
+/// Setting `matchGreedily` as `true` will also mark all subsequent
+/// source dimensions for collapsing into the target.
+static FailureOr<ReassociationIndexRange>
+findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
+                                    int64_t sourceStartIdx,
+                                    bool matchGreedily = false) {
+  const unsigned numSourceDims = sourceShape.size();
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
+
+  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+  for (; iterationRange.isInRange(sourceShapeAsRange);
+       iterationRange.rightIdx++) {
+    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+    if (sourceSize == ShapedType::kDynamic) {
+      resultRange = iterationRange;
+      break;
+    }
+  }
+  if (!resultRange)
+    return failure();
+  if (matchGreedily)
+    resultRange->rightIdx = sourceShapeAsRange.rightIdx;
+  return *resultRange;
+}
 
-  ReassociationIndices currIndices;
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence of static dimensions such that their product matches `targetSize`.
+/// By default, lazily returns once the product matches the target size. Setting
+/// `matchGreedily` as `true` will append all neighboring unit dimensions
+/// (dimensions of 1) to the match.
+static FailureOr<ReassociationIndexRange>
+findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
+                              int64_t sourceStartIdx, int64_t targetSize,
+                              bool matchGreedily = false) {
+  const unsigned numSourceDims = sourceShape.size();
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
+
+  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
   int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-    // If we have mapped all the target dimensions stop and handle the remaining
-    // tail of size-1 dimensions explicitly.
-    if (targetDim == targetShape.size())
+  while (iterationRange.isInRange(sourceShapeAsRange)) {
+    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+    if (sourceSize == ShapedType::kDynamic) {
+      // Reassociation for a static dim cannot include a dynamic dim. Reset
+      // induction variables to essentially restart the loop from the next
+      // source dimension.
+      prodOfCollapsedDims = 1;
+      iterationRange = {iterationRange.rightIdx + 1,
+                        iterationRange.rightIdx + 1};
+      continue;
+    }
+    prodOfCollapsedDims *= sourceSize;
+    // If the target size has been exceeded without matching, we need to shift
+    // the range start right. From the start of the range, roll back the
+    // multiplication until the target size exceeds the product again.
+    while (prodOfCollapsedDims > targetSize &&
+           !iterationRange.containsSingleIndex()) {
+      int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
+      prodOfCollapsedDims /= frontSourceSize;
+      // Shrink the range rightwards
+      iterationRange.leftIdx++;
+    }
+    // We could've reached the target size with the current dimension,
+    // also as a result of the above shift to right.
+    if (prodOfCollapsedDims == targetSize) {
+      resultRange = iterationRange;
       break;
+    }
+    // Increment the iteration range
+    iterationRange.rightIdx++;
+  }
+  if (!resultRange)
+    return failure();
+  if (matchGreedily) {
+    // We now want to collect all unit dimensions directly after the target
+    // product match. Advance the iterator to avoid OOB when the product match
+    // happens at the last element.
+    iterationRange.rightIdx++;
+    while (iterationRange.isInRange(sourceShapeAsRange) &&
+           sourceShape[iterationRange.rightIdx] == 1) {
+      resultRange = iterationRange;
+      iterationRange.rightIdx++;
+    }
+  }
+  return *resultRange;
+}
 
-    int64_t currTargetShape = targetShape[targetDim];
-    while (sourceDim < (sourceShape.size() - 1) &&
-           sourceShape[sourceDim] != ShapedType::kDynamic &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
-      prodOfCollapsedDims *= sourceShape[sourceDim];
-      currIndices.push_back(sourceDim++);
+/// Attempts to find a valid collapsing reassociation of `sourceShape` into
+/// `targetShape` through a simple traversal. If successful, an array of source
+/// index ranges is returned, correspondingly to each dimension in the target
+/// shape. The resulting indices shall fully cover the `sourceShape` without
+/// overlaps.
+///
+/// The algorithm is essentially a lazy one, searching for non-greedy matches -
+/// it will only yield a greedy match for the last target dimension.
+/// FIXME: The algorithm can only backtrack when it needs to append an offset
+/// for a static target dimension to the preceding dynamic one (this retains the
+/// linear complexity). As feasible, consider adding further backtracking
+/// routines to enable more reassociations, e.g.:
+/// - ?x2x?x2 into ?x2
+static FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+                                   ArrayRef<int64_t> targetShape) {
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  assert(numSourceDims > numTargetDims);
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+
+  SmallVector<ReassociationIndexRange> reassocRanges;
+  reassocRanges.reserve(numTargetDims);
+  // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
+  // cases, e.g.:
+  // - ?x2x3x5 into ?x15
+  std::optional<int64_t> prevTargetSize = std::nullopt;
+  for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
+       targetDimIdx < numTargetDims; ++targetDimIdx) {
+    int64_t targetSize = targetShape[targetDimIdx];
+    // Simply check if there are any subsequent target dimensions left - if not,
+    // the match must be made greedily.
+    bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
+    FailureOr<ReassociationIndexRange> sourceRange;
+    if (targetSize == ShapedType::kDynamic) {
+      sourceRange = findReassociationRangeForDynamicDim(
+          sourceShape, sourceDimIdx, shouldMatchGreedily);
+    } else {
+      sourceRange = findReassociationRangeForSize(
+          sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
     }
 
-    // If the current expanded dimension is dynamic, then the collapsed
-    // dimensions should also be dynamic and product of all previous unprocessed
-    // dimensions of the expanded shape should be 1.
-    if (sourceShape[sourceDim] == ShapedType::kDynamic &&
-        (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
-      return std::nullopt;
-
-    // If the collapsed dim is dynamic, the current expanded dim should also
-    // be dynamic.
-    if (currTargetShape == ShapedType::kDynamic &&
-        sourceShape[sourceDim] != ShapedType::kDynamic)
-      return std::nullopt;
-
-    // For static shapes, if the product of dimensions of the expanded shape
-    // should match the collapsed dimension shape.
-    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
-      return std::nullopt;
-
-    currIndices.push_back(sourceDim++);
-    reassociationMap.emplace_back(ReassociationIndices{});
-    std::swap(reassociationMap.back(), currIndices);
-    prodOfCollapsedDims = 1;
+    // Run sanity checks on the returned index range.
+    if (failed(sourceRange) || failed(sourceRange->verify()) ||
+        !sourceRange->isInRange(sourceShapeAsRange))
+      return failure();
+    if (sourceRange->leftIdx > sourceDimIdx) {
+      // If some source dimensions had to be skipped in order to find a match,
+      // they must be collapsed into the directly preceding dynamic dimension.
+      if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
+        return failure();
+      reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
+    }
+
+    // Store the gathered information as required for the next iteration.
+    prevTargetSize = targetSize;
+    sourceDimIdx = sourceRange->rightIdx + 1;
+    reassocRanges.push_back(*sourceRange);
   }
-  // All the dimensions in the target must have been processed.
-  if (reassociationMap.size() != targetShape.size())
+  // Fail if the source shape wasn't a full match for the target shape. We only
+  // need to check the last recorded index - any other gaps should have been
+  // mended by the main loop.
+  if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
+    return failure();
+  return reassocRanges;
+}
+
+/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
+/// the shapes right-to-left.
+static FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+                                   ArrayRef<int64_t> targetShape,
+                                   bool iterateRightToLeft) {
+  if (!iterateRightToLeft)
+    return findReassociationRangesForCollapse(sourceShape, targetShape);
+  // NB: To iterate right-to-left, we currently reverse the shapes and then
+  // reverse the result back. The reversed shapes must not be temporary, as
+  // we're passing through an ArrayRef.
+  // FIXME: It would be preferable to avoid the expensive copies. At the moment,
+  // this approach is chosen for readability of the main implementation.
+  std::vector<int64_t> sourceToReverse = sourceShape.vec(),
+                       targetToReverse = targetShape.vec();
+  std::reverse(sourceToReverse.begin(), sourceToReverse.end());
+  std::reverse(targetToReverse.begin(), targetToReverse.end());
+  auto invertedRanges =
+      findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
+  if (failed(invertedRanges))
+    return failure();
+  SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
+  unsigned numSourceDims = sourceShape.size();
+  // We have received the ranges for inverted shapes. Now we have to invert
+  // the ranges back to correspond with the original source shape.
+  for (auto &range : rangesToInvert) {
+    int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
+    range.leftIdx = numSourceDims - 1 - invRightIdx;
+    range.rightIdx = numSourceDims - 1 - invLeftIdx;
+  }
+  // Also invert the ordering of the ranges to correspond with the original
+  // target shape.
+  std::reverse(rangesToInvert.begin(), rangesToInvert.end());
+  return rangesToInvert;
+}
+
+std::optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+                                         ArrayRef<int64_t> targetShape) {
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  // We're supposed to search for a collapsing reassociation. If the sizes
+  // match, there's no actual collapsing taking place - it's either a no-op or a
+  // `tensor.reshape`-style reassociation (that would be beyond the scope of
+  // this utility).
+  if (numSourceDims <= numTargetDims)
+    return std::nullopt;
+  // Early handling for scalar target types.
+  if (numTargetDims == 0) {
+    ReassociationIndices allSourceIndices;
+    allSourceIndices.reserve(numSourceDims);
+    for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
+         ++sourceDimIdx) {
+      int64_t sourceSize = sourceShape[sourceDimIdx];
+      // All source dimensions must be unit or dynamic.
+      if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
+        return std::nullopt;
+      allSourceIndices.push_back(sourceDimIdx);
+    }
+    return SmallVector<ReassociationIndices>{allSourceIndices};
+  }
+
+  // Collect source ranges by iterating over the target shape left-to-right.
+  FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
+      findReassociationRangesForCollapse(sourceShape, targetShape);
+  if (failed(maybeForwardRanges))
+    return std::nullopt;
+  auto &ranges = *maybeForwardRanges;
+  // Now do the same in reverse. We need to get another valid reassociation
+  // through some other strategy, and then compare the results in order to
+  // disambiguate mixed subshapes, such as:
+  // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
+  // This leads us to lose some of the reassociation opportunities that can only
+  // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
+  // backtracking, the algorithm will fail right-to-left. However, this is the
+  // best way to preserve correctness.
+  FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
+      findReassociationRangesForCollapse(sourceShape, targetShape,
+                                         /*iterateRightToLeft=*/true);
+  if (failed(maybeReverseRanges))
+    return std::nullopt;
+  auto &reverseRanges = *maybeReverseRanges;
+
+  if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
     return std::nullopt;
-  // Process any remaining entries in the source shape. They all need to be
-  // 1 or dynamic.
-  for (; sourceDim < sourceShape.size(); sourceDim++) {
-    if (sourceShape[sourceDim] != ShapedType::kDynamic &&
-        sourceShape[sourceDim] != 1)
-      return std::nullopt;
-    // The map is empty when the target type is a scalar.
-    if (!reassociationMap.empty())
-      reassociationMap.back().push_back(sourceDim);
+  // Now we can check for ambiguity of each target dimension's reassociation. If
+  // successful, we put the full indices into our result map for the target
+  // shape.
+  SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
+  for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
+       ++targetDimIdx) {
+    ReassociationIndexRange &range = ranges[targetDimIdx];
+    ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
+    // Get non-overlapping indices between the ranges
+    ReassociationIndices nonMatchingIndices =
+        range.getNonOverlappingIndicesWith(reverseRange);
+    // Unit dimensions can be collapsed wherever - this is the only ambiguity
+    // that we allow.
+    for (int64_t sourceDimIdx : nonMatchingIndices) {
+      if (sourceShape[sourceDimIdx] != 1)
+        return std::nullopt;
+    }
+    reassociationMap[targetDimIdx] = range.getFullIndices();
   }
   return reassociationMap;
 }
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
 // -----
 
 // CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT:     tensor.collapse
-// CHECK:         linalg.unpack
+// CHECK:     tensor.collapse
+// CHECK-NOT:         linalg.unpack
 func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
   %c32 = arith.constant 32 : index
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0abec7e01d184..646b2197d9aa6 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1117,7 +1117,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
 
 // -----
 
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
     -> tensor<?x4x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1125,12 +1125,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
       : tensor<?x?xf32> into tensor<?x4x?xf32>
   return %1 : tensor<?x4x?xf32>
 }
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
 //   CHECK-NOT:   tensor.{{.*}}_shape
 
 // -----
 
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.coll...
[truncated]

@IanWood1
Copy link
Contributor Author

IanWood1 commented Jun 4, 2025

cc @AGindinson

@AGindinson
Copy link
Contributor

LGTM, thanks Ian for taking the lead.

@MaheshRavishankar
Copy link
Contributor

If it passes tests, id say its OK to land. If the nvidia-bot fails again, we can chalk it down to flake. I dont know who owns the nvidia-bot but they will have to help triage the error

@joker-eph
Copy link
Collaborator

joker-eph commented Jun 5, 2025

I own this bot, it is documented in the bot info on buildbot, always happy to help with a repro if there is an issue.
(it runs on GCP with a very standard ubuntu container, should not be hard)

However a flaky test is much more likely to be about the test itself: so while I can help reproducing, we need someone to take ownership of the test and the failure. We can also try a git blame of the test of the pass it is testing and find recent changes.

@MaheshRavishankar
Copy link
Contributor

Thanks @joker-eph . Do you suggest we land the "re-land" and check, or is there a way to trigger the nvidia-bot on this PR?

@AGindinson AGindinson requested a review from joker-eph June 11, 2025 21:34
@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Jun 12, 2025

I suggest landing this change

@AGindinson AGindinson merged commit 6e5a142 into main Jun 12, 2025
7 checks passed
@AGindinson AGindinson deleted the revert-142639-revert-137963-reassoc-expand-of-collapse branch June 12, 2025 08:28
@GleasonK
Copy link
Contributor

GleasonK commented Jun 13, 2025

Hello! Integrating this change into StableHLO now and am hitting some issues on tensor scalars (tensor<f32> -> tensor<1xf32> and vice versa).

For a given stablehlo.reshape %0 : tensor<f32> -> tensor<1xf32>, it seems I need to nullify the reassociation map in order to map to tensor dialect:

std::optional<SmallVector<ReassociationIndices>> reassociationMap =
            getReassociationIndicesForReshape(operandType, resultType);
// reassociationMap = {{0}}  -- used to be {}

// Generate expand operation.
if (operandType.getRank() == 0) {
  // Seems to be needed now?? Else hit:
  // error: 'tensor.expand_shape' op expected collapsed rank (0) to equal the number of reassociation maps (1).
  reassociationMap->clear();
}
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
    reshapeOp, resultType, operand, *reassociationMap);

Is this intentional? These APIs used to return {} (ref). Looking at other uses of getReassociationIndicesForReshape it looks like the intent is to be forwareded to tensor expand/collapse ops, which it seems it no longer can be for scalars. (Apologies if my understanding is entirely wrong, I've inherited this code :))

Code pointer in case its helpful: StablehloLegalizeToLinalg.cpp#L1129.

@AGindinson
Copy link
Contributor

AGindinson commented Jun 13, 2025

Hi @GleasonK, thanks for the code pointers!

I believe in my original change I'd accidentally made the map emit "collapse full dimensions into none" for scalars, and then we kind of stuck with it because it made sense at the time. The way we've had it before with an empty map feels even less correct semantically, however it does allow to distinguish between the scenarios.

Either way, this decision should be driven by the IR representation, and we're clearly failing the validation with the current approach. Looking at SHLO tests, right now an empty reassociation map is indeed employed. What should expand_shape / collapse_shape look like when they happen between tensor & scalar types? @IanWood1, could you please chime in?

I can make a quick PR to adjust the early exit logic for scalars back to what it was before. Otherwise, on the SHLO end the scalars would have to be handled separately.

@AGindinson
Copy link
Contributor

AGindinson commented Jun 13, 2025

#144118 is the draft PR. However, I'm genuinely unsure which is the more expressive representation for scalars/pseudo-tensors:

  • tensor.collapse_shape %op [] : tensor<?x?x?xf32> into tensor<f32>, or
  • tensor.collapse_shape %op [[0, 1, 2]] : tensor<?x?x?xf32> into tensor<f32>

...because the essence is similar to tensor<1x1xT> into tensor<1xT>. Obviously, I don't want to assert the "new" representation is better before proposing adjustments to the validation routine.

@IanWood1
Copy link
Contributor Author

IanWood1 commented Jun 13, 2025

I'd argue tensor.collapse_shape %op [] : tensor<?x?x?xf32> into tensor<f32> is better simply because its the status quo (not a very good reason). The current implementation ensures the invariant that size(reassociation) == rank(source) but the other ensures num_elems(reassociation) == rank(dst). I'm not sure which is more important to have.

Also, @AGindinson my bad with the incorrect test I added. I forgot an empty reassociation is the correct representation for scalar reshapes.

@AGindinson
Copy link
Contributor

AGindinson commented Jun 13, 2025

Thanks Ian, I agree the status quo aspect is important, and a potential change should be properly staged. Going forward with my PR then.

@GleasonK
Copy link
Contributor

SGTM, appreciate the quick turnaround 🙂 !
I'll land my patch in stablehlo for now and then clean it up on a future llvm bump - thanks again!

tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…m#142827)

The original PR llvm#137963 had a
nvidia bot failure. This appears to be a flaky test because rerunning
the build was successful.

This change needs commit 6f2ba47 to fix incorrect usage of
`getReassociationIndicesForCollapse`.

Reverts llvm#142639

Co-authored-by: Artem Gindinson <[email protected]>
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
…m#142827)

The original PR llvm#137963 had a
nvidia bot failure. This appears to be a flaky test because rerunning
the build was successful.

This change needs commit 6f2ba47 to fix incorrect usage of
`getReassociationIndicesForCollapse`.

Reverts llvm#142639

Co-authored-by: Artem Gindinson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants