Skip to content

Commit 1ee1143

Browse files
[mlir][Linalg][Vector] Add forwarding patterns between linalg.copy and vector.transfer
This revision adds custom rewrites for patterns that arise during linalg structured ops vectorization. These patterns allow the composition of linalg promotion, vectorization and removal of redundant copies. The patterns are voluntarily limited and restrictive atm. More robust behavior will be implemented once more powerful side effect modeling and analyses are available on view/subview. On the transfer_read side, the following pattern is rewritten: ``` %alloc = ... [optional] %view = std.view %alloc ... %subView = subview %allocOrView ... [optional] linalg.fill(%allocOrView, %cst) ... ... linalg.copy(%in, %subView) ... vector.transfer_read %allocOrView[...], %cst ... ``` into ``` [unchanged] %alloc = ... [unchanged] [optional] %view = std.view %alloc ... [unchanged] [unchanged] %subView = subview %allocOrView ... ... vector.transfer_read %in[...], %cst ... ``` On the transfer_write side, the following pattern is rewriten: ``` %alloc = ... [optional] %view = std.view %alloc ... %subView = subview %allocOrView... ... vector.transfer_write %..., %allocOrView[...] linalg.copy(%subView, %out) ``` Differential Revision: https://reviews.llvm.org/D80728
1 parent ea7db62 commit 1ee1143

File tree

4 files changed

+443
-46
lines changed

4 files changed

+443
-46
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
#include "llvm/ADT/SmallBitVector.h"
1515

1616
namespace mlir {
17+
namespace vector {
18+
19+
class TransferReadOp;
20+
class TransferWriteOp;
21+
22+
} // namespace vector
23+
1724
namespace linalg {
1825

1926
struct LinalgTilingOptions;
@@ -437,6 +444,67 @@ struct LinalgLoweringPattern : public RewritePattern {
437444
LinalgLoweringType loweringType;
438445
};
439446

447+
//===----------------------------------------------------------------------===//
448+
// Op-specific patterns.
449+
//===----------------------------------------------------------------------===//
450+
/// Match and rewrite for the pattern:
451+
/// ```
452+
/// %alloc = ...
453+
/// [optional] %view = std.view %alloc ...
454+
/// %subView = subview %allocOrView ...
455+
/// [optional] linalg.fill(%allocOrView, %cst) ...
456+
/// ...
457+
/// linalg.copy(%in, %subView) ...
458+
/// vector.transfer_read %allocOrView[...], %cst ...
459+
/// ```
460+
/// into
461+
/// ```
462+
/// [unchanged] %alloc = ...
463+
/// [unchanged] [optional] %view = std.view %alloc ...
464+
/// [unchanged] [unchanged] %subView = subview %allocOrView ...
465+
/// ...
466+
/// vector.transfer_read %in[...], %cst ...
467+
/// ```
468+
/// Where there is no interleaved use between linalg.copy and transfer_read as
469+
/// well as no interleaved use between linalg.fill and linalg.copy (if
470+
/// linalg.fill is specified).
471+
/// This is a custom rewrite to forward partial reads (with optional fills) to
472+
/// vector.transfer_read.
473+
struct LinalgCopyVTRForwardingPattern
474+
: public OpRewritePattern<vector::TransferReadOp> {
475+
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
476+
477+
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
478+
PatternRewriter &rewriter) const override;
479+
};
480+
481+
/// Match and rewrite for the pattern:
482+
/// ```
483+
/// %alloc = ...
484+
/// [optional] %view = std.view %alloc ...
485+
/// %subView = subview %allocOrView...
486+
/// ...
487+
/// vector.transfer_write %..., %allocOrView[...]
488+
/// linalg.copy(%subView, %out)
489+
/// ```
490+
/// into
491+
/// ```
492+
/// [unchanged] %alloc = ...
493+
/// [unchanged] [optional] %view = std.view %alloc ...
494+
/// [unchanged] %subView = subview %allocOrView...
495+
/// ...
496+
/// vector.transfer_write %..., %out[...]
497+
/// ```
498+
/// Where there is no interleaved use between transfer_write and linalg.copy.
499+
/// This is a custom rewrite to forward partial writes to vector.transfer_write.
500+
struct LinalgCopyVTWForwardingPattern
501+
: public OpRewritePattern<vector::TransferWriteOp> {
502+
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
503+
504+
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
505+
PatternRewriter &rewriter) const override;
506+
};
507+
440508
//===----------------------------------------------------------------------===//
441509
// Support for staged pattern application.
442510
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 171 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,13 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
103103
llvm_unreachable("Unexpected conv with padding");
104104
}
105105

106+
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
107+
(void)dbgPref;
106108
edsc::ScopedContext scope(builder, op->getLoc());
107109
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
108110
// Vectorize fill as a vector.broadcast.
109-
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
110-
"]: Rewrite linalg.fill as vector.broadcast: "
111-
<< *op << ":\n");
111+
LLVM_DEBUG(dbgs() << dbgPref
112+
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
112113
Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
113114
Value dst = std_load(memref);
114115
Value res = vector_broadcast(dst.getType(), fillOp.value());
@@ -117,9 +118,8 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
117118
}
118119

119120
// Vectorize other ops as vector contraction (currently only matmul).
120-
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
121-
"]: Rewrite linalg op as vector.contract: "
122-
<< *op << ":\n");
121+
LLVM_DEBUG(dbgs() << dbgPref
122+
<< "Rewrite linalg op as vector.contract: " << *op);
123123
auto linalgOp = cast<linalg::LinalgOp>(op);
124124
Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
125125
Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
@@ -129,3 +129,168 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
129129
linalgOp.iterator_types());
130130
std_store(res, memref);
131131
}
132+
133+
/// Check whether there is any interleaved use of any `values` between `firstOp`
134+
/// and `secondOp`. Conservatively return `true` if any op or value is in a
135+
/// different block.
136+
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
137+
ValueRange values) {
138+
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
139+
(void)dbgPref;
140+
if (firstOp->getBlock() != secondOp->getBlock() ||
141+
!firstOp->isBeforeInBlock(secondOp)) {
142+
LLVM_DEBUG(llvm::dbgs()
143+
<< dbgPref << "interleavedUses precondition failed, firstOp: "
144+
<< *firstOp << ", second op: " << *secondOp);
145+
return true;
146+
}
147+
for (auto v : values) {
148+
for (auto &u : v.getUses()) {
149+
Operation *owner = u.getOwner();
150+
if (owner == firstOp || owner == secondOp)
151+
continue;
152+
// TODO: this is too conservative, use dominance info in the future.
153+
if (owner->getBlock() == firstOp->getBlock() &&
154+
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
155+
continue;
156+
LLVM_DEBUG(llvm::dbgs()
157+
<< dbgPref << " found interleaved op " << *owner
158+
<< ", firstOp: " << *firstOp << ", second op: " << *secondOp);
159+
return true;
160+
}
161+
}
162+
return false;
163+
}
164+
165+
/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
166+
static SubViewOp getSubViewUseIfUnique(Value v) {
167+
SubViewOp subViewOp;
168+
for (auto &u : v.getUses()) {
169+
if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
170+
if (subViewOp)
171+
return SubViewOp();
172+
subViewOp = newSubViewOp;
173+
}
174+
}
175+
return subViewOp;
176+
}
177+
178+
/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
179+
/// when available.
180+
LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
181+
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
182+
183+
// Transfer into `view`.
184+
Value viewOrAlloc = xferOp.memref();
185+
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
186+
!viewOrAlloc.getDefiningOp<AllocOp>())
187+
return failure();
188+
189+
StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: ";
190+
(void)dbgPref;
191+
LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc);
192+
193+
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
194+
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
195+
if (!subViewOp)
196+
return failure();
197+
Value subView = subViewOp.getResult();
198+
LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView);
199+
200+
// Find the copy into `subView` without interleaved uses.
201+
CopyOp copyOp;
202+
for (auto &u : subView.getUses()) {
203+
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
204+
if (newCopyOp.getOutputBuffer(0) != subView)
205+
continue;
206+
LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp);
207+
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
208+
continue;
209+
copyOp = newCopyOp;
210+
break;
211+
}
212+
}
213+
if (!copyOp)
214+
return failure();
215+
LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp);
216+
217+
// Find the fill into `viewOrAlloc` without interleaved uses before the copy.
218+
FillOp maybeFillOp;
219+
for (auto &u : viewOrAlloc.getUses()) {
220+
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
221+
if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
222+
continue;
223+
LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp);
224+
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
225+
continue;
226+
maybeFillOp = newFillOp;
227+
break;
228+
}
229+
}
230+
// Ensure padding matches.
231+
if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
232+
return failure();
233+
if (maybeFillOp)
234+
LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp);
235+
236+
// `in` is the subview that linalg.copy reads. Replace it.
237+
Value in = copyOp.getInput(0);
238+
239+
Value res = rewriter.create<vector::TransferReadOp>(
240+
xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
241+
xferOp.permutation_map(), xferOp.padding(),
242+
xferOp.masked() ? *xferOp.masked() : ArrayAttr());
243+
244+
if (maybeFillOp)
245+
rewriter.eraseOp(maybeFillOp);
246+
rewriter.eraseOp(copyOp);
247+
rewriter.replaceOp(xferOp, res);
248+
249+
return success();
250+
}
251+
252+
/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
253+
/// when available.
254+
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
255+
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
256+
// Transfer into `viewOrAlloc`.
257+
Value viewOrAlloc = xferOp.memref();
258+
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
259+
!viewOrAlloc.getDefiningOp<AllocOp>())
260+
return failure();
261+
262+
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
263+
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
264+
if (!subViewOp)
265+
return failure();
266+
Value subView = subViewOp.getResult();
267+
268+
// Find the copy from `subView` without interleaved uses.
269+
CopyOp copyOp;
270+
for (auto &u : subViewOp.getResult().getUses()) {
271+
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
272+
if (newCopyOp.getInput(0) != subView)
273+
continue;
274+
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
275+
continue;
276+
copyOp = newCopyOp;
277+
break;
278+
}
279+
}
280+
if (!copyOp)
281+
return failure();
282+
283+
// `out` is the subview copied into that we replace.
284+
Value out = copyOp.getOutputBuffer(0);
285+
286+
// Forward vector.transfer into copy.
287+
rewriter.create<vector::TransferWriteOp>(
288+
xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
289+
xferOp.permutation_map(),
290+
xferOp.masked() ? *xferOp.masked() : ArrayAttr());
291+
292+
rewriter.eraseOp(copyOp);
293+
rewriter.eraseOp(xferOp);
294+
295+
return success();
296+
}

0 commit comments

Comments
 (0)