Skip to content

Commit 81e8c07

Browse files
committed
Move builder to utils
1 parent f847f45 commit 81e8c07

File tree

3 files changed

+26
-12
lines changed

3 files changed

+26
-12
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1010
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
1111

12+
#include "mlir/Dialect/Arith/IR/Arith.h"
1213
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1314
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1415
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -101,6 +102,24 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
101102
std::optional<StaticTileOffsetRange>
102103
createUnrollIterator(VectorType vType, int64_t targetRank = 1);
103104

105+
/// Returns a functor (int64_t -> Value) which returns a constant vscale
106+
/// multiple.
107+
///
108+
/// Example:
109+
/// ```c++
110+
/// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
111+
/// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
112+
/// ```
113+
inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
114+
Value vscale = nullptr;
115+
return [loc, vscale, &rewriter](int64_t multiplier) mutable {
116+
if (!vscale)
117+
vscale = rewriter.create<vector::VectorScaleOp>(loc);
118+
return rewriter.create<arith::MulIOp>(
119+
loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
120+
};
121+
}
122+
104123
/// A wrapper for getMixedSizes for vector.transfer_read and
105124
/// vector.transfer_write Ops (for source and destination, respectively).
106125
///

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Dialect/SCF/IR/SCF.h"
2525
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2626
#include "mlir/Dialect/Utils/IndexingUtils.h"
27+
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2728
#include "mlir/Transforms/OneToNTypeConversion.h"
2829

2930
#define DEBUG_TYPE "arm-sme-vector-legalization"
@@ -376,14 +377,6 @@ struct LegalizeTransferWriteOpsByDecomposition
376377
}
377378
};
378379

379-
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
380-
Value vscale = rewriter.create<vector::VectorScaleOp>(loc);
381-
return [loc, vscale, &rewriter](int64_t multiplier) {
382-
return rewriter.create<arith::MulIOp>(
383-
loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
384-
};
385-
}
386-
387380
/// Legalize a multi-tile transfer_write as a single store loop. This is done as
388381
/// part of type decomposition as at this level we know each tile write is
389382
/// disjoint, but that information is lost after decomposition (without analysis
@@ -450,7 +443,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
450443
kMatchFailureUnsupportedMaskOp);
451444

452445
auto loc = writeOp.getLoc();
453-
auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
446+
auto createVscaleMultiple =
447+
vector::makeVscaleConstantBuilder(rewriter, loc);
454448

455449
// Get SME tile and slice types.
456450
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -846,7 +840,8 @@ struct LowerIllegalTransposeStoreViaZA
846840
return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
847841

848842
auto loc = writeOp.getLoc();
849-
auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
843+
auto createVscaleMultiple =
844+
vector::makeVscaleConstantBuilder(rewriter, loc);
850845

851846
auto transposeMap = AffineMapAttr::get(
852847
AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,12 @@ func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
557557
func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
558558
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
559559
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
560-
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
561560
// CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
562561
// CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
563562
// CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
564563
// CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
565564
// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
565+
// CHECK-NEXT: %[[VSCALE:.*]] = vector.vscale
566566
// CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
567567
// CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
568568
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
@@ -597,11 +597,11 @@ func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest
597597
// CHECK-SAME: %[[J:.*]]: index)
598598
func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
599599
// CHECK: %[[C4:.*]] = arith.constant 4 : index
600-
// CHECK: %[[VSCALE:.*]] = vector.vscale
601600

602601
// <skip 3x other extract+insert chain>
603602
// CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
604603
// CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
604+
// CHECK: %[[VSCALE:.*]] = vector.vscale
605605
// CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
606606
// CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
607607
// CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>

0 commit comments

Comments
 (0)