Skip to content

Commit e4de74b

Browse files
[mlir][Vector] Tighten up application conditions in TransferReadAfter… (llvm#143869)
…WriteToBroadcast The pattern would previously apply in spurious cases and generate incorrect IR. In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking. The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore. The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous. There are other patterns that move permutation logic: - either into the transfer, or - outside of the transfer Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR. Co-authored-by: Nicolas Vasilache <[email protected]>
1 parent 62b6940 commit e4de74b

File tree

2 files changed

+117
-21
lines changed

2 files changed

+117
-21
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
46684668

46694669
LogicalResult matchAndRewrite(TransferReadOp readOp,
46704670
PatternRewriter &rewriter) const override {
4671-
if (readOp.hasOutOfBoundsDim() ||
4672-
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4673-
return failure();
46744671
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
46754672
if (!defWrite)
46764673
return failure();
4674+
// Bail if we need an alias analysis.
4675+
if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
4676+
return failure();
4677+
// Bail if we need a bounds analysis.
4678+
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
4679+
return failure();
46774680
// TODO: If the written transfer chunk is a superset of the read transfer
46784681
// chunk we could do an extract_strided_slice.
46794682
if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,28 @@ struct TransferReadAfterWriteToBroadcast
46844687
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
46854688
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
46864689
return failure();
4687-
if (readOp.getIndices() != defWrite.getIndices() ||
4688-
readOp.getMask() != defWrite.getMask())
4690+
// This pattern should only catch the broadcast case, the non-broadcast case
4691+
// should be done separately to keep application conditions clean and
4692+
// separate.
4693+
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4694+
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4695+
bool bcast = !readMap.getBroadcastDims().empty() ||
4696+
!writeMap.getBroadcastDims().empty();
4697+
if (!bcast)
4698+
return failure();
4699+
// At this point, we know we have a bcast.
4700+
// Bail in the masked case (too complex atm and needed to properly account
4701+
// for padding).
4702+
if (readOp.getMask() || defWrite.getMask())
4703+
return failure();
4704+
// If indices are not the same a shift may be required, bail.
4705+
if (readOp.getIndices() != defWrite.getIndices())
46894706
return failure();
4707+
46904708
Value vec = defWrite.getVector();
46914709
// TODO: loop through the chain of transfer_write if we can prove that they
46924710
// don't overlap with the transfer_read. This requires improving
46934711
// `isDisjointTransferIndices` helper.
4694-
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4695-
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
46964712
AffineMap map = readMap.compose(writeMap);
46974713
if (map.getNumResults() == 0)
46984714
return failure();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
408408
// -----
409409

410410
// Negative test where the extract is not a subset of the element inserted.
411-
// CHECK-LABEL: extract_strided_fold_negative
411+
// CHECK-LABEL: negative_extract_strided_fold
412412
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
413413
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
414414
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
@@ -417,7 +417,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
417417
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
418418
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
419419
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
420-
func.func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
420+
func.func @negative_extract_strided_fold(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
421421
-> (vector<6x4xf32>) {
422422
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
423423
: vector<4x4xf32> into vector<8x16xf32>
@@ -753,10 +753,10 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
753753

754754
// -----
755755

756-
// CHECK-LABEL: fold_extract_broadcast_negative
756+
// CHECK-LABEL: negative_fold_extract_broadcast
757757
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
758758
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
759-
func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
759+
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
760760
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
761761
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
762762
return %r : vector<4xf32>
@@ -895,11 +895,11 @@ func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
895895

896896
// -----
897897

898-
// CHECK-LABEL: fold_extract_shapecast_negative
898+
// CHECK-LABEL: negative_fold_extract_shapecast
899899
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
900900
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
901901
// CHECK: return %[[R]] : vector<4x2xf32>
902-
func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
902+
func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
903903
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
904904
%r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
905905
return %r : vector<4x2xf32>
@@ -1460,11 +1460,11 @@ func.func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
14601460

14611461
// -----
14621462

1463-
// CHECK-LABEL: func @store_after_load_tensor_negative
1463+
// CHECK-LABEL: func @negative_store_after_load_tensor
14641464
// CHECK: vector.transfer_read
14651465
// CHECK: vector.transfer_write
14661466
// CHECK: return
1467-
func.func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
1467+
func.func @negative_store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
14681468
%c1 = arith.constant 1 : index
14691469
%c0 = arith.constant 0 : index
14701470
%cf0 = arith.constant 0.0 : f32
@@ -1499,12 +1499,12 @@ func.func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
14991499

15001500
// -----
15011501

1502-
// CHECK-LABEL: func @store_to_load_negative_tensor
1502+
// CHECK-LABEL: func @negative_store_to_load_tensor
15031503
// CHECK: vector.transfer_write
15041504
// CHECK: vector.transfer_write
15051505
// CHECK: %[[V:.*]] = vector.transfer_read
15061506
// CHECK: return %[[V]] : vector<1x4xf32>
1507-
func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
1507+
func.func @negative_store_to_load_tensor(%arg0 : tensor<4x4xf32>,
15081508
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
15091509
%c1 = arith.constant 1 : index
15101510
%c2 = arith.constant 2 : index
@@ -1540,6 +1540,86 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
15401540

15411541
// -----
15421542

1543+
// CHECK-LABEL: func @negative_store_to_load_tensor_memref
1544+
// CHECK-NOT: vector.broadcast
1545+
// CHECK-NOT: vector.transpose
1546+
// CHECK: vector.transfer_write
1547+
// CHECK: vector.transfer_read
1548+
func.func @negative_store_to_load_tensor_memref(
1549+
%arg0 : tensor<?x?xf32>,
1550+
%arg1 : memref<?x?xf32>,
1551+
%v0 : vector<4x2xf32>
1552+
) -> vector<4x2xf32>
1553+
{
1554+
%c0 = arith.constant 0 : index
1555+
%cf0 = arith.constant 0.0 : f32
1556+
vector.transfer_write %v0, %arg1[%c0, %c0] {in_bounds = [true, true]} :
1557+
vector<4x2xf32>, memref<?x?xf32>
1558+
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
1559+
tensor<?x?xf32>, vector<4x2xf32>
1560+
return %0 : vector<4x2xf32>
1561+
}
1562+
1563+
// -----
1564+
1565+
// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
1566+
// CHECK-NOT: vector.broadcast
1567+
// CHECK-NOT: vector.transpose
1568+
// CHECK: vector.transfer_write
1569+
// CHECK: vector.transfer_read
1570+
func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
1571+
%v0 : vector<4x2xf32>) -> vector<4x2xf32> {
1572+
%c0 = arith.constant 0 : index
1573+
%cf0 = arith.constant 0.0 : f32
1574+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
1575+
vector<4x2xf32>, tensor<?x?xf32>
1576+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
1577+
tensor<?x?xf32>, vector<4x2xf32>
1578+
return %0 : vector<4x2xf32>
1579+
}
1580+
1581+
// -----
1582+
1583+
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
1584+
// CHECK-NOT: vector.broadcast
1585+
// CHECK-NOT: vector.transpose
1586+
// CHECK: vector.transfer_write
1587+
// CHECK: vector.transfer_read
1588+
func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<?x?xf32>,
1589+
%v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
1590+
%c0 = arith.constant 0 : index
1591+
%cf0 = arith.constant 0.0 : f32
1592+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
1593+
vector<4x2xf32>, tensor<?x?xf32>
1594+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
1595+
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
1596+
tensor<?x?xf32>, vector<4x2x6xf32>
1597+
return %0 : vector<4x2x6xf32>
1598+
}
1599+
1600+
// -----
1601+
1602+
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_masked
1603+
// CHECK-NOT: vector.broadcast
1604+
// CHECK-NOT: vector.transpose
1605+
// CHECK: vector.transfer_write
1606+
// CHECK: vector.transfer_read
1607+
func.func @negative_store_to_load_tensor_broadcast_masked(
1608+
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
1609+
-> vector<4x2x6xf32>
1610+
{
1611+
%c0 = arith.constant 0 : index
1612+
%cf0 = arith.constant 0.0 : f32
1613+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} :
1614+
vector<4x2xf32>, tensor<?x?xf32>
1615+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
1616+
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
1617+
tensor<?x?xf32>, vector<4x2x6xf32>
1618+
return %0 : vector<4x2x6xf32>
1619+
}
1620+
1621+
// -----
1622+
15431623
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
15441624
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
15451625
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
@@ -1604,15 +1684,15 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
16041684

16051685
// -----
16061686

1607-
// CHECK-LABEL: func @dead_store_tensor_negative
1687+
// CHECK-LABEL: func @negative_dead_store_tensor
16081688
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
16091689
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
16101690
// CHECK: vector.transfer_write
16111691
// CHECK: vector.transfer_write
16121692
// CHECK: vector.transfer_read
16131693
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
16141694
// CHECK: return %[[VTW]] : tensor<4x4xf32>
1615-
func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
1695+
func.func @negative_dead_store_tensor(%arg0 : tensor<4x4xf32>,
16161696
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
16171697
%c1 = arith.constant 1 : index
16181698
%c2 = arith.constant 2 : index
@@ -2063,10 +2143,10 @@ func.func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)
20632143

20642144
// -----
20652145

2066-
// CHECK-LABEL: extract_insert_negative
2146+
// CHECK-LABEL: negative_extract_insert
20672147
// CHECK: vector.insert_strided_slice
20682148
// CHECK: vector.extract
2069-
func.func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
2149+
func.func @negative_extract_insert(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
20702150
-> vector<16xf32> {
20712151
%0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
20722152
: vector<2x15xf32> into vector<12x8x16xf32>

0 commit comments

Comments
 (0)