Skip to content

[mlir][vector] ND vectors linearization pass #81159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 13, 2024

Conversation

Hardcode84
Copy link
Contributor

Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as array<array<... vector>> and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now.

@krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place.

Also, add ConversionPattern class operating on traits.

Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors as `array<array<... vector>>` and SPIR-V doesn't handle them as all at the moment.
Sometime it's prefferable to tread multidim vectors as linearized. Add pass to do this.
Only constants and simple elementwise ops are supported for now.

Also, move generic op return type utility to common place and add ConversionPattern operating on traits.
@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Ivan Butygin (Hardcode84)

Changes

Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as array&lt;array&lt;... vector&gt;&gt; and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now.

@krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place.

Also, add ConversionPattern class operating on traits.


Full diff: https://github.com/llvm/llvm-project/pull/81159.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.td (+9)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+6)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+23)
  • (modified) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (+7-13)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+122)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21)
  • (added) mlir/test/Dialect/Vector/linearize.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..32b4363be00949 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
+def VectorLinearize : Pass<"vector-linearize"> {
+  let summary = "Linearize ND vectors into 1D";
+  let description = [{
+    Linearizes ND vectors for N >= 2 into 1D vectors.
+  }];
+  let dependentDialects = ["vector::VectorDialect"];
+ }
+
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..45f54fc70e3261 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
 
 namespace mlir {
+class ConversionTarget;
 class RewritePatternSet;
+class TypeConverter;
 
 namespace arith {
 class AndIOp;
@@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+void populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..5081b4c06a617e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
   using ConversionPattern::matchAndRewrite;
 };
 
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+  OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+  OpTraitConversionPattern(const TypeConverter &typeConverter,
+                           MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+                     const TypeConverter &converter,
+                     ConversionPatternRewriter &rewriter);
+
 /// Add a pattern to the given pattern list to convert the signature of a
 /// FunctionOpInterface op with the given type converter. This only supports
 /// ops which use FunctionType to represent their type.
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   Location loc = op->getLoc();
   const TypeConverter *converter = getTypeConverter();
-  if (converter->isLegal(op))
-    return rewriter.notifyMatchFailure(loc, "op already legal");
-  OperationState newOp(loc, op->getName());
-  newOp.addOperands(operands);
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
 
-  SmallVector<Type> newResultTypes;
-  if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
-    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
-  newOp.addTypes(newResultTypes);
-  newOp.addAttributes(op->getAttrs());
-  Operation *legalized = rewriter.create(newOp);
-  SmallVector<Value> results = legalized->getResults();
-  for (auto [result, newType, origType] :
-       llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
     if (newType != origType)
       result = rewriter.create<arith::TruncFOp>(loc, origType, result);
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorEmulateMaskedLoadStore.cpp
   VectorEmulateNarrowType.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
+  VectorLinearize.cpp
   VectorTransferOpTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..7602e8c1976a9a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,122 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::vector {
+#define GEN_PASS_DEF_VECTORLINEARIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace mlir::vector
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = constOp.getLoc();
+    auto resType =
+        getTypeConverter()->convertType<VectorType>(constOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!dstElementsAttr)
+      return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+    dstElementsAttr = dstElementsAttr.reshape(resType);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+                                                   dstElementsAttr);
+    return success();
+  }
+};
+
+struct LinearizeVectorizable final
+    : OpTraitConversionPattern<OpTrait::Vectorizable> {
+  using OpTraitConversionPattern::OpTraitConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    FailureOr<Operation *> newOp =
+        convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(op, (*newOp)->getResults());
+    return success();
+  }
+};
+
+struct VectorLinearizePass final
+    : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
+  using VectorLinearizeBase::VectorLinearizeBase;
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+
+    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+                                                              patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target) {
+  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+    // Ignore scalable vectors for now.
+    if (type.getRank() <= 1 || type.isScalable())
+      return type;
+
+    return VectorType::get(type.getNumElements(), type.getElementType());
+  });
+
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+        !isa<VectorType>(type))
+      return nullptr;
+
+    return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+  };
+  typeConverter.addArgumentMaterialization(materializeCast);
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+
+  target.markUnknownOpDynamicallyLegal(
+      [&](Operation *op) -> std::optional<bool> {
+        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+          return typeConverter.isLegal(op);
+
+        return std::nullopt;
+      });
+
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+                                                         patterns.getContext());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..bfccef7cfe574b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
 };
 } // namespace
 
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+                           const TypeConverter &converter,
+                           ConversionPatternRewriter &rewriter) {
+  assert(op && "Invalid op");
+  Location loc = op->getLoc();
+  if (converter.isLegal(op))
+    return rewriter.notifyMatchFailure(loc, "op already legal");
+
+  OperationState newOp(loc, op->getName());
+  newOp.addOperands(operands);
+
+  SmallVector<Type> newResultTypes;
+  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+  newOp.addTypes(newResultTypes);
+  newOp.addAttributes(op->getAttrs());
+  return rewriter.create(newOp);
+}
+
 void mlir::populateFunctionOpInterfaceTypeConversionPattern(
     StringRef functionLikeOpName, RewritePatternSet &patterns,
     const TypeConverter &converter) {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..e0fac81199bc8d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+//       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+  %1 = math.sin %arg0 : vector<2x2xf32>
+//       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-mlir-vector

Author: Ivan Butygin (Hardcode84)

Changes

Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as array&lt;array&lt;... vector&gt;&gt; and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now.

@krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place.

Also, add ConversionPattern class operating on traits.


Full diff: https://github.com/llvm/llvm-project/pull/81159.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.td (+9)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+6)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+23)
  • (modified) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (+7-13)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+122)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21)
  • (added) mlir/test/Dialect/Vector/linearize.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..32b4363be00949 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
+def VectorLinearize : Pass<"vector-linearize"> {
+  let summary = "Linearize ND vectors into 1D";
+  let description = [{
+    Linearizes ND vectors for N >= 2 into 1D vectors.
+  }];
+  let dependentDialects = ["vector::VectorDialect"];
+ }
+
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..45f54fc70e3261 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
 
 namespace mlir {
+class ConversionTarget;
 class RewritePatternSet;
+class TypeConverter;
 
 namespace arith {
 class AndIOp;
@@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+void populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..5081b4c06a617e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
   using ConversionPattern::matchAndRewrite;
 };
 
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+  OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+  OpTraitConversionPattern(const TypeConverter &typeConverter,
+                           MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+                     const TypeConverter &converter,
+                     ConversionPatternRewriter &rewriter);
+
 /// Add a pattern to the given pattern list to convert the signature of a
 /// FunctionOpInterface op with the given type converter. This only supports
 /// ops which use FunctionType to represent their type.
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   Location loc = op->getLoc();
   const TypeConverter *converter = getTypeConverter();
-  if (converter->isLegal(op))
-    return rewriter.notifyMatchFailure(loc, "op already legal");
-  OperationState newOp(loc, op->getName());
-  newOp.addOperands(operands);
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
 
-  SmallVector<Type> newResultTypes;
-  if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
-    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
-  newOp.addTypes(newResultTypes);
-  newOp.addAttributes(op->getAttrs());
-  Operation *legalized = rewriter.create(newOp);
-  SmallVector<Value> results = legalized->getResults();
-  for (auto [result, newType, origType] :
-       llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
     if (newType != origType)
       result = rewriter.create<arith::TruncFOp>(loc, origType, result);
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorEmulateMaskedLoadStore.cpp
   VectorEmulateNarrowType.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
+  VectorLinearize.cpp
   VectorTransferOpTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..7602e8c1976a9a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,122 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::vector {
+#define GEN_PASS_DEF_VECTORLINEARIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace mlir::vector
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = constOp.getLoc();
+    auto resType =
+        getTypeConverter()->convertType<VectorType>(constOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!dstElementsAttr)
+      return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+    dstElementsAttr = dstElementsAttr.reshape(resType);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+                                                   dstElementsAttr);
+    return success();
+  }
+};
+
+struct LinearizeVectorizable final
+    : OpTraitConversionPattern<OpTrait::Vectorizable> {
+  using OpTraitConversionPattern::OpTraitConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    FailureOr<Operation *> newOp =
+        convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(op, (*newOp)->getResults());
+    return success();
+  }
+};
+
+struct VectorLinearizePass final
+    : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
+  using VectorLinearizeBase::VectorLinearizeBase;
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+
+    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+                                                              patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target) {
+  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+    // Ignore scalable vectors for now.
+    if (type.getRank() <= 1 || type.isScalable())
+      return type;
+
+    return VectorType::get(type.getNumElements(), type.getElementType());
+  });
+
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+        !isa<VectorType>(type))
+      return nullptr;
+
+    return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+  };
+  typeConverter.addArgumentMaterialization(materializeCast);
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+
+  target.markUnknownOpDynamicallyLegal(
+      [&](Operation *op) -> std::optional<bool> {
+        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+          return typeConverter.isLegal(op);
+
+        return std::nullopt;
+      });
+
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+                                                         patterns.getContext());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..bfccef7cfe574b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
 };
 } // namespace
 
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+                           const TypeConverter &converter,
+                           ConversionPatternRewriter &rewriter) {
+  assert(op && "Invalid op");
+  Location loc = op->getLoc();
+  if (converter.isLegal(op))
+    return rewriter.notifyMatchFailure(loc, "op already legal");
+
+  OperationState newOp(loc, op->getName());
+  newOp.addOperands(operands);
+
+  SmallVector<Type> newResultTypes;
+  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+  newOp.addTypes(newResultTypes);
+  newOp.addAttributes(op->getAttrs());
+  return rewriter.create(newOp);
+}
+
 void mlir::populateFunctionOpInterfaceTypeConversionPattern(
     StringRef functionLikeOpName, RewritePatternSet &patterns,
     const TypeConverter &converter) {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..e0fac81199bc8d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+//       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+  %1 = math.sin %arg0 : vector<2x2xf32>
+//       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-mlir-core

Author: Ivan Butygin (Hardcode84)

Changes

Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as array&lt;array&lt;... vector&gt;&gt; and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now.

@krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place.

Also, add ConversionPattern class operating on traits.


Full diff: https://github.com/llvm/llvm-project/pull/81159.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.td (+9)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+6)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+23)
  • (modified) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (+7-13)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+122)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21)
  • (added) mlir/test/Dialect/Vector/linearize.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..32b4363be00949 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
+def VectorLinearize : Pass<"vector-linearize"> {
+  let summary = "Linearize ND vectors into 1D";
+  let description = [{
+    Linearizes ND vectors for N >= 2 into 1D vectors.
+  }];
+  let dependentDialects = ["vector::VectorDialect"];
+ }
+
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..45f54fc70e3261 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
 
 namespace mlir {
+class ConversionTarget;
 class RewritePatternSet;
+class TypeConverter;
 
 namespace arith {
 class AndIOp;
@@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+void populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..5081b4c06a617e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
   using ConversionPattern::matchAndRewrite;
 };
 
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+  OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+  OpTraitConversionPattern(const TypeConverter &typeConverter,
+                           MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+                     const TypeConverter &converter,
+                     ConversionPatternRewriter &rewriter);
+
 /// Add a pattern to the given pattern list to convert the signature of a
 /// FunctionOpInterface op with the given type converter. This only supports
 /// ops which use FunctionType to represent their type.
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   Location loc = op->getLoc();
   const TypeConverter *converter = getTypeConverter();
-  if (converter->isLegal(op))
-    return rewriter.notifyMatchFailure(loc, "op already legal");
-  OperationState newOp(loc, op->getName());
-  newOp.addOperands(operands);
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
 
-  SmallVector<Type> newResultTypes;
-  if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
-    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
-  newOp.addTypes(newResultTypes);
-  newOp.addAttributes(op->getAttrs());
-  Operation *legalized = rewriter.create(newOp);
-  SmallVector<Value> results = legalized->getResults();
-  for (auto [result, newType, origType] :
-       llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
     if (newType != origType)
       result = rewriter.create<arith::TruncFOp>(loc, origType, result);
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorEmulateMaskedLoadStore.cpp
   VectorEmulateNarrowType.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
+  VectorLinearize.cpp
   VectorTransferOpTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..7602e8c1976a9a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,122 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::vector {
+#define GEN_PASS_DEF_VECTORLINEARIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace mlir::vector
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = constOp.getLoc();
+    auto resType =
+        getTypeConverter()->convertType<VectorType>(constOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!dstElementsAttr)
+      return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+    dstElementsAttr = dstElementsAttr.reshape(resType);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+                                                   dstElementsAttr);
+    return success();
+  }
+};
+
+struct LinearizeVectorizable final
+    : OpTraitConversionPattern<OpTrait::Vectorizable> {
+  using OpTraitConversionPattern::OpTraitConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    FailureOr<Operation *> newOp =
+        convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(op, (*newOp)->getResults());
+    return success();
+  }
+};
+
+struct VectorLinearizePass final
+    : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
+  using VectorLinearizeBase::VectorLinearizeBase;
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+
+    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+                                                              patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target) {
+  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+    // Ignore scalable vectors for now.
+    if (type.getRank() <= 1 || type.isScalable())
+      return type;
+
+    return VectorType::get(type.getNumElements(), type.getElementType());
+  });
+
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+        !isa<VectorType>(type))
+      return nullptr;
+
+    return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+  };
+  typeConverter.addArgumentMaterialization(materializeCast);
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+
+  target.markUnknownOpDynamicallyLegal(
+      [&](Operation *op) -> std::optional<bool> {
+        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+          return typeConverter.isLegal(op);
+
+        return std::nullopt;
+      });
+
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+                                                         patterns.getContext());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..bfccef7cfe574b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
 };
 } // namespace
 
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+                           const TypeConverter &converter,
+                           ConversionPatternRewriter &rewriter) {
+  assert(op && "Invalid op");
+  Location loc = op->getLoc();
+  if (converter.isLegal(op))
+    return rewriter.notifyMatchFailure(loc, "op already legal");
+
+  OperationState newOp(loc, op->getName());
+  newOp.addOperands(operands);
+
+  SmallVector<Type> newResultTypes;
+  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+  newOp.addTypes(newResultTypes);
+  newOp.addAttributes(op->getAttrs());
+  return rewriter.create(newOp);
+}
+
 void mlir::populateFunctionOpInterfaceTypeConversionPattern(
     StringRef functionLikeOpName, RewritePatternSet &patterns,
     const TypeConverter &converter) {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..e0fac81199bc8d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+//       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+  %1 = math.sin %arg0 : vector<2x2xf32>
+//       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for the contribution!

I think we should look at applying vector unrolling instead or linearization. It's currently happening at LLVM level but we have always talked about moving it to an earlier stage so that the lowering to LLVM and SPIR-V are aligned.

Collapsing some vector dimensions is definitely beneficial for some cases, esp. when the trailing dimension can't fill a full 1-D physical vector register, such as in your test (I think @hanhanW was working on something for these cases). However, I'm not sure we can apply it in a general way. For example, we couldn't turn an arbitrary vector<2x2xf32> load/store into a vector<4xf32> one because elements might not be contiguous in memory across the dimensions. Also, by linearizing an n-D vector instead of unrolling it, we are losing control on the actual vector length used and relying on the hardware backend to do what we expected, which is something we have been trying to avoid.

Would vector unrolling work for your case?

@Hardcode84
Copy link
Contributor Author

In our case underlying API/ABI (SPIR-V intrinsics) expects 1D vectors, but it interprets them as ND and we represent them (and surrounding arith ops) as ND at the intermediate steps, so we need to linearize it at some point. Also, vector unrolling and linearization are not mutually exclusive. We can have some partial unroll first and than run linearization on result (and potentially, some legalization pass after, breaking big 1D vectors into smaller ones, supported by HW). Regarding loads/stores and similar potentially non-linearizable ops, they are not covered by the patch and will stay untouched + shape_cast, but we can unroll them and reconstruct linear vector if we cannot statically determine they are safe.

I understand this transform is quite usecase-specific, so don't want to force it into default vector/llvm/spir-v pipeline.

@dcaballe
Copy link
Contributor

dcaballe commented Feb 8, 2024

Ok, got it. Would it make sense then to remove the pass in Vector/Transforms and just expose this transformation through the populate... API? We follow this approach for cases like this. You can move that pass under mlir/test/lib/Dialect/Vector for testing. In that way, the pass is not prescribing how the patterns should be used and leave that decision to the user.

@dcaballe
Copy link
Contributor

dcaballe commented Feb 8, 2024

Perhaps we can also add a callback function so that the user can decide when an op should be linearized or not. That should give us the level of flexibility that we need.

@Hardcode84
Copy link
Contributor Author

You can move that pass under mlir/test/lib/Dialect/Vector for testing. In that way, the pass is not prescribing how the patterns should be used and leave that decision to the user.

Makes sense, I'll do it.

Perhaps we can also add a callback function so that the user can decide when an op should be linearized or not. That should give us the level of flexibility that we need.

User already have some control over it by overriding target.mark(Unknown)OpDynamicallyLegal.

@rengolin
Copy link
Member

rengolin commented Feb 9, 2024

Not directly applicable, but our solution to a similar problem in memref was to linearize all but the last dimensions (for D>2).

This would keep the inner dimension intact and not make assumptions about the vector length, but would potentially break noncontiguous higher dimensions if the layout isn't changed accordingly?

But this does require the underlying user (print in out case) to change behaviour, which is not always a generic property.

Comment on lines 380 to 381
/// Linearizes ND vectors (N >= 2) into 1D.
void populateVectorLinearizeTypeConversionsAndLegality(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove Legality or describe what it means in the doc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the doc

// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// Arith and math ops are handled in generic way, check some of them
// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a few more tests? Perhaps some where the transformation shouldn't trigger?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more checks for shape_cast on the boundaries and check for return op unchanged, not sure what else to check.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants