@@ -284,6 +284,19 @@ void GenerateLoopNest<TiledLoopOp>::doit(
284
284
SmallVector<Value, 4 > lbs, ubs, steps;
285
285
unpackRanges (loopRanges, lbs, ubs, steps);
286
286
287
+ auto dropNonShapedValues =
288
+ [](ArrayRef<OpOperand *> operands) -> SmallVector<Value, 2 > {
289
+ SmallVector<Value, 2 > filteredOperands;
290
+ for (OpOperand *operand : operands) {
291
+ Type type = operand->get ().getType ();
292
+ if (type.isa <ShapedType>())
293
+ filteredOperands.push_back (operand->get ());
294
+ }
295
+ return filteredOperands;
296
+ };
297
+ auto inputOperands = dropNonShapedValues (linalgOp.getInputOperands ());
298
+ auto outputOperands = dropNonShapedValues (linalgOp.getOutputOperands ());
299
+
287
300
auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
288
301
ValueRange ivs, ValueRange inputs,
289
302
ValueRange outputs) {
@@ -292,9 +305,6 @@ void GenerateLoopNest<TiledLoopOp>::doit(
292
305
bodyBuilderFn (nestedBuilder, nestedLoc, ivs, outputTensors);
293
306
nestedBuilder.create <linalg::YieldOp>(nestedLoc, results);
294
307
};
295
-
296
- SmallVector<Value> inputOperands = linalgOp.getInputOperands ();
297
- SmallVector<Value> outputOperands = linalgOp.getOutputOperands ();
298
308
auto tiledLoop =
299
309
b.create <TiledLoopOp>(loc, lbs, ubs, steps, inputOperands, outputOperands,
300
310
b.getArrayAttr (iteratorTypes), wrappedBuilderFn);
0 commit comments