Skip to content

[mlir][vector] Add scalable lowering for transfer_write(transpose) #101353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,9 @@ def ConvertVectorToSCF : Pass<"convert-vector-to-scf"> {
Option<"targetRank", "target-rank", "unsigned", /*default=*/"1",
"Target vector rank to which transfer ops should be lowered">,
Option<"lowerTensors", "lower-tensors", "bool", /*default=*/"false",
"Lower transfer ops that operate on tensors">
"Lower transfer ops that operate on tensors">,
Option<"lowerScalable", "lower-scalable", "bool", /*default=*/"false",
"Add scalable vector specific lowerings (that introduce loops)">
];
}

Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ struct VectorTransferToSCFOptions {
unroll = u;
return *this;
}
/// Enable scalable vector specific lowerings (which introduce loops). These
/// work alongside fullUnroll (which unrolls until the first scalable
/// dimension).
bool lowerScalable = false;
VectorTransferToSCFOptions enableLowerScalable(bool enable = true) {
lowerScalable = enable;
return *this;
}
};

/// Collect a set of patterns to convert from the Vector dialect to SCF + func.
Expand Down
232 changes: 214 additions & 18 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -44,6 +45,18 @@ namespace {
/// Attribute name used for labeling transfer ops during progressive lowering.
static const char kPassLabel[] = "__vector_to_scf_lowering__";

/// Return true if this transfer op operates on a source tensor.
static bool isTensorOp(VectorTransferOpInterface xferOp) {
if (isa<RankedTensorType>(xferOp.getShapedType())) {
if (isa<vector::TransferWriteOp>(xferOp)) {
// TransferWriteOps on tensors have a result.
assert(xferOp->getNumResults() > 0);
}
return true;
}
return false;
}

/// Patterns that inherit from this struct have access to
/// VectorTransferToSCFOptions.
template <typename OpTy>
Expand All @@ -52,6 +65,15 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
VectorTransferToSCFOptions opt)
: OpRewritePattern<OpTy>(context), options(opt) {}

LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter) const {
if (isTensorOp(xferOp) && !options.lowerTensors) {
return rewriter.notifyMatchFailure(
xferOp, "lowering tensor transfers is disabled");
}
return success();
}

VectorTransferToSCFOptions options;
};

Expand Down Expand Up @@ -257,19 +279,6 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
newXferOp->setAttr(kPassLabel, b.getUnitAttr());
}

/// Return true if this transfer op operates on a source tensor.
template <typename OpTy>
static bool isTensorOp(OpTy xferOp) {
if (isa<RankedTensorType>(xferOp.getShapedType())) {
if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
// TransferWriteOps on tensors have a result.
assert(xferOp->getNumResults() > 0);
}
return true;
}
return false;
}

namespace lowering_n_d {

/// Helper data structure for data and mask buffers.
Expand Down Expand Up @@ -987,6 +996,189 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
}
};

/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
/// and ConstantMaskOp.
template <typename VscaleConstantBuilder>
static FailureOr<SmallVector<OpFoldResult>>
getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
if (!mask)
return SmallVector<OpFoldResult>{};
if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
return OpFoldResult(dimSize);
});
}
if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
int dimIdx = 0;
VectorType maskType = constantMask.getVectorType();
auto indexType = IndexType::get(mask.getContext());
return llvm::map_to_vector(
constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
// A scalable dim in a constant_mask means vscale x dimSize.
if (maskType.getScalableDims()[dimIdx++])
return OpFoldResult(createVscaleMultiple(dimSize));
return OpFoldResult(IntegerAttr::get(indexType, dimSize));
});
}
return failure();
}

/// Scalable vector lowering of transfer_write(transpose). This lowering only
/// supports rank 2 (scalable) vectors, but can be used in conjunction with
/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
/// unrolls until the first scalable dimension.
///
/// Example:
///
/// BEFORE:
/// ```mlir
/// %transpose = vector.transpose %vec, [1, 0]
/// : vector<4x[4]xf32> to vector<[4]x4xf32>
/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
/// : vector<[4]x4xf32>, memref<?x?xf32>
/// ```
///
/// AFTER:
/// ```mlir
/// %c1 = arith.constant 1 : index
/// %c4 = arith.constant 4 : index
/// %c0 = arith.constant 0 : index
/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
/// %vscale = vector.vscale
/// %c4_vscale = arith.muli %vscale, %c4 : index
/// scf.for %idx = %c0 to %c4_vscale step %c1 {
/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
/// %slice_i = affine.apply #map(%idx)[%i]
/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
/// : vector<4xf32>, memref<?x?xf32>
/// }
/// ```
struct ScalableTransposeTransferWriteConversion
: VectorToSCFPattern<vector::TransferWriteOp> {
using VectorToSCFPattern::VectorToSCFPattern;

LogicalResult matchAndRewrite(TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
if (failed(checkLowerTensors(writeOp, rewriter)))
return failure();

VectorType vectorType = writeOp.getVectorType();

// Note: By comparing the scalable dims to an ArrayRef of length two this
// implicitly checks the rank (is also two).
ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
if (scalableFlags != ArrayRef<bool>{true, false}) {
return rewriter.notifyMatchFailure(
writeOp, "expected vector of the form vector<[N]xMxty>");
}

auto permutationMap = writeOp.getPermutationMap();
if (!permutationMap.isIdentity()) {
return rewriter.notifyMatchFailure(
writeOp, "non-identity permutations are unsupported (lower first)");
}

// Note: This pattern is only lowering the leading dimension (to a loop),
// so we only check if the leading dimension is in bounds. The in-bounds
// attribute for the trailing dimension will be propagated.
if (!writeOp.isDimInBounds(0)) {
return rewriter.notifyMatchFailure(
writeOp, "out-of-bounds dims are unsupported (use masking)");
}

Value vector = writeOp.getVector();
auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
if (!transposeOp ||
transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
return rewriter.notifyMatchFailure(writeOp, "source not transpose");
}

auto loc = writeOp.getLoc();
auto createVscaleMultiple =
vector::makeVscaleConstantBuilder(rewriter, loc);

auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
if (failed(maskDims)) {
return rewriter.notifyMatchFailure(writeOp,
"failed to resolve mask dims");
}

int64_t fixedDimSize = vectorType.getDimSize(1);
auto fixedDimOffsets = llvm::seq(fixedDimSize);

// Extract all slices from the source of the transpose.
auto transposeSource = transposeOp.getVector();
SmallVector<Value> transposeSourceSlices =
llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
});

// Loop bounds and step.
auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto ub =
maskDims->empty()
? Value(createVscaleMultiple(vectorType.getDimSize(0)))
: vector::getAsValues(rewriter, loc, maskDims->front()).front();
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);

// Generate a new mask for the slice.
VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
Value sliceMask = nullptr;
if (!maskDims->empty()) {
sliceMask = rewriter.create<vector::CreateMaskOp>(
loc, sliceType.clone(rewriter.getI1Type()),
ArrayRef<OpFoldResult>(*maskDims).drop_front());
}

Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{};
ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
auto result = rewriter.create<scf::ForOp>(
loc, lb, ub, step, initLoopArgs,
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
// Indices for the new transfer op.
SmallVector<Value, 8> xferIndices;
getXferIndices(b, writeOp, iv, xferIndices);

// Extract a transposed slice from the source vector.
SmallVector<Value> transposeElements =
llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
return b.create<vector::ExtractOp>(
loc, transposeSourceSlices[idx], iv);
});
auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
transposeElements);

// Create the transfer_write for the slice.
Value dest =
loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
auto newWriteOp = b.create<vector::TransferWriteOp>(
loc, sliceVec, dest, xferIndices,
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
if (sliceMask)
newWriteOp.getMaskMutable().assign(sliceMask);

// Yield from the loop.
b.create<scf::YieldOp>(loc, loopIterArgs.empty()
? ValueRange{}
: newWriteOp.getResult());
});

if (isTensorOp(writeOp))
rewriter.replaceOp(writeOp, result);
else
rewriter.eraseOp(writeOp);

return success();
}
};

} // namespace lowering_n_d

namespace lowering_n_d_unrolled {
Expand Down Expand Up @@ -1100,9 +1292,8 @@ struct UnrollTransferReadConversion
if (xferOp.getVectorType().getRank() <= options.targetRank)
return rewriter.notifyMatchFailure(
xferOp, "vector rank is less or equal to target rank");
if (isTensorOp(xferOp) && !options.lowerTensors)
return rewriter.notifyMatchFailure(
xferOp, "transfers operating on tensors are excluded");
if (failed(checkLowerTensors(xferOp, rewriter)))
return failure();
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
xferOp.getShapedType().getElementType())
Expand Down Expand Up @@ -1238,7 +1429,7 @@ struct UnrollTransferWriteConversion
if (inputVectorTy.getRank() <= options.targetRank)
return failure();

if (isTensorOp(xferOp) && !options.lowerTensors)
if (failed(checkLowerTensors(xferOp, rewriter)))
return failure();
// Transfer ops that modify the element type are not supported atm.
if (inputVectorTy.getElementType() !=
Expand Down Expand Up @@ -1503,7 +1694,10 @@ void mlir::populateVectorToSCFConversionPatterns(
lowering_n_d::TransferOpConversion<TransferWriteOp>>(
patterns.getContext(), options);
}

if (options.lowerScalable) {
patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
patterns.getContext(), options);
}
if (options.targetRank == 1) {
patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
Expand All @@ -1522,13 +1716,15 @@ struct ConvertVectorToSCFPass
this->fullUnroll = options.unroll;
this->targetRank = options.targetRank;
this->lowerTensors = options.lowerTensors;
this->lowerScalable = options.lowerScalable;
}

void runOnOperation() override {
VectorTransferToSCFOptions options;
options.unroll = fullUnroll;
options.targetRank = targetRank;
options.lowerTensors = lowerTensors;
options.lowerScalable = lowerScalable;

// Lower permutation maps first.
RewritePatternSet lowerTransferPatterns(&getContext());
Expand Down
15 changes: 14 additions & 1 deletion mlir/test/Conversion/VectorToSCF/tensor-transfer-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{lower-tensors=true lower-scalable=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s

// CHECK-LABEL: func @transfer_read_2d(
// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<vector<4x9xf32>>
Expand Down Expand Up @@ -36,3 +36,16 @@ func.func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
return %t : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func @scalable_transpose_store
// CHECK-SAME: %[[TENSOR:[a-z0-9]+]]: tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} iter_args(%[[ITER_ARG:.*]] = %[[TENSOR]]) -> (tensor<?x?xf32>)
// CHECK: %[[WRITE_SLICE:.*]] = vector.transfer_write %{{.*}} %[[ITER_ARG]]
// CHECK: scf.yield %[[WRITE_SLICE]]
// CHECK: return %[[RESULT]]
func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %A: tensor<?x?xf32>, %base1: index, %base2: index) -> tensor<?x?xf32> {
%transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
%result = vector.transfer_write %transpose, %A[%base1, %base2] {in_bounds = [true, true]} : vector<[4]x4xf32>, tensor<?x?xf32>
return %result : tensor<?x?xf32>
}
Loading
Loading