Skip to content

Commit d3a5396

Browse files
committed
[MLIR] Fix VectorEmulateNarrowType constant op mask bug
This commit adds support for handling mask constants generated by the `arith.constant` op in the `VectorEmulateNarrowType` pattern. Previously, this pattern would not match due to the lack of mask constant handling in `getCompressedMaskOp`. The changes include: 1. Updating `getCompressedMaskOp` to recognize and handle `arith.constant` ops as mask value sources. 2. Handling cases where the mask is not aligned with the emulated load width. The compressed mask is adjusted to account for the offset. Limitations: - The arith.constant op can only have 1-dimensional constant values. Resolves: #115742 Signed-off-by: Alan Li <[email protected]>
1 parent 1e5bfac commit d3a5396

File tree

3 files changed

+239
-2
lines changed

3 files changed

+239
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,18 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7070
Operation *maskOp = mask.getDefiningOp();
7171
SmallVector<vector::ExtractOp, 2> extractOps;
7272
// Finding the mask creation operation.
73-
while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
73+
while (maskOp &&
74+
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
75+
maskOp)) {
7476
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
7577
maskOp = extractOp.getVector().getDefiningOp();
7678
extractOps.push_back(extractOp);
7779
}
7880
}
7981
auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
8082
auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
81-
if (!createMaskOp && !constantMaskOp)
83+
auto constantOp = dyn_cast_or_null<arith::ConstantOp>(maskOp);
84+
if (!createMaskOp && !constantMaskOp && !constantOp)
8285
return failure();
8386

8487
// Computing the "compressed" mask. All the emulation logic (i.e. computing
@@ -129,6 +132,45 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
129132
auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
130133
newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
131134
}
135+
} else if (constantOp) {
136+
assert(shape.size() == 1 && "expected 1-D mask");
137+
// Rearrange the original mask values to cover the whole potential loading
138+
// region. For example, in the case of using byte-size for emulation, given
139+
// the following mask:
140+
//
141+
// %mask = vector.constant_mask [0, 1, 0, 1, 0, 0] : vector<6xi2>
142+
//
143+
// with front offset of 1, the mask will be padded zeros in the front and
144+
// back so that its length is multiple of `scale` (and the total coverage
145+
// size is mulitiple of bytes):
146+
// %new_mask = vector.constant_mask [0, 0, 1, 0, 1, 0, 0, 0] :
147+
// vector<8xi2>
148+
//
149+
// The %new_mask is now aligned with the effective loading area and can now
150+
// be compressed.
151+
SmallVector<bool> maskValues(intraDataOffset, false);
152+
if (auto denseAttr =
153+
mlir::dyn_cast<DenseIntElementsAttr>(constantOp.getValue())) {
154+
for (auto value : denseAttr.getValues<bool>()) {
155+
maskValues.push_back(value);
156+
}
157+
while (maskValues.size() < numElements * scale) {
158+
maskValues.push_back(false);
159+
}
160+
} else {
161+
return failure();
162+
}
163+
// Compressing by combining every `scale` elements:
164+
SmallVector<bool> compressedMaskValues;
165+
for (size_t i = 0; i < maskValues.size(); i += scale) {
166+
bool combinedValue = false;
167+
for (int j = 0; j < scale; ++j) {
168+
combinedValue |= maskValues[i + j];
169+
}
170+
compressedMaskValues.push_back(combinedValue);
171+
}
172+
newMask = rewriter.create<arith::ConstantOp>(
173+
loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
132174
}
133175

134176
while (!extractOps.empty()) {

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,174 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
249249
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
250250
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
251251
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
252+
253+
// -----
254+
255+
func.func @vector_store_i2_const(%arg0: vector<3xi2>) {
256+
%0 = memref.alloc() : memref<3x3xi2>
257+
%c0 = arith.constant 0 : index
258+
%c2 = arith.constant 2 : index
259+
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
260+
return
261+
}
262+
263+
// in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
264+
// CHECK: func @vector_store_i2_const(
265+
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
266+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
267+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
268+
269+
// atomic store of the first byte
270+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
271+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
272+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
273+
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
274+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
275+
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
276+
// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
277+
// CHECK: %[[ARG:.+]]: i8):
278+
// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
279+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
280+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
281+
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
282+
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
283+
// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
284+
285+
// atomic store of the second byte
286+
// CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
287+
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
288+
// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
289+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
290+
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
291+
// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
292+
// CHECK: %[[ARG2:.+]]: i8):
293+
// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
294+
// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
295+
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
296+
// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
297+
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
298+
// CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
299+
300+
// -----
301+
302+
func.func @vector_store_i8_2(%arg0: vector<7xi2>) {
303+
%0 = memref.alloc() : memref<3x7xi2>
304+
%c0 = arith.constant 0 : index
305+
%c1 = arith.constant 1 : index
306+
vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
307+
return
308+
}
309+
310+
// in this example, emit 2 atomic stores and 1 non-atomic store
311+
312+
// CHECK: func @vector_store_i8_2(
313+
// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
314+
// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
315+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
316+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
317+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
318+
319+
// first atomic store
320+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
321+
// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
322+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
323+
// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
324+
// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
325+
// CHECK: %[[ARG:.+]]: i8):
326+
// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
327+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
328+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
329+
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
330+
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
331+
// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
332+
333+
// non atomic store part
334+
// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
335+
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
336+
// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
337+
// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
338+
// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
339+
340+
// second atomic store
341+
// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
342+
// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
343+
// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
344+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
345+
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
346+
// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
347+
// CHECK: %[[ARG2:.+]]: i8):
348+
// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
349+
// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
350+
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
351+
// CHECK-SAME: vector<4xi1>, vector<4xi2>
352+
// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
353+
// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
354+
// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
355+
356+
// -----
357+
358+
func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
359+
%0 = memref.alloc() : memref<4x1xi2>
360+
%c0 = arith.constant 0 : index
361+
%c1 = arith.constant 1 : index
362+
vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
363+
return
364+
}
365+
366+
// in this example, only emit 1 atomic store
367+
// CHECK: func @vector_store_i2_single_atomic(
368+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
369+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
370+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
371+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
372+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
373+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
374+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
375+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
376+
377+
// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
378+
// CHECK: %[[ARG:.+]]: i8):
379+
// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
380+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
381+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
382+
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
383+
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
384+
// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
385+
386+
// -----
387+
388+
func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
389+
%0 = memref.alloc() : memref<3x5xi2>
390+
%mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
391+
%c0 = arith.constant 0 : index
392+
%c1 = arith.constant 1 : index
393+
%1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
394+
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
395+
return %1 : vector<5xi2>
396+
}
397+
398+
// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
399+
// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
400+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
401+
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
402+
403+
// CHECK: %[[CST0:.+]] = arith.constant dense<true> : vector<2xi1>
404+
// CHECK: %[[CST1:.+]] = arith.constant dense<0> : vector<8xi2>
405+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[PTH]], %[[CST1]]
406+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
407+
408+
// Emulated masked load from alloc:
409+
// CHECK: %[[BCAST:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
410+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
411+
// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[CST0]], %[[BCAST]]
412+
// CHECK: %[[BCAST2:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
413+
414+
// Select from emulated loaded vector and passthru vector:
415+
// TODO: fold this part if possible.
416+
// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
417+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[BCAST2]], %[[CST2]]
418+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
419+
// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BCAST2]], %[[INSERT]] : vector<8xi1>, vector<8xi2>
420+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SELECT]]
421+
// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
422+
// CHECK: return %[[EXTRACT]] : vector<5xi2>

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,3 +624,27 @@ func.func @vector_maskedstore_i4_constant_mask(
624624
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
625625
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
626626
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
627+
628+
// -----
629+
630+
func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vector<8xi4> {
631+
%0 = memref.alloc() : memref<3x8xi4>
632+
%cst = arith.constant dense<0> : vector<8xi4>
633+
%mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
634+
%c0 = arith.constant 0 : index
635+
%1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
636+
memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
637+
return %1 : vector<8xi4>
638+
}
639+
640+
// CHECK: func @vector_maskedload_i4_arith_constant(
641+
// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]: vector<8xi4>) -> vector<8xi4> {
642+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<24xi8>
643+
// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
644+
// CHECK: %[[CST:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
645+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[PASSTHRU]] : vector<8xi4> to vector<4xi8>
646+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
647+
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C0]], %[[C0]]], %[[MASK]], %[[BITCAST]] : memref<24xi8>, vector<8xi1>, vector<4xi8> into vector<4xi8>
648+
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
649+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[BITCAST2]], %[[PASSTHRU]] : vector<4xi1>, vector<8xi4>
650+
// CHECK: return %[[SELECT]] : vector<8xi4>

0 commit comments

Comments
 (0)