Skip to content

Commit 5ad4213

Browse files
authored
[mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_for (#120118)
The API used internally expects PartialReductionOpInterface. This patch allows any operation implementing this interface to use this transform op (instead of just LinalgOp).
1 parent ccfe0de commit 5ad4213

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,8 +1765,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
17651765
let arguments = (ins TransformHandleTypeInterface:$target,
17661766
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
17671767
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
1768-
TransformHandleTypeInterface:$split_linalg_op,
1769-
TransformHandleTypeInterface:$combining_linalg_op,
1768+
TransformHandleTypeInterface:$split_op,
1769+
TransformHandleTypeInterface:$combining_op,
17701770
TransformHandleTypeInterface:$for_op);
17711771

17721772
let builders = [
@@ -1784,7 +1784,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
17841784
let extraClassDeclaration = [{
17851785
::mlir::DiagnosedSilenceableFailure applyToOne(
17861786
::mlir::transform::TransformRewriter &rewriter,
1787-
::mlir::linalg::LinalgOp target,
1787+
Operation *target,
17881788
::mlir::transform::ApplyToEachResultList &results,
17891789
::mlir::transform::TransformState &state);
17901790
}];

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2626,12 +2626,19 @@ void transform::TileReductionUsingForOp::build(
26262626
}
26272627

26282628
DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2629-
transform::TransformRewriter &rewriter, LinalgOp target,
2629+
transform::TransformRewriter &rewriter, Operation *target,
26302630
transform::ApplyToEachResultList &results,
26312631
transform::TransformState &state) {
26322632
rewriter.setInsertionPoint(target);
2633+
2634+
auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2635+
if (!partialReductionOp) {
2636+
return emitSilenceableFailure(
2637+
target->getLoc(),
2638+
"Operation should implement PartialReductionOpInterface");
2639+
}
26332640
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2634-
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2641+
rewriter, partialReductionOp,
26352642
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
26362643

26372644
if (failed(result))

0 commit comments

Comments
 (0)