@@ -604,15 +604,17 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
604
604
}
605
605
}
606
606
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
607
- tileAndFuseResult->tiledValues [0 ]};
607
+ tileAndFuseResult->tiledValues [0 ],
608
+ tileAndFuseResult->tiledOps };
608
609
}
609
610
610
611
// / Reconstruct the fused producer from within the tiled-and-fused code.
611
612
void mlir::scf::yieldReplacementForFusedProducer (
612
613
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
613
614
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
614
615
MutableArrayRef<scf::ForOp> loops) {
615
- auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
616
+ auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
617
+ fusedProducerInfo;
616
618
SmallVector<Value> initValues;
617
619
FailureOr<Value> initValue = tensor::getOrCreateDestination (
618
620
rewriter, fusableProducer.getOwner ()->getLoc (), fusableProducer);
@@ -623,8 +625,11 @@ void mlir::scf::yieldReplacementForFusedProducer(
623
625
yieldTiledValues (rewriter, initValue.value (), fusedProducerValue,
624
626
resultOffsets, resultSizes, loops);
625
627
}
626
- if (auto dstStyleProducer =
627
- fusedProducerValue.getDefiningOp <DestinationStyleOpInterface>()) {
628
+ for (auto tileAndFusedOp : tileAndFusedOps) {
629
+ auto dstStyleProducer =
630
+ dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
631
+ if (!dstStyleProducer)
632
+ continue ;
628
633
Value dstValue =
629
634
dstStyleProducer.getDpsInitOperand (fusableProducer.getResultNumber ())
630
635
->get ();
0 commit comments