Skip to content

[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n) #95743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ namespace {
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if changes in this file have much to do with a test refactoring commit. I am happy to be proven wrong if this is following usual guidelines. I am still a newbie.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle:

  • one patch == one change,
  • avoid unrelated changes.

In this patch, I am violating these rules. First, I am making 2 changes:

Second, this particular change qualifies as "unrelated". So, if we were to go by the book, I should split it into a separate PR. I am happy to do that, but I am also mindful that I'm generating a lot of PR traffic and want to reduce noise 😅

In situation like this, I try to make the intent clear in the summary:

Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to
unify comments and logic between

FlattenContiguousRowMajorTransferWritePattern and
FlattenContiguousRowMajorTransferReadPattern.

... and then follow the reviewers recommendation. If I was reviewing this, I'd say "move to a different patch - I'll happily review that" ;-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's ok with me. Code owners might have another point of view.

/// If `targetVectorBitwidth` is provided, the flattening will only happen if
/// the trailing dimension of the vector read is smaller than the provided
/// bitwidth.
Expand Down Expand Up @@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
dyn_cast<MemRefType>(collapsedSource.getType());
cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);

Expand Down Expand Up @@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_write has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
///
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
/// the trailing dimension of the vector read is smaller than the provided
/// bitwidth.
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
public:
Expand All @@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());

// 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
// If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
Expand All @@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
Expand All @@ -697,22 +704,30 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();

SmallVector<Value> collapsedIndices =
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
transferWriteOp.getIndices(), firstDimToCollapse);
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();

// 1. Collapse the source memref
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);

// 2. Generate input args for a new vector.transfer_read that will read
// from the collapsed memref.
// 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());

// 2.2 New indices
SmallVector<Value> collapsedIndices =
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
transferWriteOp.getIndices(), firstDimToCollapse);

// 3. Create new vector.transfer_write that writes to the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
Expand All @@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
rewriter.create<vector::TransferWriteOp>(
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));

// 4. Replace the old transfer_write with the new one writing the
// collapsed shape
rewriter.eraseOp(transferWriteOp);
return success();
}
Expand Down
Loading
Loading