Skip to content

[mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) #73523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
/// Checks that the indices corresponding to dimensions starting at
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
/// TODO: Extract the logic that writes to outIndices so that this method
/// simply checks one pre-condition.
static LogicalResult
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
SmallVector<Value> &outIndices) {
Expand Down Expand Up @@ -542,45 +544,100 @@ class FlattenContiguousRowMajorTransferReadPattern
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
auto source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());

// 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
// If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
if (transferReadOp.getMask())
return failure();

SmallVector<Value> collapsedIndices;
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
firstContiguousInnerDim,
collapsedIndices)))
return failure();
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();

// 1. Collapse the source memref
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
dyn_cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstContiguousInnerDim + 1);
assert(collapsedRank == firstDimToCollapse + 1);

// 2. Generate input args for a new vector.transfer_read that will read
// from the collapsed memref.
// 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());

// 2.2 New indices
// If all the collapsed indices are zero then no extra logic is needed.
// Otherwise, a new offset/index has to be computed.
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
firstDimToCollapse,
collapsedIndices))) {
// Copy all the leading indices
collapsedIndices = transferReadOp.getIndices();
collapsedIndices.resize(firstDimToCollapse);

// Compute the remaining trailing index/offset required for reading from
// the collapsed memref:
//
// offset = 0
// for (i = firstDimToCollapse; i < outputRank; ++i)
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
//
// For this example:
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
// memref<1x43x2xi32>, vector<1x2xi32>
// which would be collapsed to:
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
// memref<1x86xi32>, vector<2xi32>
// one would get the following offset:
// %offset = %arg0 * 43
AffineExpr offsetExpr, idxExpr;
bindSymbols(rewriter.getContext(), offsetExpr, idxExpr);

int64_t outputRank = transferReadOp.getIndices().size();
OpFoldResult offset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();

for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
offset = affine::makeComposedFoldedAffineApply(
rewriter, loc, offsetExpr + dim * idxExpr,
{offset, transferReadOp.getIndices()[i]});
}
if (offset.is<Value>()) {
collapsedIndices.push_back(offset.get<Value>());
} else {
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
loc, *getConstantIntValue(offset)));
}
}

// 3. Create new vector.transfer_read that reads from the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));

// 4. Replace the old transfer_read with the new one reading from the
// collapsed shape
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
return false;
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);

// TODO: Add support for memref with trailing dynamic shapes. Memrefs
// with leading dynamic dimensions are already supported.
if (ShapedType::isDynamicShape(memrefShape))
return false;

// Cond 1: A contiguous memref will always have a unit trailing stride.
if (strides.back() != 1)
return false;
Expand Down
55 changes: 55 additions & 0 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,61 @@ func.func @transfer_read_dims_mismatch_contiguous(

// -----

func.func @transfer_read_dims_mismatch_non_zero_indices(
%idx_1: index,
%idx_2: index,
%m_in: memref<1x43x4x6xi32>,
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
return
}

// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>

// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>,
// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]]
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>

// -----

func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
%idx_1: index,
%idx_2: index,
%m_in: memref<1x?x4x6xi32>,
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
return
}

// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[M_IN:.*]]: memref<1x?x4x6xi32>,
// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
// CHECK: %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>

// -----

func.func @transfer_read_dims_mismatch_non_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
%c0 = arith.constant 0 : index
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ struct TestFlattenVectorTransferPatterns
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
registry.insert<affine::AffineDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
Expand Down