Skip to content

Commit 9c3d1a1

Browse files
committed
further enhancements
1 parent 569c560 commit 9c3d1a1

File tree

4 files changed

+46
-16
lines changed

4 files changed

+46
-16
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,16 +409,20 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
409409
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
410410
///
411411
/// Definition: here 'linearization' means converting a single operation with
412-
/// 1+ vector operands and results of rank>1, into a single operation whose
413-
/// vector operands are all of rank<=1.
412+
/// 1+ vector operand/result of rank>1, into a new single operation whose
413+
/// vector operands and results are all of rank<=1.
414414
///
415415
/// This function registers (1) which operations are legal, and hence should not
416416
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
417417
/// materialze the conversion (with shape_cast)
418418
///
419419
/// Note: the set of legal operations can be extended by a user if for example
420-
/// certain rank>1 vectors are considered valid, but adding additional
420+
/// certain rank>1 vectors are considered valid, by adding additional
421421
/// dynamically legal ops to `conversionTarget`.
422+
///
423+
/// Further note: the choice to use a dialect conversion design for
424+
/// linearization is to make it easy to reuse generic structural type
425+
/// conversions for linearizing scf/cf/func operations
422426
void populateForVectorLinearize(TypeConverter &typeConverter,
423427
ConversionTarget &conversionTarget);
424428

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,9 @@ struct LinearizeVectorSplat final
444444

445445
} // namespace
446446

447-
/// Some operations currently cannot be linearized if they have scalable vector
448-
/// results. This function returns true if `op` is such an operation.
447+
/// Some operations currently will not be linearized if they have scalable
448+
/// vector results, although support should be added in the future. This
449+
/// function returns true if `op` is such an operation.
449450
static bool isNotLinearizableBecauseScalable(Operation *op) {
450451

451452
bool unsupported =
@@ -469,15 +470,14 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
469470
return containsScalableResult;
470471
}
471472

472-
/// This method defines a set of operations that are not linearizable,
473-
/// and hence considered legal for the conversion target. These ops are
474-
/// currently
473+
/// This method defines a set of operations that are not linearizable, and hence
474+
/// they are considered legal for the conversion target. These ops are
475+
/// currently,
475476
///
476-
/// 1) Ops that are not in the vector dialect, are not ConstantLike, and are not
477-
/// Vectorizable.
477+
/// 1) ones that are not in the vector dialect, are not ConstantLike, and are
478+
/// not Vectorizable, or
478479
///
479-
/// 2) Certain ops with scalable vector results, for which support has not yet
480-
/// been added.
480+
/// 2) have scalable vector results, for which support has not yet been added.
481481
static bool isNotLinearizable(Operation *op) {
482482

483483
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
@@ -488,6 +488,10 @@ static bool isNotLinearizable(Operation *op) {
488488
if (unsupported)
489489
return true;
490490

491+
// vector.shape_cast cannot be linearized.
492+
if (isa<vector::ShapeCastOp>(op))
493+
return true;
494+
491495
// Some ops currently don't support scalable vectors.
492496
if (isNotLinearizableBecauseScalable(op))
493497
return true;

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,28 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
320320
return %1 : vector<[4]x4xf16>
321321
}
322322

323+
// -----
324+
325+
// CHECK-LABEL: test_linearize_across_for
326+
func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
327+
%0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8>
328+
%c0 = arith.constant 0 : index
329+
%c1 = arith.constant 1 : index
330+
%c4 = arith.constant 4 : index
331+
332+
// CHECK: scf.for {{.*}} -> (vector<4xi8>)
333+
%1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) {
334+
335+
// CHECK: arith.addi {{.*}} : vector<4xi8>
336+
%2 = arith.addi %arg1, %0 : vector<2x2xi8>
337+
338+
// CHECK: scf.yield {{.*}} : vector<4xi8>
339+
scf.yield %2 : vector<2x2xi8>
340+
}
341+
%3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8>
342+
return %3 : vector<4xi8>
343+
}
344+
323345
// -----
324346

325347
// CHECK-LABEL: linearize_vector_splat
@@ -344,4 +366,5 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
344366
// CHECK: return %[[CAST]] : vector<4x[2]xi32>
345367
%0 = vector.splat %arg0 : vector<4x[2]xi32>
346368
return %0 : vector<4x[2]xi32>
369+
347370
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ struct TestVectorEmulateMaskedLoadStore final
837837
}
838838
};
839839

840-
namespace bit_width_constrained_vector_linearize {
840+
namespace bit_width_constrained_linearization {
841841

842842
/// Get the set of operand/result types to check for sufficiently
843843
/// small inner-most dimension size.
@@ -960,7 +960,7 @@ struct TestVectorBitWidthLinearize final
960960
}
961961
};
962962

963-
} // namespace bit_width_constrained_vector_linearize
963+
} // namespace bit_width_constrained_linearization
964964

965965
struct TestVectorLinearize final
966966
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
@@ -989,7 +989,6 @@ struct TestVectorLinearize final
989989
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
990990
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
991991
patterns);
992-
993992
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
994993
converter, patterns, target);
995994

@@ -1073,7 +1072,7 @@ void registerTestVectorLowerings() {
10731072
PassRegistration<TestVectorLinearize>();
10741073

10751074
PassRegistration<
1076-
bit_width_constrained_vector_linearize::TestVectorBitWidthLinearize>();
1075+
bit_width_constrained_linearization::TestVectorBitWidthLinearize>();
10771076

10781077
PassRegistration<TestEliminateVectorMasks>();
10791078
}

0 commit comments

Comments
 (0)