@@ -1417,6 +1417,8 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
1417
1417
Value padValue) {
1418
1418
assert (llvm::none_of (readShape,
1419
1419
[](int64_t s) { return s == ShapedType::kDynamic ; }));
1420
+ auto sourceShape = dyn_cast<ShapedType>(source.getType ()).getShape ();
1421
+ assert (sourceShape.size () == readShape.size ());
1420
1422
auto maskType = VectorType::get (readShape, builder.getI1Type ());
1421
1423
auto vectorType = VectorType::get (readShape, padValue.getType ());
1422
1424
int64_t readRank = readShape.size ();
@@ -1428,12 +1430,7 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
1428
1430
/* indices=*/ SmallVector<Value>(readRank, zero),
1429
1431
/* padding=*/ padValue,
1430
1432
/* inBounds=*/ SmallVector<bool >(readRank, true ));
1431
- auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType ()).getShape ();
1432
- if (sourceShape.size () == readShape.size () &&
1433
- llvm::all_of (llvm::zip_equal (readShape, sourceShape), [](auto it) {
1434
- return std::get<0 >(it) != ShapedType::kDynamic &&
1435
- std::get<0 >(it) == std::get<1 >(it);
1436
- })) {
1433
+ if (llvm::equal (readShape, sourceShape)) {
1437
1434
return transferReadOp;
1438
1435
}
1439
1436
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1469,10 +1466,8 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1469
1466
destShape.drop_front (inputVectorSizes.size ()),
1470
1467
[](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1471
1468
" Only dims aligned with inputVectorSizes may be dynamic" );
1472
- bool needMaskForWrite = llvm::any_of (
1473
- llvm::zip_equal (inputVectorSizes,
1474
- destShape.take_front (inputVectorSizes.size ())),
1475
- [](auto it) { return std::get<0 >(it) != std::get<1 >(it); });
1469
+ bool needMaskForWrite = !llvm::equal (
1470
+ inputVectorSizes, destShape.take_front (inputVectorSizes.size ()));
1476
1471
if (needMaskForWrite) {
1477
1472
SmallVector<int64_t > writeMaskShape;
1478
1473
writeMaskShape.append (inputVectorSizes.begin (), inputVectorSizes.end ());
@@ -1490,12 +1485,12 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1490
1485
// / padding value into:
1491
1486
// / masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1492
1487
// / As in the following example:
1493
- // / ```mlir
1488
+ // /
1494
1489
// / %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1495
1490
// / into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1496
- // / ```
1491
+ // /
1497
1492
// / This pack would be vectorized to:
1498
- // / ```mlir
1493
+ // /
1499
1494
// / %load = vector.mask %mask {
1500
1495
// / vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1501
1496
// / {in_bounds = [true, true, true]} :
0 commit comments