@@ -54,6 +54,50 @@ func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x
54
54
return %2 : tensor <128 x32 x32 x1 xf32 , #CCCD >
55
55
}
56
56
57
+ #trait_bin = {
58
+ indexing_maps = [
59
+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>,
60
+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>,
61
+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>
62
+ ],
63
+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]
64
+ }
65
+
66
+ // CHECK-FOLD-LABEL: func.func @fold_convert_multi_use(
67
+ // CHECK-FOLD: tensor.empty() : tensor<128x32x32x1xf32>
68
+ // CHECK-FOLD: linalg.generic
69
+ // CHECK-FOLD: tensor.empty() : tensor<128x32x32x1xf32, #sparse>
70
+ // CHECK-FOLD: linalg.generic
71
+ // CHECK-FOLD-NOT: sparse_tensor.convert
72
+ func.func @fold_convert_multi_use (%arg0: tensor <128 x32 x32 x1 xf32 >, %arg1: tensor <128 x32 x32 x1 xf32 >,
73
+ %arg2: tensor <128 x32 x32 x1 xf32 >, %arg3: tensor <128 x32 x32 x1 xf32 >) -> (tensor <128 x32 x32 x1 xf32 >, tensor <128 x32 x32 x1 xf32 , #CCCD >) {
74
+ %cst = arith.constant 0.000000e+00 : f32
75
+ %cst_0 = arith.constant 1.000000e+00 : f32
76
+ %cst_1 = arith.constant 1.000000e+00 : f32
77
+
78
+ %0 = tensor.empty () : tensor <128 x32 x32 x1 xf32 >
79
+ %1 = linalg.generic #trait_bin
80
+ ins (%arg0 , %arg1 : tensor <128 x32 x32 x1 xf32 >, tensor <128 x32 x32 x1 xf32 >)
81
+ outs (%0 : tensor <128 x32 x32 x1 xf32 >) {
82
+ ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
83
+ %3 = arith.mulf %in , %in_1 : f32
84
+ linalg.yield %3 : f32
85
+ } -> tensor <128 x32 x32 x1 xf32 >
86
+
87
+ // A second kernel that uses %0 as the init operand.
88
+ %3 = linalg.generic #trait_bin
89
+ ins (%arg2 , %arg3 : tensor <128 x32 x32 x1 xf32 >, tensor <128 x32 x32 x1 xf32 >)
90
+ outs (%0 : tensor <128 x32 x32 x1 xf32 >) {
91
+ ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
92
+ %3 = arith.mulf %in , %in_1 : f32
93
+ linalg.yield %3 : f32
94
+ } -> tensor <128 x32 x32 x1 xf32 >
95
+ %4 = sparse_tensor.convert %3 : tensor <128 x32 x32 x1 xf32 > to tensor <128 x32 x32 x1 xf32 , #CCCD >
96
+
97
+ return %1 , %4 : tensor <128 x32 x32 x1 xf32 >, tensor <128 x32 x32 x1 xf32 , #CCCD >
98
+ }
99
+
100
+
57
101
58
102
// FIXME: The following kernel is not sparsifiable because `arith.select`
59
103
// operations is not handled by the sparse compiler at the moment.
0 commit comments