Skip to content

Commit 25bf6a2

Browse files
committed
Revert "[mlir] Purge linalg.copy and use memref.copy instead."
This reverts commit 016956b. Reverting it to fix NVidia build without being in a hurry.
1 parent f85c6b7 commit 25bf6a2

33 files changed

+741
-188
lines changed

mlir/docs/Dialects/Linalg/_index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ seem generally appealing.
545545
Additionally, `linalg` provides a small subset of commonly named operations:
546546

547547
```
548+
* `linalg.copy`,
548549
* `linalg.fill`,
549550
* `linalg.dot`,
550551
* `linalg.matmul`,

mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,25 @@ class LinalgOpToLibraryCallRewrite
3939
PatternRewriter &rewriter) const override;
4040
};
4141

42+
/// Rewrite pattern specialization for CopyOp, kicks in when both input and
43+
/// output permutations are left unspecified or are the identity.
44+
class CopyOpToLibraryCallRewrite : public OpRewritePattern<CopyOp> {
45+
public:
46+
using OpRewritePattern<CopyOp>::OpRewritePattern;
47+
LogicalResult matchAndRewrite(CopyOp op,
48+
PatternRewriter &rewriter) const override;
49+
};
50+
51+
/// Rewrite CopyOp with permutations into a sequence of TransposeOp and
52+
/// permutation-free CopyOp. This interplays with TransposeOpConversion and
53+
/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
54+
class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
55+
public:
56+
using OpRewritePattern<CopyOp>::OpRewritePattern;
57+
LogicalResult matchAndRewrite(CopyOp op,
58+
PatternRewriter &rewriter) const override;
59+
};
60+
4261
/// Populate the given list with patterns that convert from Linalg to Standard.
4362
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
4463

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
4242
}: memref<2xf32>, memref<2xf32>
4343
br ^bb3(%0 : memref<2xf32>)
4444
^bb3(%1: memref<2xf32>):
45-
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
45+
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
4646
return
4747
}
4848
}
@@ -58,7 +58,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
5858
cond_br %arg0, ^bb1, ^bb2
5959
^bb1: // pred: ^bb0
6060
%0 = memref.alloc() : memref<2xf32>
61-
memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
61+
linalg.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
6262
br ^bb3(%0 : memref<2xf32>)
6363
^bb2: // pred: ^bb0
6464
%1 = memref.alloc() : memref<2xf32>
@@ -72,11 +72,11 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
7272
linalg.yield %4 : f32
7373
}: memref<2xf32>, memref<2xf32>
7474
%2 = memref.alloc() : memref<2xf32>
75-
memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
75+
linalg.copy(%1, %2) : memref<2xf32>, memref<2xf32>
7676
dealloc %1 : memref<2xf32>
7777
br ^bb3(%2 : memref<2xf32>)
7878
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
79-
memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
79+
linalg.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
8080
dealloc %3 : memref<2xf32>
8181
return
8282
}

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
1818
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
19+
include "mlir/Interfaces/CopyOpInterface.td"
1920
include "mlir/Interfaces/InferTypeOpInterface.td"
2021
include "mlir/Interfaces/SideEffectInterfaces.td"
2122

@@ -56,6 +57,119 @@ class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
5657
//===----------------------------------------------------------------------===//
5758
// Named Linalg ops, implemented as special configurations of generic ops.
5859
//===----------------------------------------------------------------------===//
60+
// At the moment these are not declarative and require a bunch of C++ code.
61+
// In the future, these should be migrated to a declarative specification.
62+
def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
63+
let description = [{
64+
Copies the data in the input view into the output view.
65+
66+
Usage:
67+
68+
```mlir
69+
linalg.copy(%arg0, %arg1) : memref<?xf32, stride_specification>,
70+
memref<?xf32, stride_specification>
71+
```
72+
73+
One possible lowering to loop form is:
74+
75+
```mlir
76+
%0 = linalg.dim %arg0, 0 : index
77+
scf.for %i0 = %c0 to %0 step %c1 {
78+
%1 = load %arg0[%i0] : memref<?xf32, stride_specification>
79+
store %1, %arg1[%i0] : memref<?xf32, stride_specification>
80+
}
81+
```
82+
83+
Optionally, can take `input_permutation` and `output_permutation` attributes
84+
to reorder the dimensions of the input and output views.
85+
86+
Usage:
87+
88+
```mlir
89+
linalg.copy(%arg0, %arg1) {inputPermutation : (i, j, k) -> (i, k, j),
90+
outputPermutation : (i, j, k) -> (k, j, i)} :
91+
memref<?x?x?xf32, stride_specification>,
92+
memref<?x?x?xf32, stride_specification>
93+
```
94+
95+
One possible lowering to loop form is:
96+
97+
```mlir
98+
%0 = linalg.dim %arg0, 0
99+
%1 = linalg.dim %arg0, 1
100+
%2 = linalg.dim %arg0, 2
101+
scf.for %i0 = %c0 to %{{.*}} step %c1 {
102+
scf.for %i1 = %c0 to %{{.*}} step %c1 {
103+
scf.for %i2 = %c0 to %{{.*}} step %c1 {
104+
%3 = load %arg0[%i0, %i2, %i1] :
105+
memref<?x?x?xf32, stride_specification>
106+
store %3, %arg1[%i2, %i1, %i0] :
107+
memref<?x?x?xf32, stride_specification>
108+
```
109+
110+
The views are expected to be compatible for correctness but this is not
111+
enforced at the moment.
112+
}];
113+
114+
let arguments = (ins
115+
AnyStridedMemRef:$input,
116+
AnyStridedMemRef:$output,
117+
OptionalAttr<AffineMapAttr>:$inputPermutation,
118+
OptionalAttr<AffineMapAttr>:$outputPermutation);
119+
let regions = (region AnyRegion:$region);
120+
121+
let builders = [
122+
OpBuilder<(ins "Value":$input, "Value":$output,
123+
CArg<"AffineMap", "AffineMap()">:$inputPermutation,
124+
CArg<"AffineMap", "AffineMap()">:$outputPermutation,
125+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
126+
127+
let extraClassDeclaration = structuredOpsDecls # [{
128+
ValueRange inputs() { return getOperands().take_front(); }
129+
ValueRange outputs() { return getOperands().take_back(); }
130+
131+
// Rank-polymorphic.
132+
// filling_value -> O(ivs) with parallel iterators.
133+
ArrayAttr iterator_types() {
134+
int64_t nPar = getRank(getInputOperand(0));
135+
return Builder(getContext()).getStrArrayAttr(
136+
SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
137+
}
138+
139+
// I(input_perm(ivs)) -> O(output_perm(ivs))
140+
ArrayAttr indexing_maps() {
141+
MLIRContext *context = getContext();
142+
auto maybeInputMap = inputPermutation();
143+
auto maybeOutputMap = outputPermutation();
144+
int64_t inputRank = getRank(getInputOperand(0));
145+
int64_t outputRank = getRank(getOutputOperand(0));
146+
return Builder(getContext()).getAffineMapArrayAttr({
147+
extractOrIdentityMap(maybeInputMap, inputRank, context),
148+
extractOrIdentityMap(maybeOutputMap, outputRank, context)});
149+
}
150+
151+
Value getSource() { return input();}
152+
Value getTarget() { return output(); }
153+
154+
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
155+
static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
156+
getRegionBuilder() {
157+
return &regionBuilder;
158+
}
159+
static unsigned getNumRegionArgs() { return 2; }
160+
}];
161+
let verifier = [{ return ::verify(*this); }];
162+
163+
let assemblyFormat = [{
164+
`(` $input `,` $output `)` attr-dict `:`
165+
type($input) `,` type($output)
166+
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
167+
}];
168+
169+
let hasCanonicalizer = 1;
170+
let hasFolder = 1;
171+
let skipDefaultBuilders = 1;
172+
}
59173

60174
def FillOp : LinalgStructured_Op<"fill", []> {
61175
let arguments = (ins

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def LinalgComprehensiveModuleBufferize :
5252
Option<"useAlloca", "use-alloca", "bool",
5353
/*default=*/"false",
5454
"Use stack allocations for memrefs (for testing purposes only)">,
55-
Option<"useLinalgCopy", "use-memref.copy", "bool",
55+
Option<"useLinalgCopy", "use-linalg-copy", "bool",
5656
/*default=*/"false",
5757
"Use a copy operation implemented as a Linalg op.">,
5858
Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ struct LinalgPromotionOptions {
349349
return *this;
350350
}
351351
/// Callback function to do the copy of data to and from the promoted
352-
/// subview. If None then a memref.copy is used.
352+
/// subview. If None then a linalg.copy is used.
353353
Optional<CopyCallbackFn> copyInFn = None;
354354
Optional<CopyCallbackFn> copyOutFn = None;
355355
LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const &copyIn,
@@ -390,9 +390,6 @@ FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
390390
/// Emit a suitable vector form for a Linalg op with fully static shape.
391391
LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
392392

393-
/// Emit a suitable vector form for a Copy op with fully static shape.
394-
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
395-
396393
/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
397394
FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
398395
LinalgOp linalgOp);
@@ -937,15 +934,6 @@ struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
937934
LinalgTransformationFilter filter;
938935
};
939936

940-
/// `filter` controls LinalgTransformMarker matching and update when specified.
941-
/// See `vectorizeLinalgOp` for more details.
942-
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
943-
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
944-
945-
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
946-
PatternRewriter &rewriter) const override;
947-
};
948-
949937
/// Return vector::CombiningKind for the given op.
950938
llvm::Optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
951939

@@ -1218,7 +1206,7 @@ void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
12181206
/// %subView = subview %allocOrView ...
12191207
/// [optional] linalg.fill(%allocOrView, %cst) ...
12201208
/// ...
1221-
/// memref.copy(%in, %subView) ...
1209+
/// linalg.copy(%in, %subView) ...
12221210
/// vector.transfer_read %allocOrView[...], %cst ...
12231211
/// ```
12241212
/// into
@@ -1229,8 +1217,8 @@ void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
12291217
/// ...
12301218
/// vector.transfer_read %in[...], %cst ...
12311219
/// ```
1232-
/// Where there is no interleaved use between memref.copy and transfer_read as
1233-
/// well as no interleaved use between linalg.fill and memref.copy (if
1220+
/// Where there is no interleaved use between linalg.copy and transfer_read as
1221+
/// well as no interleaved use between linalg.fill and linalg.copy (if
12341222
/// linalg.fill is specified).
12351223
/// This is a custom rewrite to forward partial reads (with optional fills) to
12361224
/// vector.transfer_read.
@@ -1249,7 +1237,7 @@ struct LinalgCopyVTRForwardingPattern
12491237
/// %subView = subview %allocOrView...
12501238
/// ...
12511239
/// vector.transfer_write %..., %allocOrView[...]
1252-
/// memref.copy(%subView, %out)
1240+
/// linalg.copy(%subView, %out)
12531241
/// ```
12541242
/// into
12551243
/// ```
@@ -1259,7 +1247,7 @@ struct LinalgCopyVTRForwardingPattern
12591247
/// ...
12601248
/// vector.transfer_write %..., %out[...]
12611249
/// ```
1262-
/// Where there is no interleaved use between transfer_write and memref.copy.
1250+
/// Where there is no interleaved use between transfer_write and linalg.copy.
12631251
/// This is a custom rewrite to forward partial writes to vector.transfer_write.
12641252
struct LinalgCopyVTWForwardingPattern
12651253
: public OpRewritePattern<vector::TransferWriteOp> {

mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
9696

9797
LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
9898
LinalgOp op, PatternRewriter &rewriter) const {
99+
// Only LinalgOp for which there is no specialized pattern go through this.
100+
if (isa<CopyOp>(op))
101+
return failure();
102+
99103
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
100104
if (!libraryCallName)
101105
return failure();
@@ -109,12 +113,65 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
109113
return success();
110114
}
111115

116+
LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
117+
CopyOp op, PatternRewriter &rewriter) const {
118+
auto inputPerm = op.inputPermutation();
119+
if (inputPerm.hasValue() && !inputPerm->isIdentity())
120+
return failure();
121+
auto outputPerm = op.outputPermutation();
122+
if (outputPerm.hasValue() && !outputPerm->isIdentity())
123+
return failure();
124+
125+
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
126+
if (!libraryCallName)
127+
return failure();
128+
129+
rewriter.replaceOpWithNewOp<mlir::CallOp>(
130+
op, libraryCallName.getValue(), TypeRange(),
131+
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
132+
op.getOperands()));
133+
return success();
134+
}
135+
136+
LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
137+
CopyOp op, PatternRewriter &rewriter) const {
138+
Value in = op.input(), out = op.output();
139+
140+
// If either inputPerm or outputPerm are non-identities, insert transposes.
141+
auto inputPerm = op.inputPermutation();
142+
if (inputPerm.hasValue() && !inputPerm->isIdentity())
143+
in = rewriter.create<memref::TransposeOp>(op.getLoc(), in,
144+
AffineMapAttr::get(*inputPerm));
145+
auto outputPerm = op.outputPermutation();
146+
if (outputPerm.hasValue() && !outputPerm->isIdentity())
147+
out = rewriter.create<memref::TransposeOp>(op.getLoc(), out,
148+
AffineMapAttr::get(*outputPerm));
149+
150+
// If nothing was transposed, fail and let the conversion kick in.
151+
if (in == op.input() && out == op.output())
152+
return failure();
153+
154+
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
155+
if (!libraryCallName)
156+
return failure();
157+
158+
rewriter.replaceOpWithNewOp<mlir::CallOp>(
159+
op, libraryCallName.getValue(), TypeRange(),
160+
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
161+
return success();
162+
}
163+
112164
/// Populate the given list with patterns that convert from Linalg to Standard.
113165
void mlir::linalg::populateLinalgToStandardConversionPatterns(
114166
RewritePatternSet &patterns) {
115167
// TODO: ConvOp conversion needs to export a descriptor with relevant
116168
// attribute values such as kernel striding and dilation.
117-
patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
169+
// clang-format off
170+
patterns.add<
171+
CopyOpToLibraryCallRewrite,
172+
CopyTransposeRewrite,
173+
LinalgOpToLibraryCallRewrite>(patterns.getContext());
174+
// clang-format on
118175
}
119176

120177
namespace {

0 commit comments

Comments
 (0)