Skip to content

Commit 21bd52c

Browse files
Update
1 parent 94ab287 commit 21bd52c

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,11 +631,9 @@ struct ConvertVectorMaskedLoad final
631631
*foldedIntraVectorOffset, origElements);
632632
}
633633
} else {
634-
auto resultVector = rewriter.create<arith::ConstantOp>(
635-
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
636634
result = dynamicallyExtractSubVector(
637-
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
638-
linearizedInfo.intraDataOffset, origElements);
635+
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
636+
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
639637
}
640638
rewriter.replaceOp(op, result);
641639

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
205205
// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
206206
// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
207207
// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
208+
209+
// extract passthru vector, and insert into zero vector, this is for constructing a new passthru
208210
// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
209211
// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
210212
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -215,21 +217,33 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
215217
// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
216218
// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
217219
// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
220+
221+
// bitcast the new passthru vector to emulated i8 vector
218222
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
223+
224+
// use the emulated i8 vector to masked load from the memory
219225
// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
220226
// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
227+
228+
// bitcast back to i2 vector
221229
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
230+
222231
// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
232+
233+
// create a mask vector and select passthru part from the loaded vector.
234+
// note that if indices are known then we can fold the part generating mask.
223235
// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
224236
// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
225237
// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
226238
// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
227239
// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
228240
// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
241+
229242
// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
230-
// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
243+
244+
// finally, insert the selected parts into actual passthru vector.
231245
// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
232-
// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
246+
// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[PTH]] [0] : i2 into vector<3xi2>
233247
// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
234248
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
235249
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>

0 commit comments

Comments
 (0)