Skip to content

Commit 6c61c2e

Browse files
newlingTIFitis
authored andcommitted
[mlir][vector] Address linearization comments (post commit) (llvm#138075)
This PR adds some documentation to address comments in llvm#136581 This PR adds a test for linearization across scf.for. This new test might be considered redundant by more experienced MLIRers, but might help newer users understand how to linearize scf/cf/func operations easily The documentation added in this PR also tightens our definition of linearization, to now exclude unrolling (which creates multiple ops from 1 op). We hadn't really specified what linearization meant before.
1 parent 86fafd5 commit 6c61c2e

File tree

5 files changed

+78
-44
lines changed

5 files changed

+78
-44
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,22 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
407407
RewritePatternSet &patterns, PatternBenefit benefit = 1);
408408

409409
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
410-
/// This registers (1) which operations are legal and hence should not be
411-
/// linearized, (2) what converted types are (rank-1 vectors) and how to
410+
///
411+
/// Definition: here 'linearization' means converting a single operation with
412+
/// 1+ vector operand/result of rank>1, into a new single operation whose
413+
/// vector operands and results are all of rank<=1.
414+
///
415+
/// This function registers (1) which operations are legal, and hence should not
416+
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
412417
/// materialze the conversion (with shape_cast)
413418
///
414419
/// Note: the set of legal operations can be extended by a user if for example
415-
/// certain rank>1 vectors are considered valid, but adding additional
420+
/// certain rank>1 vectors are considered valid, by adding additional
416421
/// dynamically legal ops to `conversionTarget`.
422+
///
423+
/// Further note: the choice to use a dialect conversion design for
424+
/// linearization is to make it easy to reuse generic structural type
425+
/// conversions for linearizing scf/cf/func operations
417426
void populateForVectorLinearize(TypeConverter &typeConverter,
418427
ConversionTarget &conversionTarget);
419428

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class ConvertForOpTypes
9999
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need
100100
// to clone the op.
101101
//
102-
// 2. We need to resue the original region instead of cloning it, otherwise
102+
// 2. We need to reuse the original region instead of cloning it, otherwise
103103
// the dialect conversion framework thinks that we just inserted all the
104104
// cloned child ops. But what we want is to "take" the child regions and let
105105
// the dialect conversion framework continue recursively into ops inside

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -626,45 +626,49 @@ struct LinearizeVectorCreateMask final
626626

627627
} // namespace
628628

629-
/// Return true if the operation `op` does not support scalable vectors and
630-
/// has at least 1 scalable vector result. These ops should all eventually
631-
/// support scalable vectors, and this function should be removed.
632-
static bool isNotLinearizableBecauseScalable(Operation *op) {
633-
634-
bool unsupported =
635-
isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
636-
vector::ExtractOp, vector::InsertOp>(op);
637-
if (!unsupported)
638-
return false;
639-
640-
// Check if any of the results is a scalable vector type.
641-
auto types = op->getResultTypes();
642-
bool containsScalableResult =
643-
std::any_of(types.begin(), types.end(), [](Type type) {
644-
auto vecType = dyn_cast<VectorType>(type);
645-
return vecType && vecType.isScalable();
646-
});
647-
648-
return containsScalableResult;
649-
}
650-
651-
static bool isNotLinearizable(Operation *op) {
629+
/// This method defines the set of operations that are linearizable, and hence
630+
/// that are considered illegal for the conversion target.
631+
static bool isLinearizable(Operation *op) {
652632

653633
// Only ops that are in the vector dialect, are ConstantLike, or
654634
// are Vectorizable might be linearized currently.
655635
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
656636
StringRef opDialect = op->getDialect()->getNamespace();
657-
bool unsupported = (opDialect != vectorDialect) &&
658-
!op->hasTrait<OpTrait::ConstantLike>() &&
659-
!op->hasTrait<OpTrait::Vectorizable>();
660-
if (unsupported)
661-
return true;
662-
663-
// Some ops currently don't support scalable vectors.
664-
if (isNotLinearizableBecauseScalable(op))
665-
return true;
637+
bool supported = (opDialect == vectorDialect) ||
638+
op->hasTrait<OpTrait::ConstantLike>() ||
639+
op->hasTrait<OpTrait::Vectorizable>();
640+
if (!supported)
641+
return false;
666642

667-
return false;
643+
return TypeSwitch<Operation *, bool>(op)
644+
// As type legalization is done with vector.shape_cast, shape_cast
645+
// itself cannot be linearized (will create new shape_casts to linearize
646+
// ad infinitum).
647+
.Case<vector::ShapeCastOp>([&](auto) { return false; })
648+
// The operations
649+
// - vector.extract_strided_slice
650+
// - vector.extract
651+
// - vector.insert_strided_slice
652+
// - vector.insert
653+
// are linearized to a rank-1 vector.shuffle by the current patterns.
654+
// vector.shuffle only supports fixed size vectors, so it is impossible to
655+
// use this approach to linearize these ops if they operate on scalable
656+
// vectors.
657+
.Case<vector::ExtractStridedSliceOp>(
658+
[&](vector::ExtractStridedSliceOp extractOp) {
659+
return !extractOp.getType().isScalable();
660+
})
661+
.Case<vector::InsertStridedSliceOp>(
662+
[&](vector::InsertStridedSliceOp insertOp) {
663+
return !insertOp.getType().isScalable();
664+
})
665+
.Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
666+
return !insertOp.getType().isScalable();
667+
})
668+
.Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
669+
return !extractOp.getSourceVectorType().isScalable();
670+
})
671+
.Default([&](auto) { return true; });
668672
}
669673

670674
void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
@@ -698,7 +702,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
698702

699703
target.markUnknownOpDynamicallyLegal(
700704
[=](Operation *op) -> std::optional<bool> {
701-
if (isNotLinearizable(op))
705+
if (!isLinearizable(op))
702706
return true;
703707
// This will return true if, for all operand and result types `t`,
704708
// convertType(t) = t. This is true if there are no rank>=2 vectors.

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,28 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
392392

393393
// -----
394394

395+
// CHECK-LABEL: test_linearize_across_for
396+
func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
397+
%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
398+
%c0 = arith.constant 0 : index
399+
%c1 = arith.constant 1 : index
400+
%c4 = arith.constant 4 : index
401+
402+
// CHECK: scf.for {{.*}} -> (vector<4xi8>)
403+
%1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
404+
405+
// CHECK: arith.addi {{.*}} : vector<4xi8>
406+
%2 = arith.addi %arg1, %0 : vector<2x2xi8>
407+
408+
// CHECK: scf.yield {{.*}} : vector<4xi8>
409+
scf.yield %2 : vector<2x2xi8>
410+
}
411+
%3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
412+
return %3 : vector<4xi8>
413+
}
414+
415+
// -----
416+
395417
// CHECK-LABEL: linearize_vector_splat
396418
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
397419
func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
@@ -414,6 +436,7 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
414436
// CHECK: return %[[CAST]] : vector<4x[2]xi32>
415437
%0 = vector.splat %arg0 : vector<4x[2]xi32>
416438
return %0 : vector<4x[2]xi32>
439+
417440
}
418441

419442
// -----

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
1919
#include "mlir/Dialect/SCF/IR/SCF.h"
20+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2021
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2122
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2223
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -836,9 +837,6 @@ struct TestVectorEmulateMaskedLoadStore final
836837
}
837838
};
838839

839-
// TODO: move this code into the user project.
840-
namespace vendor {
841-
842840
/// Get the set of operand/result types to check for sufficiently
843841
/// small inner-most dimension size.
844842
static SmallVector<std::pair<Type, unsigned>>
@@ -960,8 +958,6 @@ struct TestVectorBitWidthLinearize final
960958
}
961959
};
962960

963-
} // namespace vendor
964-
965961
struct TestVectorLinearize final
966962
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
967963
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
@@ -987,6 +983,8 @@ struct TestVectorLinearize final
987983
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
988984
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
989985
patterns);
986+
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
987+
converter, patterns, target);
990988

991989
if (failed(applyPartialConversion(getOperation(), target,
992990
std::move(patterns))))
@@ -1067,7 +1065,7 @@ void registerTestVectorLowerings() {
10671065

10681066
PassRegistration<TestVectorLinearize>();
10691067

1070-
PassRegistration<vendor::TestVectorBitWidthLinearize>();
1068+
PassRegistration<TestVectorBitWidthLinearize>();
10711069

10721070
PassRegistration<TestEliminateVectorMasks>();
10731071
}

0 commit comments

Comments
 (0)