Skip to content

Fix unsupported transpose ops for scalable vectors in LowerVectorTransfer #86163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,19 @@ struct TransferWritePermutationLowering
// Generate new transfer_write operation.
Value newVec = rewriter.create<vector::TransposeOp>(
op.getLoc(), op.getVector(), indices);

auto vectorType = cast<VectorType>(newVec.getType());

if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe *vectorType.getScalableDims().end() is could be an out-of-bounds access. The .end() method returns an iterator passed the end of the array, I think what you probably want is:

Suggested change
if (vectorType.isScalable() && !*vectorType.getScalableDims().end()) {
if (vectorType.isScalable() && !vectorType.getScalableDims().back()) {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, for this check there's isLegalVectorType() here: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp#L421-L430

Maybe this could be moved to some general utils?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, for this check there's isLegalVectorType()

Re-using that hook would make sense to me, but we can't call it isLegalVectorType (and, in general, we need to be careful when labeling vectors as "illegal"):

  • vectors like vector<[4]x4xi32> are perfectly legal at the "abstract" Vector dialect level,
  • only once we start lowering to LLVM (and/or SVE/SME), these types needs to be eliminated and it makes sense to consider them as "illegal".

For scalable vectors that have only 1 scalable dim, this code is correct . @cfRod, I suggest adding a comment that we are assuming that at most 1 dim is scalable. @MacDue , I can rename and move isLegalVectorType in a separate PR. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me 👍 (I agree with the naming, the isLegal was only within the context of the ArmSME pass).

rewriter.eraseOp(newVec.getDefiningOp());
return failure();
Comment on lines +212 to +213
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for? It returns failure() so this is a case where the pattern should not apply (and any changes rolled-back, I think), but it's erasing an operation.

Copy link
Member

@MacDue MacDue Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've misunderstood this change. What's happening here? Did the test case previously lower with TransferWritePermutationLowering, and now TransferWriteNonPermutationLowering now manages to lower the same thing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and any changes rolled-back, I think), but it's erasing an operation.

It's erasing the Op to roll-back the changes :) And making sure that the new Op is not added to the list of Ops to be processed by the pattern rewriter driver.

I think I've misunderstood this change. What's happening here? Did the test case previously lower with TransferWritePermutationLowering, and now TransferWriteNonPermutationLowering now manages to lower the same thing?

The test case used to trigger TransferWritePermutationLowering, but that's now disabled and looks like some other pattern is triggered. @cfRod , do you know which one?

Copy link
Contributor Author

@cfRod cfRod Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously it was: TransferWriteNonPermutationLowering called first and so the two vector.broadcast are added and then a transpose for the mask and then TransferWritePermutationLowering is called to add the transpose for the input

* Pattern (anonymous namespace)::TransferWriteNonPermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWriteNonPermutationLowering"
  ** Insert  : 'vector.broadcast'(0xac685f872c10)
  ** Insert  : 'vector.broadcast'(0xac685f879e40)
  ** Insert  : 'vector.transpose'(0xac685f879ed0)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::vector::detail::TransferWriteOpGenericAdaptorBase::Properties)
  ** Insert  : 'vector.transfer_write'(0xac685f7e1d00)
  ** Replace : 'vector.transfer_write'(0xac685f7b1eb0)
  ** Erase   : 'vector.transfer_write'(0xac685f7b1eb0)
"(anonymous namespace)::TransferWriteNonPermutationLowering" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
%c0 = arith.constant 0 : index
%0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
%1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
%2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
return
}


} -> success : pattern matched
//===-------------------------------------------===//

//===-------------------------------------------===//
Processing operation : 'vector.transfer_write'(0xac685f7e1d00) {
"vector.transfer_write"(%1, %arg1, %0, %0, %0, %0, %0, %0, %0, %3) <{in_bounds = [true, true, true, true, true, true], operandSegmentSizes = array<i32: 1, 1, 7, 1>, permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6, d1, d2)>}> : (vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>, index, index, index, index, index, index, index, vector<4x[8]x1x1x1x1xi1>) -> ()

ImplicitTypeIDRegistry::lookupOrInsert(mlir::OpTrait::OneTypedResult<mlir::VectorType>::Impl<Empty>)

* Pattern (anonymous namespace)::TransferWritePermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWritePermutationLowering"
  ** Insert  : 'vector.transpose'(0xac685f87b6b0)
  ** Insert  : 'vector.transfer_write'(0xac685f7b1eb0)
  ** Replace : 'vector.transfer_write'(0xac685f7e1d00)
  ** Erase   : 'vector.transfer_write'(0xac685f7e1d00)
"(anonymous namespace)::TransferWritePermutationLowering" result 1
} -> success : pattern applied successfully
// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %arg2: vector<4x[8]xi1>) {
%c0 = arith.constant 0 : index
%0 = vector.broadcast %arg0 : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
%1 = vector.broadcast %arg2 : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
%2 = vector.transpose %1, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
%3 = vector.transpose %0, [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
vector.transfer_write %3, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %2 {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
return
}

After this patch, the second transpose is "erased"

* Pattern (anonymous namespace)::TransferWritePermutationLowering : 'vector.transfer_write -> ()' {
Trying to match "(anonymous namespace)::TransferWritePermutationLowering"
    ** Insert  : 'vector.transpose'(0xaea769721de0)
    ** Erase   : 'vector.transpose'(0xaea769721de0)
"(anonymous namespace)::TransferWritePermutationLowering" result 0
  } -> failure : pattern failed to match

}

auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
op.getMask(), newInBoundsAttr);

return success();
}
};
Expand Down Expand Up @@ -273,7 +280,7 @@ struct TransferWriteNonPermutationLowering
missingInnerDim.size());
// Mask: add unit dims at the end of the shape.
Value newMask;
if (op.getMask())
if (op.getMask() && !op.getVectorType().isScalable())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can the mask be omitted for scalable vectors? I see the input test case is also masked, but the new vector.transfer_write is unmasked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can erase the unsupported transpose. Similar to how we do it TransferWritePermutationLowering?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In TransferWritePermutationLowering the rewrite fails (i.e. it does nothing), erasing the transpose is just doing a cleanup. Here the rewrite succeeds but the mask is omitted, which changes the semantics of the operation, which could lead to incorrect results.

newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
missingInnerDim.size());
exprs.append(map.getResults().begin(), map.getResults().end());
Expand Down
25 changes: 12 additions & 13 deletions mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,23 @@ func.func @permutation_with_mask_scalable(%2: memref<?x?xf32>, %dim_1: index, %d
return %1 : vector<8x[4]x2xf32>
}

// CHECK: func.func @permutation_with_mask_transfer_write_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[BCAST_1:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
// CHECK: %[[BCAST_2:.*]] = vector.broadcast %[[MASK]] : vector<4x[8]xi1> to vector<1x1x1x1x4x[8]xi1>
// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BCAST_2]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi1> to vector<4x[8]x1x1x1x1xi1>
// CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BCAST_1]], [4, 5, 0, 1, 2, 3] : vector<1x1x1x1x4x[8]xi16> to vector<4x[8]x1x1x1x1xi16>
// CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[TRANSPOSE_1]] {in_bounds = [true, true, true, true, true, true]} : vector<4x[8]x1x1x1x1xi16>, memref<1x4x?x1x1x1x1xi16>
// CHECK: return
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
// CHECK-LABEL: func.func @permutation_with_mask_transfer_write_scalable(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
// CHECK-SAME: %[[ARG_1:.*]]: memref<1x4x?x1x1x1x1xi16>,
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : vector<4x[8]xi16> to vector<1x1x1x1x4x[8]xi16>
// CHECK: vector.transfer_write %[[BCAST]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true, true], permutation_map = #map} : vector<1x1x1x1x4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>
// CHECK: return
// CHECK: }
func.func @permutation_with_mask_transfer_write_scalable(%arg0: vector<4x[8]xi16>, %arg1: memref<1x4x?x1x1x1x1xi16>, %mask: vector<4x[8]xi1>){
%c0 = arith.constant 0 : index
vector.transfer_write %arg0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0, %c0], %mask {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2)>
} : vector<4x[8]xi16>, memref<1x4x?x1x1x1x1xi16>

return
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
Expand Down