Skip to content

Commit eb537ff

Browse files
committed
[MLIR][Linalg] Introduce SpecializeOp
Introduce an operation to specialize linalg.generics, for example, detecting a linalg.generic that is semantically equivalent to a linalg.copy and replacing the former with the latter. After code generation, it is helpful to lower named operations to vendor-optimized libraries.
1 parent 4fed3d3 commit eb537ff

File tree

8 files changed

+267
-0
lines changed

8 files changed

+267
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
110110
// TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
111111
bool isaConvolutionOpInterface(LinalgOp linalgOp);
112112

113+
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
114+
bool isaCopyOpInterface(LinalgOp linalgOp);
115+
113116
namespace detail {
114117

115118
/// Returns true if the block contains a contraction of the following form:

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,43 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
390390
}];
391391
}
392392

393+
//===----------------------------------------------------------------------===//
394+
// SpecializeOp
395+
//===----------------------------------------------------------------------===//
396+
397+
def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
398+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
399+
TransformOpInterface, TransformEachOpTrait,
400+
ReportTrackingListenerFailuresOpTrait]> {
401+
let description = [{
402+
Transforms a generic operation into the equivalent named form.
403+
404+
#### Return modes
405+
406+
This operation ignores non-Linalg ops and drops them in the return. If all
407+
the operations referred to by the `target` handle specialize, the transform
408+
succeeds; otherwise, the operation produces a silenceable failure. The return
409+
handle points to only the subset of successfully produced equivalent named
410+
operations, which can be empty or contain the original ops if they were already
411+
in named form. The supported specialization to named Linalg operations are:
412+
- linalg.copy of any rank.
413+
}];
414+
415+
let arguments = (ins TransformHandleTypeInterface:$target);
416+
let results = (outs TransformHandleTypeInterface:$transformed);
417+
let assemblyFormat =
418+
"$target attr-dict `:` "
419+
"custom<SemiFunctionType>(type($target), type($transformed))";
420+
421+
let extraClassDeclaration = [{
422+
::mlir::DiagnosedSilenceableFailure applyToOne(
423+
::mlir::transform::TransformRewriter &rewriter,
424+
::mlir::linalg::LinalgOp target,
425+
::mlir::transform::ApplyToEachResultList &results,
426+
::mlir::transform::TransformState &state);
427+
}];
428+
}
429+
393430
//===----------------------------------------------------------------------===//
394431
// InterchangeOp
395432
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,11 @@ FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
668668
FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter,
669669
LinalgOp namedOp);
670670

671+
/// Create a namedOp from the given GenericOp and replace the GenericOp.
672+
/// Currently we can specialize only trivial linalg copy operations.
673+
FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter,
674+
GenericOp genericOp);
675+
671676
/// Create a new buffer using the `allocationFn` provided. The size of this
672677
/// buffer is the smallest constant bounding size along each dimension that
673678
/// can be computed for the size of the result of `subView`. Returns the

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ using namespace mlir::linalg;
3232
//===----------------------------------------------------------------------===//
3333
// Interface utility functions
3434
//===----------------------------------------------------------------------===//
35+
3536
bool linalg::detail::canOpOperandsBeDroppedImpl(
3637
linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
3738
SmallVector<AffineMap> indexingMaps;
@@ -48,6 +49,27 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
4849
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
4950
}
5051

52+
//===----------------------------------------------------------------------===//
53+
// CopyOpInterface implementation
54+
//===----------------------------------------------------------------------===//
55+
56+
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
57+
// Structural.
58+
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
59+
return false;
60+
61+
// Operands and maps.
62+
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
63+
return false;
64+
auto mapRange = linalgOp.getIndexingMapsArray();
65+
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66+
!mapRange.back().isIdentity()) {
67+
return false;
68+
}
69+
// Region.
70+
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
71+
}
72+
5173
//===----------------------------------------------------------------------===//
5274
// ContractionOpInterface implementation
5375
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,30 @@ transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
10181018
return emitDefaultSilenceableFailure(target);
10191019
}
10201020

1021+
//===----------------------------------------------------------------------===//
1022+
// SpecializeOp
1023+
//===----------------------------------------------------------------------===/
1024+
1025+
DiagnosedSilenceableFailure
1026+
transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1027+
LinalgOp target,
1028+
transform::ApplyToEachResultList &results,
1029+
transform::TransformState &state) {
1030+
// Exit early if the operation is not a generic.
1031+
if (!isa<GenericOp>(target)) {
1032+
results.push_back(target);
1033+
return DiagnosedSilenceableFailure::success();
1034+
}
1035+
rewriter.setInsertionPoint(target);
1036+
FailureOr<LinalgOp> named =
1037+
specializeGenericOp(rewriter, cast<GenericOp>(target));
1038+
if (succeeded(named)) {
1039+
results.push_back(named->getOperation());
1040+
return DiagnosedSilenceableFailure::success();
1041+
}
1042+
return emitDefaultSilenceableFailure(target);
1043+
}
1044+
10211045
//===----------------------------------------------------------------------===//
10221046
// InterchangeOp
10231047
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2424
NamedOpConversions.cpp
2525
Padding.cpp
2626
Promotion.cpp
27+
Specialize.cpp
2728
Split.cpp
2829
SplitReduction.cpp
2930
SubsetHoisting.cpp
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a method to specialize generic operations to named
10+
// operations. Conceptually it is the opposite of generalize.cpp.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
16+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17+
#include "llvm/Support/Debug.h"
18+
19+
#define DEBUG_TYPE "linalg-specialization"
20+
21+
using namespace mlir;
22+
using namespace mlir::linalg;
23+
24+
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
25+
GenericOp genericOp) {
26+
if (isaCopyOpInterface(genericOp)) {
27+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
28+
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
29+
return namedOp;
30+
}
31+
return failure();
32+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
2+
3+
#map = affine_map<(d0, d1) -> (d0, d1)>
4+
#map1 = affine_map<(d0, d1) -> (d0)>
5+
#map2 = affine_map<(d0, d1) -> (d1, d0)>
6+
7+
func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
8+
// expected-note @below {{when applied to this op}}
9+
linalg.generic {
10+
indexing_maps = [#map1, #map],
11+
iterator_types = ["parallel", "parallel"]}
12+
ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
13+
^bb0(%in: f32, %out: f32):
14+
linalg.yield %in : f32
15+
}
16+
return
17+
}
18+
19+
func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
20+
// expected-note @below {{when applied to this op}}
21+
linalg.generic {
22+
indexing_maps = [#map, #map],
23+
iterator_types = ["parallel", "parallel"]}
24+
ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
25+
^bb0(%in: f32, %out: f32):
26+
%0 = arith.addf %in, %out : f32
27+
linalg.yield %0 : f32
28+
}
29+
return
30+
}
31+
32+
func.func @transpose_op_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
33+
// expected-note @below {{when applied to this op}}
34+
linalg.generic {
35+
indexing_maps = [#map, #map2],
36+
iterator_types = ["parallel", "parallel"]}
37+
ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
38+
^bb0(%in: f32, %out: f32):
39+
linalg.yield %in : f32
40+
}
41+
return
42+
}
43+
44+
func.func @copy_with_up_cast(%arg0: memref<?x?xf16>, %arg1: memref<?x?xf32>) {
45+
// expected-note @below {{when applied to this op}}
46+
linalg.generic {
47+
indexing_maps = [#map, #map],
48+
iterator_types = ["parallel", "parallel"]}
49+
ins(%arg0 : memref<?x?xf16>) outs(%arg1 : memref<?x?xf32>) {
50+
^bb0(%in: f16, %out: f32):
51+
%0 = arith.extf %in : f16 to f32
52+
linalg.yield %0 : f32
53+
}
54+
return
55+
}
56+
57+
func.func @copy_with_down_cast(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf16>) {
58+
// expected-note @below {{when applied to this op}}
59+
linalg.generic {
60+
indexing_maps = [#map, #map],
61+
iterator_types = ["parallel", "parallel"]}
62+
ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf16>) {
63+
^bb0(%in: f32, %out: f16):
64+
%0 = arith.truncf %in : f32 to f16
65+
linalg.yield %0 : f16
66+
}
67+
return
68+
}
69+
70+
module attributes {transform.with_named_sequence} {
71+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
72+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
73+
// expected-error @below {{failed to apply}}
74+
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
75+
transform.yield
76+
}
77+
}
78+
79+
// -----
80+
81+
#map = affine_map<(d0, d1) -> (d0, d1)>
82+
83+
func.func @specialize_trivial_copy_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
84+
linalg.generic {
85+
indexing_maps = [#map, #map],
86+
iterator_types = ["parallel", "parallel"]}
87+
ins(%arg0 : memref<?x?xf32>) outs(%arg1 : memref<?x?xf32>) {
88+
^bb0(%in: f32, %out: f32):
89+
linalg.yield %in : f32
90+
}
91+
return
92+
}
93+
94+
// CHECK-LABEL: specialize_trivial_copy_memref
95+
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>, %[[ARG1:.+]]: memref<?x?xf32>
96+
// CHECK-NOT: linalg.generic
97+
// CHECK: linalg.copy ins(%[[ARG0]] : memref<?x?xf32>) outs(%[[ARG1]] : memref<?x?xf32>)
98+
99+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
100+
101+
func.func @specialize_trivial_copy_tensor(%arg0: tensor<?x?x?xf32>,
102+
%arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
103+
%0 = linalg.generic {
104+
indexing_maps = [#map1, #map1],
105+
iterator_types = ["parallel", "parallel", "parallel"]}
106+
ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
107+
^bb0(%in: f32, %out: f32):
108+
linalg.yield %in : f32
109+
} -> tensor<?x?x?xf32>
110+
return %0 : tensor<?x?x?xf32>
111+
}
112+
113+
// CHECK-LABEL: specialize_trivial_copy_tensor
114+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>
115+
// CHECK-NOT: linalg.generic
116+
// CHECK: %{{.+}} = linalg.copy ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>)
117+
118+
func.func @already_trivial_copy_memref(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
119+
linalg.copy ins(%arg0: memref<?x?xf32>) outs(%arg1: memref<?x?xf32>)
120+
return
121+
}
122+
123+
// CHECK-LABEL: already_trivial_copy_memref
124+
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xf32>, %[[ARG1:.+]]: memref<?x?xf32>
125+
// CHECK: linalg.copy ins(%[[ARG0]] : memref<?x?xf32>) outs(%[[ARG1]] : memref<?x?xf32>)
126+
127+
func.func @already_trivial_copy_tensor(%arg0: tensor<?x?x?xf32>,
128+
%arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
129+
%0 = linalg.copy ins(%arg0: tensor<?x?x?xf32>) outs(%arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
130+
return %0 : tensor<?x?x?xf32>
131+
}
132+
133+
// CHECK-LABEL: already_trivial_copy_tensor
134+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>
135+
// CHECK: %{{.+}} = linalg.copy ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>)
136+
137+
module attributes {transform.with_named_sequence} {
138+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
139+
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
140+
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
141+
transform.yield
142+
}
143+
}

0 commit comments

Comments
 (0)