@@ -70,3 +70,39 @@ func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>)
70
70
} -> tensor <10 xf32 >
71
71
return %0 , %argb : tensor <10 xf32 >, tensor <10 xf32 >
72
72
}
73
+
74
+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
75
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
76
+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
77
+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
78
+ #sparse = #sparse_tensor.encoding <{ map = (d0 , d1 ) -> (d0 : dense , d1 : compressed), posWidth = 64 , crdWidth = 64 }>
79
+
80
+ // linalg.generic with sparse tensors does not necessarily bufferize to
81
+ // element-wise access into the underlying sparse data structures.
82
+
83
+ // CHECK-LABEL: func @sparse_non_elementwise(
84
+ func.func @sparse_non_elementwise (%arg0: tensor <64 x64 xf32 , #sparse >, %arg1: tensor <64 x64 xf32 >, %arg2: tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 > {
85
+ %cst = arith.constant 0.000000e+00 : f32
86
+ // CHECK: %[[alloc0:.*]] = bufferization.alloc_tensor()
87
+ // CHECK: %[[alloc1:.*]] = bufferization.alloc_tensor()
88
+ %0 = bufferization.alloc_tensor () : tensor <64 x64 xf32 >
89
+ // CHECK: %[[generic0:.*]] = linalg.generic {{.*}} outs(%[[alloc1]] : {{.*}})
90
+ %1 = linalg.generic {index ing_maps = [#map ], iterator_types = [" parallel" , " parallel" ]} outs (%0 : tensor <64 x64 xf32 >) {
91
+ ^bb0 (%out: f32 ):
92
+ linalg.yield %cst : f32
93
+ } -> tensor <64 x64 xf32 >
94
+ // CHECK: linalg.generic {{.*}} outs(%[[generic0]] : {{.*}})
95
+ %2 = linalg.generic {index ing_maps = [#map1 , #map2 , #map3 ], iterator_types = [" parallel" , " parallel" , " reduction" ]} ins (%arg2 , %arg2 : tensor <64 x64 xf32 >, tensor <64 x64 xf32 >) outs (%1 : tensor <64 x64 xf32 >) {
96
+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
97
+ %4 = arith.mulf %in , %in_0 : f32
98
+ %5 = arith.addf %out , %4 : f32
99
+ linalg.yield %5 : f32
100
+ } -> tensor <64 x64 xf32 >
101
+ // CHECK: linalg.generic {{.*}} outs(%[[alloc0]] : {{.*}})
102
+ %3 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" ]} ins (%arg0 , %2 : tensor <64 x64 xf32 , #sparse >, tensor <64 x64 xf32 >) outs (%0 : tensor <64 x64 xf32 >) attrs = {sorted = true } {
103
+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
104
+ %4 = arith.mulf %in , %in_0 : f32
105
+ linalg.yield %4 : f32
106
+ } -> tensor <64 x64 xf32 >
107
+ return %3 : tensor <64 x64 xf32 >
108
+ }
0 commit comments