Skip to content

Commit 0210067

Browse files
[mlir][Vector] Tighten up application conditions in TransferReadAfterWriteToBroadcast
The pattern would previously apply in spurious cases and generate incorrect IR. In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking. The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore. The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous. There are other patterns that move permutation logic: - either into the transfer, or - outside of the transfer Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR.
1 parent 9d491bc commit 0210067

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
46684668

46694669
LogicalResult matchAndRewrite(TransferReadOp readOp,
46704670
PatternRewriter &rewriter) const override {
4671-
if (readOp.hasOutOfBoundsDim() ||
4672-
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4673-
return failure();
46744671
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
46754672
if (!defWrite)
46764673
return failure();
4674+
// Bail if we need an alias analysis.
4675+
if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
4676+
return failure();
4677+
// Bail if we need a bounds analysis.
4678+
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
4679+
return failure();
46774680
// TODO: If the written transfer chunk is a superset of the read transfer
46784681
// chunk we could do an extract_strided_slice.
46794682
if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,28 @@ struct TransferReadAfterWriteToBroadcast
46844687
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
46854688
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
46864689
return failure();
4687-
if (readOp.getIndices() != defWrite.getIndices() ||
4688-
readOp.getMask() != defWrite.getMask())
4690+
// This pattern should only catch the broadcast case, the non-broadcast case
4691+
// should be done separately to keep application conditions clean and
4692+
// separate.
4693+
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4694+
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4695+
bool bcast = !readMap.getBroadcastDims().empty() ||
4696+
!writeMap.getBroadcastDims().empty();
4697+
if (!bcast)
4698+
return failure();
4699+
// At this point, we know we have a bcast.
4700+
// Bail in the masked case (too complext atm and needed to properly account
4701+
// for padding).
4702+
if (readOp.getMask() || defWrite.getMask())
4703+
return failure();
4704+
// If indices are not the same a shift may be required, bail.
4705+
if (readOp.getIndices() != defWrite.getIndices())
46894706
return failure();
4707+
46904708
Value vec = defWrite.getVector();
46914709
// TODO: loop through the chain of transfer_write if we can prove that they
46924710
// don't overlap with the transfer_read. This requires improving
46934711
// `isDisjointTransferIndices` helper.
4694-
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4695-
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
46964712
AffineMap map = readMap.compose(writeMap);
46974713
if (map.getNumResults() == 0)
46984714
return failure();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,43 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
15401540

15411541
// -----
15421542

1543+
// CHECK-LABEL: func @store_to_load_tensor_no_actual_broadcast
1544+
// CHECK-NOT: vector.broadcast
1545+
// CHECK-NOT: vector.transpose
1546+
// CHECK: vector.transfer_write
1547+
// CHECK: vector.transfer_read
1548+
func.func @store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
1549+
%v0 : vector<4x2xf32>) -> vector<4x2xf32> {
1550+
%c0 = arith.constant 0 : index
1551+
%cf0 = arith.constant 0.0 : f32
1552+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
1553+
vector<4x2xf32>, tensor<?x?xf32>
1554+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
1555+
tensor<?x?xf32>, vector<4x2xf32>
1556+
return %0 : vector<4x2xf32>
1557+
}
1558+
1559+
// -----
1560+
1561+
// CHECK-LABEL: func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize
1562+
// CHECK-NOT: vector.broadcast
1563+
// CHECK-NOT: vector.transpose
1564+
// CHECK: vector.transfer_write
1565+
// CHECK: vector.transfer_read
1566+
func.func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize(%arg0 : tensor<?x?xf32>,
1567+
%v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
1568+
%c0 = arith.constant 0 : index
1569+
%cf0 = arith.constant 0.0 : f32
1570+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
1571+
vector<4x2xf32>, tensor<?x?xf32>
1572+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
1573+
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
1574+
tensor<?x?xf32>, vector<4x2x6xf32>
1575+
return %0 : vector<4x2x6xf32>
1576+
}
1577+
1578+
// -----
1579+
15431580
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
15441581
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
15451582
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>

0 commit comments

Comments
 (0)