Skip to content

[mlir][linalg] Add support for masked vectorization of tensor.insert_slice (2/N) #123031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2716,56 +2716,56 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
}
auto vecType = VectorType::get(vecShape, sourceType.getElementType());

// 3. Generate TransferReadOp.
SmallVector<Value> readIndices(
vecType.getRank(),
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
Operation *read = rewriter.create<vector::TransferReadOp>(
sliceOp.getLoc(), vecType, source, readIndices, padValue,
ArrayRef<bool>{readInBounds});
// 3. Generate TransferReadOp + TransferWriteOp
ReifiedRankedShapedTypeDims reifiedSrcSizes;
Value maskOp;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the declaration to where it is initialized, i.e., l.2742?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that l.2742 sits within an if block and the generated mask is also used outside, e.g. l.2756.

This is roughly the structure:

Value maskOp;
if (!inputVectorSizes.empty()) {
  // Generate the mask - this will make `if (maskOp)` below evaluate to TRUE
}

// Generate readOp

if (maskOp) {
  // Mask the readOp
}

// Generate writeOp (depends on readOp)

if (maskOp) {
  // Mask the writeOp
}

I have ideas how to improve this, but no spare cycles 😢 (there's createWriteOrMaskedWrite and createReadOrMaskedRead that we should re-use here, but that won't work as-is).

If that's OK, will add this to my TODO list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that I missed it. I see, thanks!


// If vector sizes are user provided, make sure to mask xfer_read.
// If vector sizes are user provided, make sure to mask. First, generate the
// mask.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't user-provided vector sizes lead to an unmasked scenario? We have a method that checks if mask is needed here (can't remember the name right now). Couldn't use it for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that's createWriteOrMaskedWrite that I mentioned in my reply to HanHan :) I hit some issue with that, so will re-visit in a separate PR (I'm quite keen to progress this and to reduce my PR backlog)

if (!inputVectorSizes.empty()) {
auto *srcDefOp = source.getDefiningOp();
if (!srcDefOp) {
LDBG("Unable to get the defining Op of " << sliceOp);
return failure();
}

ReifiedRankedShapedTypeDims reifiedSrcSizes;
LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
rewriter, reifiedSrcSizes);
if (status.failed()) {
LDBG("Unable to reify result shapes of " << sliceOp);
LDBG("Unable to reify result shapes of " << srcDefOp);
return failure();
}

// Create the mask
SmallVector<int64_t> readMaskShape(
sliceOp.getSource().getType().getShape());
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
Value maskOp = rewriter.create<vector::CreateMaskOp>(
maskOp = rewriter.create<vector::CreateMaskOp>(
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);

// Mask the xfer_read Op
read = mlir::vector::maskOperation(rewriter, read, maskOp);
}

// 4. Generate TransferWriteOp.
if (!inputVectorSizes.empty() &&
ShapedType::isDynamicShape(resultType.getShape())) {
LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp);
return failure();
SmallVector<Value> readIndices(
vecType.getRank(),
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
Operation *read = rewriter.create<vector::TransferReadOp>(
sliceOp.getLoc(), vecType, source, readIndices, padValue,
ArrayRef<bool>{readInBounds});

if (maskOp) {
read = mlir::vector::maskOperation(rewriter, read, maskOp);
}

auto writeIndices = getValueOrCreateConstantIndexOp(
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());

// 5. Finalize
Operation *write = rewriter.create<vector::TransferWriteOp>(
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
ArrayRef<bool>{writeInBounds});

if (maskOp) {
write = mlir::vector::maskOperation(rewriter, write, maskOp);
}

// 4. Finalize
newResults.push_back(write->getResult(0));

return success();
Expand Down
23 changes: 0 additions & 23 deletions mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -280,26 +280,3 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// One of the _destination_ dimensions is dynamic (but _source_ dimensions are static).

func.func private @insert_slice_dynamic_dest_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> {
%c2 = arith.constant 2 : index
%init = tensor.empty(%size) : tensor<?x3xi32>

%source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
// expected-error @+1 {{Attempted to vectorize, but failed}}
%res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<?x3xi32>

return %res : tensor<?x3xi32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
transform.yield
}
}
89 changes: 82 additions & 7 deletions mlir/test/Dialect/Linalg/vectorization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1130,14 +1130,14 @@ func.func private @insert_slice_static_sizes(%source: tensor<?x3x?x1xi32>) -> te
// CHECK: %[[C_2:.*]] = arith.constant 2 : index
// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<5x3xi32>
// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SEC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C_5:.*]] = arith.constant 5 : index
// CHECK: %[[C_1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C_5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C0]], %[[C0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32>
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32>
// CHECK: return %[[RES]] : tensor<5x3xi32>

module attributes {transform.with_named_sequence} {
Expand Down Expand Up @@ -1170,11 +1170,11 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor<?x3x?x1xi32>, %s
// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32>
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C_1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
// CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C_1]] : vector<8x1xi1>
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<?x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32>
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32> } : vector<8x1xi1> -> tensor<5x3xi32>
// CHECK: return %[[RES]] : tensor<5x3xi32>

module attributes {transform.with_named_sequence} {
Expand All @@ -1184,3 +1184,78 @@ func.func private @insert_slice_dynamic_src_dim(%source: tensor<?x3x?x1xi32>, %s
transform.yield
}
}

// -----

// One of the _destination_ dimensions is dynamic (but _source_ dimensions are static).

func.func private @insert_slice_dynamic_dest_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> {
%c2 = arith.constant 2 : index
%init = tensor.empty(%size) : tensor<?x3xi32>

%source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
%res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<?x3xi32>

return %res : tensor<?x3xi32>
}

// CHECK-LABEL: func.func private @insert_slice_dynamic_dest_dim(
// CHECK-SAME: %[[SRC:.*]]: tensor<?x3x?x1xi32>,
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?x3xi32> {
// CHECK: %[[C_2:.*]] = arith.constant 2 : index
// CHECK: %[[INIT:.*]] = tensor.empty(%[[SIZE]]) : tensor<?x3xi32>
// CHECK: %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK: %[[C_5:.*]] = arith.constant 5 : index
// CHECK: %[[C_1:.*]] = arith.constant 1 : index
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1>
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index
// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<?x3xi32> } : vector<8x1xi1> -> tensor<?x3xi32>
// CHECK: return %[[WRITE]] : tensor<?x3xi32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
transform.yield
}
}

// -----

// At least one _source_ and one _destination_ dimensions are dynamic.

func.func private @insert_slice_dynamic_source_and_dest_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> {
%c2 = arith.constant 2 : index
%init = tensor.empty(%size) : tensor<?x3xi32>

%source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32>
%res = tensor.insert_slice %source_slice into %init[0, %c2] [%size, 1] [1, 1] : tensor<?x1xi32> into tensor<?x3xi32>

return %res : tensor<?x3xi32>
}

// CHECK-LABEL: func.func private @insert_slice_dynamic_source_and_dest_dim(
// CHECK-SAME: %[[SRC:.*]]: tensor<?x3x?x1xi32>,
// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?x3xi32> {
// CHECK: %[[C_2:.*]] = arith.constant 2 : index
// CHECK: %[[INIT:.*]] = tensor.empty(%[[SIZE]]) : tensor<?x3xi32>
// CHECK: %[[SRC_SIZE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32>
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C1]] : vector<8x1xi1>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SIZE]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<?x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
// CHECK: %[[C_0_1:.*]] = arith.constant 0 : index
// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[INIT]]{{\[}}%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<?x3xi32> } : vector<8x1xi1> -> tensor<?x3xi32>
// CHECK: return %[[WRITE]] : tensor<?x3xi32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
transform.yield
}
}