@@ -157,3 +157,82 @@ def pass_an_op_directly(arg0, arg1):
157
157
return linalg .matmul (lhs , rhs , outs = init )
158
158
159
159
print (module )
160
+
161
+
162
+ # CHECK-LABEL: TEST: testIdentityRegionOps
163
+ @run
164
+ def testIdentityRegionOps ():
165
+ with Context (), Location .unknown ():
166
+ module = Module .create ()
167
+ f32 = F32Type .get ()
168
+ with InsertionPoint (module .body ):
169
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
170
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
171
+ op1 = tensor .EmptyOp ([1 , 13 ], f32 )
172
+ op2 = tensor .EmptyOp ([13 , 1 ], f32 )
173
+ # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
174
+ op3 = linalg .TransposeOp (
175
+ result = [RankedTensorType .get ((13 , 1 ), f32 )],
176
+ input = op1 ,
177
+ init = op2 ,
178
+ permutation = [1 , 0 ],
179
+ )
180
+ linalg .fill_builtin_region (op3 .operation )
181
+
182
+ # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
183
+ op4 = linalg .transpose (op2 , outs = [op1 ], permutation = [1 , 0 ])
184
+
185
+ # CHECK: func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
186
+ @func .FuncOp .from_py_func (
187
+ MemRefType .get ((1 , 13 ), f32 ),
188
+ MemRefType .get ((13 , 1 ), f32 ),
189
+ )
190
+ def transpose_op (op1 , op2 ):
191
+ # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
192
+ op3 = linalg .TransposeOp (
193
+ result = [],
194
+ input = op1 ,
195
+ init = op2 ,
196
+ permutation = [1 , 0 ],
197
+ )
198
+ linalg .fill_builtin_region (op3 .operation )
199
+ # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
200
+ op4 = linalg .transpose (op2 , outs = [op1 ], permutation = [1 , 0 ])
201
+
202
+ # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
203
+ # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
204
+ op1 = tensor .EmptyOp ([16 ], f32 )
205
+ op2 = tensor .EmptyOp ([16 , 64 ], f32 )
206
+ # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
207
+ op3 = linalg .BroadcastOp (
208
+ result = [RankedTensorType .get ((16 , 64 ), f32 )],
209
+ input = op1 ,
210
+ init = op2 ,
211
+ dimensions = [1 ],
212
+ )
213
+ linalg .fill_builtin_region (op3 .operation )
214
+
215
+ # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
216
+ op4 = tensor .EmptyOp ([64 ], f32 )
217
+ # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
218
+ op5 = linalg .broadcast (op4 , outs = [op2 ], dimensions = [0 ])
219
+
220
+ # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
221
+ @func .FuncOp .from_py_func (
222
+ MemRefType .get ((16 ,), f32 ),
223
+ MemRefType .get ((16 , 64 ), f32 ),
224
+ MemRefType .get ((64 ,), f32 ),
225
+ )
226
+ def broadcast_op (op1 , op2 , op3 ):
227
+ # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
228
+ op4 = linalg .BroadcastOp (
229
+ result = [],
230
+ input = op1 ,
231
+ init = op2 ,
232
+ dimensions = [1 ],
233
+ )
234
+ linalg .fill_builtin_region (op4 .operation )
235
+ # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
236
+ op5 = linalg .broadcast (op3 , outs = [op2 ], dimensions = [0 ])
237
+
238
+ print (module )
0 commit comments