Skip to content

Commit 3324f4d

Browse files
authored
[mlir][sparse] avoid incompatible linalg fuse-into-consumer (#86752)
This fixes an "infinite" loop bug, where the incoming IR was repeatedly rewritten while adding identical cast operations. The test for compatible types should include the notion of an encoding. If it differs, then a naive fusion into the consumer is invalid.
1 parent c43932e commit 3324f4d

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,10 @@ bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
276276
if (sourceType.getRank() != targetType.getRank())
277277
return false;
278278

279+
// Requires same encoding.
280+
if (sourceType.getEncoding() != targetType.getEncoding())
281+
return false;
282+
279283
// If cast is towards more static sizes along any dimension, don't fold.
280284
for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
281285
if (!ShapedType::isDynamic(std::get<0>(t)) &&
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: mlir-opt %s --canonicalize --pre-sparsification-rewrite | FileCheck %s
2+
3+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4+
5+
#sparse = #sparse_tensor.encoding<{
6+
map = (d0, d1, d2) ->
7+
(d0 : compressed(nonunique),
8+
d1 : singleton(nonunique, soa),
9+
d2 : singleton(soa)),
10+
posWidth = 64,
11+
crdWidth = 64
12+
}>
13+
14+
15+
module {
16+
//
17+
// This IR should not end up in an infinite loop trying to fold
18+
// the linalg producer into the tensor cast consumer (even though
19+
// static sizes can fold, the different encodings cannot). The
20+
// cast was sloppy to begin with (but it has been observed by
21+
// external sources) and can be easily repaired by the sparsifier.
22+
//
23+
// CHECK-LABEL: func @avoid_fold
24+
// CHECK: arith.constant
25+
// CHECK: tensor.empty()
26+
// CHECK: linalg.generic
27+
// CHECK: sparse_tensor.convert
28+
// CHECK: return
29+
//
30+
func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
31+
%1 = tensor.empty() : tensor<10x20x30xf64>
32+
%2 = linalg.generic { indexing_maps = [#map, #map],
33+
iterator_types = ["parallel", "parallel", "parallel"]
34+
}
35+
ins (%0 : tensor<10x20x30xf64, #sparse>)
36+
outs(%1 : tensor<10x20x30xf64>) {
37+
^bb0(%in: f64, %out: f64):
38+
%cst = arith.constant 0.000000e+00 : f64
39+
%4 = arith.cmpf ugt, %in, %cst : f64
40+
%5 = arith.select %4, %in, %cst : f64
41+
linalg.yield %5 : f64
42+
} -> tensor<10x20x30xf64>
43+
%cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
44+
return %cast : tensor<10x20x30xf64, #sparse>
45+
}
46+
}
47+

0 commit comments

Comments
 (0)