Skip to content

[mlir][vector] Refine vectorisation of tensor.extract #109580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,27 +810,35 @@ static Value calculateGatherOffset(RewriterBase &rewriter,

enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };

/// Find the non-unit dim in a linalgOp.
/// When executing this hook, it is expected that only one dim will be non-unit.
/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
/// loads before calling this method. This is used for finding contiguous loads
/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
/// this condition is expected to hold for statically shaped Linalg Ops only.
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
uint64_t nonUnitDim = 0;
uint64_t countNonUnitDim = 0;
for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
if (tripCount.value() != 1) {
nonUnitDim = tripCount.index();
countNonUnitDim++;
}
}

/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
/// represents a contiguous load operation.
///
/// Note that when calling this hook, it is assumed that the output vector is
/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
/// labelled as a gather load before entering this method.
///
/// Following on from the above, it is assumed that:
/// * for statically shaped loops, when no masks are used, only one dim is !=
/// 1 (that's what the shape of the output vector is based on).
/// * for dynamically shaped loops, there might be more non-unit dims
/// as the output vector type is user-specified.
///
/// TODO: Statically shaped loops + vector masking
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
assert(linalgOp.hasDynamicShape() ||
countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
"non-unit loop dim is expected");
(void)countNonUnitDim;
return nonUnitDim;
llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) ==
1 &&
"For statically shaped Linalg Ops, only one "
"non-unit loop dim is expected");

size_t idx = loopRanges.size() - 1;
for (; idx >= 0; idx--)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx is unsigned, so idx >= 0 is true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! #112900

if (loopRanges[idx] != 1)
break;

return idx;
}

/// Checks whether `val` can be used for calculating a loop invariant index.
Expand All @@ -854,11 +862,11 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
assert(defOp && "This is neither a block argument nor an operation result");

// IndexOp is loop invariant as long as its result remains constant across
// iterations. Given the assumptions on the loop ranges above, only the
// trailing loop dim ever changes.
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
return (indexOp.getDim() != trailingLoopDim);
// iterations. Note that for dynamic shapes, the corresponding dim will also
// be conservatively treated as != 1.
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
}

auto *ancestor = block->findAncestorOpInBlock(*defOp);

Expand All @@ -877,7 +885,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
return result;
}

/// Check whether \p val could be used for calculating the trailing index for a
/// Check whether `val` could be used for calculating the trailing index for a
/// contiguous load operation.
///
/// There are currently 3 types of values that are allowed here:
Expand All @@ -886,13 +894,14 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
/// 3. results of basic arithmetic operations (linear and continuous)
/// involving 1., 2. and 3.
/// This method returns True if indeed only such values are used in calculating
/// \p val.
/// `val.`
///
/// Additionally, the trailing index for a contiguous load operation should
/// increment by 1 with every loop iteration, i.e. be based on:
/// * `linalg.index <dim>` ,
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
/// updated to `true` when such an op is found.
/// where <dim> is the trailing non-unit dim of the iteration space (this way,
/// `linalg.index <dim>` increments by 1 with every loop iteration).
/// `foundIndexOp` is updated to `true` when such Op is found.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
bool &foundIndexOp, VectorType resType) {

Expand All @@ -912,12 +921,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");

// Given the assumption on the loop ranges above, we expect only 1 non-unit
// loop dim.
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);

if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);

foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
return true;
}

Expand Down Expand Up @@ -1012,7 +1019,10 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
bool foundIndexOp = false;
bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
foundIndexOp, resType);
isContiguousLoad &= foundIndexOp;
// TODO: Support generating contiguous loads for column vectors - that will
// require adding a permutation map to tranfer_read Ops.
bool isRowVector = resType.getShape().back() != 1;
isContiguousLoad &= (foundIndexOp && isRowVector);

if (isContiguousLoad) {
LDBG("Found contigous load: " << extractOp);
Expand Down Expand Up @@ -1073,6 +1083,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// b. contiguous loads.
// Both cases use vector.transfer_read.

assert(llvm::count_if(resultType.getShape(),
[](uint64_t dim) { return dim != 1; }) &&
"Contiguous loads and scalar loads + broadcast only support 1-D "
"vectors ATM!");

// Collect indices for `vector.transfer_read`. At this point, the indices will
// either be scalars or would have been broadcast to vectors matching the
// result type. For indices that are vectors, there are two options:
Expand Down
90 changes: 90 additions & 0 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,96 @@ module attributes {transform.with_named_sequence} {

// -----

// Reading a 1D column vector (hence a candidate for a contiguous load), but given
// %1, it's a gather load.

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<8x1xf32>
%res = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel", "parallel"]
} outs(%0 : tensor<8x1xf32>) {
^bb0(%arg1: f32):
%1 = linalg.index 0 : index
%extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
linalg.yield %extracted : f32
} -> tensor<8x1xf32>
return %res : tensor<8x1xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// CHECK: return %[[RES]] : tensor<8x1xf32>

// -----

// Same as above, but the access indices have been swapped and hence this _is_
// a contiguous load. Currently not supported and lowered as vector.gather
// instead.
// TODO: Make sure that this is lowered as a contiguous load.

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<8x1xf32>
%res = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel", "parallel"]
} outs(%0 : tensor<8x1xf32>) {
^bb0(%arg1: f32):
%1 = linalg.index 0 : index
%extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
linalg.yield %extracted : f32
} -> tensor<8x1xf32>
return %res : tensor<8x1xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func.func @index_from_output_column_vector_contiguous_load(
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
// CHECK: return %[[RES]] : tensor<8x1xf32>

// -----

#map = affine_map<(d0) -> (d0)>
func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32>, %arg1: tensor<5xi32>) -> tensor<5xf32> {
%c5 = arith.constant 5 : index
Expand Down
Loading