Skip to content

Commit 44a047c

Browse files
[MLIR][ArmSVE] Add initial lowering of vector.contract to SVE *MMLA instructions (#135636)
1 parent f849866 commit 44a047c

File tree

11 files changed

+841
-1
lines changed

11 files changed

+841
-1
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14311431
"bool", /*default=*/"false",
14321432
"Enables the use of ArmSVE dialect while lowering the vector "
14331433
"dialect.">,
1434+
Option<"armI8MM", "enable-arm-i8mm",
1435+
"bool", /*default=*/"false",
1436+
"Enables the use of Arm FEAT_I8MM instructions while lowering "
1437+
"the vector dialect.">,
14341438
Option<"x86Vector", "enable-x86vector",
14351439
"bool", /*default=*/"false",
14361440
"Enables the use of X86Vector dialect while lowering the vector "

mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class RewritePatternSet;
2020
void populateArmSVELegalizeForLLVMExportPatterns(
2121
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
2222

23+
void populateLowerContractionToSVEI8MMPatternPatterns(
24+
RewritePatternSet &patterns);
25+
2326
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
2427
/// intrinsics.
2528
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);

mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
3535
MLIRVectorToLLVM
3636

3737
MLIRArmNeonDialect
38+
MLIRArmNeonTransforms
3839
MLIRArmSVEDialect
3940
MLIRArmSVETransforms
4041
MLIRAMXDialect

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/AMX/Transforms.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
17+
#include "mlir/Dialect/ArmNeon/Transforms.h"
1718
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
1819
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
1920
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8283
populateVectorStepLoweringPatterns(patterns);
8384
populateVectorRankReducingFMAPattern(patterns);
8485
populateVectorGatherLoweringPatterns(patterns);
86+
if (armI8MM) {
87+
if (armNeon)
88+
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
89+
if (armSVE)
90+
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
91+
}
8592
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
8693
}
8794

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class LowerContractionToSMMLAPattern
5656
// Avoid 0-D vectors and 1-D rhs:
5757
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
5858
return failure();
59+
// This codegen does not work for scalable vectors. Return failure so this
60+
// pattern is not accidentally chosen over patterns that lower to ArmSVE.
61+
if (lhsType.isScalable() || rhsType.isScalable())
62+
return failure();
5963
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
6064
auto dimN = rhsType.getDimSize(0);
6165
auto dimK = rhsType.getDimSize(1);
@@ -238,5 +242,5 @@ class LowerContractionToSMMLAPattern
238242
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
239243
RewritePatternSet &patterns) {
240244
MLIRContext *context = patterns.getContext();
241-
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
245+
patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
242246
}

mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRArmSVETransforms
22
LegalizeForLLVMExport.cpp
33
LegalizeVectorStorage.cpp
4+
LowerContractionToSVEI8MMPattern.cpp
45

56
DEPENDS
67
MLIRArmSVEConversionsIncGen

0 commit comments

Comments
 (0)