Skip to content

Commit 89b144e

Browse files
[mlir][linalg] Vectorize tensor.extract using contiguous loads
This patch implements vectorization of tensor.extract for n-D tensor (n >= 2) using contiguous load operations, i.e. `vector.transfer_read`. This is a follow-up of https://reviews.llvm.org/D137660 in which gather loads were used, i.e. `vector.gather`. It is always safe to use gather load operations when the underlying memory pattern is contiguous, but not vice-verse. At the moment, the following conditions have to be met for contiguous loads to be generated: 1. The _output tensor_ must be a 1-D vector with the trailing dim > 1, e.g. `tensor<1x1x4xi32`, 2. The trailing dim in the _input tensor_ must be > 1, e.g. `tensor<1x1x4i32>` would be fine, but not `tensor<1x4x1xi32>`. If these conditions are not satisfied, gather loads are generated instead. Condition 1 guarantees that the iteration space of the corresponding `linalg.generic` Op is relatively simple. That makes analysing the indices for `tensor.extract` rather straightforward. Condition 2 is mostly there to avoid weird vectorisation patterns resulting in vectors like: `vector<1x1x1xi32>`. In practice, tensors like `tensor<1x4x1xi32>` should be collapsed to `tensor<1x4xi32>` before vectorisation, but that's beyond the scope of this patch. If needed, both conditions can be relaxed. I've not been able to find a good motivating example for these, hence skipping. For reference, `tosa.resize` (lowered to Linalg) was the driving example used here. As a bonus, the test from "vectorization-unsupported.mlir" is moved to "vectorization.mlir" with proper CHECK lines added. Differential Revision: https://reviews.llvm.org/D141998 Co-authored-by: Diego Caballero <[email protected]>
1 parent f7b7c69 commit 89b144e

File tree

3 files changed

+307
-48
lines changed

3 files changed

+307
-48
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 196 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -611,11 +611,11 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
611611

612612
const size_t numIndices = extractOp.getIndices().size();
613613
for (size_t i = 1; i < numIndices; i++) {
614+
Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
615+
614616
auto dimSize = broadcastIfNeeded(
615617
rewriter,
616-
rewriter.create<arith::ConstantIndexOp>(
617-
loc,
618-
extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
618+
rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
619619
indexVecType.getShape());
620620

621621
offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
@@ -630,6 +630,143 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
630630
return offset;
631631
}
632632

633+
enum VectorMemoryAccessKind {
634+
// TODO: ScalarBroadcast,
635+
Contiguous,
636+
Gather
637+
};
638+
639+
/// Check whether /p val can be used for calculating an index for a contiguous
640+
/// load operation, i.e. whether /p val:
641+
/// * is invariant with respect to /p linalgOp, i.e. whether it remains
642+
/// constant for all iterations, and
643+
/// * increments with the loop iterator (when /p strideZero is false) or is
644+
/// not affected by the loop indices (/p strideZero is true).
645+
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, size_t dim,
646+
bool strideZero) {
647+
auto *block = linalgOp.getBlock();
648+
649+
// Bail out if this is a block argument for this linalg.generic Op.
650+
// TODO: We could try analysing the corresponding affine map here.
651+
if (val.dyn_cast<BlockArgument>())
652+
return llvm::all_of(block->getArguments(),
653+
[&val](Value v) { return (v != val); });
654+
655+
Operation *defOp = val.getDefiningOp();
656+
assert(defOp && "This is neither a block argument nor an operation result");
657+
658+
// Given the assumption on the shape of the target tensor, index Op is
659+
// either:
660+
// * constant (for non-trailing dims), or
661+
// * increments with stride one together with the trailing dimension
662+
// Both cases are fine for contigious loads.
663+
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
664+
return strideZero ? (indexOp.getDim() != dim) : (indexOp.getDim() == dim);
665+
666+
auto *ancestor = block->findAncestorOpInBlock(*defOp);
667+
668+
// Values define outside `linalgOp`.
669+
if (!ancestor)
670+
return true;
671+
672+
// Values defined inside `linalgOp`, which are constant.
673+
if (dyn_cast<arith::ConstantOp>(ancestor))
674+
return true;
675+
676+
bool result = true;
677+
for (auto op : ancestor->getOperands())
678+
result &= isContiguousLoadIdx(linalgOp, op, dim, strideZero);
679+
680+
return result;
681+
}
682+
683+
/// Check whether the calculation of \p val is based on linalg.index Op with
684+
/// the dim attribute matching \p dim.
685+
static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) {
686+
auto *block = linalgOp.getBlock();
687+
auto targetShape = linalgOp.getStaticLoopRanges();
688+
689+
if (val.isa<BlockArgument>())
690+
return false;
691+
692+
Operation *defOp = val.getDefiningOp();
693+
assert(defOp && "This is neither a block argument nor an operation result");
694+
695+
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
696+
return (indexOp.getDim() == dim);
697+
698+
auto *ancestor = block->findAncestorOpInBlock(*defOp);
699+
700+
if (!ancestor)
701+
return false;
702+
703+
bool result = false;
704+
for (auto op : ancestor->getOperands())
705+
result |= isBasedOnIndexOp(linalgOp, op, dim);
706+
707+
return result;
708+
}
709+
710+
/// Check whether \p extractOp would be a gather or a contiguous load Op after
711+
/// vectorising \p linalgOp. Note that it is always safe to use gather load
712+
/// operations for contiguous loads (albeit slow), but not vice-versa. When in
713+
/// doubt, bail out and assume that \p extractOp is a gather load.
714+
static VectorMemoryAccessKind
715+
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
716+
LinalgOp &linalgOp) {
717+
718+
auto targetShape = linalgOp.getStaticLoopRanges();
719+
720+
// Assume that it's a gather load when reading _into_:
721+
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
722+
// * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
723+
// TODO: Relax these conditions.
724+
if ((llvm::count_if(targetShape,
725+
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
726+
targetShape.back() == 1)
727+
return VectorMemoryAccessKind::Gather;
728+
729+
auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
730+
731+
// Assume that it's a gather load when reading _from_ a tensor for which the
732+
// trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
733+
// TODO: Relax this condition.
734+
if (inputShape.getShape().back() == 1)
735+
return VectorMemoryAccessKind::Gather;
736+
737+
bool isContiguous = true;
738+
739+
// Iterate over all indices. Analyze whether the way each index is calculate
740+
// is suitable for contiguous load operations (e.g. loop invariant).
741+
auto indices = extractOp.getIndices();
742+
for (auto [i, indexVal] : llvm::enumerate(indices)) {
743+
if (inputShape.getShape()[i] == 1) {
744+
// This extractOp index must be a loop-invariant constant
745+
continue;
746+
}
747+
748+
auto extractOpBottomIdx = indices.size() - 1;
749+
auto strideOneDim = targetShape.size() - 1;
750+
bool strideZero = (i != extractOpBottomIdx);
751+
isContiguous &=
752+
isContiguousLoadIdx(linalgOp, indexVal, strideOneDim, strideZero);
753+
}
754+
755+
// The calculation of the trailing index must include the loop index. Given
756+
// the assumption on the output tensor (which is defined by the iteration
757+
// space), only the trailing dim matters.
758+
auto extractOpTrailingIdx = indices.back();
759+
isContiguous &=
760+
isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, targetShape.size() - 1);
761+
762+
if (isContiguous) {
763+
LDBG("Found contigous load: " << extractOp);
764+
return VectorMemoryAccessKind::Contiguous;
765+
}
766+
767+
return VectorMemoryAccessKind::Gather;
768+
}
769+
633770
/// Helper function to vectorize the tensor.extract operations. Returns
634771
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
635772
/// should map the produced operations. This function is meant to be used as a
@@ -660,15 +797,64 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
660797
extractOp.getIndices().size(),
661798
rewriter.create<arith::ConstantIndexOp>(loc, 0));
662799

663-
Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
800+
VectorMemoryAccessKind memAccessKind =
801+
getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
802+
803+
// 1. Handle gather access
804+
if (memAccessKind == VectorMemoryAccessKind::Gather) {
805+
Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
806+
807+
// Generate the gather load
808+
Operation *gatherOp = rewriter.create<vector::GatherOp>(
809+
loc, resultType, extractOp.getTensor(), baseIndices, offset,
810+
maskConstantOp, passThruConstantOp);
811+
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
812+
813+
LDBG("Vectorised as gather load: " << extractOp);
814+
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
815+
}
816+
817+
// 2. Handle contiguous access.
818+
SmallVector<Value> transferReadIdxs;
819+
auto resTrailingDim = resultType.getShape().back();
820+
auto zero = rewriter.create<arith::ConstantOp>(
821+
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
822+
823+
// Collect indices for `vector.transfer_read`. At this point, the indices will
824+
// either be scalars or would have been broadcast to vectors matching the
825+
// result type. For indices that are vectors, there are two options:
826+
// * for non-trailing indices, all elements are identical (contiguous
827+
// loads are identified by looking for non-trailing indices that are
828+
// invariant with respect to the corresponding linalg.generic), or
829+
// * for trailing indices, the index vector will contain values with stride
830+
// one, but for `vector.transfer_read` only the first (i.e. 0th) index is
831+
// needed.
832+
// This means that
833+
// * for scalar indices - just re-use it,
834+
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
835+
// (0th) element and use that.
836+
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
837+
auto idx = bvm.lookup(extractOp.getIndices()[i]);
838+
if (idx.getType().isIndex()) {
839+
transferReadIdxs.push_back(idx);
840+
continue;
841+
}
842+
843+
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
844+
loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
845+
bvm.lookup(extractOp.getIndices()[i]));
846+
transferReadIdxs.push_back(
847+
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
848+
}
849+
850+
// `tensor.extract_element` is always in-bounds, hence the following holds.
851+
SmallVector<bool> inBounds(resultType.getRank(), true);
664852

665-
// Generate the gather load
666-
Operation *gatherOp = rewriter.create<vector::GatherOp>(
667-
loc, resultType, extractOp.getTensor(), baseIndices, offset,
668-
maskConstantOp, passThruConstantOp);
669-
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
853+
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
854+
loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
670855

671-
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
856+
LDBG("Vectorised as contiguous load: " << extractOp);
857+
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
672858
}
673859

674860
/// Emit reduction operations if the shapes of the value to reduce is different

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)