Skip to content

Commit 6dd68c1

Browse files
committed
check for equal loop types in checkFusionStructuralLegality
1 parent f50c6aa commit 6dd68c1

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,11 @@ static bool isOpSibling(Operation *target, Operation *source,
11551155
bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface target,
11561156
LoopLikeOpInterface source,
11571157
Diagnostic &diag) {
1158+
if (target->getName() != source->getName()) {
1159+
diag << "target and source must be same loop type";
1160+
return false;
1161+
}
1162+
11581163
bool iterSpaceEq =
11591164
target.getLoopLowerBounds() == source.getLoopLowerBounds() &&
11601165
target.getLoopUpperBounds() == source.getLoopUpperBounds() &&

mlir/test/Dialect/SCF/transform-loop-fuse-sibling.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,3 +526,36 @@ module attributes {transform.with_named_sequence} {
526526
transform.yield
527527
}
528528
}
529+
530+
// -----
531+
532+
func.func @non_matching_loop_types_err(%A: memref<2xf32>, %B: memref<2xf32>) {
533+
%c2 = arith.constant 2 : index
534+
%c0 = arith.constant 0 : index
535+
%c1 = arith.constant 1 : index
536+
%c1fp = arith.constant 1.0 : f32
537+
%sum = memref.alloc() : memref<2xf32>
538+
// expected-error @below {{target and source must be same loop type}}
539+
scf.for %i = %c0 to %c2 step %c1 {
540+
%B_elem = memref.load %B[%i] : memref<2xf32>
541+
%sum_elem = arith.addf %B_elem, %c1fp : f32
542+
memref.store %sum_elem, %sum[%i] : memref<2xf32>
543+
}
544+
scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
545+
%sum_elem = memref.load %sum[%i] : memref<2xf32>
546+
%A_elem = memref.load %A[%i] : memref<2xf32>
547+
%product_elem = arith.mulf %sum_elem, %A_elem : f32
548+
memref.store %product_elem, %B[%i] : memref<2xf32>
549+
scf.reduce
550+
}
551+
memref.dealloc %sum : memref<2xf32>
552+
return
553+
}
554+
module attributes {transform.with_named_sequence} {
555+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
556+
%0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
557+
%1 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
558+
%fused = transform.loop.fuse_sibling %0 into %1 : (!transform.any_op,!transform.any_op) -> !transform.any_op
559+
transform.yield
560+
}
561+
}

0 commit comments

Comments
 (0)