Skip to content

[mlir][Vector] Clean up populateVectorToLLVMConversionPatterns #119975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 17, 2024

Conversation

matthias-springer
Copy link
Member

Clean up populateVectorToLLVMConversionPatterns so that it populates only conversion patterns. All rewrite patterns that do not lower to LLVM should be populated into a separate greedy pattern rewrite.

The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions.

Depends on #119973.

@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-vector

Author: Matthias Springer (matthias-springer)

Changes

Clean up populateVectorToLLVMConversionPatterns so that it populates only conversion patterns. All rewrite patterns that do not lower to LLVM should be populated into a separate greedy pattern rewrite.

The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions.

Depends on #119973.


Full diff: https://github.com/llvm/llvm-project/pull/119975.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+4)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+12)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+14-13)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+5-1)
  • (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+2-2)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (-5)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 3d643c96b45008..c507b23c6d4de6 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
                                            int64_t targetRank = 1,
                                            PatternBenefit benefit = 1);
 
+/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
+/// n > 1.
+void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1497d662dcdbdd..2fe3b1302e5e5b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -32,10 +32,12 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Error.h"
@@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 
 void GpuToLLVMConversionPass::runOnOperation() {
   MLIRContext *context = &getContext();
+
+  // Perform progressive lowering of vector transfer operations.
+  {
+    RewritePatternSet patterns(&getContext());
+    // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
+    vector::populateVectorTransferLoweringPatterns(patterns,
+                                                   /*maxTransferRank=*/1);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+
   LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
   RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9a07c323c7358..577b74bb7e0c26 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion
 
 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
 /// Non-scalable versions of this operation are handled in Vector Transforms.
-class VectorCreateMaskOpRewritePattern
-    : public OpRewritePattern<vector::CreateMaskOp> {
+class VectorCreateMaskOpConversion
+    : public OpConversionPattern<vector::CreateMaskOp> {
 public:
-  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
+  explicit VectorCreateMaskOpConversion(MLIRContext *context,
                                             bool enableIndexOpt)
-      : OpRewritePattern<vector::CreateMaskOp>(context),
+      : OpConversionPattern<vector::CreateMaskOp>(context),
         force32BitVectorIndices(enableIndexOpt) {}
 
-  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
-                                PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const override {
     auto dstType = op.getType();
     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
       return failure();
@@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern
         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
-                                                 op.getOperand(0));
+                                                 adaptor.getOperands()[0]);
     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                 indices, bounds);
@@ -1896,16 +1896,19 @@ struct VectorScalableStepOpLowering
 
 } // namespace
 
+void mlir::vector::populateVectorRankReducingFMAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
     bool reassociateFPReductions, bool force32BitVectorIndices) {
+  // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
-  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
-  populateVectorInsertExtractStridedSliceTransforms(patterns);
-  populateVectorStepLoweringPatterns(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
-  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
+  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1925,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
                VectorScalableStepOpLowering>(converter);
-  // Transfer ops with rank > 1 are handled by VectorToSCF.
-  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 64a9ad8e9bade0..2d94c2f2e85a08 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass
 
 void ConvertVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and all contraction
-  // operations. Also materializes masks, applies folding and DCE.
+  // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
+  // applies folding and DCE.
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,
                                               force32BitVectorIndices);
+    populateVectorInsertExtractStridedSliceTransforms(patterns);
+    populateVectorStepLoweringPatterns(patterns);
+    populateVectorRankReducingFMAPattern(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 44deb45cd752b4..532a2383cea9ef 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
-  func.func @func(%arg: vector<11xf32>) {
+  func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
     // CHECK: vector.mask
     // CHECK-SAME: vector.yield %arg0
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
-    return
+    return %127 : vector<11xf32>
   }
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ea88fece9e662d..f95e943250bd44 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32
 // CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
 //  CHECK-SAME:     %[[ARG:.*]]: vector<4x[8]xf32>)
 // CHECK:           %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK:           %[[T2:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
 // CHECK:           %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
 // CHECK:           %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
@@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4
   return %0 : vector<4x4x4xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 
 // -----
@@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, %
   return %0 : vector<4x4x[4]xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 
 // -----
@@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c
   return %0 : vector<4x4x4xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 
 // -----
@@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4]
   return %0 : vector<4x4x[4]xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 
 // -----

@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Clean up populateVectorToLLVMConversionPatterns so that it populates only conversion patterns. All rewrite patterns that do not lower to LLVM should be populated into a separate greedy pattern rewrite.

The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions.

Depends on #119973.


Full diff: https://github.com/llvm/llvm-project/pull/119975.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+4)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+12)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+14-13)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+5-1)
  • (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+2-2)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (-5)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 3d643c96b45008..c507b23c6d4de6 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
                                            int64_t targetRank = 1,
                                            PatternBenefit benefit = 1);
 
+/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
+/// n > 1.
+void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1497d662dcdbdd..2fe3b1302e5e5b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -32,10 +32,12 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Error.h"
@@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 
 void GpuToLLVMConversionPass::runOnOperation() {
   MLIRContext *context = &getContext();
+
+  // Perform progressive lowering of vector transfer operations.
+  {
+    RewritePatternSet patterns(&getContext());
+    // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
+    vector::populateVectorTransferLoweringPatterns(patterns,
+                                                   /*maxTransferRank=*/1);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+
   LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
   RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9a07c323c7358..577b74bb7e0c26 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion
 
 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
 /// Non-scalable versions of this operation are handled in Vector Transforms.
-class VectorCreateMaskOpRewritePattern
-    : public OpRewritePattern<vector::CreateMaskOp> {
+class VectorCreateMaskOpConversion
+    : public OpConversionPattern<vector::CreateMaskOp> {
 public:
-  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
+  explicit VectorCreateMaskOpConversion(MLIRContext *context,
                                             bool enableIndexOpt)
-      : OpRewritePattern<vector::CreateMaskOp>(context),
+      : OpConversionPattern<vector::CreateMaskOp>(context),
         force32BitVectorIndices(enableIndexOpt) {}
 
-  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
-                                PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const override {
     auto dstType = op.getType();
     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
       return failure();
@@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern
         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
-                                                 op.getOperand(0));
+                                                 adaptor.getOperands()[0]);
     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                 indices, bounds);
@@ -1896,16 +1896,19 @@ struct VectorScalableStepOpLowering
 
 } // namespace
 
+void mlir::vector::populateVectorRankReducingFMAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
     bool reassociateFPReductions, bool force32BitVectorIndices) {
+  // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
-  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
-  populateVectorInsertExtractStridedSliceTransforms(patterns);
-  populateVectorStepLoweringPatterns(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
-  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
+  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1925,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
                VectorScalableStepOpLowering>(converter);
-  // Transfer ops with rank > 1 are handled by VectorToSCF.
-  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 64a9ad8e9bade0..2d94c2f2e85a08 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass
 
 void ConvertVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and all contraction
-  // operations. Also materializes masks, applies folding and DCE.
+  // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
+  // applies folding and DCE.
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,
                                               force32BitVectorIndices);
+    populateVectorInsertExtractStridedSliceTransforms(patterns);
+    populateVectorStepLoweringPatterns(patterns);
+    populateVectorRankReducingFMAPattern(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 44deb45cd752b4..532a2383cea9ef 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
-  func.func @func(%arg: vector<11xf32>) {
+  func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
     // CHECK: vector.mask
     // CHECK-SAME: vector.yield %arg0
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
-    return
+    return %127 : vector<11xf32>
   }
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ea88fece9e662d..f95e943250bd44 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32
 // CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
 //  CHECK-SAME:     %[[ARG:.*]]: vector<4x[8]xf32>)
 // CHECK:           %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK:           %[[T2:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
 // CHECK:           %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
 // CHECK:           %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
@@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4
   return %0 : vector<4x4x4xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 
 // -----
@@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, %
   return %0 : vector<4x4x[4]xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 
 // -----
@@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c
   return %0 : vector<4x4x4xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 
 // -----
@@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4]
   return %0 : vector<4x4x[4]xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 
 // -----

Copy link

github-actions bot commented Dec 14, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The non-GPU changes LGTM. The CHECK lines removed in tests were just dead code, so thanks for the clean-up!

The GPU parts look reasonable, but it might be worth waiting a few days in case someone more experienced wants to take a look. If there are no comments, I would just land this.

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks OK to me

Base automatically changed from users/matthias-springer/mask_mat to main December 17, 2024 10:26
@matthias-springer matthias-springer force-pushed the users/matthias-springer/vec_to_llvm_conv branch from e5926b6 to 03fcf33 Compare December 17, 2024 10:28
@matthias-springer matthias-springer merged commit 0693b9e into main Dec 17, 2024
5 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/vec_to_llvm_conv branch December 17, 2024 10:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants