Skip to content

Commit 1288ba3

Browse files
committed
Reapply "[mlir][Vector] Tighten up application conditions in TransferReadAfter… (llvm#143869)"
This reverts commit a1e53bc.
1 parent adf3f48 commit 1288ba3

File tree

2 files changed

+117
-21
lines changed

2 files changed

+117
-21
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4757,12 +4757,15 @@ struct TransferReadAfterWriteToBroadcast
47574757

47584758
LogicalResult matchAndRewrite(TransferReadOp readOp,
47594759
PatternRewriter &rewriter) const override {
4760-
if (readOp.hasOutOfBoundsDim() ||
4761-
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4762-
return failure();
47634760
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
47644761
if (!defWrite)
47654762
return failure();
4763+
// Bail if we need an alias analysis.
4764+
if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
4765+
return failure();
4766+
// Bail if we need a bounds analysis.
4767+
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
4768+
return failure();
47664769
// TODO: If the written transfer chunk is a superset of the read transfer
47674770
// chunk we could do an extract_strided_slice.
47684771
if (readOp.getTransferChunkAccessed() !=
@@ -4773,15 +4776,28 @@ struct TransferReadAfterWriteToBroadcast
47734776
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
47744777
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
47754778
return failure();
4776-
if (readOp.getIndices() != defWrite.getIndices() ||
4777-
readOp.getMask() != defWrite.getMask())
4779+
// This pattern should only catch the broadcast case, the non-broadcast case
4780+
// should be done separately to keep application conditions clean and
4781+
// separate.
4782+
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4783+
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4784+
bool bcast = !readMap.getBroadcastDims().empty() ||
4785+
!writeMap.getBroadcastDims().empty();
4786+
if (!bcast)
4787+
return failure();
4788+
// At this point, we know we have a bcast.
4789+
// Bail in the masked case (too complex atm and needed to properly account
4790+
// for padding).
4791+
if (readOp.getMask() || defWrite.getMask())
4792+
return failure();
4793+
// If indices are not the same a shift may be required, bail.
4794+
if (readOp.getIndices() != defWrite.getIndices())
47784795
return failure();
4796+
47794797
Value vec = defWrite.getVector();
47804798
// TODO: loop through the chain of transfer_write if we can prove that they
47814799
// don't overlap with the transfer_read. This requires improving
47824800
// `isDisjointTransferIndices` helper.
4783-
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4784-
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
47854801
AffineMap map = readMap.compose(writeMap);
47864802
if (map.getNumResults() == 0)
47874803
return failure();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
408408
// -----
409409

410410
// Negative test where the extract is not a subset of the element inserted.
411-
// CHECK-LABEL: extract_strided_fold_negative
411+
// CHECK-LABEL: negative_extract_strided_fold
412412
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
413413
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
414414
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
@@ -417,7 +417,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
417417
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
418418
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
419419
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
420-
func.func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
420+
func.func @negative_extract_strided_fold(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
421421
-> (vector<6x4xf32>) {
422422
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
423423
: vector<4x4xf32> into vector<8x16xf32>
@@ -753,10 +753,10 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
753753

754754
// -----
755755

756-
// CHECK-LABEL: fold_extract_broadcast_negative
756+
// CHECK-LABEL: negative_fold_extract_broadcast
757757
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
758758
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
759-
func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
759+
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
760760
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
761761
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
762762
return %r : vector<4xf32>
@@ -895,11 +895,11 @@ func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
895895

896896
// -----
897897

898-
// CHECK-LABEL: fold_extract_shapecast_negative
898+
// CHECK-LABEL: negative_fold_extract_shapecast
899899
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
900900
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
901901
// CHECK: return %[[R]] : vector<4x2xf32>
902-
func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
902+
func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
903903
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
904904
%r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
905905
return %r : vector<4x2xf32>
@@ -1460,11 +1460,11 @@ func.func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
14601460

14611461
// -----
14621462

1463-
// CHECK-LABEL: func @store_after_load_tensor_negative
1463+
// CHECK-LABEL: func @negative_store_after_load_tensor
14641464
// CHECK: vector.transfer_read
14651465
// CHECK: vector.transfer_write
14661466
// CHECK: return
1467-
func.func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
1467+
func.func @negative_store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
14681468
%c1 = arith.constant 1 : index
14691469
%c0 = arith.constant 0 : index
14701470
%cf0 = arith.constant 0.0 : f32
@@ -1499,12 +1499,12 @@ func.func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
14991499

15001500
// -----
15011501

1502-
// CHECK-LABEL: func @store_to_load_negative_tensor
1502+
// CHECK-LABEL: func @negative_store_to_load_tensor
15031503
// CHECK: vector.transfer_write
15041504
// CHECK: vector.transfer_write
15051505
// CHECK: %[[V:.*]] = vector.transfer_read
15061506
// CHECK: return %[[V]] : vector<1x4xf32>
1507-
func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
1507+
func.func @negative_store_to_load_tensor(%arg0 : tensor<4x4xf32>,
15081508
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
15091509
%c1 = arith.constant 1 : index
15101510
%c2 = arith.constant 2 : index
@@ -1540,6 +1540,86 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
15401540

15411541
// -----
15421542

1543+
// CHECK-LABEL: func @negative_store_to_load_tensor_memref
1544+
// CHECK-NOT: vector.broadcast
1545+
// CHECK-NOT: vector.transpose
1546+
// CHECK: vector.transfer_write
1547+
// CHECK: vector.transfer_read
1548+
func.func @negative_store_to_load_tensor_memref(
1549+
%arg0 : tensor<?x?xf32>,
1550+
%arg1 : memref<?x?xf32>,
1551+
%v0 : vector<4x2xf32>
1552+
) -> vector<4x2xf32>
1553+
{
1554+
%c0 = arith.constant 0 : index
1555+
%cf0 = arith.constant 0.0 : f32
1556+
vector.transfer_write %v0, %arg1[%c0, %c0] {in_bounds = [true, true]} :
1557+
vector<4x2xf32>, memref<?x?xf32>
1558+
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
1559+
tensor<?x?xf32>, vector<4x2xf32>
1560+
return %0 : vector<4x2xf32>
1561+
}
1562+
1563+
// -----
1564+
1565+
// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
1566+
// CHECK-NOT: vector.broadcast
1567+
// CHECK-NOT: vector.transpose
1568+
// CHECK: vector.transfer_write
1569+
// CHECK: vector.transfer_read
1570+
func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
1571+
%v0 : vector<4x2xf32>) -> vector<4x2xf32> {
1572+
%c0 = arith.constant 0 : index
1573+
%cf0 = arith.constant 0.0 : f32
1574+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
1575+
vector<4x2xf32>, tensor<?x?xf32>
1576+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
1577+
tensor<?x?xf32>, vector<4x2xf32>
1578+
return %0 : vector<4x2xf32>
1579+
}
1580+
1581+
// -----
1582+
1583+
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
1584+
// CHECK-NOT: vector.broadcast
1585+
// CHECK-NOT: vector.transpose
1586+
// CHECK: vector.transfer_write
1587+
// CHECK: vector.transfer_read
1588+
func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<?x?xf32>,
1589+
%v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
1590+
%c0 = arith.constant 0 : index
1591+
%cf0 = arith.constant 0.0 : f32
1592+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
1593+
vector<4x2xf32>, tensor<?x?xf32>
1594+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
1595+
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
1596+
tensor<?x?xf32>, vector<4x2x6xf32>
1597+
return %0 : vector<4x2x6xf32>
1598+
}
1599+
1600+
// -----
1601+
1602+
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_masked
1603+
// CHECK-NOT: vector.broadcast
1604+
// CHECK-NOT: vector.transpose
1605+
// CHECK: vector.transfer_write
1606+
// CHECK: vector.transfer_read
1607+
func.func @negative_store_to_load_tensor_broadcast_masked(
1608+
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
1609+
-> vector<4x2x6xf32>
1610+
{
1611+
%c0 = arith.constant 0 : index
1612+
%cf0 = arith.constant 0.0 : f32
1613+
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} :
1614+
vector<4x2xf32>, tensor<?x?xf32>
1615+
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
1616+
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
1617+
tensor<?x?xf32>, vector<4x2x6xf32>
1618+
return %0 : vector<4x2x6xf32>
1619+
}
1620+
1621+
// -----
1622+
15431623
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
15441624
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
15451625
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
@@ -1604,15 +1684,15 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
16041684

16051685
// -----
16061686

1607-
// CHECK-LABEL: func @dead_store_tensor_negative
1687+
// CHECK-LABEL: func @negative_dead_store_tensor
16081688
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
16091689
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
16101690
// CHECK: vector.transfer_write
16111691
// CHECK: vector.transfer_write
16121692
// CHECK: vector.transfer_read
16131693
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
16141694
// CHECK: return %[[VTW]] : tensor<4x4xf32>
1615-
func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
1695+
func.func @negative_dead_store_tensor(%arg0 : tensor<4x4xf32>,
16161696
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
16171697
%c1 = arith.constant 1 : index
16181698
%c2 = arith.constant 2 : index
@@ -2063,10 +2143,10 @@ func.func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)
20632143

20642144
// -----
20652145

2066-
// CHECK-LABEL: extract_insert_negative
2146+
// CHECK-LABEL: negative_extract_insert
20672147
// CHECK: vector.insert_strided_slice
20682148
// CHECK: vector.extract
2069-
func.func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
2149+
func.func @negative_extract_insert(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
20702150
-> vector<16xf32> {
20712151
%0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
20722152
: vector<2x15xf32> into vector<12x8x16xf32>

0 commit comments

Comments
 (0)