Skip to content

Commit ebc8153

Browse files
committed
Revert "Revert "[mlir] Purge linalg.copy and use memref.copy instead.""
This reverts commit 25bf6a2.
1 parent 9c52a19 commit ebc8153

39 files changed

+215
-704
lines changed

mlir/docs/Dialects/Linalg/_index.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,6 @@ seem generally appealing.
545545
Additionally, `linalg` provides a small subset of commonly named operations:
546546

547547
```
548-
* `linalg.copy`,
549548
* `linalg.fill`,
550549
* `linalg.dot`,
551550
* `linalg.matmul`,

mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,6 @@ class LinalgOpToLibraryCallRewrite
3939
PatternRewriter &rewriter) const override;
4040
};
4141

42-
/// Rewrite pattern specialization for CopyOp, kicks in when both input and
43-
/// output permutations are left unspecified or are the identity.
44-
class CopyOpToLibraryCallRewrite : public OpRewritePattern<CopyOp> {
45-
public:
46-
using OpRewritePattern<CopyOp>::OpRewritePattern;
47-
LogicalResult matchAndRewrite(CopyOp op,
48-
PatternRewriter &rewriter) const override;
49-
};
50-
51-
/// Rewrite CopyOp with permutations into a sequence of TransposeOp and
52-
/// permutation-free CopyOp. This interplays with TransposeOpConversion and
53-
/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
54-
class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
55-
public:
56-
using OpRewritePattern<CopyOp>::OpRewritePattern;
57-
LogicalResult matchAndRewrite(CopyOp op,
58-
PatternRewriter &rewriter) const override;
59-
};
60-
6142
/// Populate the given list with patterns that convert from Linalg to Standard.
6243
void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns);
6344

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
4242
}: memref<2xf32>, memref<2xf32>
4343
br ^bb3(%0 : memref<2xf32>)
4444
^bb3(%1: memref<2xf32>):
45-
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
45+
"memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
4646
return
4747
}
4848
}
@@ -58,7 +58,7 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
5858
cond_br %arg0, ^bb1, ^bb2
5959
^bb1: // pred: ^bb0
6060
%0 = memref.alloc() : memref<2xf32>
61-
linalg.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
61+
memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
6262
br ^bb3(%0 : memref<2xf32>)
6363
^bb2: // pred: ^bb0
6464
%1 = memref.alloc() : memref<2xf32>
@@ -72,11 +72,11 @@ def BufferDeallocation : Pass<"buffer-deallocation", "FuncOp"> {
7272
linalg.yield %4 : f32
7373
}: memref<2xf32>, memref<2xf32>
7474
%2 = memref.alloc() : memref<2xf32>
75-
linalg.copy(%1, %2) : memref<2xf32>, memref<2xf32>
75+
memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
7676
dealloc %1 : memref<2xf32>
7777
br ^bb3(%2 : memref<2xf32>)
7878
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
79-
linalg.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
79+
memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
8080
dealloc %3 : memref<2xf32>
8181
return
8282
}

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 0 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
1818
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
19-
include "mlir/Interfaces/CopyOpInterface.td"
2019
include "mlir/Interfaces/InferTypeOpInterface.td"
2120
include "mlir/Interfaces/SideEffectInterfaces.td"
2221

@@ -57,119 +56,6 @@ class LinalgStructured_Op<string mnemonic, list<Trait> props>
5756
//===----------------------------------------------------------------------===//
5857
// Named Linalg ops, implemented as special configurations of generic ops.
5958
//===----------------------------------------------------------------------===//
60-
// At the moment these are not declarative and require a bunch of C++ code.
61-
// In the future, these should be migrated to a declarative specification.
62-
def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
63-
let description = [{
64-
Copies the data in the input view into the output view.
65-
66-
Usage:
67-
68-
```mlir
69-
linalg.copy(%arg0, %arg1) : memref<?xf32, stride_specification>,
70-
memref<?xf32, stride_specification>
71-
```
72-
73-
One possible lowering to loop form is:
74-
75-
```mlir
76-
%0 = linalg.dim %arg0, 0 : index
77-
scf.for %i0 = %c0 to %0 step %c1 {
78-
%1 = load %arg0[%i0] : memref<?xf32, stride_specification>
79-
store %1, %arg1[%i0] : memref<?xf32, stride_specification>
80-
}
81-
```
82-
83-
Optionally, can take `input_permutation` and `output_permutation` attributes
84-
to reorder the dimensions of the input and output views.
85-
86-
Usage:
87-
88-
```mlir
89-
linalg.copy(%arg0, %arg1) {inputPermutation : (i, j, k) -> (i, k, j),
90-
outputPermutation : (i, j, k) -> (k, j, i)} :
91-
memref<?x?x?xf32, stride_specification>,
92-
memref<?x?x?xf32, stride_specification>
93-
```
94-
95-
One possible lowering to loop form is:
96-
97-
```mlir
98-
%0 = linalg.dim %arg0, 0
99-
%1 = linalg.dim %arg0, 1
100-
%2 = linalg.dim %arg0, 2
101-
scf.for %i0 = %c0 to %{{.*}} step %c1 {
102-
scf.for %i1 = %c0 to %{{.*}} step %c1 {
103-
scf.for %i2 = %c0 to %{{.*}} step %c1 {
104-
%3 = load %arg0[%i0, %i2, %i1] :
105-
memref<?x?x?xf32, stride_specification>
106-
store %3, %arg1[%i2, %i1, %i0] :
107-
memref<?x?x?xf32, stride_specification>
108-
```
109-
110-
The views are expected to be compatible for correctness but this is not
111-
enforced at the moment.
112-
}];
113-
114-
let arguments = (ins
115-
AnyStridedMemRef:$input,
116-
AnyStridedMemRef:$output,
117-
OptionalAttr<AffineMapAttr>:$inputPermutation,
118-
OptionalAttr<AffineMapAttr>:$outputPermutation);
119-
let regions = (region AnyRegion:$region);
120-
121-
let builders = [
122-
OpBuilder<(ins "Value":$input, "Value":$output,
123-
CArg<"AffineMap", "AffineMap()">:$inputPermutation,
124-
CArg<"AffineMap", "AffineMap()">:$outputPermutation,
125-
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
126-
127-
let extraClassDeclaration = structuredOpsDecls # [{
128-
ValueRange inputs() { return getOperands().take_front(); }
129-
ValueRange outputs() { return getOperands().take_back(); }
130-
131-
// Rank-polymorphic.
132-
// filling_value -> O(ivs) with parallel iterators.
133-
ArrayAttr iterator_types() {
134-
int64_t nPar = getRank(getInputOperand(0));
135-
return Builder(getContext()).getStrArrayAttr(
136-
SmallVector<StringRef, 8>(nPar, getParallelIteratorTypeName()));
137-
}
138-
139-
// I(input_perm(ivs)) -> O(output_perm(ivs))
140-
ArrayAttr indexing_maps() {
141-
MLIRContext *context = getContext();
142-
auto maybeInputMap = inputPermutation();
143-
auto maybeOutputMap = outputPermutation();
144-
int64_t inputRank = getRank(getInputOperand(0));
145-
int64_t outputRank = getRank(getOutputOperand(0));
146-
return Builder(getContext()).getAffineMapArrayAttr({
147-
extractOrIdentityMap(maybeInputMap, inputRank, context),
148-
extractOrIdentityMap(maybeOutputMap, outputRank, context)});
149-
}
150-
151-
Value getSource() { return input();}
152-
Value getTarget() { return output(); }
153-
154-
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
155-
static std::function<void(ImplicitLocOpBuilder &b, Block &block)>
156-
getRegionBuilder() {
157-
return &regionBuilder;
158-
}
159-
static unsigned getNumRegionArgs() { return 2; }
160-
}];
161-
let verifier = [{ return ::verify(*this); }];
162-
163-
let assemblyFormat = [{
164-
`(` $input `,` $output `)` attr-dict `:`
165-
type($input) `,` type($output)
166-
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
167-
}];
168-
169-
let hasCanonicalizer = 1;
170-
let hasFolder = 1;
171-
let skipDefaultBuilders = 1;
172-
}
17359

17460
def FillOp : LinalgStructured_Op<"fill", []> {
17561
let arguments = (ins

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def LinalgComprehensiveModuleBufferize :
5252
Option<"useAlloca", "use-alloca", "bool",
5353
/*default=*/"false",
5454
"Use stack allocations for memrefs (for testing purposes only)">,
55-
Option<"useLinalgCopy", "use-linalg-copy", "bool",
55+
Option<"useLinalgCopy", "use-memref.copy", "bool",
5656
/*default=*/"false",
5757
"Use a copy operation implemented as a Linalg op.">,
5858
Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ struct LinalgPromotionOptions {
349349
return *this;
350350
}
351351
/// Callback function to do the copy of data to and from the promoted
352-
/// subview. If None then a linalg.copy is used.
352+
/// subview. If None then a memref.copy is used.
353353
Optional<CopyCallbackFn> copyInFn = None;
354354
Optional<CopyCallbackFn> copyOutFn = None;
355355
LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const &copyIn,
@@ -390,6 +390,9 @@ FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
390390
/// Emit a suitable vector form for a Linalg op with fully static shape.
391391
LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
392392

393+
/// Emit a suitable vector form for a Copy op with fully static shape.
394+
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
395+
393396
/// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
394397
FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
395398
LinalgOp linalgOp);
@@ -934,6 +937,15 @@ struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
934937
LinalgTransformationFilter filter;
935938
};
936939

940+
/// `filter` controls LinalgTransformMarker matching and update when specified.
941+
/// See `vectorizeLinalgOp` for more details.
942+
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
943+
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
944+
945+
LogicalResult matchAndRewrite(memref::CopyOp copyOp,
946+
PatternRewriter &rewriter) const override;
947+
};
948+
937949
/// Return vector::CombiningKind for the given op.
938950
llvm::Optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
939951

@@ -1206,7 +1218,7 @@ void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
12061218
/// %subView = subview %allocOrView ...
12071219
/// [optional] linalg.fill(%allocOrView, %cst) ...
12081220
/// ...
1209-
/// linalg.copy(%in, %subView) ...
1221+
/// memref.copy(%in, %subView) ...
12101222
/// vector.transfer_read %allocOrView[...], %cst ...
12111223
/// ```
12121224
/// into
@@ -1217,8 +1229,8 @@ void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
12171229
/// ...
12181230
/// vector.transfer_read %in[...], %cst ...
12191231
/// ```
1220-
/// Where there is no interleaved use between linalg.copy and transfer_read as
1221-
/// well as no interleaved use between linalg.fill and linalg.copy (if
1232+
/// Where there is no interleaved use between memref.copy and transfer_read as
1233+
/// well as no interleaved use between linalg.fill and memref.copy (if
12221234
/// linalg.fill is specified).
12231235
/// This is a custom rewrite to forward partial reads (with optional fills) to
12241236
/// vector.transfer_read.
@@ -1237,7 +1249,7 @@ struct LinalgCopyVTRForwardingPattern
12371249
/// %subView = subview %allocOrView...
12381250
/// ...
12391251
/// vector.transfer_write %..., %allocOrView[...]
1240-
/// linalg.copy(%subView, %out)
1252+
/// memref.copy(%subView, %out)
12411253
/// ```
12421254
/// into
12431255
/// ```
@@ -1247,7 +1259,7 @@ struct LinalgCopyVTRForwardingPattern
12471259
/// ...
12481260
/// vector.transfer_write %..., %out[...]
12491261
/// ```
1250-
/// Where there is no interleaved use between transfer_write and linalg.copy.
1262+
/// Where there is no interleaved use between transfer_write and memref.copy.
12511263
/// This is a custom rewrite to forward partial writes to vector.transfer_write.
12521264
struct LinalgCopyVTWForwardingPattern
12531265
: public OpRewritePattern<vector::TransferWriteOp> {

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def CopyOp : MemRef_Op<"copy",
407407
}];
408408

409409
let hasCanonicalizer = 1;
410+
let hasFolder = 1;
410411
let verifier = ?;
411412
}
412413

mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
9696

9797
LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
9898
LinalgOp op, PatternRewriter &rewriter) const {
99-
// Only LinalgOp for which there is no specialized pattern go through this.
100-
if (isa<CopyOp>(op))
101-
return failure();
102-
10399
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
104100
if (!libraryCallName)
105101
return failure();
@@ -113,65 +109,13 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
113109
return success();
114110
}
115111

116-
LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
117-
CopyOp op, PatternRewriter &rewriter) const {
118-
auto inputPerm = op.inputPermutation();
119-
if (inputPerm.hasValue() && !inputPerm->isIdentity())
120-
return failure();
121-
auto outputPerm = op.outputPermutation();
122-
if (outputPerm.hasValue() && !outputPerm->isIdentity())
123-
return failure();
124-
125-
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
126-
if (!libraryCallName)
127-
return failure();
128-
129-
rewriter.replaceOpWithNewOp<mlir::CallOp>(
130-
op, libraryCallName.getValue(), TypeRange(),
131-
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
132-
op.getOperands()));
133-
return success();
134-
}
135-
136-
LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
137-
CopyOp op, PatternRewriter &rewriter) const {
138-
Value in = op.input(), out = op.output();
139-
140-
// If either inputPerm or outputPerm are non-identities, insert transposes.
141-
auto inputPerm = op.inputPermutation();
142-
if (inputPerm.hasValue() && !inputPerm->isIdentity())
143-
in = rewriter.create<memref::TransposeOp>(op.getLoc(), in,
144-
AffineMapAttr::get(*inputPerm));
145-
auto outputPerm = op.outputPermutation();
146-
if (outputPerm.hasValue() && !outputPerm->isIdentity())
147-
out = rewriter.create<memref::TransposeOp>(op.getLoc(), out,
148-
AffineMapAttr::get(*outputPerm));
149-
150-
// If nothing was transposed, fail and let the conversion kick in.
151-
if (in == op.input() && out == op.output())
152-
return failure();
153-
154-
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
155-
if (!libraryCallName)
156-
return failure();
157-
158-
rewriter.replaceOpWithNewOp<mlir::CallOp>(
159-
op, libraryCallName.getValue(), TypeRange(),
160-
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
161-
return success();
162-
}
163112

164113
/// Populate the given list with patterns that convert from Linalg to Standard.
165114
void mlir::linalg::populateLinalgToStandardConversionPatterns(
166115
RewritePatternSet &patterns) {
167116
// TODO: ConvOp conversion needs to export a descriptor with relevant
168117
// attribute values such as kernel striding and dilation.
169-
// clang-format off
170-
patterns.add<
171-
CopyOpToLibraryCallRewrite,
172-
CopyTransposeRewrite,
173-
LinalgOpToLibraryCallRewrite>(patterns.getContext());
174-
// clang-format on
118+
patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
175119
}
176120

177121
namespace {

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,9 +914,10 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
914914
auto sourcePtr = promote(unrankedSource);
915915
auto targetPtr = promote(unrankedTarget);
916916

917+
unsigned bitwidth = mlir::DataLayout::closest(op).getTypeSizeInBits(
918+
srcType.getElementType());
917919
auto elemSize = rewriter.create<LLVM::ConstantOp>(
918-
loc, getIndexType(),
919-
rewriter.getIndexAttr(srcType.getElementTypeBitWidth() / 8));
920+
loc, getIndexType(), rewriter.getIndexAttr(bitwidth / 8));
920921
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
921922
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
922923
rewriter.create<LLVM::CallOp>(loc, copyFn,

0 commit comments

Comments
 (0)