Skip to content

Commit 3fa5ee6

Browse files
committed
[mlir][ArmSME] Introduce custom TypeConverter for ArmSME
At the moment, SME-to-LLVM lowerings rely entirely on `LLVMTypeConverter`. This patch introduces a dedicated `TypeConverter` that inherits from `LLVMTypeConverter` (it will also be used when lowering ArmSME Ops to LLVM). The new type converter merely disables lowerings for `VectorType` to prevent 2-d scalable vectors (common in the context of ArmSME), e.g. `vector<[16]x[16]xi8>`, entering the LLVM Type converter. LLVM does not support arrays of scalable vectors and hence the need for specialisation. In the case of SME such types are effectively eliminated when emitting LLVM IR intrinsics for SME. Differential Revision: https://reviews.llvm.org/D155365
1 parent e65cabb commit 3fa5ee6

File tree

5 files changed

+40
-1
lines changed

5 files changed

+40
-1
lines changed

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
1010
#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
1111

12+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1213
#include "mlir/Pass/Pass.h"
1314

1415
namespace mlir {
1516

1617
class RewritePatternSet;
1718

1819
namespace arm_sme {
20+
//===----------------------------------------------------------------------===//
21+
// The EnableArmStreaming pass.
22+
//===----------------------------------------------------------------------===//
1923
// Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of
2024
// the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
2125
// In a locally streaming function PSTATE.SM is kept internal and the callee
@@ -33,6 +37,14 @@ createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
3337
/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
3438
std::unique_ptr<Pass> createTileAllocationPass();
3539

40+
//===----------------------------------------------------------------------===//
41+
// Type ArmSMETypeConverter pass.
42+
//===----------------------------------------------------------------------===//
43+
class ArmSMETypeConverter : public LLVMTypeConverter {
44+
public:
45+
ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
46+
};
47+
3648
//===----------------------------------------------------------------------===//
3749
// Registration
3850
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRVectorToLLVM
1717
MLIRArmNeonDialect
1818
MLIRArmSMEDialect
1919
MLIRArmSMETransforms
20+
MLIRVectorToArmSME
2021
MLIRArmSVEDialect
2122
MLIRArmSVETransforms
2223
MLIRAMXDialect

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
1717
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
18+
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
1819
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
1920
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
2021
#include "mlir/Dialect/ArmSVE/Transforms.h"
@@ -96,6 +97,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
9697
target.addLegalDialect<arith::ArithDialect>();
9798
target.addLegalDialect<memref::MemRefDialect>();
9899
target.addLegalOp<UnrealizedConversionCastOp>();
100+
arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
101+
99102
if (armNeon) {
100103
// TODO: we may or may not want to include in-dialect lowering to
101104
// LLVM-compatible operations here. So far, all operations in the dialect
@@ -108,7 +111,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
108111
}
109112
if (armSME) {
110113
configureArmSMELegalizeForExportTarget(target);
111-
populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
114+
populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
112115
}
113116
if (amx) {
114117
configureAMXLegalizeForExportTarget(target);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
10+
11+
using namespace mlir;
12+
arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
13+
MLIRContext *ctx, const LowerToLLVMOptions &options)
14+
: LLVMTypeConverter(ctx, options) {
15+
// Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
16+
// vectors (common in the context of ArmSME), e.g.
17+
// `vector<[16]x[16]xi8>`,
18+
// entering the LLVM Type converter. LLVM does not support arrays of scalable
19+
// vectors, but in the case of SME such types are effectively eliminated when
20+
// emitting ArmSME LLVM IR intrinsics.
21+
addConversion([&](VectorType type) { return type; });
22+
}

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRArmSMETransforms
2+
ArmSMETypeConverter.cpp
23
EnableArmStreaming.cpp
34
LegalizeForLLVMExport.cpp
45
TileAllocation.cpp

0 commit comments

Comments
 (0)