Skip to content

Commit 35ef399

Browse files
authored
[mlir][vector] ND vectors linearization pass (#81159)
Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as `array<array<... vector>>` and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now. @krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place. Also, add ConversionPattern class operating on traits.
1 parent bfc0b7c commit 35ef399

File tree

8 files changed

+206
-13
lines changed

8 files changed

+206
-13
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
2121

2222
namespace mlir {
23+
class ConversionTarget;
2324
class RewritePatternSet;
25+
class TypeConverter;
2426

2527
namespace arith {
2628
class AndIOp;
@@ -375,6 +377,13 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
375377
void populateVectorTransposeNarrowTypeRewritePatterns(
376378
RewritePatternSet &patterns, PatternBenefit benefit = 1);
377379

380+
/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
381+
/// provided ConversionTarget with the appropriate legality configuration for
382+
/// the ops to get converted properly.
383+
void populateVectorLinearizeTypeConversionsAndLegality(
384+
TypeConverter &typeConverter, RewritePatternSet &patterns,
385+
ConversionTarget &target);
386+
378387
} // namespace vector
379388
} // namespace mlir
380389

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
604604
using ConversionPattern::matchAndRewrite;
605605
};
606606

607+
/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
608+
/// for matching and rewriting against instances of an operation that possess a
609+
/// given trait.
610+
template <template <typename> class TraitType>
611+
class OpTraitConversionPattern : public ConversionPattern {
612+
public:
613+
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
614+
: ConversionPattern(Pattern::MatchTraitOpTypeTag(),
615+
TypeID::get<TraitType>(), benefit, context) {}
616+
OpTraitConversionPattern(const TypeConverter &typeConverter,
617+
MLIRContext *context, PatternBenefit benefit = 1)
618+
: ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
619+
TypeID::get<TraitType>(), benefit, context) {}
620+
};
621+
622+
/// Generic utility to convert op result types according to type converter
623+
/// without knowing exact op type.
624+
/// Clones existing op with new result types and returns it.
625+
FailureOr<Operation *>
626+
convertOpResultTypes(Operation *op, ValueRange operands,
627+
const TypeConverter &converter,
628+
ConversionPatternRewriter &rewriter);
629+
607630
/// Add a pattern to the given pattern list to convert the signature of a
608631
/// FunctionOpInterface op with the given type converter. This only supports
609632
/// ops which use FunctionType to represent their type.

mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
7676
ConversionPatternRewriter &rewriter) const {
7777
Location loc = op->getLoc();
7878
const TypeConverter *converter = getTypeConverter();
79-
if (converter->isLegal(op))
80-
return rewriter.notifyMatchFailure(loc, "op already legal");
81-
OperationState newOp(loc, op->getName());
82-
newOp.addOperands(operands);
79+
FailureOr<Operation *> legalized =
80+
convertOpResultTypes(op, operands, *converter, rewriter);
81+
if (failed(legalized))
82+
return failure();
8383

84-
SmallVector<Type> newResultTypes;
85-
if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
86-
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
87-
newOp.addTypes(newResultTypes);
88-
newOp.addAttributes(op->getAttrs());
89-
Operation *legalized = rewriter.create(newOp);
90-
SmallVector<Value> results = legalized->getResults();
91-
for (auto [result, newType, origType] :
92-
llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
84+
SmallVector<Value> results = (*legalized)->getResults();
85+
for (auto [result, newType, origType] : llvm::zip_equal(
86+
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
9387
if (newType != origType)
9488
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
9589
}

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
1616
VectorEmulateMaskedLoadStore.cpp
1717
VectorEmulateNarrowType.cpp
1818
VectorInsertExtractStridedSliceRewritePatterns.cpp
19+
VectorLinearize.cpp
1920
VectorTransferOpTransforms.cpp
2021
VectorTransferSplitRewritePatterns.cpp
2122
VectorTransforms.cpp
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
//===- VectorLinearize.cpp - vector linearization transforms --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns and pass for linearizing ND vectors into 1D.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/IR/TypeUtilities.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
using namespace mlir;
21+
22+
namespace {
23+
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
24+
using OpConversionPattern::OpConversionPattern;
25+
26+
LogicalResult
27+
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
28+
ConversionPatternRewriter &rewriter) const override {
29+
Location loc = constOp.getLoc();
30+
auto resType =
31+
getTypeConverter()->convertType<VectorType>(constOp.getType());
32+
if (!resType)
33+
return rewriter.notifyMatchFailure(loc, "can't convert return type");
34+
35+
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
36+
if (!dstElementsAttr)
37+
return rewriter.notifyMatchFailure(loc, "unsupported attr type");
38+
39+
dstElementsAttr = dstElementsAttr.reshape(resType);
40+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
41+
dstElementsAttr);
42+
return success();
43+
}
44+
};
45+
46+
struct LinearizeVectorizable final
47+
: OpTraitConversionPattern<OpTrait::Vectorizable> {
48+
using OpTraitConversionPattern::OpTraitConversionPattern;
49+
50+
LogicalResult
51+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
52+
ConversionPatternRewriter &rewriter) const override {
53+
FailureOr<Operation *> newOp =
54+
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
55+
if (failed(newOp))
56+
return failure();
57+
58+
rewriter.replaceOp(op, (*newOp)->getResults());
59+
return success();
60+
}
61+
};
62+
} // namespace
63+
64+
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
65+
TypeConverter &typeConverter, RewritePatternSet &patterns,
66+
ConversionTarget &target) {
67+
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
68+
// Ignore scalable vectors for now.
69+
if (type.getRank() <= 1 || type.isScalable())
70+
return type;
71+
72+
return VectorType::get(type.getNumElements(), type.getElementType());
73+
});
74+
75+
auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
76+
Location loc) -> Value {
77+
if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
78+
!isa<VectorType>(type))
79+
return nullptr;
80+
81+
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
82+
};
83+
typeConverter.addArgumentMaterialization(materializeCast);
84+
typeConverter.addSourceMaterialization(materializeCast);
85+
typeConverter.addTargetMaterialization(materializeCast);
86+
87+
target.markUnknownOpDynamicallyLegal(
88+
[&](Operation *op) -> std::optional<bool> {
89+
if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
90+
return typeConverter.isLegal(op);
91+
92+
return std::nullopt;
93+
});
94+
95+
patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
96+
patterns.getContext());
97+
}

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3130,6 +3130,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
31303130
};
31313131
} // namespace
31323132

3133+
FailureOr<Operation *>
3134+
mlir::convertOpResultTypes(Operation *op, ValueRange operands,
3135+
const TypeConverter &converter,
3136+
ConversionPatternRewriter &rewriter) {
3137+
assert(op && "Invalid op");
3138+
Location loc = op->getLoc();
3139+
if (converter.isLegal(op))
3140+
return rewriter.notifyMatchFailure(loc, "op already legal");
3141+
3142+
OperationState newOp(loc, op->getName());
3143+
newOp.addOperands(operands);
3144+
3145+
SmallVector<Type> newResultTypes;
3146+
if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
3147+
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
3148+
3149+
newOp.addTypes(newResultTypes);
3150+
newOp.addAttributes(op->getAttrs());
3151+
return rewriter.create(newOp);
3152+
}
3153+
31333154
void mlir::populateFunctionOpInterfaceTypeConversionPattern(
31343155
StringRef functionLikeOpName, RewritePatternSet &patterns,
31353156
const TypeConverter &converter) {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
2+
3+
// CHECK-LABEL: test_linearize
4+
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
5+
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
6+
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
7+
// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
8+
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
9+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
10+
11+
// Arith and math ops are handled in generic way, check some of them
12+
// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
13+
%1 = math.sin %arg0 : vector<2x2xf32>
14+
// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
15+
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
16+
17+
// CHECK: return %[[RES]] : vector<2x2xf32>
18+
return %0 : vector<2x2xf32>
19+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,33 @@ struct TestVectorEmulateMaskedLoadStore final
823823
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
824824
}
825825
};
826+
827+
struct TestVectorLinearize final
828+
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
829+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
830+
831+
StringRef getArgument() const override { return "test-vector-linearize"; }
832+
StringRef getDescription() const override {
833+
return "Linearizes ND vectors for N >= 2 into 1D vectors";
834+
}
835+
void getDependentDialects(DialectRegistry &registry) const override {
836+
registry.insert<vector::VectorDialect>();
837+
}
838+
839+
void runOnOperation() override {
840+
auto *context = &getContext();
841+
842+
TypeConverter typeConverter;
843+
RewritePatternSet patterns(context);
844+
ConversionTarget target(*context);
845+
846+
vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
847+
patterns, target);
848+
if (failed(applyPartialConversion(getOperation(), target,
849+
std::move(patterns))))
850+
return signalPassFailure();
851+
}
852+
};
826853
} // namespace
827854

828855
namespace mlir {
@@ -867,6 +894,8 @@ void registerTestVectorLowerings() {
867894
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
868895

869896
PassRegistration<TestVectorEmulateMaskedLoadStore>();
897+
898+
PassRegistration<TestVectorLinearize>();
870899
}
871900
} // namespace test
872901
} // namespace mlir

0 commit comments

Comments
 (0)