Skip to content

[mlir][vector] Address linearization comments (post commit) #138075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,22 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);

/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
/// This registers (1) which operations are legal and hence should not be
/// linearized, (2) what converted types are (rank-1 vectors) and how to
///
/// Definition: here 'linearization' means converting a single operation with
/// 1+ vector operand/result of rank>1, into a new single operation whose
/// vector operands and results are all of rank<=1.
///
/// This function registers (1) which operations are legal, and hence should not
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
/// materialze the conversion (with shape_cast)
///
/// Note: the set of legal operations can be extended by a user if for example
/// certain rank>1 vectors are considered valid, but adding additional
/// certain rank>1 vectors are considered valid, by adding additional
/// dynamically legal ops to `conversionTarget`.
///
/// Further note: the choice to use a dialect conversion design for
/// linearization is to make it easy to reuse generic structural type
/// conversions for linearizing scf/cf/func operations
void populateForVectorLinearize(TypeConverter &typeConverter,
ConversionTarget &conversionTarget);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ConvertForOpTypes
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need
// to clone the op.
//
// 2. We need to resue the original region instead of cloning it, otherwise
// 2. We need to reuse the original region instead of cloning it, otherwise
// the dialect conversion framework thinks that we just inserted all the
// cloned child ops. But what we want is to "take" the child regions and let
// the dialect conversion framework continue recursively into ops inside
Expand Down
72 changes: 38 additions & 34 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,45 +626,49 @@ struct LinearizeVectorCreateMask final

} // namespace

/// Return true if the operation `op` does not support scalable vectors and
/// has at least 1 scalable vector result. These ops should all eventually
/// support scalable vectors, and this function should be removed.
static bool isNotLinearizableBecauseScalable(Operation *op) {

bool unsupported =
isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
vector::ExtractOp, vector::InsertOp>(op);
if (!unsupported)
return false;

// Check if any of the results is a scalable vector type.
auto types = op->getResultTypes();
bool containsScalableResult =
std::any_of(types.begin(), types.end(), [](Type type) {
auto vecType = dyn_cast<VectorType>(type);
return vecType && vecType.isScalable();
});

return containsScalableResult;
}

static bool isNotLinearizable(Operation *op) {
/// This method defines the set of operations that are linearizable, and hence
/// that are considered illegal for the conversion target.
static bool isLinearizable(Operation *op) {

// Only ops that are in the vector dialect, are ConstantLike, or
// are Vectorizable might be linearized currently.
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
StringRef opDialect = op->getDialect()->getNamespace();
bool unsupported = (opDialect != vectorDialect) &&
!op->hasTrait<OpTrait::ConstantLike>() &&
!op->hasTrait<OpTrait::Vectorizable>();
if (unsupported)
return true;

// Some ops currently don't support scalable vectors.
if (isNotLinearizableBecauseScalable(op))
return true;
bool supported = (opDialect == vectorDialect) ||
op->hasTrait<OpTrait::ConstantLike>() ||
op->hasTrait<OpTrait::Vectorizable>();
if (!supported)
return false;

return false;
return TypeSwitch<Operation *, bool>(op)
// As type legalization is done with vector.shape_cast, shape_cast
// itself cannot be linearized (will create new shape_casts to linearize
// ad infinitum).
.Case<vector::ShapeCastOp>([&](auto) { return false; })
// The operations
// - vector.extract_strided_slice
// - vector.extract
// - vector.insert_strided_slice
// - vector.insert
// are linearized to a rank-1 vector.shuffle by the current patterns.
// vector.shuffle only supports fixed size vectors, so it is impossible to
// use this approach to linearize these ops if they operate on scalable
// vectors.
.Case<vector::ExtractStridedSliceOp>(
[&](vector::ExtractStridedSliceOp extractOp) {
return !extractOp.getType().isScalable();
})
.Case<vector::InsertStridedSliceOp>(
[&](vector::InsertStridedSliceOp insertOp) {
return !insertOp.getType().isScalable();
})
.Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
return !insertOp.getType().isScalable();
})
.Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
return !extractOp.getSourceVectorType().isScalable();
})
.Default([&](auto) { return true; });
}

void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
Expand Down Expand Up @@ -698,7 +702,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,

target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
if (isNotLinearizable(op))
if (!isLinearizable(op))
return true;
// This will return true if, for all operand and result types `t`,
// convertType(t) = t. This is true if there are no rank>=2 vectors.
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Vector/linearize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,28 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {

// -----

// CHECK-LABEL: test_linearize_across_for
func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

// CHECK: scf.for {{.*}} -> (vector<4xi8>)
%1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {

// CHECK: arith.addi {{.*}} : vector<4xi8>
%2 = arith.addi %arg1, %0 : vector<2x2xi8>

// CHECK: scf.yield {{.*}} : vector<4xi8>
scf.yield %2 : vector<2x2xi8>
}
%3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
return %3 : vector<4xi8>
}

// -----

// CHECK-LABEL: linearize_vector_splat
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
Expand All @@ -414,6 +436,7 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
// CHECK: return %[[CAST]] : vector<4x[2]xi32>
%0 = vector.splat %arg0 : vector<4x[2]xi32>
return %0 : vector<4x[2]xi32>

}

// -----
Expand Down
10 changes: 4 additions & 6 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
Expand Down Expand Up @@ -836,9 +837,6 @@ struct TestVectorEmulateMaskedLoadStore final
}
};

// TODO: move this code into the user project.
namespace vendor {

/// Get the set of operand/result types to check for sufficiently
/// small inner-most dimension size.
static SmallVector<std::pair<Type, unsigned>>
Expand Down Expand Up @@ -960,8 +958,6 @@ struct TestVectorBitWidthLinearize final
}
};

} // namespace vendor

struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
Expand All @@ -987,6 +983,8 @@ struct TestVectorLinearize final
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
patterns);
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down Expand Up @@ -1067,7 +1065,7 @@ void registerTestVectorLowerings() {

PassRegistration<TestVectorLinearize>();

PassRegistration<vendor::TestVectorBitWidthLinearize>();
PassRegistration<TestVectorBitWidthLinearize>();

PassRegistration<TestEliminateVectorMasks>();
}
Expand Down