-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] ND vectors linearization pass #81159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors as `array<array<... vector>>` and SPIR-V doesn't handle them as all at the moment. Sometime it's prefferable to tread multidim vectors as linearized. Add pass to do this. Only constants and simple elementwise ops are supported for now. Also, move generic op return type utility to common place and add ConversionPattern operating on traits.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-math Author: Ivan Butygin (Hardcode84) ChangesCommon backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as @krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place. Also, add ConversionPattern class operating on traits. Full diff: https://github.com/llvm/llvm-project/pull/81159.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..32b4363be00949 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
let constructor = "mlir::vector::createLowerVectorMaskPass()";
}
+def VectorLinearize : Pass<"vector-linearize"> {
+ let summary = "Linearize ND vectors into 1D";
+ let description = [{
+ Linearizes ND vectors for N >= 2 into 1D vectors.
+ }];
+ let dependentDialects = ["vector::VectorDialect"];
+ }
+
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..45f54fc70e3261 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
namespace mlir {
+class ConversionTarget;
class RewritePatternSet;
+class TypeConverter;
namespace arith {
class AndIOp;
@@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+void populateVectorLinearizeTypeConversionsAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..5081b4c06a617e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
using ConversionPattern::matchAndRewrite;
};
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+ OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+ : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+ TypeID::get<TraitType>(), benefit, context) {}
+ OpTraitConversionPattern(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+ TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+ const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter);
+
/// Add a pattern to the given pattern list to convert the signature of a
/// FunctionOpInterface op with the given type converter. This only supports
/// ops which use FunctionType to represent their type.
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
const TypeConverter *converter = getTypeConverter();
- if (converter->isLegal(op))
- return rewriter.notifyMatchFailure(loc, "op already legal");
- OperationState newOp(loc, op->getName());
- newOp.addOperands(operands);
+ FailureOr<Operation *> legalized =
+ convertOpResultTypes(op, operands, *converter, rewriter);
+ if (failed(legalized))
+ return failure();
- SmallVector<Type> newResultTypes;
- if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
- return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
- newOp.addTypes(newResultTypes);
- newOp.addAttributes(op->getAttrs());
- Operation *legalized = rewriter.create(newOp);
- SmallVector<Value> results = legalized->getResults();
- for (auto [result, newType, origType] :
- llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+ SmallVector<Value> results = (*legalized)->getResults();
+ for (auto [result, newType, origType] : llvm::zip_equal(
+ results, (*legalized)->getResultTypes(), op->getResultTypes())) {
if (newType != origType)
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorEmulateMaskedLoadStore.cpp
VectorEmulateNarrowType.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
+ VectorLinearize.cpp
VectorTransferOpTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..7602e8c1976a9a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,122 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::vector {
+#define GEN_PASS_DEF_VECTORLINEARIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace mlir::vector
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = constOp.getLoc();
+ auto resType =
+ getTypeConverter()->convertType<VectorType>(constOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+ auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!dstElementsAttr)
+ return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+ dstElementsAttr = dstElementsAttr.reshape(resType);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+ dstElementsAttr);
+ return success();
+ }
+};
+
+struct LinearizeVectorizable final
+ : OpTraitConversionPattern<OpTrait::Vectorizable> {
+ using OpTraitConversionPattern::OpTraitConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ FailureOr<Operation *> newOp =
+ convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+ if (failed(newOp))
+ return failure();
+
+ rewriter.replaceOp(op, (*newOp)->getResults());
+ return success();
+ }
+};
+
+struct VectorLinearizePass final
+ : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
+ using VectorLinearizeBase::VectorLinearizeBase;
+
+ void runOnOperation() override {
+ auto *context = &getContext();
+
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(context);
+ ConversionTarget target(*context);
+
+ vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target) {
+ typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+ // Ignore scalable vectors for now.
+ if (type.getRank() <= 1 || type.isScalable())
+ return type;
+
+ return VectorType::get(type.getNumElements(), type.getElementType());
+ });
+
+ auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+ !isa<VectorType>(type))
+ return nullptr;
+
+ return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+ };
+ typeConverter.addArgumentMaterialization(materializeCast);
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ target.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) -> std::optional<bool> {
+ if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+ return typeConverter.isLegal(op);
+
+ return std::nullopt;
+ });
+
+ patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+ patterns.getContext());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..bfccef7cfe574b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
};
} // namespace
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+ const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter) {
+ assert(op && "Invalid op");
+ Location loc = op->getLoc();
+ if (converter.isLegal(op))
+ return rewriter.notifyMatchFailure(loc, "op already legal");
+
+ OperationState newOp(loc, op->getName());
+ newOp.addOperands(operands);
+
+ SmallVector<Type> newResultTypes;
+ if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+ newOp.addTypes(newResultTypes);
+ newOp.addAttributes(op->getAttrs());
+ return rewriter.create(newOp);
+}
+
void mlir::populateFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, RewritePatternSet &patterns,
const TypeConverter &converter) {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..e0fac81199bc8d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
+ %1 = math.sin %arg0 : vector<2x2xf32>
+// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+ %2 = arith.addf %arg0, %0 : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
|
@llvm/pr-subscribers-mlir-vector Author: Ivan Butygin (Hardcode84) ChangesCommon backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as @krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place. Also, add ConversionPattern class operating on traits. Full diff: https://github.com/llvm/llvm-project/pull/81159.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..32b4363be00949 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
let constructor = "mlir::vector::createLowerVectorMaskPass()";
}
+def VectorLinearize : Pass<"vector-linearize"> {
+ let summary = "Linearize ND vectors into 1D";
+ let description = [{
+ Linearizes ND vectors for N >= 2 into 1D vectors.
+ }];
+ let dependentDialects = ["vector::VectorDialect"];
+ }
+
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..45f54fc70e3261 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
namespace mlir {
+class ConversionTarget;
class RewritePatternSet;
+class TypeConverter;
namespace arith {
class AndIOp;
@@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+void populateVectorLinearizeTypeConversionsAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..5081b4c06a617e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
using ConversionPattern::matchAndRewrite;
};
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+ OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+ : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+ TypeID::get<TraitType>(), benefit, context) {}
+ OpTraitConversionPattern(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+ TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+ const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter);
+
/// Add a pattern to the given pattern list to convert the signature of a
/// FunctionOpInterface op with the given type converter. This only supports
/// ops which use FunctionType to represent their type.
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
const TypeConverter *converter = getTypeConverter();
- if (converter->isLegal(op))
- return rewriter.notifyMatchFailure(loc, "op already legal");
- OperationState newOp(loc, op->getName());
- newOp.addOperands(operands);
+ FailureOr<Operation *> legalized =
+ convertOpResultTypes(op, operands, *converter, rewriter);
+ if (failed(legalized))
+ return failure();
- SmallVector<Type> newResultTypes;
- if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
- return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
- newOp.addTypes(newResultTypes);
- newOp.addAttributes(op->getAttrs());
- Operation *legalized = rewriter.create(newOp);
- SmallVector<Value> results = legalized->getResults();
- for (auto [result, newType, origType] :
- llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+ SmallVector<Value> results = (*legalized)->getResults();
+ for (auto [result, newType, origType] : llvm::zip_equal(
+ results, (*legalized)->getResultTypes(), op->getResultTypes())) {
if (newType != origType)
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorEmulateMaskedLoadStore.cpp
VectorEmulateNarrowType.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
+ VectorLinearize.cpp
VectorTransferOpTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..7602e8c1976a9a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,122 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::vector {
+#define GEN_PASS_DEF_VECTORLINEARIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace mlir::vector
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = constOp.getLoc();
+ auto resType =
+ getTypeConverter()->convertType<VectorType>(constOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+ auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!dstElementsAttr)
+ return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+ dstElementsAttr = dstElementsAttr.reshape(resType);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+ dstElementsAttr);
+ return success();
+ }
+};
+
+struct LinearizeVectorizable final
+ : OpTraitConversionPattern<OpTrait::Vectorizable> {
+ using OpTraitConversionPattern::OpTraitConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ FailureOr<Operation *> newOp =
+ convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+ if (failed(newOp))
+ return failure();
+
+ rewriter.replaceOp(op, (*newOp)->getResults());
+ return success();
+ }
+};
+
+struct VectorLinearizePass final
+ : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
+ using VectorLinearizeBase::VectorLinearizeBase;
+
+ void runOnOperation() override {
+ auto *context = &getContext();
+
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(context);
+ ConversionTarget target(*context);
+
+ vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target) {
+ typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+ // Ignore scalable vectors for now.
+ if (type.getRank() <= 1 || type.isScalable())
+ return type;
+
+ return VectorType::get(type.getNumElements(), type.getElementType());
+ });
+
+ auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+ !isa<VectorType>(type))
+ return nullptr;
+
+ return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+ };
+ typeConverter.addArgumentMaterialization(materializeCast);
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ target.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) -> std::optional<bool> {
+ if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+ return typeConverter.isLegal(op);
+
+ return std::nullopt;
+ });
+
+ patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+ patterns.getContext());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..bfccef7cfe574b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
};
} // namespace
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+ const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter) {
+ assert(op && "Invalid op");
+ Location loc = op->getLoc();
+ if (converter.isLegal(op))
+ return rewriter.notifyMatchFailure(loc, "op already legal");
+
+ OperationState newOp(loc, op->getName());
+ newOp.addOperands(operands);
+
+ SmallVector<Type> newResultTypes;
+ if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+ newOp.addTypes(newResultTypes);
+ newOp.addAttributes(op->getAttrs());
+ return rewriter.create(newOp);
+}
+
void mlir::populateFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, RewritePatternSet &patterns,
const TypeConverter &converter) {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..e0fac81199bc8d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
+ %1 = math.sin %arg0 : vector<2x2xf32>
+// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+ %2 = arith.addf %arg0, %0 : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
|
@llvm/pr-subscribers-mlir-core Author: Ivan Butygin (Hardcode84) ChangesCommon backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as @krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place. Also, add ConversionPattern class operating on traits. Full diff: https://github.com/llvm/llvm-project/pull/81159.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..32b4363be00949 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
let constructor = "mlir::vector::createLowerVectorMaskPass()";
}
+def VectorLinearize : Pass<"vector-linearize"> {
+ let summary = "Linearize ND vectors into 1D";
+ let description = [{
+ Linearizes ND vectors for N >= 2 into 1D vectors.
+ }];
+ let dependentDialects = ["vector::VectorDialect"];
+ }
+
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..45f54fc70e3261 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
namespace mlir {
+class ConversionTarget;
class RewritePatternSet;
+class TypeConverter;
namespace arith {
class AndIOp;
@@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+void populateVectorLinearizeTypeConversionsAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..5081b4c06a617e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
using ConversionPattern::matchAndRewrite;
};
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+ OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+ : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+ TypeID::get<TraitType>(), benefit, context) {}
+ OpTraitConversionPattern(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+ TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+ const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter);
+
/// Add a pattern to the given pattern list to convert the signature of a
/// FunctionOpInterface op with the given type converter. This only supports
/// ops which use FunctionType to represent their type.
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
const TypeConverter *converter = getTypeConverter();
- if (converter->isLegal(op))
- return rewriter.notifyMatchFailure(loc, "op already legal");
- OperationState newOp(loc, op->getName());
- newOp.addOperands(operands);
+ FailureOr<Operation *> legalized =
+ convertOpResultTypes(op, operands, *converter, rewriter);
+ if (failed(legalized))
+ return failure();
- SmallVector<Type> newResultTypes;
- if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
- return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
- newOp.addTypes(newResultTypes);
- newOp.addAttributes(op->getAttrs());
- Operation *legalized = rewriter.create(newOp);
- SmallVector<Value> results = legalized->getResults();
- for (auto [result, newType, origType] :
- llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+ SmallVector<Value> results = (*legalized)->getResults();
+ for (auto [result, newType, origType] : llvm::zip_equal(
+ results, (*legalized)->getResultTypes(), op->getResultTypes())) {
if (newType != origType)
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorEmulateMaskedLoadStore.cpp
VectorEmulateNarrowType.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
+ VectorLinearize.cpp
VectorTransferOpTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..7602e8c1976a9a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,122 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::vector {
+#define GEN_PASS_DEF_VECTORLINEARIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace mlir::vector
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = constOp.getLoc();
+ auto resType =
+ getTypeConverter()->convertType<VectorType>(constOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+ auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!dstElementsAttr)
+ return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+ dstElementsAttr = dstElementsAttr.reshape(resType);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+ dstElementsAttr);
+ return success();
+ }
+};
+
+struct LinearizeVectorizable final
+ : OpTraitConversionPattern<OpTrait::Vectorizable> {
+ using OpTraitConversionPattern::OpTraitConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ FailureOr<Operation *> newOp =
+ convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+ if (failed(newOp))
+ return failure();
+
+ rewriter.replaceOp(op, (*newOp)->getResults());
+ return success();
+ }
+};
+
+struct VectorLinearizePass final
+ : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
+ using VectorLinearizeBase::VectorLinearizeBase;
+
+ void runOnOperation() override {
+ auto *context = &getContext();
+
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(context);
+ ConversionTarget target(*context);
+
+ vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target) {
+ typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+ // Ignore scalable vectors for now.
+ if (type.getRank() <= 1 || type.isScalable())
+ return type;
+
+ return VectorType::get(type.getNumElements(), type.getElementType());
+ });
+
+ auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+ !isa<VectorType>(type))
+ return nullptr;
+
+ return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+ };
+ typeConverter.addArgumentMaterialization(materializeCast);
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ target.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) -> std::optional<bool> {
+ if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+ return typeConverter.isLegal(op);
+
+ return std::nullopt;
+ });
+
+ patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+ patterns.getContext());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..bfccef7cfe574b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
};
} // namespace
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+ const TypeConverter &converter,
+ ConversionPatternRewriter &rewriter) {
+ assert(op && "Invalid op");
+ Location loc = op->getLoc();
+ if (converter.isLegal(op))
+ return rewriter.notifyMatchFailure(loc, "op already legal");
+
+ OperationState newOp(loc, op->getName());
+ newOp.addOperands(operands);
+
+ SmallVector<Type> newResultTypes;
+ if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+ newOp.addTypes(newResultTypes);
+ newOp.addAttributes(op->getAttrs());
+ return rewriter.create(newOp);
+}
+
void mlir::populateFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, RewritePatternSet &patterns,
const TypeConverter &converter) {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..e0fac81199bc8d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
+ %1 = math.sin %arg0 : vector<2x2xf32>
+// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+ %2 = arith.addf %arg0, %0 : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the contribution!
I think we should look at applying vector unrolling instead or linearization. It's currently happening at LLVM level but we have always talked about moving it to an earlier stage so that the lowering to LLVM and SPIR-V are aligned.
Collapsing some vector dimensions is definitely beneficial for some cases, esp. when the trailing dimension can't fill a full 1-D physical vector register, such as in your test (I think @hanhanW was working on something for these cases). However, I'm not sure we can apply it in a general way. For example, we couldn't turn an arbitrary vector<2x2xf32>
load/store into a vector<4xf32>
one because elements might not be contiguous in memory across the dimensions. Also, by linearizing an n-D vector instead of unrolling it, we are losing control on the actual vector length used and relying on the hardware backend to do what we expected, which is something we have been trying to avoid.
Would vector unrolling work for your case?
In our case underlying API/ABI (SPIR-V intrinsics) expects 1D vectors, but it interprets them as ND and we represent them (and surrounding arith ops) as ND at the intermediate steps, so we need to linearize it at some point. Also, vector unrolling and linearization are not mutually exclusive. We can have some partial unroll first and than run linearization on result (and potentially, some legalization pass after, breaking big 1D vectors into smaller ones, supported by HW). Regarding loads/stores and similar potentially non-linearizable ops, they are not covered by the patch and will stay untouched + I understand this transform is quite usecase-specific, so don't want to force it into default vector/llvm/spir-v pipeline. |
Ok, got it. Would it make sense then to remove the pass in |
Perhaps we can also add a callback function so that the user can decide when an op should be linearized or not. That should give us the level of flexibility that we need. |
Makes sense, I'll do it.
User already have some control over it by overriding |
Not directly applicable, but our solution to a similar problem in memref was to linearize all but the last dimensions (for D>2). This would keep the inner dimension intact and not make assumptions about the vector length, but would potentially break noncontiguous higher dimensions if the layout isn't changed accordingly? But this does require the underlying user (print in out case) to change behaviour, which is not always a generic property. |
/// Linearizes ND vectors (N >= 2) into 1D. | ||
void populateVectorLinearizeTypeConversionsAndLegality( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove Legality
or describe what it means in the doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the doc
// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> | ||
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> | ||
// Arith and math ops are handled in generic way, check some of them | ||
// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a few more tests? Perhaps some where the transformation shouldn't trigger?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more checks for shape_cast
on the boundaries and check for return
op unchanged, not sure what else to check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks!
Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as
array<array<... vector>>
and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now.@krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place.
Also, add ConversionPattern class operating on traits.