|
33 | 33 | #include "mlir/IR/OpImplementation.h"
|
34 | 34 | #include "mlir/IR/PatternMatch.h"
|
35 | 35 | #include "mlir/IR/TypeUtilities.h"
|
| 36 | +#include "mlir/IR/ValueRange.h" |
36 | 37 | #include "mlir/Interfaces/SubsetOpInterface.h"
|
37 | 38 | #include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
38 | 39 | #include "mlir/Support/LLVM.h"
|
@@ -2387,9 +2388,129 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
|
2387 | 2388 | return success();
|
2388 | 2389 | }
|
2389 | 2390 |
|
| 2391 | +/// Rewrite from_elements on multiple scalar extracts as a shape_cast |
| 2392 | +/// on a single extract. Example: |
| 2393 | +/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8> |
| 2394 | +/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8> |
| 2395 | +/// %2 = vector.from_elements %0, %1 : vector<2xi8> |
| 2396 | +/// |
| 2397 | +/// becomes |
| 2398 | +/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8> |
| 2399 | +/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8> |
| 2400 | +/// |
| 2401 | +/// The requirements for this to be valid are |
| 2402 | +/// |
| 2403 | +/// i) The elements are extracted from the same vector (%source). |
| 2404 | +/// |
| 2405 | +/// ii) The elements form a suffix of %source. Specifically, the number |
| 2406 | +/// of elements is the same as the product of the last N dimension sizes |
| 2407 | +/// of %source, for some N. |
| 2408 | +/// |
| 2409 | +/// iii) The elements are extracted contiguously in ascending order. |
| 2410 | + |
| 2411 | +class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> { |
| 2412 | + |
| 2413 | + using OpRewritePattern::OpRewritePattern; |
| 2414 | + |
| 2415 | + LogicalResult matchAndRewrite(FromElementsOp fromElements, |
| 2416 | + PatternRewriter &rewriter) const override { |
| 2417 | + |
| 2418 | + // Handled by `rewriteFromElementsAsSplat` |
| 2419 | + if (fromElements.getType().getNumElements() == 1) |
| 2420 | + return failure(); |
| 2421 | + |
| 2422 | + // The common source that all elements are extracted from, if one exists. |
| 2423 | + TypedValue<VectorType> source; |
| 2424 | + // The position of the combined extract operation, if one is created. |
| 2425 | + ArrayRef<int64_t> combinedPosition; |
| 2426 | + // The expected index of extraction of the current element in the loop, if |
| 2427 | + // elements are extracted contiguously in ascending order. |
| 2428 | + SmallVector<int64_t> expectedPosition; |
| 2429 | + |
| 2430 | + for (auto [insertIndex, element] : |
| 2431 | + llvm::enumerate(fromElements.getElements())) { |
| 2432 | + |
| 2433 | + // Check that the element is from a vector.extract operation. |
| 2434 | + auto extractOp = |
| 2435 | + dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp()); |
| 2436 | + if (!extractOp) { |
| 2437 | + return rewriter.notifyMatchFailure(fromElements, |
| 2438 | + "element not from vector.extract"); |
| 2439 | + } |
| 2440 | + |
| 2441 | + // Check condition (i) by checking that all elements have the same source |
| 2442 | + // as the first element. |
| 2443 | + if (insertIndex == 0) { |
| 2444 | + source = extractOp.getVector(); |
| 2445 | + } else if (extractOp.getVector() != source) { |
| 2446 | + return rewriter.notifyMatchFailure(fromElements, |
| 2447 | + "element from different vector"); |
| 2448 | + } |
| 2449 | + |
| 2450 | + ArrayRef<int64_t> position = extractOp.getStaticPosition(); |
| 2451 | + int64_t rank = position.size(); |
| 2452 | + assert(rank == source.getType().getRank() && |
| 2453 | + "scalar extract must have full rank position"); |
| 2454 | + |
| 2455 | + // Check condition (ii) by checking that the position that the first |
| 2456 | + // element is extracted from has sufficient trailing 0s. For example, in |
| 2457 | + // |
| 2458 | + // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8> |
| 2459 | + // [...] |
| 2460 | + // %elms = vector.from_elements %elm0, [...] : vector<12xi8> |
| 2461 | + // |
| 2462 | + // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12 |
| 2463 | + // elements, which is the number of elements of %n, so this is valid. |
| 2464 | + if (insertIndex == 0) { |
| 2465 | + const int64_t numElms = fromElements.getType().getNumElements(); |
| 2466 | + int64_t numSuffixElms = 1; |
| 2467 | + int64_t index = rank; |
| 2468 | + while (index > 0 && position[index - 1] == 0 && |
| 2469 | + numSuffixElms < numElms) { |
| 2470 | + numSuffixElms *= source.getType().getDimSize(index - 1); |
| 2471 | + --index; |
| 2472 | + } |
| 2473 | + if (numSuffixElms != numElms) { |
| 2474 | + return rewriter.notifyMatchFailure( |
| 2475 | + fromElements, "elements do not form a suffix of source"); |
| 2476 | + } |
| 2477 | + expectedPosition = llvm::to_vector(position); |
| 2478 | + combinedPosition = position.drop_back(rank - index); |
| 2479 | + } |
| 2480 | + |
| 2481 | + // Check condition (iii). |
| 2482 | + else if (expectedPosition != position) { |
| 2483 | + return rewriter.notifyMatchFailure( |
| 2484 | + fromElements, "elements not in ascending order (static order)"); |
| 2485 | + } |
| 2486 | + increment(expectedPosition, source.getType().getShape()); |
| 2487 | + } |
| 2488 | + |
| 2489 | + auto extracted = rewriter.createOrFold<vector::ExtractOp>( |
| 2490 | + fromElements.getLoc(), source, combinedPosition); |
| 2491 | + |
| 2492 | + rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
| 2493 | + fromElements, fromElements.getType(), extracted); |
| 2494 | + |
| 2495 | + return success(); |
| 2496 | + } |
| 2497 | + |
| 2498 | + /// Increments n-D `indices` by 1 starting from the innermost dimension. |
| 2499 | + static void increment(MutableArrayRef<int64_t> indices, |
| 2500 | + ArrayRef<int64_t> shape) { |
| 2501 | + for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) { |
| 2502 | + indices[dim] += 1; |
| 2503 | + if (indices[dim] < shape[dim]) |
| 2504 | + break; |
| 2505 | + indices[dim] = 0; |
| 2506 | + } |
| 2507 | + } |
| 2508 | +}; |
| 2509 | + |
2390 | 2510 | void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
2391 | 2511 | MLIRContext *context) {
|
2392 | 2512 | results.add(rewriteFromElementsAsSplat);
|
| 2513 | + results.add<FromElementsToShapeCast>(context); |
2393 | 2514 | }
|
2394 | 2515 |
|
2395 | 2516 | //===----------------------------------------------------------------------===//
|
|
0 commit comments