Skip to content

Commit fc398a1

Browse files
authored
[mlir][sparse] test optimization of binary-valued operations (#90986)
Make sure consumer-producer fusion happens (to avoid the temporary dense tensor) and constant folding occurs in the generated code.
1 parent dce13b4 commit fc398a1

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// RUN: mlir-opt %s --linalg-fuse-elementwise-ops \
2+
// RUN: --sparsification-and-bufferization | FileCheck %s
3+
4+
#Sparse = #sparse_tensor.encoding<{
5+
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
6+
explicitVal = 1.0 : f32
7+
}>
8+
9+
#trait3p = {
10+
indexing_maps = [
11+
affine_map<(i,j,k) -> (i,j,k)>, // A
12+
affine_map<(i,j,k) -> (i,j,k)>, // B
13+
affine_map<(i,j,k) -> (i,j,k)> // X (out)
14+
],
15+
iterator_types = ["parallel", "parallel", "parallel"]
16+
}
17+
18+
#trait3r = {
19+
indexing_maps = [
20+
affine_map<(i,j,k) -> (i,j,k)>, // A
21+
affine_map<(i,j,k) -> ()> // X (out)
22+
],
23+
iterator_types = ["reduction", "reduction", "reduction"]
24+
}
25+
26+
//
27+
// Make sure X += A * A => X += 1 in single loop.
28+
//
29+
//
30+
// CHECK-LABEL: func.func @sum_squares(
31+
// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>,
32+
// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>,
33+
// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf32>,
34+
// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> {
35+
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
36+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
37+
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
38+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index
39+
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2 : index
40+
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
41+
// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
42+
// CHECK: linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
43+
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
44+
// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
45+
// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_7]] : index
46+
// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_14]]) -> (f32) {
47+
// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_15]] : index
48+
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_19]]] : memref<?xindex>
49+
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
50+
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_21]]] : memref<?xindex>
51+
// CHECK: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_18]]) -> (f32) {
52+
// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_4]] : f32
53+
// CHECK: scf.yield %[[VAL_26]] : f32
54+
// CHECK: } {"Emitted from" = "linalg.generic"}
55+
// CHECK: scf.yield %[[VAL_23]] : f32
56+
// CHECK: } {"Emitted from" = "linalg.generic"}
57+
// CHECK: scf.yield %[[VAL_16]] : f32
58+
// CHECK: } {"Emitted from" = "linalg.generic"}
59+
// CHECK: memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
60+
// CHECK: return %[[VAL_10]] : memref<f32>
61+
// CHECK: }
62+
//
63+
func.func @sum_squares(%a: tensor<2x3x8xf32, #Sparse>) -> tensor<f32> {
64+
%cst = arith.constant 0.000000e+00 : f32
65+
%0 = tensor.empty() : tensor<2x3x8xf32>
66+
%1 = linalg.generic #trait3p
67+
ins(%a, %a : tensor<2x3x8xf32, #Sparse>, tensor<2x3x8xf32, #Sparse>)
68+
outs(%0 : tensor<2x3x8xf32>) {
69+
^bb0(%in1: f32, %in2: f32, %out: f32):
70+
%mul = arith.mulf %in1, %in2 : f32
71+
linalg.yield %mul : f32
72+
} -> tensor<2x3x8xf32>
73+
%2 = tensor.empty() : tensor<f32>
74+
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<f32>) -> tensor<f32>
75+
%4 = linalg.generic #trait3r
76+
ins(%1 : tensor<2x3x8xf32>)
77+
outs(%3 : tensor<f32>) {
78+
^bb0(%in: f32, %out: f32):
79+
%add = arith.addf %in, %out : f32
80+
linalg.yield %add : f32
81+
} -> tensor<f32>
82+
83+
return %4 : tensor<f32>
84+
}
85+
86+
//
87+
// Make sure X += A * B => X += B in single loop.
88+
//
89+
// CHECK-LABEL: func.func @sum_products(
90+
// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>,
91+
// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>,
92+
// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf32>,
93+
// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#{{.*}}>,
94+
// CHECK-SAME: %[[VAL_4:.*4]]: memref<2x3x8xf32>) -> memref<f32> {
95+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
96+
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
97+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index
98+
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2 : index
99+
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
100+
// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
101+
// CHECK: linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
102+
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
103+
// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
104+
// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_7]] : index
105+
// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_14]]) -> (f32) {
106+
// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_15]] : index
107+
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_19]]] : memref<?xindex>
108+
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
109+
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_21]]] : memref<?xindex>
110+
// CHECK: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_18]]) -> (f32) {
111+
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_24]]] : memref<?xindex>
112+
// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_13]], %[[VAL_17]], %[[VAL_26]]] : memref<2x3x8xf32>
113+
// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_25]] : f32
114+
// CHECK: scf.yield %[[VAL_28]] : f32
115+
// CHECK: } {"Emitted from" = "linalg.generic"}
116+
// CHECK: scf.yield %[[VAL_23]] : f32
117+
// CHECK: } {"Emitted from" = "linalg.generic"}
118+
// CHECK: scf.yield %[[VAL_16]] : f32
119+
// CHECK: } {"Emitted from" = "linalg.generic"}
120+
// CHECK: memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
121+
// CHECK: return %[[VAL_10]] : memref<f32>
122+
// CHECK: }
123+
//
124+
func.func @sum_products(%a: tensor<2x3x8xf32, #Sparse>, %b: tensor<2x3x8xf32>) -> tensor<f32> {
125+
%cst = arith.constant 0.000000e+00 : f32
126+
%0 = tensor.empty() : tensor<2x3x8xf32>
127+
%1 = linalg.generic #trait3p
128+
ins(%a, %b : tensor<2x3x8xf32, #Sparse>, tensor<2x3x8xf32>)
129+
outs(%0 : tensor<2x3x8xf32>) {
130+
^bb0(%in1: f32, %in2: f32, %out: f32):
131+
%mul = arith.mulf %in1, %in2 : f32
132+
linalg.yield %mul : f32
133+
} -> tensor<2x3x8xf32>
134+
%2 = tensor.empty() : tensor<f32>
135+
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<f32>) -> tensor<f32>
136+
%4 = linalg.generic #trait3r
137+
ins(%1 : tensor<2x3x8xf32>)
138+
outs(%3 : tensor<f32>) {
139+
^bb0(%in: f32, %out: f32):
140+
%add = arith.addf %in, %out : f32
141+
linalg.yield %add : f32
142+
} -> tensor<f32>
143+
144+
return %4 : tensor<f32>
145+
}

0 commit comments

Comments
 (0)