Skip to content

Commit 018d8ac

Browse files
committed
[mlir][Vector] Fix a propagation bug with transfer_read
In the vector distribute patterns, we used to move `vector.transfer_read`s out of `vector.warp_execute_on_lane0`s irrespectively of how they were defined. This could create transfer_read operations that would read values from within the warpOp's body from outside of the body. E.g., ``` warpop { %defined_in_body %read = transfer_read %defined_in_body vector.yield %read } ``` => ``` warpop { %defined_in_body vector.yield ... } // %defined_in_body is referenced outside of its scope. %read = transfer_read %defined_in_body ``` The fix consists in checking that all the values feeding the new `transfer_read` are defined outside of warpOp's body. Note: We could do this check before creating any operation, but that would mean knowing what `affine::makeComposedAffineApply` actually do. So the current fix is a trade off of coupling the implementations of this propagation and `makeComposedAffineApply` versus compile time. Differential Revision: https://reviews.llvm.org/D152149
1 parent 27aea17 commit 018d8ac

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,10 +757,31 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
757757
rewriter, read.getLoc(), d0 + scale * d1,
758758
{indices[indexPos], warpOp.getLaneid()});
759759
}
760-
Value newRead = rewriter.create<vector::TransferReadOp>(
760+
auto newRead = rewriter.create<vector::TransferReadOp>(
761761
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
762762
read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
763763
read.getInBoundsAttr());
764+
765+
// Check that the produced operation is legal.
766+
// The transfer op may be reading from values that are defined within
767+
// warpOp's body, which is illegal.
768+
// We do the check late because incdices may be changed by
769+
// makeComposeAffineApply. This rewrite may remove dependencies from
770+
// warOp's body.
771+
// E.g., warop {
772+
// %idx = affine.apply...[%outsideDef]
773+
// ... = transfer_read ...[%idx]
774+
// }
775+
// will be rewritten in:
776+
// warop {
777+
// }
778+
// %new_idx = affine.apply...[%outsideDef]
779+
// ... = transfer_read ...[%new_idx]
780+
if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
781+
return warpOp.isDefinedOutsideOfRegion(value);
782+
}))
783+
return failure();
784+
764785
rewriter.replaceAllUsesWith(distributedVal, newRead);
765786
return success();
766787
}

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,3 +1109,46 @@ func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
11091109
}
11101110
return %r : vector<4x96xf32>
11111111
}
1112+
1113+
// -----
1114+
1115+
// Check that we don't propagate transfer_reads that have dependencies on
1116+
// values inside the warp_execute_on_lane_0.
1117+
// In this case, propagating would create transfer_read that depends on the
1118+
// extractelment defined in the body.
1119+
1120+
// CHECK-PROP-LABEL: func @transfer_read_no_prop(
1121+
// CHECK-PROP-SAME: %[[IN2:[^ :]*]]: vector<1x2xindex>,
1122+
// CHECK-PROP-SAME: %[[AR1:[^ :]*]]: memref<1x4x2xi32>,
1123+
// CHECK-PROP-SAME: %[[AR2:[^ :]*]]: memref<1x4x1024xf32>)
1124+
// CHECK-PROP-DAG: %[[C0:.*]] = arith.constant 0 : index
1125+
// CHECK-PROP-DAG: %[[THREADID:.*]] = gpu.thread_id x
1126+
// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]]
1127+
// CHECK-PROP: %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}]
1128+
// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<1x64xi32>
1129+
// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
1130+
// CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex>
1131+
// CHECK-PROP: %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[EXTRACTELT]], %[[C0]]],
1132+
// CHECK-PROP: vector.yield %[[TRANSFERREAD]] : vector<64xf32>
1133+
// CHECK-PROP: return %[[W]]
1134+
func.func @transfer_read_no_prop(%in2: vector<1x2xindex>, %ar1 : memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
1135+
%0 = gpu.thread_id x
1136+
%c0_i32 = arith.constant 0 : i32
1137+
%c0 = arith.constant 0 : index
1138+
%cst = arith.constant dense<0> : vector<1x64xi32>
1139+
%cst_0 = arith.constant dense<true> : vector<1x64xi1>
1140+
%cst_1 = arith.constant dense<3> : vector<64xindex>
1141+
%cst_2 = arith.constant dense<0> : vector<64xindex>
1142+
%cst_6 = arith.constant 0.000000e+00 : f32
1143+
1144+
%18 = vector.warp_execute_on_lane_0(%0)[32] args(%in2 : vector<1x2xindex>) -> (vector<2xf32>) {
1145+
^bb0(%arg4: vector<1x64xindex>):
1146+
%28 = vector.gather %ar1[%c0, %c0, %c0] [%arg4], %cst_0, %cst : memref<1x4x2xi32>, vector<1x64xindex>, vector<1x64xi1>, vector<1x64xi32> into vector<1x64xi32>
1147+
%29 = vector.extract %28[0] : vector<1x64xi32>
1148+
%30 = arith.index_cast %29 : vector<64xi32> to vector<64xindex>
1149+
%36 = vector.extractelement %30[%c0_i32 : i32] : vector<64xindex>
1150+
%37 = vector.transfer_read %ar2[%c0, %36, %c0], %cst_6 {in_bounds = [true]} : memref<1x4x1024xf32>, vector<64xf32>
1151+
vector.yield %37 : vector<64xf32>
1152+
}
1153+
return %18 : vector<2xf32>
1154+
}

0 commit comments

Comments
 (0)