Skip to content

[mlir][vector] ND vectors linearization pass #81159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"

namespace mlir {
class ConversionTarget;
class RewritePatternSet;
class TypeConverter;

namespace arith {
class AndIOp;
Expand Down Expand Up @@ -375,6 +377,13 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
/// provided ConversionTarget with the appropriate legality configuration for
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target);

} // namespace vector
} // namespace mlir

Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
using ConversionPattern::matchAndRewrite;
};

/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
/// for matching and rewriting against instances of an operation that possess a
/// given trait.
template <template <typename> class TraitType>
class OpTraitConversionPattern : public ConversionPattern {
public:
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(Pattern::MatchTraitOpTypeTag(),
TypeID::get<TraitType>(), benefit, context) {}
OpTraitConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
TypeID::get<TraitType>(), benefit, context) {}
};

/// Generic utility to convert op result types according to type converter
/// without knowing exact op type.
/// Clones existing op with new result types and returns it.
FailureOr<Operation *>
convertOpResultTypes(Operation *op, ValueRange operands,
const TypeConverter &converter,
ConversionPatternRewriter &rewriter);

/// Add a pattern to the given pattern list to convert the signature of a
/// FunctionOpInterface op with the given type converter. This only supports
/// ops which use FunctionType to represent their type.
Expand Down
20 changes: 7 additions & 13 deletions mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
const TypeConverter *converter = getTypeConverter();
if (converter->isLegal(op))
return rewriter.notifyMatchFailure(loc, "op already legal");
OperationState newOp(loc, op->getName());
newOp.addOperands(operands);
FailureOr<Operation *> legalized =
convertOpResultTypes(op, operands, *converter, rewriter);
if (failed(legalized))
return failure();

SmallVector<Type> newResultTypes;
if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
newOp.addAttributes(op->getAttrs());
Operation *legalized = rewriter.create(newOp);
SmallVector<Value> results = legalized->getResults();
for (auto [result, newType, origType] :
llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
if (newType != origType)
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorEmulateMaskedLoadStore.cpp
VectorEmulateNarrowType.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorLinearize.cpp
VectorTransferOpTransforms.cpp
VectorTransferSplitRewritePatterns.cpp
VectorTransforms.cpp
Expand Down
97 changes: 97 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
//===- VectorLinearize.cpp - vector linearization transforms --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns and pass for linearizing ND vectors into 1D.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = constOp.getLoc();
auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");

auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!dstElementsAttr)
return rewriter.notifyMatchFailure(loc, "unsupported attr type");

dstElementsAttr = dstElementsAttr.reshape(resType);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
dstElementsAttr);
return success();
}
};

struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Operation *> newOp =
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
if (failed(newOp))
return failure();

rewriter.replaceOp(op, (*newOp)->getResults());
return success();
}
};
} // namespace

void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
// Ignore scalable vectors for now.
if (type.getRank() <= 1 || type.isScalable())
return type;

return VectorType::get(type.getNumElements(), type.getElementType());
});

auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
!isa<VectorType>(type))
return nullptr;

return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
};
typeConverter.addArgumentMaterialization(materializeCast);
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);

target.markUnknownOpDynamicallyLegal(
[&](Operation *op) -> std::optional<bool> {
if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
return typeConverter.isLegal(op);

return std::nullopt;
});

patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
patterns.getContext());
}
21 changes: 21 additions & 0 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
};
} // namespace

FailureOr<Operation *>
mlir::convertOpResultTypes(Operation *op, ValueRange operands,
const TypeConverter &converter,
ConversionPatternRewriter &rewriter) {
assert(op && "Invalid op");
Location loc = op->getLoc();
if (converter.isLegal(op))
return rewriter.notifyMatchFailure(loc, "op already legal");

OperationState newOp(loc, op->getName());
newOp.addOperands(operands);

SmallVector<Type> newResultTypes;
if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");

newOp.addTypes(newResultTypes);
newOp.addAttributes(op->getAttrs());
return rewriter.create(newOp);
}

void mlir::populateFunctionOpInterfaceTypeConversionPattern(
StringRef functionLikeOpName, RewritePatternSet &patterns,
const TypeConverter &converter) {
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Vector/linearize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s

// CHECK-LABEL: test_linearize
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>

// Arith and math ops are handled in generic way, check some of them
// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a few more tests? Perhaps some where the transformation shouldn't trigger?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more checks for shape_cast on the boundaries and check for return op unchanged, not sure what else to check.

%1 = math.sin %arg0 : vector<2x2xf32>
// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
%2 = arith.addf %arg0, %0 : vector<2x2xf32>

// CHECK: return %[[RES]] : vector<2x2xf32>
return %0 : vector<2x2xf32>
}
29 changes: 29 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,33 @@ struct TestVectorEmulateMaskedLoadStore final
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)

StringRef getArgument() const override { return "test-vector-linearize"; }
StringRef getDescription() const override {
return "Linearizes ND vectors for N >= 2 into 1D vectors";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}

void runOnOperation() override {
auto *context = &getContext();

TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);

vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace

namespace mlir {
Expand Down Expand Up @@ -867,6 +894,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();

PassRegistration<TestVectorEmulateMaskedLoadStore>();

PassRegistration<TestVectorLinearize>();
}
} // namespace test
} // namespace mlir