Skip to content

Commit e081ff9

Browse files
committed
[mlir][linalg] Adds destinationStyleOpInterface
Tagged winograd.output_transform op with destinationStyleOpInterface
1 parent f480e97 commit e081ff9

File tree

3 files changed

+21
-20
lines changed

3 files changed

+21
-20
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
313313
}
314314

315315
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
316-
[AllElementTypesMatch<["value", "output"]>,
316+
[AllElementTypesMatch<["value", "output"]>, DestinationStyleOpInterface,
317317
DeclareOpInterfaceMethods<TilingInterface,
318318
["getIterationDomain",
319319
"getLoopIteratorTypes",
@@ -396,6 +396,7 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
396396
int64_t getOutputFDim() {
397397
return 3;
398398
}
399+
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
399400
}];
400401
let hasVerifier = 1;
401402
}

mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,15 @@ module attributes {transform.with_named_sequence} {
8585
// CHECK: scf.yield %[[S9]]
8686
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
8787
// CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
88+
// CHECK: %[[S7:.*]] = tensor.empty()
8889
// CHECK: %[[S6:.*]] = linalg.batch_matmul
8990
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2]
90-
// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
91-
// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]])
91+
// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
9292
// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
9393
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
9494
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
9595
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
96-
// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
96+
// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
9797
// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
9898
// CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
9999
// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
@@ -221,16 +221,16 @@ module attributes {transform.with_named_sequence} {
221221
// CHECK: scf.yield %[[S9]]
222222
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
223223
// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]]
224+
// CHECK: %[[S7:.*]] = tensor.empty()
224225
// CHECK: %[[S6:.*]] = linalg.batch_matmul
225226
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2]
226227
// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0]
227-
// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
228-
// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]])
228+
// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[PADDED_8]])
229229
// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
230230
// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
231231
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
232232
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
233-
// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
233+
// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1]
234234
// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
235235
// CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
236236
// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]

mlir/test/Dialect/Linalg/transform-tile-winograd.mlir

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,14 @@ module attributes {transform.with_named_sequence} {
279279
// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index
280280
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
281281
// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
282-
// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
283-
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]]
282+
// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[ARG1]]) -> (tensor<2x8x8x2xf32>)
283+
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] iter_args(%[[ARG6:.*]] = %[[ARG5]]) -> (tensor<2x8x8x2xf32>)
284284
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32>
285285
// CHECK: %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
286286
// CHECK: %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
287287
// CHECK: %[[S5:.*]] = affine.apply #[[$MAP1]]()
288288
// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]]()
289-
// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
289+
// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG6]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32>
290290

291291
// -----
292292

@@ -321,10 +321,10 @@ module attributes {transform.with_named_sequence} {
321321
// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
322322
// CHECK-DAG: %[[C2_5:.*]] = arith.constant 2 : index
323323
// CHECK-DAG: %[[C2_7:.*]] = arith.constant 2 : index
324-
// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]]
325-
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]]
326-
// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]]
327-
// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]]
324+
// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]] iter_args(%[[ARG9:.*]] = %[[ARG1]]) -> (tensor<3x8x8x5xf32>)
325+
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]] iter_args(%[[ARG10:.*]] = %[[ARG9]]) -> (tensor<3x8x8x5xf32>)
326+
// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]] iter_args(%[[ARG11:.*]] = %[[ARG10]])
327+
// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]] iter_args(%[[ARG12:.*]] = %[[ARG11]])
328328
// CHECK: %[[C3_8:.*]] = arith.constant 3 : index
329329
// CHECK: %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]])
330330
// CHECK: %[[C5_9:.*]] = arith.constant 5 : index
@@ -334,7 +334,7 @@ module attributes {transform.with_named_sequence} {
334334
// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]](%[[ARG4]])
335335
// CHECK: %[[S9:.*]] = affine.apply #[[$MAP3]]()
336336
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP3]]()
337-
// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
337+
// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor<?x?x?x?xf32>
338338

339339
// -----
340340

@@ -367,14 +367,14 @@ module attributes {transform.with_named_sequence} {
367367
// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
368368
// CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index
369369
// CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index
370-
// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]]
371-
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]]
372-
// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]]
373-
// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]]
370+
// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[ARG1]]) -> (tensor<3x8x1x5xf32>)
371+
// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]] iter_args(%[[ARG10:.*]] = %[[ARG9]]) -> (tensor<3x8x1x5xf32>)
372+
// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]] iter_args(%[[ARG11:.*]] = %[[ARG10]]) -> (tensor<3x8x1x5xf32>)
373+
// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]] iter_args(%[[ARG12:.*]] = %[[ARG11]]) -> (tensor<3x8x1x5xf32>)
374374
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32>
375375
// CHECK: %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]])
376376
// CHECK: %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
377377
// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]]()
378378
// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]()
379-
// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
379+
// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG12]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32>
380380
// CHECK: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32>

0 commit comments

Comments
 (0)