Skip to content

Commit 543446a

Browse files
authored
[mli][vector] canonicalize vector.from_elements from ascending extracts (#139819)
Example: ```mlir %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8> %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8> %2 = vector.from_elements %0, %1 : vector<2xi8> ``` becomes ```mlir %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8> ``` It was decided that we should spill canonicalization tests into new files (see [discussion](#135096 (review))) In view of this I added the new tests to a new file specifically for canonicalization of from_elements. To be consistent in the location of the tests, I moved existing tests `extract_scalar_from_from_element`, `extract_1d_from_from_elements`, `extract_2d_from_from_elements` and `from_elements_to_splat` from `canonicalize.mlir` to `canonicalze/vector-from-elements.mlir`. In addition to moving I changed the LIT variables to all be upper-case for consistency.
1 parent 3dffd71 commit 543446a

File tree

3 files changed

+389
-69
lines changed

3 files changed

+389
-69
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/IR/OpImplementation.h"
3434
#include "mlir/IR/PatternMatch.h"
3535
#include "mlir/IR/TypeUtilities.h"
36+
#include "mlir/IR/ValueRange.h"
3637
#include "mlir/Interfaces/SubsetOpInterface.h"
3738
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
3839
#include "mlir/Support/LLVM.h"
@@ -2387,9 +2388,129 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
23872388
return success();
23882389
}
23892390

2391+
/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2392+
/// on a single extract. Example:
2393+
/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2394+
/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2395+
/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2396+
///
2397+
/// becomes
2398+
/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2399+
/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2400+
///
2401+
/// The requirements for this to be valid are
2402+
///
2403+
/// i) The elements are extracted from the same vector (%source).
2404+
///
2405+
/// ii) The elements form a suffix of %source. Specifically, the number
2406+
/// of elements is the same as the product of the last N dimension sizes
2407+
/// of %source, for some N.
2408+
///
2409+
/// iii) The elements are extracted contiguously in ascending order.
2410+
2411+
class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2412+
2413+
using OpRewritePattern::OpRewritePattern;
2414+
2415+
LogicalResult matchAndRewrite(FromElementsOp fromElements,
2416+
PatternRewriter &rewriter) const override {
2417+
2418+
// Handled by `rewriteFromElementsAsSplat`
2419+
if (fromElements.getType().getNumElements() == 1)
2420+
return failure();
2421+
2422+
// The common source that all elements are extracted from, if one exists.
2423+
TypedValue<VectorType> source;
2424+
// The position of the combined extract operation, if one is created.
2425+
ArrayRef<int64_t> combinedPosition;
2426+
// The expected index of extraction of the current element in the loop, if
2427+
// elements are extracted contiguously in ascending order.
2428+
SmallVector<int64_t> expectedPosition;
2429+
2430+
for (auto [insertIndex, element] :
2431+
llvm::enumerate(fromElements.getElements())) {
2432+
2433+
// Check that the element is from a vector.extract operation.
2434+
auto extractOp =
2435+
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
2436+
if (!extractOp) {
2437+
return rewriter.notifyMatchFailure(fromElements,
2438+
"element not from vector.extract");
2439+
}
2440+
2441+
// Check condition (i) by checking that all elements have the same source
2442+
// as the first element.
2443+
if (insertIndex == 0) {
2444+
source = extractOp.getVector();
2445+
} else if (extractOp.getVector() != source) {
2446+
return rewriter.notifyMatchFailure(fromElements,
2447+
"element from different vector");
2448+
}
2449+
2450+
ArrayRef<int64_t> position = extractOp.getStaticPosition();
2451+
int64_t rank = position.size();
2452+
assert(rank == source.getType().getRank() &&
2453+
"scalar extract must have full rank position");
2454+
2455+
// Check condition (ii) by checking that the position that the first
2456+
// element is extracted from has sufficient trailing 0s. For example, in
2457+
//
2458+
// %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2459+
// [...]
2460+
// %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2461+
//
2462+
// The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2463+
// elements, which is the number of elements of %n, so this is valid.
2464+
if (insertIndex == 0) {
2465+
const int64_t numElms = fromElements.getType().getNumElements();
2466+
int64_t numSuffixElms = 1;
2467+
int64_t index = rank;
2468+
while (index > 0 && position[index - 1] == 0 &&
2469+
numSuffixElms < numElms) {
2470+
numSuffixElms *= source.getType().getDimSize(index - 1);
2471+
--index;
2472+
}
2473+
if (numSuffixElms != numElms) {
2474+
return rewriter.notifyMatchFailure(
2475+
fromElements, "elements do not form a suffix of source");
2476+
}
2477+
expectedPosition = llvm::to_vector(position);
2478+
combinedPosition = position.drop_back(rank - index);
2479+
}
2480+
2481+
// Check condition (iii).
2482+
else if (expectedPosition != position) {
2483+
return rewriter.notifyMatchFailure(
2484+
fromElements, "elements not in ascending order (static order)");
2485+
}
2486+
increment(expectedPosition, source.getType().getShape());
2487+
}
2488+
2489+
auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2490+
fromElements.getLoc(), source, combinedPosition);
2491+
2492+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2493+
fromElements, fromElements.getType(), extracted);
2494+
2495+
return success();
2496+
}
2497+
2498+
/// Increments n-D `indices` by 1 starting from the innermost dimension.
2499+
static void increment(MutableArrayRef<int64_t> indices,
2500+
ArrayRef<int64_t> shape) {
2501+
for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2502+
indices[dim] += 1;
2503+
if (indices[dim] < shape[dim])
2504+
break;
2505+
indices[dim] = 0;
2506+
}
2507+
}
2508+
};
2509+
23902510
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
23912511
MLIRContext *context) {
23922512
results.add(rewriteFromElementsAsSplat);
2513+
results.add<FromElementsToShapeCast>(context);
23932514
}
23942515

23952516
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2943,75 +2943,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
29432943

29442944
// -----
29452945

2946-
// CHECK-LABEL: func @extract_scalar_from_from_elements(
2947-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2948-
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
2949-
// Extract from 0D.
2950-
%0 = vector.from_elements %a : vector<f32>
2951-
%1 = vector.extract %0[] : f32 from vector<f32>
2952-
2953-
// Extract from 1D.
2954-
%2 = vector.from_elements %a : vector<1xf32>
2955-
%3 = vector.extract %2[0] : f32 from vector<1xf32>
2956-
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
2957-
%5 = vector.extract %4[4] : f32 from vector<5xf32>
2958-
2959-
// Extract from 2D.
2960-
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
2961-
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
2962-
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
2963-
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
2964-
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
2965-
2966-
// CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
2967-
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
2968-
}
2969-
2970-
// -----
2971-
2972-
// CHECK-LABEL: func @extract_1d_from_from_elements(
2973-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2974-
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
2975-
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
2976-
// CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
2977-
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
2978-
// CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
2979-
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
2980-
// CHECK: return %[[splat1]], %[[splat2]]
2981-
return %1, %2 : vector<3xf32>, vector<3xf32>
2982-
}
2983-
2984-
// -----
2985-
2986-
// CHECK-LABEL: func @extract_2d_from_from_elements(
2987-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
2988-
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
2989-
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
2990-
// CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
2991-
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
2992-
// CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
2993-
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
2994-
// CHECK: return %[[splat1]], %[[splat2]]
2995-
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
2996-
}
2997-
2998-
// -----
2999-
3000-
// CHECK-LABEL: func @from_elements_to_splat(
3001-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
3002-
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
3003-
// CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
3004-
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
3005-
// CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
3006-
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
3007-
// CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
3008-
%2 = vector.from_elements %a : vector<f32>
3009-
// CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
3010-
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
3011-
}
3012-
3013-
// -----
3014-
30152946
// CHECK-LABEL: func @vector_insert_const_regression(
30162947
// CHECK: llvm.mlir.undef
30172948
// CHECK: vector.insert

0 commit comments

Comments
 (0)