Skip to content

Commit eca7698

Browse files
committed
[mlir][vector] NFC: Expose castAwayContractionLeadingOneDim
This commit exposes the transformation behind the pattern. It is useful for more targeted application on a specific op for once. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D148758
1 parent ca554ad commit eca7698

File tree

2 files changed

+122
-106
lines changed

2 files changed

+122
-106
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ enum class AtomicRMWKind : uint64_t;
4444
} // namespace arith
4545

4646
namespace vector {
47+
class ContractionOp;
4748
class TransferReadOp;
4849
class TransferWriteOp;
4950
class VectorDialect;
@@ -76,6 +77,11 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
7677
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
7778
PatternBenefit benefit = 1);
7879

80+
/// Cast away the leading unit dim, if exists, for the given contract op.
81+
/// Return success if the transformation applies; return failure otherwise.
82+
LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
83+
RewriterBase &rewriter);
84+
7985
/// Collect a set of leading one dimension removal patterns.
8086
///
8187
/// These patterns insert vector.shape_cast to remove leading one dimensions

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 116 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,121 @@ struct CastAwayTransferWriteLeadingOneDim
279279
}
280280
};
281281

282+
} // namespace
283+
284+
LogicalResult
285+
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
286+
RewriterBase &rewriter) {
287+
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
288+
if (oldAccType == nullptr)
289+
return failure();
290+
if (oldAccType.getRank() < 2)
291+
return failure();
292+
if (oldAccType.getShape()[0] != 1)
293+
return failure();
294+
// currently we support only dropping one dim but the pattern can be applied
295+
// greedily to drop more.
296+
int64_t dropDim = 1;
297+
298+
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
299+
SmallVector<AffineMap> newIndexingMaps;
300+
301+
auto oldIteratorTypes = contractOp.getIteratorTypes();
302+
SmallVector<Attribute> newIteratorTypes;
303+
304+
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
305+
306+
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
307+
// only parallel type iterators can be dropped.
308+
return failure();
309+
310+
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
311+
int64_t currDim = it.index();
312+
if (currDim == dimToDrop)
313+
continue;
314+
newIteratorTypes.push_back(it.value());
315+
}
316+
317+
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
318+
contractOp.getAcc()};
319+
SmallVector<Value> newOperands;
320+
321+
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
322+
// Check if the dim to be dropped exists as a leading dim in the operand
323+
// if it does then we use vector.extract to drop it.
324+
bool validExtract = false;
325+
SmallVector<AffineExpr> results;
326+
auto map = it.value();
327+
int64_t orginalZeroDim = it.value().getDimPosition(0);
328+
if (orginalZeroDim != dimToDrop) {
329+
// There are two reasons to be in this path, 1. We need to
330+
// tranpose the operand to make the dim to be dropped
331+
// leading. 2. The dim to be dropped does not exist and in
332+
// that case we dont want to add a unit tranpose but we must
333+
// check all the indices to make sure this is the case.
334+
bool tranposeNeeded = false;
335+
SmallVector<int64_t> perm;
336+
SmallVector<AffineExpr> transposeResults;
337+
338+
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
339+
int64_t currDim = map.getDimPosition(i);
340+
if (currDim == dimToDrop) {
341+
tranposeNeeded = true;
342+
perm.insert(perm.begin(), i);
343+
auto targetExpr = rewriter.getAffineDimExpr(currDim);
344+
transposeResults.insert(transposeResults.begin(), targetExpr);
345+
} else {
346+
perm.push_back(i);
347+
auto targetExpr = rewriter.getAffineDimExpr(currDim);
348+
transposeResults.push_back(targetExpr);
349+
}
350+
}
351+
// Do the tranpose now if needed so that we can drop the
352+
// correct dim using extract later.
353+
if (tranposeNeeded) {
354+
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
355+
contractOp.getContext());
356+
operands[it.index()] = rewriter.create<vector::TransposeOp>(
357+
contractOp.getLoc(), operands[it.index()], perm);
358+
}
359+
}
360+
// We have taken care to have the dim to be dropped be
361+
// the leading dim. If its still not leading that means it
362+
// does not exist in this operand and hence we do not need
363+
// an extract.
364+
if (map.getDimPosition(0) == dimToDrop)
365+
validExtract = true;
366+
367+
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
368+
int64_t currDim = map.getDimPosition(i);
369+
if (currDim == dimToDrop)
370+
// This is the dim we are dropping.
371+
continue;
372+
auto targetExpr = rewriter.getAffineDimExpr(
373+
currDim < dimToDrop ? currDim : currDim - 1);
374+
results.push_back(targetExpr);
375+
}
376+
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
377+
contractOp.getContext()));
378+
// Extract if its a valid extraction, otherwise use the operand
379+
// without extraction.
380+
newOperands.push_back(
381+
validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
382+
operands[it.index()],
383+
splatZero(dropDim))
384+
: operands[it.index()]);
385+
}
386+
auto newContractOp = rewriter.create<vector::ContractionOp>(
387+
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
388+
rewriter.getAffineMapArrayAttr(newIndexingMaps),
389+
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
390+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
391+
contractOp, contractOp->getResultTypes()[0], newContractOp);
392+
return success();
393+
}
394+
395+
namespace {
396+
282397
/// Turns vector.contract on vector with leading 1 dimensions into
283398
/// vector.extract followed by vector.contract on vector without leading
284399
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
@@ -289,112 +404,7 @@ struct CastAwayContractionLeadingOneDim
289404

290405
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
291406
PatternRewriter &rewriter) const override {
292-
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
293-
if (oldAccType == nullptr)
294-
return failure();
295-
if (oldAccType.getRank() < 2)
296-
return failure();
297-
if (oldAccType.getShape()[0] != 1)
298-
return failure();
299-
// currently we support only dropping one dim but the pattern can be applied
300-
// greedily to drop more.
301-
int64_t dropDim = 1;
302-
303-
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
304-
SmallVector<AffineMap> newIndexingMaps;
305-
306-
auto oldIteratorTypes = contractOp.getIteratorTypes();
307-
SmallVector<Attribute> newIteratorTypes;
308-
309-
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
310-
311-
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
312-
// only parallel type iterators can be dropped.
313-
return failure();
314-
315-
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
316-
int64_t currDim = it.index();
317-
if (currDim == dimToDrop)
318-
continue;
319-
newIteratorTypes.push_back(it.value());
320-
}
321-
322-
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
323-
contractOp.getAcc()};
324-
SmallVector<Value> newOperands;
325-
326-
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
327-
// Check if the dim to be dropped exists as a leading dim in the operand
328-
// if it does then we use vector.extract to drop it.
329-
bool validExtract = false;
330-
SmallVector<AffineExpr> results;
331-
auto map = it.value();
332-
int64_t orginalZeroDim = it.value().getDimPosition(0);
333-
if (orginalZeroDim != dimToDrop) {
334-
// There are two reasons to be in this path, 1. We need to
335-
// tranpose the operand to make the dim to be dropped
336-
// leading. 2. The dim to be dropped does not exist and in
337-
// that case we dont want to add a unit tranpose but we must
338-
// check all the indices to make sure this is the case.
339-
bool tranposeNeeded = false;
340-
SmallVector<int64_t> perm;
341-
SmallVector<AffineExpr> transposeResults;
342-
343-
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
344-
int64_t currDim = map.getDimPosition(i);
345-
if (currDim == dimToDrop) {
346-
tranposeNeeded = true;
347-
perm.insert(perm.begin(), i);
348-
auto targetExpr = rewriter.getAffineDimExpr(currDim);
349-
transposeResults.insert(transposeResults.begin(), targetExpr);
350-
} else {
351-
perm.push_back(i);
352-
auto targetExpr = rewriter.getAffineDimExpr(currDim);
353-
transposeResults.push_back(targetExpr);
354-
}
355-
}
356-
// Do the tranpose now if needed so that we can drop the
357-
// correct dim using extract later.
358-
if (tranposeNeeded) {
359-
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
360-
contractOp.getContext());
361-
operands[it.index()] = rewriter.create<vector::TransposeOp>(
362-
contractOp.getLoc(), operands[it.index()], perm);
363-
}
364-
}
365-
// We have taken care to have the dim to be dropped be
366-
// the leading dim. If its still not leading that means it
367-
// does not exist in this operand and hence we do not need
368-
// an extract.
369-
if (map.getDimPosition(0) == dimToDrop)
370-
validExtract = true;
371-
372-
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
373-
int64_t currDim = map.getDimPosition(i);
374-
if (currDim == dimToDrop)
375-
// This is the dim we are dropping.
376-
continue;
377-
auto targetExpr = rewriter.getAffineDimExpr(
378-
currDim < dimToDrop ? currDim : currDim - 1);
379-
results.push_back(targetExpr);
380-
}
381-
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
382-
contractOp.getContext()));
383-
// Extract if its a valid extraction, otherwise use the operand
384-
// without extraction.
385-
newOperands.push_back(validExtract
386-
? rewriter.create<vector::ExtractOp>(
387-
contractOp.getLoc(), operands[it.index()],
388-
splatZero(dropDim))
389-
: operands[it.index()]);
390-
}
391-
auto newContractOp = rewriter.create<vector::ContractionOp>(
392-
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
393-
rewriter.getAffineMapArrayAttr(newIndexingMaps),
394-
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
395-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
396-
contractOp, contractOp->getResultTypes()[0], newContractOp);
397-
return success();
407+
return castAwayContractionLeadingOneDim(contractOp, rewriter);
398408
}
399409
};
400410

0 commit comments

Comments
 (0)