Skip to content

Commit a3f6350

Browse files
bjacobGroverkss
authored andcommitted
Revert "[mlir][Vector] Tighten up application conditions in TransferReadAfter… (llvm#143869)"
This reverts commit e4de74b.
1 parent d31ba52 commit a3f6350

File tree

2 files changed

+21
-117
lines changed

2 files changed

+21
-117
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4758,15 +4758,12 @@ struct TransferReadAfterWriteToBroadcast
47584758

47594759
LogicalResult matchAndRewrite(TransferReadOp readOp,
47604760
PatternRewriter &rewriter) const override {
4761+
if (readOp.hasOutOfBoundsDim() ||
4762+
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
4763+
return failure();
47614764
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
47624765
if (!defWrite)
47634766
return failure();
4764-
// Bail if we need an alias analysis.
4765-
if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
4766-
return failure();
4767-
// Bail if we need a bounds analysis.
4768-
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
4769-
return failure();
47704767
// TODO: If the written transfer chunk is a superset of the read transfer
47714768
// chunk we could do an extract_strided_slice.
47724769
if (readOp.getTransferChunkAccessed() !=
@@ -4777,28 +4774,15 @@ struct TransferReadAfterWriteToBroadcast
47774774
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
47784775
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
47794776
return failure();
4780-
// This pattern should only catch the broadcast case, the non-broadcast case
4781-
// should be done separately to keep application conditions clean and
4782-
// separate.
4783-
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4784-
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
4785-
bool bcast = !readMap.getBroadcastDims().empty() ||
4786-
!writeMap.getBroadcastDims().empty();
4787-
if (!bcast)
4788-
return failure();
4789-
// At this point, we know we have a bcast.
4790-
// Bail in the masked case (too complex atm and needed to properly account
4791-
// for padding).
4792-
if (readOp.getMask() || defWrite.getMask())
4793-
return failure();
4794-
// If indices are not the same a shift may be required, bail.
4795-
if (readOp.getIndices() != defWrite.getIndices())
4777+
if (readOp.getIndices() != defWrite.getIndices() ||
4778+
readOp.getMask() != defWrite.getMask())
47964779
return failure();
4797-
47984780
Value vec = defWrite.getVector();
47994781
// TODO: loop through the chain of transfer_write if we can prove that they
48004782
// don't overlap with the transfer_read. This requires improving
48014783
// `isDisjointTransferIndices` helper.
4784+
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
4785+
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
48024786
AffineMap map = readMap.compose(writeMap);
48034787
if (map.getNumResults() == 0)
48044788
return failure();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
408408
// -----
409409

410410
// Negative test where the extract is not a subset of the element inserted.
411-
// CHECK-LABEL: negative_extract_strided_fold
411+
// CHECK-LABEL: extract_strided_fold_negative
412412
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
413413
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
414414
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
@@ -417,7 +417,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
417417
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
418418
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
419419
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
420-
func.func @negative_extract_strided_fold(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
420+
func.func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
421421
-> (vector<6x4xf32>) {
422422
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
423423
: vector<4x4xf32> into vector<8x16xf32>
@@ -753,10 +753,10 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
753753

754754
// -----
755755

756-
// CHECK-LABEL: negative_fold_extract_broadcast
756+
// CHECK-LABEL: fold_extract_broadcast_negative
757757
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
758758
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
759-
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
759+
func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
760760
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
761761
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
762762
return %r : vector<4xf32>
@@ -895,11 +895,11 @@ func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {
895895

896896
// -----
897897

898-
// CHECK-LABEL: negative_fold_extract_shapecast
898+
// CHECK-LABEL: fold_extract_shapecast_negative
899899
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
900900
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
901901
// CHECK: return %[[R]] : vector<4x2xf32>
902-
func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
902+
func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
903903
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
904904
%r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
905905
return %r : vector<4x2xf32>
@@ -1460,11 +1460,11 @@ func.func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
14601460

14611461
// -----
14621462

1463-
// CHECK-LABEL: func @negative_store_after_load_tensor
1463+
// CHECK-LABEL: func @store_after_load_tensor_negative
14641464
// CHECK: vector.transfer_read
14651465
// CHECK: vector.transfer_write
14661466
// CHECK: return
1467-
func.func @negative_store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
1467+
func.func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
14681468
%c1 = arith.constant 1 : index
14691469
%c0 = arith.constant 0 : index
14701470
%cf0 = arith.constant 0.0 : f32
@@ -1499,12 +1499,12 @@ func.func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
14991499

15001500
// -----
15011501

1502-
// CHECK-LABEL: func @negative_store_to_load_tensor
1502+
// CHECK-LABEL: func @store_to_load_negative_tensor
15031503
// CHECK: vector.transfer_write
15041504
// CHECK: vector.transfer_write
15051505
// CHECK: %[[V:.*]] = vector.transfer_read
15061506
// CHECK: return %[[V]] : vector<1x4xf32>
1507-
func.func @negative_store_to_load_tensor(%arg0 : tensor<4x4xf32>,
1507+
func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
15081508
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
15091509
%c1 = arith.constant 1 : index
15101510
%c2 = arith.constant 2 : index
@@ -1540,86 +1540,6 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
15401540

15411541
// -----
15421542

1543-
// CHECK-LABEL: func @negative_store_to_load_tensor_memref
1544-
// CHECK-NOT: vector.broadcast
1545-
// CHECK-NOT: vector.transpose
1546-
// CHECK: vector.transfer_write
1547-
// CHECK: vector.transfer_read
1548-
func.func @negative_store_to_load_tensor_memref(
1549-
%arg0 : tensor<?x?xf32>,
1550-
%arg1 : memref<?x?xf32>,
1551-
%v0 : vector<4x2xf32>
1552-
) -> vector<4x2xf32>
1553-
{
1554-
%c0 = arith.constant 0 : index
1555-
%cf0 = arith.constant 0.0 : f32
1556-
vector.transfer_write %v0, %arg1[%c0, %c0] {in_bounds = [true, true]} :
1557-
vector<4x2xf32>, memref<?x?xf32>
1558-
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
1559-
tensor<?x?xf32>, vector<4x2xf32>
1560-
return %0 : vector<4x2xf32>
1561-
}
1562-
1563-
// -----
1564-
1565-
// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
1566-
// CHECK-NOT: vector.broadcast
1567-
// CHECK-NOT: vector.transpose
1568-
// CHECK: vector.transfer_write
1569-
// CHECK: vector.transfer_read
1570-
func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
1571-
%v0 : vector<4x2xf32>) -> vector<4x2xf32> {
1572-
%c0 = arith.constant 0 : index
1573-
%cf0 = arith.constant 0.0 : f32
1574-
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
1575-
vector<4x2xf32>, tensor<?x?xf32>
1576-
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
1577-
tensor<?x?xf32>, vector<4x2xf32>
1578-
return %0 : vector<4x2xf32>
1579-
}
1580-
1581-
// -----
1582-
1583-
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
1584-
// CHECK-NOT: vector.broadcast
1585-
// CHECK-NOT: vector.transpose
1586-
// CHECK: vector.transfer_write
1587-
// CHECK: vector.transfer_read
1588-
func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<?x?xf32>,
1589-
%v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
1590-
%c0 = arith.constant 0 : index
1591-
%cf0 = arith.constant 0.0 : f32
1592-
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
1593-
vector<4x2xf32>, tensor<?x?xf32>
1594-
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
1595-
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
1596-
tensor<?x?xf32>, vector<4x2x6xf32>
1597-
return %0 : vector<4x2x6xf32>
1598-
}
1599-
1600-
// -----
1601-
1602-
// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_masked
1603-
// CHECK-NOT: vector.broadcast
1604-
// CHECK-NOT: vector.transpose
1605-
// CHECK: vector.transfer_write
1606-
// CHECK: vector.transfer_read
1607-
func.func @negative_store_to_load_tensor_broadcast_masked(
1608-
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
1609-
-> vector<4x2x6xf32>
1610-
{
1611-
%c0 = arith.constant 0 : index
1612-
%cf0 = arith.constant 0.0 : f32
1613-
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} :
1614-
vector<4x2xf32>, tensor<?x?xf32>
1615-
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
1616-
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
1617-
tensor<?x?xf32>, vector<4x2x6xf32>
1618-
return %0 : vector<4x2x6xf32>
1619-
}
1620-
1621-
// -----
1622-
16231543
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
16241544
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
16251545
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
@@ -1684,15 +1604,15 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
16841604

16851605
// -----
16861606

1687-
// CHECK-LABEL: func @negative_dead_store_tensor
1607+
// CHECK-LABEL: func @dead_store_tensor_negative
16881608
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
16891609
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
16901610
// CHECK: vector.transfer_write
16911611
// CHECK: vector.transfer_write
16921612
// CHECK: vector.transfer_read
16931613
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
16941614
// CHECK: return %[[VTW]] : tensor<4x4xf32>
1695-
func.func @negative_dead_store_tensor(%arg0 : tensor<4x4xf32>,
1615+
func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
16961616
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
16971617
%c1 = arith.constant 1 : index
16981618
%c2 = arith.constant 2 : index
@@ -2143,10 +2063,10 @@ func.func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)
21432063

21442064
// -----
21452065

2146-
// CHECK-LABEL: negative_extract_insert
2066+
// CHECK-LABEL: extract_insert_negative
21472067
// CHECK: vector.insert_strided_slice
21482068
// CHECK: vector.extract
2149-
func.func @negative_extract_insert(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
2069+
func.func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
21502070
-> vector<16xf32> {
21512071
%0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
21522072
: vector<2x15xf32> into vector<12x8x16xf32>

0 commit comments

Comments
 (0)