Skip to content

[mlir][vector] Move extract_strided_slice canonicalization to folding #135676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 16, 2025

Conversation

newling
Copy link
Contributor

@newling newling commented Apr 14, 2025

Folders are preferred: https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations

Included here : some missing ----- between lit test file with mlir-opt with -split-input-file flag

@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2025

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

Changes

Folders are preferred: https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations

Included here : some missing ----- between lit test file with mlir-opt with -split-input-file flag


Full diff: https://github.com/llvm/llvm-project/pull/135676.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+58-102)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+7)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..1eb6daa402422 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3717,6 +3717,58 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
     return getVector();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
+
+  // All subsequent successful folds require a constant input.
+  Attribute foldInput = adaptor.getVector();
+  if (!foldInput)
+    return {};
+
+  // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+  if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
+    DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
+
+  // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+  if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
+    // TODO: Handle non-unit strides when they become available.
+    if (hasNonUnitStrides())
+      return {};
+
+    VectorType sourceVecTy = getSourceVectorType();
+    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+    VectorType sliceVecTy = getType();
+    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+    int64_t rank = sliceVecTy.getRank();
+
+    // Expand offsets and sizes to match the vector rank.
+    SmallVector<int64_t, 4> offsets(rank, 0);
+    copy(getI64SubArray(getOffsets()), offsets.begin());
+
+    SmallVector<int64_t, 4> sizes(sourceShape);
+    copy(getI64SubArray(getSizes()), sizes.begin());
+
+    // Calculate the slice elements by enumerating all slice positions and
+    // linearizing them. The enumeration order is lexicographic which yields a
+    // sequence of monotonically increasing linearized position indices.
+    const auto denseValuesBegin = dense.value_begin<Attribute>();
+    SmallVector<Attribute> sliceValues;
+    sliceValues.reserve(sliceVecTy.getNumElements());
+    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+    do {
+      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+      assert(linearizedPosition < sourceVecTy.getNumElements() &&
+             "Invalid index");
+      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+    } while (
+        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
+
+    assert(static_cast<int64_t>(sliceValues.size()) ==
+               sliceVecTy.getNumElements() &&
+           "Invalid number of slice elements");
+    return DenseElementsAttr::get(sliceVecTy, sliceValues);
+  }
+
   return {};
 }
 
@@ -3781,98 +3833,6 @@ class StridedSliceConstantMaskFolder final
   }
 };
 
-// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
-class StridedSliceSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-
-    auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
-                                          splat.getSplatValue<Attribute>());
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
-// ConstantOp.
-class StridedSliceNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    // The splat case is handled by `StridedSliceSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // TODO: Handle non-unit strides when they become available.
-    if (extractStridedSliceOp.hasNonUnitStrides())
-      return failure();
-
-    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
-    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
-    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
-
-    VectorType sliceVecTy = extractStridedSliceOp.getType();
-    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
-    int64_t sliceRank = sliceVecTy.getRank();
-
-    // Expand offsets and sizes to match the vector rank.
-    SmallVector<int64_t, 4> offsets(sliceRank, 0);
-    copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
-
-    SmallVector<int64_t, 4> sizes(sourceShape);
-    copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
-
-    // Calculate the slice elements by enumerating all slice positions and
-    // linearizing them. The enumeration order is lexicographic which yields a
-    // sequence of monotonically increasing linearized position indices.
-    auto denseValuesBegin = dense.value_begin<Attribute>();
-    SmallVector<Attribute> sliceValues;
-    sliceValues.reserve(sliceVecTy.getNumElements());
-    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
-    do {
-      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
-      assert(linearizedPosition < sourceVecTy.getNumElements() &&
-             "Invalid index");
-      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
-    } while (
-        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
-
-    assert(static_cast<int64_t>(sliceValues.size()) ==
-               sliceVecTy.getNumElements() &&
-           "Invalid number of slice elements");
-    auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
 // BroadcastOp(ExtractStrideSliceOp).
 class StridedSliceBroadcast final
@@ -4016,8 +3976,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
-              StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
+  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
               StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
       context);
 }
@@ -5657,10 +5616,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(constant) -> constant
   if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
-    return DenseElementsAttr::get(resultType,
-                                  splatAttr.getSplatValue<Attribute>());
-  }
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
+    return splatAttr.reshape(getType());
 
   // shape_cast(poison) -> poison
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6004,10 +5961,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
-  if (auto attr =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
-    if (attr.isSplat())
-      return attr.reshape(getResultVectorType());
+  if (auto splat =
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
+    return splat.reshape(getResultVectorType());
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..6556df22e069b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
   return %0, %2 : vector<4x8xf32>, vector<2xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_f16_to_f32
 //              bit pattern: 0x40004000
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
@@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
   return %cast0, %cast1: vector<4xf32>, vector<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_i8_to_i32
 //              bit pattern: 0xA0A0A0A0
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
@@ -1710,6 +1714,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
 }
 
 // -----
+
 // CHECK-LABEL:   func.func @vector_multi_reduction_scalable(
 // CHECK-SAME:     %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
 // CHECK-SAME:     %[[VAL_1:.*]]: vector<1x[4]xf32>,
@@ -2251,6 +2256,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
   return %0 : vector<8x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func @transpose_splat2(
 // CHECK-SAME:                           %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
 // CHECK:           %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>

@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2025

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Folders are preferred: https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations

Included here : some missing ----- between lit test file with mlir-opt with -split-input-file flag


Full diff: https://github.com/llvm/llvm-project/pull/135676.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+58-102)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+7)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..1eb6daa402422 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3717,6 +3717,58 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
     return getVector();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
+
+  // All subsequent successful folds require a constant input.
+  Attribute foldInput = adaptor.getVector();
+  if (!foldInput)
+    return {};
+
+  // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+  if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
+    DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
+
+  // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+  if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
+    // TODO: Handle non-unit strides when they become available.
+    if (hasNonUnitStrides())
+      return {};
+
+    VectorType sourceVecTy = getSourceVectorType();
+    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+    VectorType sliceVecTy = getType();
+    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+    int64_t rank = sliceVecTy.getRank();
+
+    // Expand offsets and sizes to match the vector rank.
+    SmallVector<int64_t, 4> offsets(rank, 0);
+    copy(getI64SubArray(getOffsets()), offsets.begin());
+
+    SmallVector<int64_t, 4> sizes(sourceShape);
+    copy(getI64SubArray(getSizes()), sizes.begin());
+
+    // Calculate the slice elements by enumerating all slice positions and
+    // linearizing them. The enumeration order is lexicographic which yields a
+    // sequence of monotonically increasing linearized position indices.
+    const auto denseValuesBegin = dense.value_begin<Attribute>();
+    SmallVector<Attribute> sliceValues;
+    sliceValues.reserve(sliceVecTy.getNumElements());
+    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+    do {
+      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+      assert(linearizedPosition < sourceVecTy.getNumElements() &&
+             "Invalid index");
+      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+    } while (
+        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
+
+    assert(static_cast<int64_t>(sliceValues.size()) ==
+               sliceVecTy.getNumElements() &&
+           "Invalid number of slice elements");
+    return DenseElementsAttr::get(sliceVecTy, sliceValues);
+  }
+
   return {};
 }
 
@@ -3781,98 +3833,6 @@ class StridedSliceConstantMaskFolder final
   }
 };
 
-// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
-class StridedSliceSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-
-    auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
-                                          splat.getSplatValue<Attribute>());
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
-// ConstantOp.
-class StridedSliceNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    // The splat case is handled by `StridedSliceSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // TODO: Handle non-unit strides when they become available.
-    if (extractStridedSliceOp.hasNonUnitStrides())
-      return failure();
-
-    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
-    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
-    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
-
-    VectorType sliceVecTy = extractStridedSliceOp.getType();
-    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
-    int64_t sliceRank = sliceVecTy.getRank();
-
-    // Expand offsets and sizes to match the vector rank.
-    SmallVector<int64_t, 4> offsets(sliceRank, 0);
-    copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
-
-    SmallVector<int64_t, 4> sizes(sourceShape);
-    copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
-
-    // Calculate the slice elements by enumerating all slice positions and
-    // linearizing them. The enumeration order is lexicographic which yields a
-    // sequence of monotonically increasing linearized position indices.
-    auto denseValuesBegin = dense.value_begin<Attribute>();
-    SmallVector<Attribute> sliceValues;
-    sliceValues.reserve(sliceVecTy.getNumElements());
-    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
-    do {
-      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
-      assert(linearizedPosition < sourceVecTy.getNumElements() &&
-             "Invalid index");
-      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
-    } while (
-        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
-
-    assert(static_cast<int64_t>(sliceValues.size()) ==
-               sliceVecTy.getNumElements() &&
-           "Invalid number of slice elements");
-    auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
 // BroadcastOp(ExtractStrideSliceOp).
 class StridedSliceBroadcast final
@@ -4016,8 +3976,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
-              StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
+  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
               StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
       context);
 }
@@ -5657,10 +5616,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(constant) -> constant
   if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
-    return DenseElementsAttr::get(resultType,
-                                  splatAttr.getSplatValue<Attribute>());
-  }
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
+    return splatAttr.reshape(getType());
 
   // shape_cast(poison) -> poison
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6004,10 +5961,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
-  if (auto attr =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
-    if (attr.isSplat())
-      return attr.reshape(getResultVectorType());
+  if (auto splat =
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
+    return splat.reshape(getResultVectorType());
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..6556df22e069b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
   return %0, %2 : vector<4x8xf32>, vector<2xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_f16_to_f32
 //              bit pattern: 0x40004000
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
@@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
   return %cast0, %cast1: vector<4xf32>, vector<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_i8_to_i32
 //              bit pattern: 0xA0A0A0A0
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
@@ -1710,6 +1714,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
 }
 
 // -----
+
 // CHECK-LABEL:   func.func @vector_multi_reduction_scalable(
 // CHECK-SAME:     %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
 // CHECK-SAME:     %[[VAL_1:.*]]: vector<1x[4]xf32>,
@@ -2251,6 +2256,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
   return %0 : vector<8x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func @transpose_splat2(
 // CHECK-SAME:                           %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
 // CHECK:           %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!


// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
// TODO: Handle non-unit strides when they become available.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move this long folder to a function and follow the pattern:

if (succeeded(foldXYZ(...)))
  return ...;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

return DenseElementsAttr::get(resultType,
splatAttr.getSplatValue<Attribute>());
}
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: drop llvm::?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grep -r " dyn_cast" lib/Dialect/Vector/IR/VectorOps.cpp | wc -l
26
grep -r "llvm::dyn_cast" lib/Dialect/Vector/IR/VectorOps.cpp | wc -l
49
🤷

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -3712,12 +3712,69 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
return failure();
}

namespace {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style of the rest of the file is to use static instead of an anonymous namespace

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear: It is actually the documented LLVM coding style, nothing specific to this particular file.

@qedawkins qedawkins merged commit c7fae59 into llvm:main Apr 16, 2025
11 checks passed
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants