Skip to content

Commit b3ce6dc

Browse files
authored
[mlir][licm] Make scf.if recursively speculatable (#122031)
This change: - makes **scf.if** recursively speculatable like **affine.if** is. - also introduces related LICM tests for both **scf.if** and **affine.if**
1 parent 29ed600 commit b3ce6dc

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def ForallOp : SCF_Op<"forall", [
302302
AttrSizedOperandSegments,
303303
AutomaticAllocationScope,
304304
DeclareOpInterfaceMethods<LoopLikeOpInterface,
305-
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
305+
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
306306
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
307307
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
308308
RecursiveMemoryEffects,
@@ -671,7 +671,7 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
671671
"getNumRegionInvocations", "getRegionInvocationBounds",
672672
"getEntrySuccessorRegions"]>,
673673
InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
674-
RecursiveMemoryEffects, NoRegionArguments]> {
674+
RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
675675
let summary = "if-then-else operation";
676676
let description = [{
677677
The `scf.if` operation represents an if-then-else construct for

mlir/test/Transforms/loop-invariant-code-motion.mlir

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,64 @@ func.func @invariant_affine_if() {
124124

125125
// -----
126126

127+
func.func @hoist_invariant_affine_if_success(%lb: index, %ub: index, %step: index) -> i32 {
128+
%cst_0 = arith.constant 0 : i32
129+
%cst_42 = arith.constant 42 : i32
130+
%sum_result = affine.for %i = %lb to %ub iter_args(%acc = %cst_0) -> i32 {
131+
%conditional_add = affine.if affine_set<() : ()> () -> (i32) {
132+
%add = arith.addi %cst_42, %cst_42 : i32
133+
affine.yield %add : i32
134+
} else {
135+
%poison = ub.poison : i32
136+
affine.yield %poison : i32
137+
}
138+
%sum = arith.addi %acc, %conditional_add : i32
139+
affine.yield %sum : i32
140+
}
141+
142+
// CHECK-LABEL: hoist_invariant_affine_if_success
143+
// CHECK-NEXT: arith.constant 0 : i32
144+
// CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
145+
// CHECK-NEXT: %[[IF:.*]] = affine.if
146+
// CHECK-NEXT: arith.addi %[[CST]], %[[CST]] : i32
147+
// CHECK: affine.for
148+
// CHECK-NOT: affine.if
149+
// CHECK-NEXT: arith.addi %{{.*}}, %[[IF]]
150+
151+
return %sum_result : i32
152+
}
153+
154+
// -----
155+
156+
func.func @hoist_variant_affine_if_failure(%lb: index, %ub: index, %step: index) -> i32 {
157+
%cst_0 = arith.constant 0 : i32
158+
%cst_42 = arith.constant 42 : i32
159+
%ind_7 = arith.constant 7 : index
160+
%sum_result = affine.for %i = %lb to %ub iter_args(%acc = %cst_0) -> i32 {
161+
%conditional_add = affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%i, %ind_7) -> (i32) {
162+
%add = arith.addi %cst_42, %cst_42 : i32
163+
affine.yield %add : i32
164+
} else {
165+
%poison = ub.poison : i32
166+
affine.yield %poison : i32
167+
}
168+
%sum = arith.addi %acc, %conditional_add : i32
169+
affine.yield %sum : i32
170+
}
171+
172+
// CHECK-LABEL: hoist_variant_affine_if_failure
173+
// CHECK-NEXT: arith.constant 0 : i32
174+
// CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
175+
// CHECK-NEXT: arith.constant 7 : index
176+
// CHECK-NEXT: affine.for
177+
// CHECK-NEXT: %[[IF:.*]] = affine.if
178+
// CHECK: arith.addi %{{.*}}, %[[IF]]
179+
180+
return %sum_result : i32
181+
}
182+
183+
// -----
184+
127185
func.func @hoist_affine_for_with_unknown_trip_count(%lb: index, %ub: index) {
128186
affine.for %arg0 = 0 to 10 {
129187
affine.for %arg1 = %lb to %ub {
@@ -383,6 +441,69 @@ func.func @parallel_loop_with_invariant() {
383441

384442
// -----
385443

444+
func.func @hoist_invariant_scf_if_success(%lb: index, %ub: index, %step: index) -> i32 {
445+
%cst_0 = arith.constant 0 : i32
446+
%cst_42 = arith.constant 42 : i32
447+
%true = arith.constant true
448+
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
449+
%conditional_add = scf.if %true -> (i32) {
450+
%add = arith.addi %cst_42, %cst_42 : i32
451+
scf.yield %add : i32
452+
} else {
453+
%poison = ub.poison : i32
454+
scf.yield %poison : i32
455+
}
456+
%sum = arith.addi %acc, %conditional_add : i32
457+
scf.yield %sum : i32
458+
}
459+
460+
// CHECK-LABEL: hoist_invariant_scf_if_success
461+
// CHECK-NEXT: arith.constant 0 : i32
462+
// CHECK-NEXT: %[[CST:.*]] = arith.constant 42 : i32
463+
// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true
464+
// CHECK-NEXT: %[[IF:.*]] = scf.if %[[TRUE]]
465+
// CHECK-NEXT: arith.addi %[[CST]], %[[CST]] : i32
466+
// CHECK: scf.for
467+
// CHECK-NOT: scf.if
468+
// CHECK-NEXT: arith.addi %{{.*}}, %[[IF]]
469+
470+
return %sum_result : i32
471+
}
472+
473+
// -----
474+
475+
func.func @hoist_variant_scf_if_failure(%lb: index, %ub: index, %step: index) -> i32 {
476+
%cst_0 = arith.constant 0 : i32
477+
%cst_42 = arith.constant 42 : i32
478+
%ind_7 = arith.constant 7 : index
479+
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
480+
%cond = arith.cmpi ult, %i, %ind_7 : index
481+
%conditional_add = scf.if %cond -> (i32) {
482+
%add = arith.addi %cst_42, %cst_42 : i32
483+
scf.yield %add : i32
484+
} else {
485+
%poison = ub.poison : i32
486+
scf.yield %poison : i32
487+
}
488+
%sum = arith.addi %acc, %conditional_add : i32
489+
scf.yield %sum : i32
490+
}
491+
492+
// CHECK-LABEL: hoist_variant_scf_if_failure
493+
// CHECK-NEXT: arith.constant 0 : i32
494+
// CHECK-NEXT: %[[CST_42:.*]] = arith.constant 42 : i32
495+
// CHECK-NEXT: %[[CST_7:.*]] = arith.constant 7 : index
496+
// CHECK-NEXT: scf.for %[[IV:.*]] = %{{.*}} to %{{.*}}
497+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[IV]], %[[CST_7]]
498+
// CHECK-NEXT: %[[IF:.*]] = scf.if %[[CMP]]
499+
// CHECK-NEXT: arith.addi %[[CST_42]], %[[CST_42]] : i32
500+
// CHECK: arith.addi %{{.*}}, %[[IF]]
501+
502+
return %sum_result : i32
503+
}
504+
505+
// -----
506+
386507
func.func private @make_val() -> (index)
387508

388509
// CHECK-LABEL: func @nested_uses_inside

0 commit comments

Comments
 (0)