Skip to content

Commit 4b2ba5a

Browse files
authored
[mlir][sve] Add an e2e for linalg.matmul with mixed types (llvm#73773)
Apart from the test itself, this patch also updates a few patterns to fix how new VectorType(s) are created. Namely, it makes sure that "scalability" is correctly propagated. Regression tests will be updated seperately while auditing Vector dialect tests in the context of scalable vectors: * https://github.com/orgs/llvm/projects/23
1 parent ae4d7ac commit 4b2ba5a

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
455455

456456
Type castResTy = getElementTypeOrSelf(op->getResult(0));
457457
if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
458-
castResTy = VectorType::get(vecTy.getShape(), castResTy);
458+
castResTy = vecTy.clone(castResTy);
459459
auto *castOp =
460460
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
461461
bcastOp.getSource(), castResTy, op->getAttrs());
@@ -527,16 +527,14 @@ struct ReorderElementwiseOpsOnTranspose final
527527
srcValues.push_back(transposeOp.getVector());
528528
} else {
529529
// This is a constant. Create a reverse transpose op for it.
530-
auto vectorType = VectorType::get(
531-
srcType.getShape(),
532-
cast<VectorType>(operand.getType()).getElementType());
530+
auto vectorType =
531+
srcType.clone(cast<VectorType>(operand.getType()).getElementType());
533532
srcValues.push_back(rewriter.create<vector::TransposeOp>(
534533
operand.getLoc(), vectorType, operand, invOrder));
535534
}
536535
}
537536

538-
auto vectorType = VectorType::get(
539-
srcType.getShape(),
537+
auto vectorType = srcType.clone(
540538
cast<VectorType>(op->getResultTypes()[0]).getElementType());
541539
Operation *elementwiseOp =
542540
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
@@ -1314,8 +1312,8 @@ struct CanonicalizeContractMatmulToMMT final
13141312
Value trans =
13151313
rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
13161314
VectorType newType =
1317-
VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1318-
cast<VectorType>(mat.getType()).getElementType());
1315+
cast<VectorType>(trans.getType())
1316+
.clone(cast<VectorType>(mat.getType()).getElementType());
13191317
return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
13201318
}
13211319
if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// DEFINE: %{compile} = mlir-opt %s \
2+
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
3+
// DEFINE: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage \
4+
// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
5+
// DEFINE: %{entry_point} = matmul_mixed_ty
6+
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
7+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
8+
9+
// RUN: %{compile}
10+
11+
// RUN: %{run} | FileCheck %s
12+
13+
func.func @matmul_mixed_ty() {
14+
// Matrix dimensions
15+
%K = arith.constant 3 : index
16+
%M = arith.constant 5 : index
17+
%N = arith.constant 15 : index
18+
%c0_i8 = arith.constant 0 : i8
19+
%c0_i32 = arith.constant 0 : i32
20+
21+
// Allocate the matrices
22+
%A_alloc = bufferization.alloc_tensor(%M, %K) : tensor<?x?xi8>
23+
%B_alloc = bufferization.alloc_tensor(%K, %N) : tensor<?x?xi8>
24+
%C_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xi32>
25+
26+
// Initialise the matrices
27+
%pi = arith.constant 123 : i8
28+
%A = linalg.fill ins(%pi : i8) outs(%A_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
29+
%B = linalg.fill ins(%pi : i8) outs(%B_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
30+
%C_in = linalg.fill ins(%c0_i32 : i32) outs(%C_alloc : tensor<?x?xi32>) -> tensor<?x?xi32>
31+
32+
// Matmul
33+
%C_out = linalg.matmul ins(%A, %B: tensor<?x?xi8>, tensor<?x?xi8>) outs(%C_in: tensor<?x?xi32>) -> tensor<?x?xi32>
34+
35+
// Print and verify the output
36+
// CHECK-LABEL: SVE: START OF TEST OUTPUT
37+
vector.print str "SVE: START OF TEST OUTPUT"
38+
39+
// CHECK-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
40+
// CHECK-COUNT-5: [45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387]
41+
%xf = tensor.cast %C_out : tensor<?x?xi32> to tensor<*xi32>
42+
call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
43+
44+
// CHECK-NEXT: SVE: END OF TEST OUTPUT
45+
vector.print str "SVE: END OF TEST OUTPUT"
46+
47+
return
48+
}
49+
50+
module attributes {transform.with_named_sequence} {
51+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
52+
%matmul = transform.structured.match ops{["linalg.matmul"]} in %module
53+
: (!transform.any_op) -> !transform.any_op
54+
55+
// Step 1: Tile
56+
%module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
57+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
58+
59+
// Step 2: Vectorize
60+
%tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
61+
: (!transform.any_op) -> !transform.any_op
62+
transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
63+
64+
// Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
65+
%func = transform.structured.match ops{["func.func"]} in %module
66+
: (!transform.any_op) -> !transform.op<"func.func">
67+
transform.apply_patterns to %func {
68+
transform.apply_patterns.vector.reduction_to_contract
69+
transform.apply_patterns.vector.transfer_permutation_patterns
70+
transform.apply_patterns.vector.lower_masked_transfers
71+
} : !transform.op<"func.func">
72+
73+
// Step 4: Lower vector.contract to vector.fma
74+
transform.apply_patterns to %func {
75+
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
76+
transform.apply_patterns.vector.lower_outerproduct
77+
} : !transform.op<"func.func">
78+
79+
transform.yield
80+
}
81+
}
82+
83+
func.func private @printMemrefI32(%ptr : tensor<*xi32>)

0 commit comments

Comments
 (0)