|
1 | 1 | # RUN: %PYTHON %s | FileCheck %s
|
2 | 2 |
|
3 |
| -from mlir.dialects import arith, builtin, func, linalg, tensor |
| 3 | +from mlir.dialects import arith, func, linalg, tensor, memref |
4 | 4 | from mlir.dialects.linalg.opdsl.lang import *
|
5 | 5 | from mlir.ir import *
|
6 | 6 |
|
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
|
84 | 84 |
|
85 | 85 | print(module)
|
86 | 86 |
|
| 87 | + |
87 | 88 | # CHECK-LABEL: TEST: testIdentityRegionOps
|
88 | 89 | @run
|
89 | 90 | def testIdentityRegionOps():
|
@@ -161,3 +162,97 @@ def broadcast_op(op1, op2, op3):
|
161 | 162 | op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
|
162 | 163 |
|
163 | 164 | print(module)
|
| 165 | + |
| 166 | + |
| 167 | +# CHECK-LABEL: TEST: testGenericOp |
| 168 | +@run |
| 169 | +def testGenericOp(): |
| 170 | + with Context(), Location.unknown(): |
| 171 | + module = Module.create() |
| 172 | + f32 = F32Type.get() |
| 173 | + memref_t = MemRefType.get([10, 10], f32) |
| 174 | + with InsertionPoint(module.body): |
| 175 | + id_map_1 = AffineMap.get_identity(2) |
| 176 | + # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32> |
| 177 | + # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32> |
| 178 | + x = tensor.empty((16, 16), f32) |
| 179 | + y = tensor.empty((16, 16), f32) |
| 180 | + |
| 181 | + # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) { |
| 182 | + # CHECK: ^bb0(%in: f32, %out: f32): |
| 183 | + # CHECK: linalg.yield %in : f32 |
| 184 | + # CHECK: } -> tensor<16x16xf32> |
| 185 | + @linalg.generic( |
| 186 | + [x], |
| 187 | + [y], |
| 188 | + [id_map_1, id_map_1], |
| 189 | + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], |
| 190 | + ) |
| 191 | + def f(a, b): |
| 192 | + assert isinstance(a, Value) |
| 193 | + assert isinstance(a.type, F32Type) |
| 194 | + assert isinstance(b, Value) |
| 195 | + assert isinstance(b.type, F32Type) |
| 196 | + return a |
| 197 | + |
| 198 | + assert isinstance(f, Value) |
| 199 | + assert isinstance(f.type, RankedTensorType) |
| 200 | + |
| 201 | + # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32> |
| 202 | + z = tensor.empty((16, 16, 16), f32) |
| 203 | + |
| 204 | + minor_id = AffineMap.get_minor_identity(3, 2) |
| 205 | + id_map_2 = AffineMap.get_identity(3) |
| 206 | + |
| 207 | + # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) { |
| 208 | + # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32): |
| 209 | + # CHECK: linalg.yield %in, %out : f32, f32 |
| 210 | + # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>) |
| 211 | + @linalg.generic( |
| 212 | + [x], |
| 213 | + [z, z], |
| 214 | + [minor_id, id_map_2, id_map_2], |
| 215 | + [ |
| 216 | + linalg.IteratorType.parallel, |
| 217 | + linalg.IteratorType.parallel, |
| 218 | + linalg.IteratorType.parallel, |
| 219 | + ], |
| 220 | + ) |
| 221 | + def g(a, b, c): |
| 222 | + assert isinstance(a, Value) |
| 223 | + assert isinstance(a.type, F32Type) |
| 224 | + assert isinstance(b, Value) |
| 225 | + assert isinstance(b.type, F32Type) |
| 226 | + assert isinstance(c, Value) |
| 227 | + assert isinstance(c.type, F32Type) |
| 228 | + return a, b |
| 229 | + |
| 230 | + assert isinstance(g, OpResultList) |
| 231 | + assert len(g) == 2 |
| 232 | + assert isinstance(g[0].type, RankedTensorType) |
| 233 | + assert isinstance(g[1].type, RankedTensorType) |
| 234 | + |
| 235 | + # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32> |
| 236 | + # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32> |
| 237 | + xx = memref.alloc(memref_t, [], []) |
| 238 | + yy = memref.alloc(memref_t, [], []) |
| 239 | + |
| 240 | + # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) { |
| 241 | + # CHECK: ^bb0(%in: f32, %out: f32): |
| 242 | + # CHECK: linalg.yield %in : f32 |
| 243 | + # CHECK: } |
| 244 | + @linalg.generic( |
| 245 | + [xx], |
| 246 | + [yy], |
| 247 | + [id_map_1, id_map_1], |
| 248 | + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], |
| 249 | + ) |
| 250 | + def f(a, b): |
| 251 | + assert isinstance(a, Value) |
| 252 | + assert isinstance(a.type, F32Type) |
| 253 | + assert isinstance(b, Value) |
| 254 | + assert isinstance(b.type, F32Type) |
| 255 | + return a |
| 256 | + |
| 257 | + module.operation.verify() |
| 258 | + print(module) |
0 commit comments