Skip to content

Commit a177be5

Browse files
authored
[mlir][Linalg] Bugfix in decompose generic by unfolding permutation (#126737)
The pattern was returning success() by default which made the greedy pattern application act as if the IR was modified and even though nothing was changed and thus it can prevent it from converging for no legitimate reason. The patch makes the rewrite pattern return failure() by default and success() if and only if the IR changed. An example of unexpected behavior is by running `mlir-opt input.mlir --linalg-specialize-generic-ops`, we obtain an empty mlir as output with `input.mlir` as follows: ``` #map = affine_map<(d0) -> (d0)> func.func @f(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> { %0 = tensor.empty() : tensor<8xi32> %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1: tensor<8xi32>, tensor<8xi32>) outs(%0: tensor<8xi32>) { ^bb0(%in: i32, %in_0: i32, %out: i32): %2 = arith.addi %in, %in_0: i32 linalg.yield %2: i32 } -> tensor<8xi32> return %1 : tensor<8xi32> } ```
1 parent 837b89f commit a177be5

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,21 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
223223
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
224224
}
225225

226-
if (isChanged) {
227-
SmallVector<Value> operands = op->getOperands();
228-
ValueRange operandsRef(operands);
229-
230-
auto newOp = rewriter.create<linalg::GenericOp>(
231-
/*location=*/op.getLoc(),
232-
/*resultTensorTypes=*/op->getResultTypes(),
233-
/*inputs=*/newInitValues,
234-
/*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
235-
/*indexingMaps=*/newMap,
236-
/*iteratorTypes=*/op.getIteratorTypesArray());
237-
238-
newOp.getRegion().takeBody(op->getRegion(0));
239-
rewriter.replaceOp(op, newOp->getResults());
240-
}
226+
if (!isChanged)
227+
return failure();
228+
229+
SmallVector<Value> operands = op->getOperands();
230+
ValueRange operandsRef(operands);
231+
232+
auto newOp = rewriter.create<linalg::GenericOp>(
233+
/*location=*/op.getLoc(),
234+
/*resultTensorTypes=*/op->getResultTypes(),
235+
/*inputs=*/newInitValues,
236+
/*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
237+
/*indexingMaps=*/newMap,
238+
/*iteratorTypes=*/op.getIteratorTypesArray());
239+
newOp.getRegion().takeBody(op->getRegion(0));
240+
rewriter.replaceOp(op, newOp->getResults());
241241
return success();
242242
}
243243

0 commit comments

Comments
 (0)