Skip to content

Commit cfec2c7

Browse files
committed
Refine how bcast dims are handled
Only mark bcast dims as "in bounds" when all non-bcast dims are "in bounds".
1 parent c6797cb commit cfec2c7

File tree

4 files changed

+39
-5
lines changed

4 files changed

+39
-5
lines changed

mlir/include/mlir/IR/AffineMap.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ class AffineMap {
146146
/// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
147147
bool isMinorIdentity() const;
148148

149+
/// Returns the list of broadcast dimensions (i.e. dims indicated by value 0
150+
/// in the result).
151+
/// Ex:
152+
/// * (d0, d1, d2) -> (0, d1) gives [0]
153+
/// * (d0, d1, d2) -> (d2, d1) gives []
154+
/// * (d0, d1, d2, d4) -> (d0, 0, d1, 0) gives [1, 3]
155+
SmallVector<unsigned> getBroadcastDims() const;
156+
149157
/// Returns true if this affine map is a minor identity up to broadcasted
150158
/// dimensions which are indicated by value 0 in the result. If
151159
/// `broadcastedDims` is not null, it will be populated with the indices of

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4134,6 +4134,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
41344134
bool changed = false;
41354135
SmallVector<bool, 4> newInBounds;
41364136
newInBounds.reserve(op.getTransferRank());
4137+
SmallVector<unsigned> nonBcastDims;
41374138
for (unsigned i = 0; i < op.getTransferRank(); ++i) {
41384139
// 1. Already marked as in-bounds, nothing to see here.
41394140
if (op.isDimInBounds(i)) {
@@ -4148,15 +4149,27 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
41484149
// 2.a Non-broadcast dim
41494150
inBounds = isInBounds(op, /*resultIdx=*/i,
41504151
/*indicesIdx=*/dimExpr.getPosition());
4151-
} else {
4152-
// 2.b Broadcast dim
4153-
inBounds = true;
4152+
// 2.b Broadcast dims are handled after processing non-bcast dims
4153+
// FIXME: constant expr != 0 are not broadcasts - should such
4154+
// constants be allowed at all?
4155+
nonBcastDims.push_back(i);
41544156
}
41554157

41564158
newInBounds.push_back(inBounds);
41574159
// We commit the pattern if it is "more inbounds".
41584160
changed |= inBounds;
41594161
}
4162+
4163+
// Handle broadcast dims: if all non-broadcast dims are "in
4164+
// bounds", then all bcast dims should be "in bounds" as well.
4165+
bool allNonBcastDimsInBounds = llvm::all_of(
4166+
nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
4167+
if (allNonBcastDimsInBounds)
4168+
llvm::for_each(permutationMap.getBroadcastDims(), [&](unsigned idx) {
4169+
changed |= !newInBounds[idx];
4170+
newInBounds[idx] = true;
4171+
});
4172+
41604173
if (!changed)
41614174
return failure();
41624175
// OpBuilder is only used as a helper to build an I64ArrayAttr.

mlir/lib/IR/AffineMap.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ bool AffineMap::isMinorIdentity() const {
158158
getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
159159
}
160160

161+
SmallVector<unsigned> AffineMap::getBroadcastDims() const {
162+
SmallVector<unsigned> broadcastedDims;
163+
for (const auto &[resIdx, expr] : llvm::enumerate(getResults())) {
164+
if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
165+
if (constExpr.getValue() != 0)
166+
continue;
167+
broadcastedDims.push_back(resIdx);
168+
}
169+
}
170+
171+
return broadcastedDims;
172+
}
173+
161174
/// Returns true if this affine map is a minor identity up to broadcasted
162175
/// dimensions which are indicated by value 0 in the result.
163176
bool AffineMap::isMinorIdentityWithBroadcasting(

mlir/test/Dialect/Vector/vector-transfer-unroll.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func.func @transfer_read_unroll_permutation(%mem : memref<6x4xf32>) -> vector<4x
207207
func.func @transfer_read_unroll_broadcast(%mem : memref<6x4xf32>) -> vector<6x4xf32> {
208208
%c0 = arith.constant 0 : index
209209
%cf0 = arith.constant 0.0 : f32
210-
%res = vector.transfer_read %mem[%c0, %c0], %cf0 permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32>
210+
%res = vector.transfer_read %mem[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32>
211211
return %res : vector<6x4xf32>
212212
}
213213

@@ -234,7 +234,7 @@ func.func @transfer_read_unroll_broadcast(%mem : memref<6x4xf32>) -> vector<6x4x
234234
func.func @transfer_read_unroll_broadcast_permuation(%mem : memref<6x4xf32>) -> vector<4x6xf32> {
235235
%c0 = arith.constant 0 : index
236236
%cf0 = arith.constant 0.0 : f32
237-
%res = vector.transfer_read %mem[%c0, %c0], %cf0 permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
237+
%res = vector.transfer_read %mem[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32>
238238
return %res : vector<4x6xf32>
239239
}
240240

0 commit comments

Comments
 (0)