Skip to content

[mlir][Vector] Tighten up application conditions in TransferReadAfter… #143869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast

LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
if (readOp.hasOutOfBoundsDim() ||
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
return failure();
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
return failure();
// Bail if we need an alias analysis.
if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
return failure();
// Bail if we need a bounds analysis.
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
return failure();
// TODO: If the written transfer chunk is a superset of the read transfer
// chunk we could do an extract_strided_slice.
if (readOp.getTransferChunkAccessed() !=
Expand All @@ -4684,15 +4687,28 @@ struct TransferReadAfterWriteToBroadcast
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
return failure();
if (readOp.getIndices() != defWrite.getIndices() ||
readOp.getMask() != defWrite.getMask())
// This pattern should only catch the broadcast case, the non-broadcast case
// should be done separately to keep application conditions clean and
// separate.
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
bool bcast = !readMap.getBroadcastDims().empty() ||
!writeMap.getBroadcastDims().empty();
if (!bcast)
return failure();
// At this point, we know we have a bcast.
// Bail in the masked case (too complex atm and needed to properly account
// for padding).
if (readOp.getMask() || defWrite.getMask())
return failure();
// If indices are not the same a shift may be required, bail.
if (readOp.getIndices() != defWrite.getIndices())
return failure();

Value vec = defWrite.getVector();
// TODO: loop through the chain of transfer_write if we can prove that they
// don't overlap with the transfer_read. This requires improving
// `isDisjointTransferIndices` helper.
AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
AffineMap map = readMap.compose(writeMap);
if (map.getNumResults() == 0)
return failure();
Expand Down
108 changes: 94 additions & 14 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
// -----

// Negative test where the extract is not a subset of the element inserted.
// CHECK-LABEL: extract_strided_fold_negative
// CHECK-LABEL: negative_extract_strided_fold
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
Expand All @@ -417,7 +417,7 @@ func.func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
func.func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
func.func @negative_extract_strided_fold(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
-> (vector<6x4xf32>) {
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
Expand Down Expand Up @@ -753,10 +753,10 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,

// -----

// CHECK-LABEL: fold_extract_broadcast_negative
// CHECK-LABEL: negative_fold_extract_broadcast
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
return %r : vector<4xf32>
Expand Down Expand Up @@ -895,11 +895,11 @@ func.func @fold_extract_shapecast_0d_source(%arg0 : vector<f32>) -> f32 {

// -----

// CHECK-LABEL: fold_extract_shapecast_negative
// CHECK-LABEL: negative_fold_extract_shapecast
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<4x2xf32> from vector<2x4x2xf32>
// CHECK: return %[[R]] : vector<4x2xf32>
func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
%r = vector.extract %0[1] : vector<4x2xf32> from vector<2x4x2xf32>
return %r : vector<4x2xf32>
Expand Down Expand Up @@ -1460,11 +1460,11 @@ func.func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {

// -----

// CHECK-LABEL: func @store_after_load_tensor_negative
// CHECK-LABEL: func @negative_store_after_load_tensor
// CHECK: vector.transfer_read
// CHECK: vector.transfer_write
// CHECK: return
func.func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
func.func @negative_store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
Expand Down Expand Up @@ -1499,12 +1499,12 @@ func.func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,

// -----

// CHECK-LABEL: func @store_to_load_negative_tensor
// CHECK-LABEL: func @negative_store_to_load_tensor
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: %[[V:.*]] = vector.transfer_read
// CHECK: return %[[V]] : vector<1x4xf32>
func.func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
func.func @negative_store_to_load_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
Expand Down Expand Up @@ -1540,6 +1540,86 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,

// -----

// CHECK-LABEL: func @negative_store_to_load_tensor_memref
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
vector.transfer_write %v0, %arg1[%c0, %c0] {in_bounds = [true, true]} :
vector<4x2xf32>, memref<?x?xf32>
%0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
tensor<?x?xf32>, vector<4x2xf32>
return %0 : vector<4x2xf32>
}

// -----

// CHECK-LABEL: func @negative_store_to_load_tensor_no_actual_broadcast
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
%v0 : vector<4x2xf32>) -> vector<4x2xf32> {
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
vector<4x2xf32>, tensor<?x?xf32>
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
tensor<?x?xf32>, vector<4x2xf32>
return %0 : vector<4x2xf32>
}

// -----

// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_out_of_bounds
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<?x?xf32>,
%v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
vector<4x2xf32>, tensor<?x?xf32>
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
tensor<?x?xf32>, vector<4x2x6xf32>
return %0 : vector<4x2x6xf32>
}

// -----

// CHECK-LABEL: func @negative_store_to_load_tensor_broadcast_masked
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transpose
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
%w0 = vector.transfer_write %v0, %arg0[%c0, %c0], %mask {in_bounds = [true, true]} :
vector<4x2xf32>, tensor<?x?xf32>
%0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
tensor<?x?xf32>, vector<4x2x6xf32>
return %0 : vector<4x2x6xf32>
}

// -----

// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
Expand Down Expand Up @@ -1604,15 +1684,15 @@ func.func @dead_store_tensor(%arg0 : tensor<4x4xf32>,

// -----

// CHECK-LABEL: func @dead_store_tensor_negative
// CHECK-LABEL: func @negative_dead_store_tensor
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: vector.transfer_write
// CHECK: vector.transfer_write
// CHECK: vector.transfer_read
// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
// CHECK: return %[[VTW]] : tensor<4x4xf32>
func.func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
func.func @negative_dead_store_tensor(%arg0 : tensor<4x4xf32>,
%v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
Expand Down Expand Up @@ -2063,10 +2143,10 @@ func.func @extract_insert_rank_reduce(%a: vector<4xf32>, %b: vector<8x16xf32>)

// -----

// CHECK-LABEL: extract_insert_negative
// CHECK-LABEL: negative_extract_insert
// CHECK: vector.insert_strided_slice
// CHECK: vector.extract
func.func @extract_insert_negative(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
func.func @negative_extract_insert(%a: vector<2x15xf32>, %b: vector<12x8x16xf32>)
-> vector<16xf32> {
%0 = vector.insert_strided_slice %a, %b {offsets = [4, 2, 0], strides = [1, 1]}
: vector<2x15xf32> into vector<12x8x16xf32>
Expand Down