Skip to content

Commit 569c560

Browse files
committed
first commit (needs refinement)
1 parent 8b9ae65 commit 569c560

File tree

4 files changed

+34
-16
lines changed

4 files changed

+34
-16
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,13 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
407407
RewritePatternSet &patterns, PatternBenefit benefit = 1);
408408

409409
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
410-
/// This registers (1) which operations are legal and hence should not be
411-
/// linearized, (2) what converted types are (rank-1 vectors) and how to
410+
///
411+
/// Definition: here 'linearization' means converting a single operation with
412+
/// 1+ vector operands and results of rank>1, into a single operation whose
413+
/// vector operands are all of rank<=1.
414+
///
415+
/// This function registers (1) which operations are legal, and hence should not
416+
/// be linearized, (2) what the converted types are (rank-1 vectors) and how to
412417
/// materialze the conversion (with shape_cast)
413418
///
414419
/// Note: the set of legal operations can be extended by a user if for example

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class ConvertForOpTypes
9999
// PR47938 tracks this issue, but it seems hard to fix. Instead, we need
100100
// to clone the op.
101101
//
102-
// 2. We need to resue the original region instead of cloning it, otherwise
102+
// 2. We need to reuse the original region instead of cloning it, otherwise
103103
// the dialect conversion framework thinks that we just inserted all the
104104
// cloned child ops. But what we want is to "take" the child regions and let
105105
// the dialect conversion framework continue recursively into ops inside

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,6 @@ struct LinearizeVectorExtractStridedSlice final
134134
VectorType dstType =
135135
getTypeConverter()->convertType<VectorType>(extractOp.getType());
136136
assert(dstType && "vector type destination expected.");
137-
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
138-
return rewriter.notifyMatchFailure(extractOp,
139-
"scalable vectors are not supported.");
140137

141138
ArrayAttr offsets = extractOp.getOffsets();
142139
ArrayAttr sizes = extractOp.getSizes();
@@ -447,18 +444,21 @@ struct LinearizeVectorSplat final
447444

448445
} // namespace
449446

450-
/// Return true if the operation `op` does not support scalable vectors and
451-
/// has at least 1 scalable vector result. These ops should all eventually
452-
/// support scalable vectors, and this function should be removed.
447+
/// Some operations currently cannot be linearized if they have scalable vector
448+
/// results. This function returns true if `op` is such an operation.
453449
static bool isNotLinearizableBecauseScalable(Operation *op) {
454450

455451
bool unsupported =
456452
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
457453
op);
454+
455+
// Case where linearization is possible even when there are scalable vector
456+
// results.
458457
if (!unsupported)
459458
return false;
460459

461-
// Check if any of the results is a scalable vector type.
460+
// Check if any of the results is a scalable vector type, and if there are
461+
// return true (not linearizable).
462462
auto types = op->getResultTypes();
463463
bool containsScalableResult =
464464
std::any_of(types.begin(), types.end(), [](Type type) {
@@ -469,10 +469,17 @@ static bool isNotLinearizableBecauseScalable(Operation *op) {
469469
return containsScalableResult;
470470
}
471471

472+
/// This method defines a set of operations that are not linearizable,
473+
/// and hence considered legal for the conversion target. These ops are
474+
/// currently
475+
///
476+
/// 1) Ops that are not in the vector dialect, are not ConstantLike, and are not
477+
/// Vectorizable.
478+
///
479+
/// 2) Certain ops with scalable vector results, for which support has not yet
480+
/// been added.
472481
static bool isNotLinearizable(Operation *op) {
473482

474-
// Only ops that are in the vector dialect, are ConstantLike, or
475-
// are Vectorizable might be linearized currently.
476483
StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
477484
StringRef opDialect = op->getDialect()->getNamespace();
478485
bool unsupported = (opDialect != vectorDialect) &&

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
1919
#include "mlir/Dialect/SCF/IR/SCF.h"
20+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2021
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2122
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2223
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -836,8 +837,7 @@ struct TestVectorEmulateMaskedLoadStore final
836837
}
837838
};
838839

839-
// TODO: move this code into the user project.
840-
namespace vendor {
840+
namespace bit_width_constrained_vector_linearize {
841841

842842
/// Get the set of operand/result types to check for sufficiently
843843
/// small inner-most dimension size.
@@ -960,7 +960,7 @@ struct TestVectorBitWidthLinearize final
960960
}
961961
};
962962

963-
} // namespace vendor
963+
} // namespace bit_width_constrained_vector_linearize
964964

965965
struct TestVectorLinearize final
966966
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
@@ -982,12 +982,17 @@ struct TestVectorLinearize final
982982
RewritePatternSet patterns(&context);
983983
ConversionTarget target(context);
984984

985+
SmallVector<Operation *> ops;
986+
985987
vector::populateForVectorLinearize(converter, target);
986988

987989
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
988990
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
989991
patterns);
990992

993+
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
994+
converter, patterns, target);
995+
991996
if (failed(applyPartialConversion(getOperation(), target,
992997
std::move(patterns))))
993998
return signalPassFailure();
@@ -1067,7 +1072,8 @@ void registerTestVectorLowerings() {
10671072

10681073
PassRegistration<TestVectorLinearize>();
10691074

1070-
PassRegistration<vendor::TestVectorBitWidthLinearize>();
1075+
PassRegistration<
1076+
bit_width_constrained_vector_linearize::TestVectorBitWidthLinearize>();
10711077

10721078
PassRegistration<TestEliminateVectorMasks>();
10731079
}

0 commit comments

Comments
 (0)