Skip to content

Commit 5c5278c

Browse files
committed
last comments
1 parent 6c0e2a1 commit 5c5278c

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,8 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14171417
Value padValue) {
14181418
assert(llvm::none_of(readShape,
14191419
[](int64_t s) { return s == ShapedType::kDynamic; }));
1420+
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
1421+
assert(sourceShape.size() == readShape.size());
14201422
auto maskType = VectorType::get(readShape, builder.getI1Type());
14211423
auto vectorType = VectorType::get(readShape, padValue.getType());
14221424
int64_t readRank = readShape.size();
@@ -1428,12 +1430,7 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
14281430
/*indices=*/SmallVector<Value>(readRank, zero),
14291431
/*padding=*/padValue,
14301432
/*inBounds=*/SmallVector<bool>(readRank, true));
1431-
auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType()).getShape();
1432-
if (sourceShape.size() == readShape.size() &&
1433-
llvm::all_of(llvm::zip_equal(readShape, sourceShape), [](auto it) {
1434-
return std::get<0>(it) != ShapedType::kDynamic &&
1435-
std::get<0>(it) == std::get<1>(it);
1436-
})) {
1433+
if (llvm::equal(readShape, sourceShape)) {
14371434
return transferReadOp;
14381435
}
14391436
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1469,10 +1466,8 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
14691466
destShape.drop_front(inputVectorSizes.size()),
14701467
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
14711468
"Only dims aligned with inputVectorSizes may be dynamic");
1472-
bool needMaskForWrite = llvm::any_of(
1473-
llvm::zip_equal(inputVectorSizes,
1474-
destShape.take_front(inputVectorSizes.size())),
1475-
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
1469+
bool needMaskForWrite = !llvm::equal(
1470+
inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
14761471
if (needMaskForWrite) {
14771472
SmallVector<int64_t> writeMaskShape;
14781473
writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
@@ -1490,12 +1485,12 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
14901485
/// padding value into:
14911486
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
14921487
/// As in the following example:
1493-
/// ```mlir
1488+
///
14941489
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
14951490
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1496-
/// ```
1491+
///
14971492
/// This pack would be vectorized to:
1498-
/// ```mlir
1493+
///
14991494
/// %load = vector.mask %mask {
15001495
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
15011496
/// {in_bounds = [true, true, true]} :

0 commit comments

Comments
 (0)