Skip to content

Commit e0afc77

Browse files
committed
further enhancements
1 parent 40c73e1 commit e0afc77

File tree

4 files changed

+45
-16
lines changed

4 files changed

+45
-16
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,16 +409,20 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
409409
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
410410
///
411411
/// Definition: here 'linearization' means converting a single operation with
412-
/// 1+ vector operands and results of rank>1, into a single operation whose
413-
/// vector operands are all of rank<=1.
412+
/// 1+ vector operand/result of rank>1, into a new single operation whose
413+
/// vector operands and results are all of rank<=1.
414414
///
415415
/// This function registers (1) which operations are legal, and hence should not
416416
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
417417
/// materialze the conversion (with shape_cast)
418418
///
419419
/// Note: the set of legal operations can be extended by a user if for example
420-
/// certain rank>1 vectors are considered valid, but adding additional
420+
/// certain rank>1 vectors are considered valid, by adding additional
421421
/// dynamically legal ops to `conversionTarget`.
422+
///
423+
/// Further note: the choice to use a dialect conversion design for
424+
/// linearization is to make it easy to reuse generic structural type
425+
/// conversions for linearizing scf/cf/func operations
422426
void populateForVectorLinearize(TypeConverter &typeConverter,
423427
ConversionTarget &conversionTarget);
424428

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,9 @@ struct LinearizeVectorBitCast final
423423

424424
} // namespace
425425

426-
/// Some operations currently cannot be linearized if they have scalable vector
427-
/// results. This function returns true if `op` is such an operation.
426+
/// Some operations currently will not be linearized if they have scalable
427+
/// vector results, although support should be added in the future. This
428+
/// function returns true if `op` is such an operation.
428429
static bool isNotLinearizableBecauseScalable(Operation *op) {
429430

430431
bool unsupported =
@@ -448,15 +449,14 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
448449
return containsScalableResult;
449450
}
450451

451-
/// This method defines a set of operations that are not linearizable,
452-
/// and hence considered legal for the conversion target. These ops are
453-
/// currently
452+
/// This method defines a set of operations that are not linearizable, and hence
453+
/// they are considered legal for the conversion target. These ops are
454+
/// currently,
454455
///
455-
/// 1) Ops that are not in the vector dialect, are not ConstantLike, and are not
456-
/// Vectorizable.
456+
/// 1) ones that are not in the vector dialect, are not ConstantLike, and are
457+
/// not Vectorizable, or
457458
///
458-
/// 2) Certain ops with scalable vector results, for which support has not yet
459-
/// been added.
459+
/// 2) have scalable vector results, for which support has not yet been added.
460460
static bool isNotLinearizable(Operation *op) {
461461

462462
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
@@ -467,6 +467,10 @@ static bool isNotLinearizable(Operation *op) {
467467
if (unsupported)
468468
return true;
469469

470+
// vector.shape_cast cannot be linearized.
471+
if (isa<vector::ShapeCastOp>(op))
472+
return true;
473+
470474
// Some ops currently don't support scalable vectors.
471475
if (isNotLinearizableBecauseScalable(op))
472476
return true;

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,25 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
413413
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
414414
return %1 : vector<[4]x4xf16>
415415
}
416+
417+
// -----
418+
419+
// DEFAULT-LABEL: test_linearize_across_for
420+
func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
421+
%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
422+
%c0 = arith.constant 0 : index
423+
%c1 = arith.constant 1 : index
424+
%c4 = arith.constant 4 : index
425+
426+
// DEFAULT: scf.for {{.*}} -> (vector<4xi8>)
427+
%1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
428+
429+
// DEFAULT: arith.addi {{.*}} : vector<4xi8>
430+
%2 = arith.addi %arg1, %0 : vector<2x2xi8>
431+
432+
// DEFAULT: scf.yield {{.*}} : vector<4xi8>
433+
scf.yield %2 : vector<2x2xi8>
434+
}
435+
%3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
436+
return %3 : vector<4xi8>
437+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ struct TestVectorEmulateMaskedLoadStore final
837837
}
838838
};
839839

840-
namespace bit_width_constrained_vector_linearize {
840+
namespace bit_width_constrained_linearization {
841841

842842
/// Get the set of operand/result types to check for sufficiently
843843
/// small inner-most dimension size.
@@ -960,7 +960,7 @@ struct TestVectorBitWidthLinearize final
960960
}
961961
};
962962

963-
} // namespace bit_width_constrained_vector_linearize
963+
} // namespace bit_width_constrained_linearization
964964

965965
struct TestVectorLinearize final
966966
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
@@ -989,7 +989,6 @@ struct TestVectorLinearize final
989989
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
990990
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
991991
patterns);
992-
993992
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
994993
converter, patterns, target);
995994

@@ -1073,7 +1072,7 @@ void registerTestVectorLowerings() {
10731072
PassRegistration<TestVectorLinearize>();
10741073

10751074
PassRegistration<
1076-
bit_width_constrained_vector_linearize::TestVectorBitWidthLinearize>();
1075+
bit_width_constrained_linearization::TestVectorBitWidthLinearize>();
10771076

10781077
PassRegistration<TestEliminateVectorMasks>();
10791078
}

0 commit comments

Comments
 (0)