@@ -2476,39 +2476,42 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2476
2476
void transform::TileUsingForOp::build (
2477
2477
OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2478
2478
Value target, ArrayRef<int64_t > staticTileSizes,
2479
- ArrayRef<int64_t > interchange,
2479
+ ArrayRef<bool > continuousTiles, ArrayRef< int64_t > interchange,
2480
2480
std::optional<ArrayRef<bool >> scalableSizes) {
2481
2481
return build (builder, result, loopTypes,
2482
2482
/* target=*/ target,
2483
2483
/* mixedTileSizes=*/
2484
2484
getAsOpFoldResult (builder.getI64ArrayAttr (staticTileSizes)),
2485
- interchange, scalableSizes);
2485
+ continuousTiles, interchange, scalableSizes);
2486
2486
}
2487
2487
2488
2488
void transform::TileUsingForOp::build (
2489
2489
OpBuilder &builder, OperationState &result, Value target,
2490
- ArrayRef<int64_t > staticTileSizes, ArrayRef<int64_t > interchange,
2490
+ ArrayRef<int64_t > staticTileSizes, ArrayRef<bool > continuousTiles,
2491
+ ArrayRef<int64_t > interchange,
2491
2492
std::optional<ArrayRef<bool >> scalableSizes) {
2492
2493
build (builder, result, target,
2493
2494
getAsOpFoldResult (builder.getI64ArrayAttr (staticTileSizes)),
2494
- interchange, scalableSizes);
2495
+ builder.getDenseBoolArrayAttr (continuousTiles), interchange,
2496
+ scalableSizes);
2495
2497
}
2496
2498
2497
2499
void transform::TileUsingForOp::build (
2498
2500
OpBuilder &builder, OperationState &result, Value target,
2499
- ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t > interchange,
2501
+ ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<bool > continuousTiles,
2502
+ ArrayRef<int64_t > interchange,
2500
2503
std::optional<ArrayRef<bool >> scalableSizes) {
2501
2504
// Loop types are automaticaly splat by the callee, setting up one is
2502
2505
// enough.
2503
2506
SmallVector<Type> loopTypes (1 , builder.getType <transform::AnyOpType>());
2504
- build (builder, result, loopTypes, target, mixedTileSizes, interchange ,
2505
- scalableSizes);
2507
+ build (builder, result, loopTypes, target, mixedTileSizes, continuousTiles ,
2508
+ interchange, scalableSizes);
2506
2509
}
2507
2510
2508
2511
void transform::TileUsingForOp::build (
2509
2512
OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2510
2513
Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2511
- ArrayRef<int64_t > interchange,
2514
+ ArrayRef<bool > continuousTiles, ArrayRef< int64_t > interchange,
2512
2515
std::optional<ArrayRef<bool >> scalableSizes) {
2513
2516
SmallVector<int64_t > staticTileSizes;
2514
2517
SmallVector<Value> dynamicTileSizes;
@@ -2517,6 +2520,7 @@ void transform::TileUsingForOp::build(
2517
2520
// attributes for multiple variadic operands. In the absence of this,
2518
2521
// horrible bugs ensue.
2519
2522
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr (staticTileSizes);
2523
+ auto continuousTilesAttr = builder.getDenseBoolArrayAttr (continuousTiles);
2520
2524
unsigned numExpectedLoops =
2521
2525
staticTileSizes.size () - llvm::count (staticTileSizes, 0 );
2522
2526
SmallVector<Type> resultTypes;
@@ -2535,6 +2539,7 @@ void transform::TileUsingForOp::build(
2535
2539
/* target=*/ target,
2536
2540
/* dynamic_sizes=*/ dynamicTileSizes,
2537
2541
/* static_sizes=*/ staticTileSizesAttr,
2542
+ /* continuous_tiles=*/ continuousTilesAttr,
2538
2543
/* interchange=*/ builder.getDenseI64ArrayAttr (interchange),
2539
2544
/* scalable_sizes=*/ expandedScalableSizes);
2540
2545
}
@@ -2675,8 +2680,15 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2675
2680
}
2676
2681
2677
2682
tilingOptions.setInterchange (getInterchange ());
2678
- FailureOr<scf::SCFTilingResult> maybeTilingResult =
2679
- tileUsingSCF (rewriter, tilingInterface, tilingOptions);
2683
+ tilingOptions.setCTileMapping (getContinuousTiles ());
2684
+
2685
+ FailureOr<scf::SCFTilingResult> maybeTilingResult;
2686
+ if (tilingOptions.continuousTileMappingVector .empty ())
2687
+ maybeTilingResult =
2688
+ tileUsingSCF (rewriter, tilingInterface, tilingOptions);
2689
+ else
2690
+ maybeTilingResult =
2691
+ continuousTileUsingSCF (rewriter, tilingInterface, tilingOptions);
2680
2692
if (failed (maybeTilingResult))
2681
2693
return DiagnosedSilenceableFailure::definiteFailure ();
2682
2694
@@ -2725,6 +2737,18 @@ ParseResult parseOptionalInterchange(OpAsmParser &parser,
2725
2737
return success ();
2726
2738
}
2727
2739
2740
+ ParseResult parseOptionalContinuousTiles (OpAsmParser &parser,
2741
+ OperationState &result) {
2742
+ if (failed (parser.parseOptionalKeyword (" continuous_tiles" )))
2743
+ return success ();
2744
+ if (failed (parser.parseEqual ()))
2745
+ return failure ();
2746
+ result.addAttribute (
2747
+ transform::TileUsingForOp::getContinuousTilesAttrName (result.name ),
2748
+ DenseBoolArrayAttr::parse (parser, Type{}));
2749
+ return success ();
2750
+ }
2751
+
2728
2752
void printOptionalInterchange (OpAsmPrinter &p,
2729
2753
ArrayRef<int64_t > interchangeVals) {
2730
2754
if (!interchangeVals.empty ()) {
@@ -2747,6 +2771,7 @@ ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
2747
2771
if (parser.parseOperand (target) || parser.getCurrentLocation (&operandLoc) ||
2748
2772
parseDynamicIndexList (parser, dynamicSizes, staticSizes, scalableVals) ||
2749
2773
parseOptionalInterchange (parser, result) ||
2774
+ parseOptionalContinuousTiles (parser, result) ||
2750
2775
parser.parseOptionalAttrDict (result.attributes ) ||
2751
2776
parser.parseColonType (functionalType))
2752
2777
return ParseResult::failure ();
0 commit comments