Skip to content

Commit 5c0bdc4

Browse files
committed
squash commits to make rebase easier
1 parent 6ed05ed commit 5c0bdc4

File tree

5 files changed

+70
-44
lines changed

5 files changed

+70
-44
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,22 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
407407
RewritePatternSet &patterns, PatternBenefit benefit = 1);
408408

409409
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
410-
/// This registers (1) which operations are legal and hence should not be
411-
/// linearized, (2) what converted types are (rank-1 vectors) and how to
410+
///
411+
/// Definition: here 'linearization' means converting a single operation with
412+
/// 1+ vector operand/result of rank>1, into a new single operation whose
413+
/// vector operands and results are all of rank<=1.
414+
///
415+
/// This function registers (1) which operations are legal, and hence should not
416+
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
412417
/// materialze the conversion (with shape_cast)
413418
///
414419
/// Note: the set of legal operations can be extended by a user if for example
415-
/// certain rank>1 vectors are considered valid, but adding additional
420+
/// certain rank>1 vectors are considered valid, by adding additional
416421
/// dynamically legal ops to `conversionTarget`.
422+
///
423+
/// Further note: the choice to use a dialect conversion design for
424+
/// linearization is to make it easy to reuse generic structural type
425+
/// conversions for linearizing scf/cf/func operations
417426
void populateForVectorLinearize(TypeConverter &typeConverter,
418427
ConversionTarget &conversionTarget);
419428

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class ConvertForOpTypes
9999
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need
100100
// to clone the op.
101101
//
102-
// 2. We need to resue the original region instead of cloning it, otherwise
102+
// 2. We need to reuse the original region instead of cloning it, otherwise
103103
// the dialect conversion framework thinks that we just inserted all the
104104
// cloned child ops. But what we want is to "take" the child regions and let
105105
// the dialect conversion framework continue recursively into ops inside

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -568,45 +568,41 @@ struct LinearizeVectorSplat final
568568

569569
} // namespace
570570

571-
/// Return true if the operation `op` does not support scalable vectors and
572-
/// has at least 1 scalable vector result. These ops should all eventually
573-
/// support scalable vectors, and this function should be removed.
574-
static bool isNotLinearizableBecauseScalable(Operation *op) {
575-
576-
bool unsupported =
577-
isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
578-
vector::ExtractOp, vector::InsertOp>(op);
579-
if (!unsupported)
580-
return false;
581-
582-
// Check if any of the results is a scalable vector type.
583-
auto types = op->getResultTypes();
584-
bool containsScalableResult =
585-
std::any_of(types.begin(), types.end(), [](Type type) {
586-
auto vecType = dyn_cast<VectorType>(type);
587-
return vecType && vecType.isScalable();
588-
});
589-
590-
return containsScalableResult;
591-
}
592-
593-
static bool isNotLinearizable(Operation *op) {
571+
/// This method defines the set of operations that are linearizable, and hence
572+
/// that are considered illegal for the conversion target.
573+
static bool isLinearizable(Operation *op) {
594574

595575
// Only ops that are in the vector dialect, are ConstantLike, or
596576
// are Vectorizable might be linearized currently.
597577
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
598578
StringRef opDialect = op->getDialect()->getNamespace();
599-
bool unsupported = (opDialect != vectorDialect) &&
600-
!op->hasTrait<OpTrait::ConstantLike>() &&
601-
!op->hasTrait<OpTrait::Vectorizable>();
602-
if (unsupported)
603-
return true;
604-
605-
// Some ops currently don't support scalable vectors.
606-
if (isNotLinearizableBecauseScalable(op))
607-
return true;
579+
bool supported = (opDialect == vectorDialect) ||
580+
op->hasTrait<OpTrait::ConstantLike>() ||
581+
op->hasTrait<OpTrait::Vectorizable>();
582+
if (!supported)
583+
return false;
608584

609-
return false;
585+
return TypeSwitch<Operation *, bool>(op)
586+
// As type legalization is done with vector.shape_cast, shape_cast
587+
// itself cannot be linearized (will create new shape_casts to linearize
588+
// ad infinitum).
589+
.Case<vector::ShapeCastOp>([&](auto) { return false; })
590+
// vector.extract_strided_slice, vector.extract, and vector.insert
591+
// operations are linearized to a rank-1 vector.shuffle by the current
592+
// patterns. vector.shuffle only supports fixed size vectors, so it is
593+
// impossible to use this approach to linearize these ops if they operate
594+
// on scalable vectors.
595+
.Case<vector::ExtractStridedSliceOp>(
596+
[&](vector::ExtractStridedSliceOp extractOp) {
597+
return !extractOp.getType().isScalable();
598+
})
599+
.Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
600+
return !insertOp.getType().isScalable();
601+
})
602+
.Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
603+
return !extractOp.getSourceVectorType().isScalable();
604+
})
605+
.Default([&](auto) { return true; });
610606
}
611607

612608
void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
@@ -640,7 +636,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
640636

641637
target.markUnknownOpDynamicallyLegal(
642638
[=](Operation *op) -> std::optional<bool> {
643-
if (isNotLinearizable(op))
639+
if (!isLinearizable(op))
644640
return true;
645641
// This will return true if, for all operand and result types `t`,
646642
// convertType(t) = t. This is true if there are no rank>=2 vectors.

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,28 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
392392

393393
// -----
394394

395+
// CHECK-LABEL: test_linearize_across_for
396+
func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
397+
%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
398+
%c0 = arith.constant 0 : index
399+
%c1 = arith.constant 1 : index
400+
%c4 = arith.constant 4 : index
401+
402+
// CHECK: scf.for {{.*}} -> (vector<4xi8>)
403+
%1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
404+
405+
// CHECK: arith.addi {{.*}} : vector<4xi8>
406+
%2 = arith.addi %arg1, %0 : vector<2x2xi8>
407+
408+
// CHECK: scf.yield {{.*}} : vector<4xi8>
409+
scf.yield %2 : vector<2x2xi8>
410+
}
411+
%3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
412+
return %3 : vector<4xi8>
413+
}
414+
415+
// -----
416+
395417
// CHECK-LABEL: linearize_vector_splat
396418
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
397419
func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
@@ -414,5 +436,6 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
414436
// CHECK: return %[[CAST]] : vector<4x[2]xi32>
415437
%0 = vector.splat %arg0 : vector<4x[2]xi32>
416438
return %0 : vector<4x[2]xi32>
439+
417440
}
418441

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
1919
#include "mlir/Dialect/SCF/IR/SCF.h"
20+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2021
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2122
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2223
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -836,9 +837,6 @@ struct TestVectorEmulateMaskedLoadStore final
836837
}
837838
};
838839

839-
// TODO: move this code into the user project.
840-
namespace vendor {
841-
842840
/// Get the set of operand/result types to check for sufficiently
843841
/// small inner-most dimension size.
844842
static SmallVector<std::pair<Type, unsigned>>
@@ -960,8 +958,6 @@ struct TestVectorBitWidthLinearize final
960958
}
961959
};
962960

963-
} // namespace vendor
964-
965961
struct TestVectorLinearize final
966962
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
967963
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
@@ -987,6 +983,8 @@ struct TestVectorLinearize final
987983
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
988984
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
989985
patterns);
986+
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
987+
converter, patterns, target);
990988

991989
if (failed(applyPartialConversion(getOperation(), target,
992990
std::move(patterns))))
@@ -1067,7 +1065,7 @@ void registerTestVectorLowerings() {
10671065

10681066
PassRegistration<TestVectorLinearize>();
10691067

1070-
PassRegistration<vendor::TestVectorBitWidthLinearize>();
1068+
PassRegistration<TestVectorBitWidthLinearize>();
10711069

10721070
PassRegistration<TestEliminateVectorMasks>();
10731071
}

0 commit comments

Comments
 (0)