Skip to content

[MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns #112394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,13 @@ namespace {
/// inserting a memref.subview dropping those unit dims. The vector shapes are
/// also reduced accordingly.
class TransferReadDropUnitDimsPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
FailureOr<Value>
matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
vector::MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Expand All @@ -376,6 +378,10 @@ class TransferReadDropUnitDimsPattern
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
return failure();
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
// out.
if (reducedRank == 0 && maskingOp)
return failure();
// Check if the reduced vector shape matches the reduced source shape.
// Otherwise, this case is not supported yet.
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
Expand Down Expand Up @@ -406,27 +412,37 @@ class TransferReadDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
rewriter.getBoolArrayAttr(inBounds));

if (maskingOp) {
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
maskingOp.getMask());
newTransferReadOp = mlir::vector::maskOperation(
rewriter, newTransferReadOp, shapeCastMask);
}

auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, vectorType, newTransferReadOp);
rewriter.replaceOp(transferReadOp, shapeCast);
loc, vectorType, newTransferReadOp->getResults()[0]);

return success();
return shapeCast;
}
};

/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
/// vector shapes are also reduced accordingly.
class TransferWriteDropUnitDimsPattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
: public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
FailureOr<Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
vector::MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Expand All @@ -444,6 +460,10 @@ class TransferWriteDropUnitDimsPattern
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
return failure();
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
// out.
if (reducedRank == 0 && maskingOp)
return failure();
// Check if the reduced vector shape matches the reduced destination shape.
// Otherwise, this case is not supported yet.
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
Expand Down Expand Up @@ -474,13 +494,26 @@ class TransferWriteDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
maskOp, rewriter.getBoolArrayAttr(inBounds));

if (maskingOp) {
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
maskingOp.getMask());
newXferWrite =
mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
}

return success();
if (transferWriteOp.hasPureTensorSemantics())
return newXferWrite->getResults()[0];

// With Memref semantics, there's no return value. Use empty value to signal
// success.
return Value();
}
};

Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,15 @@ func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) ->

// -----

func.func @vector_mask_passthru_type_mismatch(%t0: tensor<f32>, %m0: vector<i1>) -> vector<f32> {
%ft0 = arith.constant 0.0 : f32
// expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
%0 = vector.mask %m0 { vector.transfer_read %t0[], %ft0 : tensor<f32>, vector<f32> } : vector<i1> -> vector<f32>
return %0 : vector<f32>
}

// -----

// expected-note@+1 {{prior use here}}
func.func @vector_mask_passthru_type_mismatch(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> {
%ft0 = arith.constant 0.0 : f32
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s

//-----------------------------------------------------------------------------
// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
//-----------------------------------------------------------------------------

func.func @transfer_read_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
%c0 = arith.constant 0 : index
Expand All @@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]

func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
func.func @transfer_read_rank_reducing_masked(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
%mask: vector<3x2xi1>) -> vector<3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
%v = vector.mask %mask {
vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
} : vector<3x2xi1> -> vector<3x2xi8>
return %v : vector<3x2xi8>
}
// CHECK-LABEL: func @transfer_read_rank_reducing_masked
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.mask %[[MASK]]
// CHECK-SAME: vector.transfer_read %[[SUBVIEW]]

func.func @transfer_write_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
%vec : vector<3x2xi8>) {

%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
Expand All @@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]

func.func @transfer_write_rank_reducing_masked(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
%vec : vector<3x2xi8>,
%mask: vector<3x2xi1>) {
%c0 = arith.constant 0 : index
vector.mask %mask {
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
} : vector<3x2xi1>
return
}
// CHECK-LABEL: func @transfer_write_rank_reducing_masked
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.mask %[[MASK]]
// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]

func.func @transfer_read_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -68,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>

func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
%arg : memref<1x1x1x1x1xf32>,
%mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {

%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%v = vector.mask %mask {
vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst
: memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
} : vector<1x1x1xi1> -> vector<1x1x1xf32>
return %v : vector<1x1x1xf32>
}
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked
// CHECK-NOT: vector.shape_cast
// CHECK-NOT: memref.subview

func.func @transfer_write_and_vector_rank_reducing_to_0d(
%arg : memref<1x1x1x1x1xf32>,
%vec : vector<1x1x1xf32>) {
Expand All @@ -82,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>

func.func @transfer_write_and_vector_rank_reducing_to_0d_masked(
%arg : memref<1x1x1x1x1xf32>,
%vec : vector<1x1x1xf32>,
%mask: vector<1x1x1xi1>) {

%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
vector.mask %mask {
vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0] :
vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
} : vector<1x1x1xi1>
return
}
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked
// CHECK-NOT: vector.shape_cast
// CHECK-NOT: memref.subview

func.func @transfer_read_dynamic_rank_reducing(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
%c0 = arith.constant 0 : index
Expand Down
Loading