@@ -122,24 +122,37 @@ LoopLikeOpInterface mlir::createFused(LoopLikeOpInterface target,
122
122
NewYieldValuesFn newYieldValuesFn,
123
123
FuseTerminatorFn fuseTerminatorFn) {
124
124
auto targetIterArgs = target.getRegionIterArgs ();
125
- auto targetInductionVar = *target.getLoopInductionVars ();
125
+ std::optional<SmallVector<Value>> targetInductionVar =
126
+ target.getLoopInductionVars ();
126
127
SmallVector<Value> targetYieldOperands (target.getYieldedValues ());
127
128
auto sourceIterArgs = source.getRegionIterArgs ();
128
- auto sourceInductionVar = *source.getLoopInductionVars ();
129
+ std::optional<SmallVector<Value>> sourceInductionVar =
130
+ *source.getLoopInductionVars ();
129
131
SmallVector<Value> sourceYieldOperands (source.getYieldedValues ());
130
132
auto sourceRegion = source.getLoopRegions ().front ();
131
- LoopLikeOpInterface fusedLoop = *target.replaceWithAdditionalYields (
132
- rewriter, source.getInits (), /* replaceInitOperandUsesInLoop=*/ false ,
133
- newYieldValuesFn);
133
+
134
+ FailureOr<LoopLikeOpInterface> maybeFusedLoop =
135
+ target.replaceWithAdditionalYields (rewriter, source.getInits (),
136
+ /* replaceInitOperandUsesInLoop=*/ false ,
137
+ newYieldValuesFn);
138
+ if (failed (maybeFusedLoop))
139
+ llvm_unreachable (" failed to replace loop" );
140
+ LoopLikeOpInterface fusedLoop = *maybeFusedLoop;
134
141
135
142
// Map control operands.
136
143
IRMapping mapping;
137
- mapping.map (targetInductionVar, *fusedLoop.getLoopInductionVars ());
144
+ std::optional<SmallVector<Value>> fusedInductionVar =
145
+ fusedLoop.getLoopInductionVars ();
146
+ if (fusedInductionVar) {
147
+ if (!targetInductionVar || !sourceInductionVar)
148
+ llvm_unreachable (" expected target and source loops to have induction vars" );
149
+ mapping.map (*targetInductionVar, *fusedInductionVar);
150
+ mapping.map (*sourceInductionVar, *fusedInductionVar);
151
+ }
138
152
mapping.map (targetIterArgs,
139
153
fusedLoop.getRegionIterArgs ().take_front (targetIterArgs.size ()));
140
154
mapping.map (targetYieldOperands,
141
155
fusedLoop.getYieldedValues ().take_front (targetIterArgs.size ()));
142
- mapping.map (sourceInductionVar, *fusedLoop.getLoopInductionVars ());
143
156
mapping.map (sourceIterArgs,
144
157
fusedLoop.getRegionIterArgs ().take_back (sourceIterArgs.size ()));
145
158
mapping.map (sourceYieldOperands,
0 commit comments