-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Move ArmSME -> intrinsics lowerings to convert-arm-sme-to-llvm
pass
#72890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…-to-llvm pass (NFC) This gives more flexibility with when these lowerings are performed, without also lowering unrelated vector ops.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis gives more flexibility with when these lowerings are performed, without also lowering unrelated vector ops. This is a NFC (other than adding a new Patch is 46.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72890.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
new file mode 100644
index 000000000000000..ce778581b2cee37
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -0,0 +1,26 @@
+//===- ArmSMEToLLVM.h - Convert ArmSME to LLVM dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
+#define MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Create a pass to convert a subset of ArmSME ops to SCF.
+std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 3078d909a8946dd..a25fd17ea923fb5 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -15,6 +15,7 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 626f5f3d19d307e..a0cc05319bb7299 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1241,6 +1241,19 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArmSMEToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+ let summary = "Lower the operations from the ArmSME dialect into the LLVM "
+ "dialect";
+ let constructor = "mlir::createConvertArmSMEToLLVMPass()";
+ let dependentDialects = [
+ "arm_sme::ArmSMEDialect",
+ "LLVM::LLVMDialect"];
+}
+
//===----------------------------------------------------------------------===//
// VectorToLLVM
//===----------------------------------------------------------------------===//
@@ -1280,10 +1293,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
- Option<"armSME", "enable-arm-sme",
- "bool", /*default=*/"false",
- "Enables the use of ArmSME dialect while lowering the vector "
- "dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index fae04513859938b..8ea3e1e57b7caa5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -20,15 +20,6 @@ void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace arm_sme
-/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
-/// intrinsics.
-void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
-
-/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
-/// intrinsics.
-void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
-
} // namespace mlir
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
similarity index 86%
rename from mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
rename to mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 6ccb652ecbbc29e..66eee98cd23e4be 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -1,24 +1,36 @@
-//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
+//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of ArmSME operations to LLVM intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
using namespace mlir;
-using namespace mlir::arm_sme;
namespace {
@@ -40,11 +52,11 @@ namespace {
/// The 'arm_sme.cast_tile_to_vector' (which models the return) and the
/// 'arith.shli' (which generates the mask) will be folded away after tile
/// allocation and canonization.
-struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
- using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
+struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
+ using ConvertOpToLLVMPattern<arm_sme::ZeroOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(ZeroOp zero, OpAdaptor adaptor,
+ matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
@@ -121,7 +133,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
};
/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
-struct LoadTileSliceToArmSMELowering
+struct LoadTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
@@ -220,7 +232,7 @@ struct LoadTileSliceToArmSMELowering
};
/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
-struct StoreTileSliceToArmSMELowering
+struct StoreTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
@@ -313,7 +325,7 @@ struct StoreTileSliceToArmSMELowering
};
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
-struct MoveVectorToTileSliceToArmSMELowering
+struct MoveVectorToTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
@@ -373,7 +385,7 @@ struct MoveVectorToTileSliceToArmSMELowering
};
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
-struct MoveTileSliceToVectorArmSMELowering
+struct MoveTileSliceToVectorConversion
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
@@ -456,7 +468,8 @@ struct OuterProductOpConversion
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
- // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+ // [1]
+ // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;
@@ -475,7 +488,7 @@ struct OuterProductOpConversion
};
// TODO: Support CombiningKind::Sub for outer products.
- if (outerProductOp.getKind() != CombiningKind::Add)
+ if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
return outerProductOp.emitError("unsupported kind");
auto resultVectorType = outerProductOp.getResultType();
@@ -522,32 +535,49 @@ struct OuterProductOpConversion
} // namespace
-void mlir::configureArmSMELegalizeForExportTarget(
- LLVMConversionTarget &target) {
- target.addLegalOp<
- scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
- arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
- arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
- arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
- arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
- arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
- arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
- arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
- arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
- arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
- target.addLegalOp<GetTileID>();
- target.addIllegalOp<vector::OuterProductOp>();
-}
+namespace {
+
+struct ConvertArmSMEToLLVMPass
+ : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ arm_sme::ArmSMETypeConverter converter(&getContext(),
+ LowerToLLVMOptions(&getContext()));
+
+ patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
+ MoveVectorToTileSliceConversion, StoreTileSliceConversion,
+ OuterProductOpConversion, ZeroOpConversion>(converter);
+
+ LLVMConversionTarget target(getContext());
+ target.addLegalDialect<arith::ArithDialect>();
+ target.addLegalOp<UnrealizedConversionCastOp>();
+ target.addLegalOp<
+ scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
+ arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
+ arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+ arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+ arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+ arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+ arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+ arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+ arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+ arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+ arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+ arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+ arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+ arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+ target.addLegalOp<arm_sme::GetTileID>();
+ target.addIllegalOp<vector::OuterProductOp>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
-void mlir::populateArmSMELegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<
- LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
- MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
- OuterProductOpConversion, ZeroOpConversion>(converter);
+std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
+ return std::make_unique<ConvertArmSMEToLLVMPass>();
}
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000000..9914f39e17a1a91
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArmSMEToLLVM
+ ArmSMEToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMETransforms
+ MLIRArmSMEDialect
+ MLIRArmSMEUtils
+ MLIRTransforms
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 822ce5aca255510..c3a2481975040c9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
add_subdirectory(ArmSMEToSCF)
+add_subdirectory(ArmSMEToLLVM)
add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLibm)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4c6d0672d4108ef..ff8e78a668e0f10 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,9 +14,6 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -52,8 +49,6 @@ struct LowerVectorToLLVMPass
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
- if (armSME)
- registry.insert<arm_sme::ArmSMEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
@@ -96,7 +91,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
- arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
if (armNeon) {
// TODO: we may or may not want to include in-dialect lowering to
@@ -108,10 +102,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
- if (armSME) {
- configureArmSMELegalizeForExportTarget(target);
- populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
- }
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 8f485db4e8438b1..e2407d9f48f7061 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_dialect_library(MLIRArmSMETransforms
ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
- LegalizeForLLVMExport.cpp
TileAllocation.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
index 2c26c62ad42481e..65996e81c42d909 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s
// This test verifies the temporary casts that are emitted when lowering to
// intrinsics to preserve data flow are correct. Canonicalization will remove
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 8fdcf69958244f3..fa62332bc3f5b17 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
// Test conversion of ArmSME ops to LLVM intrinsics.
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index 0f31278eefd1550..ba650b031e6110b 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING
// CHECK-LABEL: @declaration
func.func private @declaration()
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 26cd91bd3e8956a..2378f4234aef1ef 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
// RUN: -allocate-arm-sme-tiles -canonicalize \
// RUN: -allow-unregistered-dialect \
// RUN: | FileCheck %s
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 721ff8f2c3589d4..c288f786f89a947 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
//===----------------------------------------------------------------------===//
// vector.transfer_write
@@ -17,9 +17,8 @@
// CHEC...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Benjamin Maxwell (MacDue) ChangesThis gives more flexibility with when these lowerings are performed, without also lowering unrelated vector ops. This is a NFC (other than adding a new Patch is 46.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72890.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
new file mode 100644
index 000000000000000..ce778581b2cee37
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -0,0 +1,26 @@
+//===- ArmSMEToLLVM.h - Convert ArmSME to LLVM dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
+#define MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Create a pass to convert a subset of ArmSME ops to SCF.
+std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 3078d909a8946dd..a25fd17ea923fb5 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -15,6 +15,7 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 626f5f3d19d307e..a0cc05319bb7299 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1241,6 +1241,19 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArmSMEToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+ let summary = "Lower the operations from the ArmSME dialect into the LLVM "
+ "dialect";
+ let constructor = "mlir::createConvertArmSMEToLLVMPass()";
+ let dependentDialects = [
+ "arm_sme::ArmSMEDialect",
+ "LLVM::LLVMDialect"];
+}
+
//===----------------------------------------------------------------------===//
// VectorToLLVM
//===----------------------------------------------------------------------===//
@@ -1280,10 +1293,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
- Option<"armSME", "enable-arm-sme",
- "bool", /*default=*/"false",
- "Enables the use of ArmSME dialect while lowering the vector "
- "dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index fae04513859938b..8ea3e1e57b7caa5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -20,15 +20,6 @@ void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace arm_sme
-/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
-/// intrinsics.
-void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
-
-/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
-/// intrinsics.
-void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
-
} // namespace mlir
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
similarity index 86%
rename from mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
rename to mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 6ccb652ecbbc29e..66eee98cd23e4be 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -1,24 +1,36 @@
-//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
+//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of ArmSME operations to LLVM intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
using namespace mlir;
-using namespace mlir::arm_sme;
namespace {
@@ -40,11 +52,11 @@ namespace {
/// The 'arm_sme.cast_tile_to_vector' (which models the return) and the
/// 'arith.shli' (which generates the mask) will be folded away after tile
/// allocation and canonization.
-struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
- using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
+struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
+ using ConvertOpToLLVMPattern<arm_sme::ZeroOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(ZeroOp zero, OpAdaptor adaptor,
+ matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
@@ -121,7 +133,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
};
/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
-struct LoadTileSliceToArmSMELowering
+struct LoadTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
@@ -220,7 +232,7 @@ struct LoadTileSliceToArmSMELowering
};
/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
-struct StoreTileSliceToArmSMELowering
+struct StoreTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
@@ -313,7 +325,7 @@ struct StoreTileSliceToArmSMELowering
};
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
-struct MoveVectorToTileSliceToArmSMELowering
+struct MoveVectorToTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
@@ -373,7 +385,7 @@ struct MoveVectorToTileSliceToArmSMELowering
};
/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
-struct MoveTileSliceToVectorArmSMELowering
+struct MoveTileSliceToVectorConversion
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
@@ -456,7 +468,8 @@ struct OuterProductOpConversion
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
- // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+ // [1]
+ // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;
@@ -475,7 +488,7 @@ struct OuterProductOpConversion
};
// TODO: Support CombiningKind::Sub for outer products.
- if (outerProductOp.getKind() != CombiningKind::Add)
+ if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
return outerProductOp.emitError("unsupported kind");
auto resultVectorType = outerProductOp.getResultType();
@@ -522,32 +535,49 @@ struct OuterProductOpConversion
} // namespace
-void mlir::configureArmSMELegalizeForExportTarget(
- LLVMConversionTarget &target) {
- target.addLegalOp<
- scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
- arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
- arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
- arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
- arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
- arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
- arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
- arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
- arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
- arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
- target.addLegalOp<GetTileID>();
- target.addIllegalOp<vector::OuterProductOp>();
-}
+namespace {
+
+struct ConvertArmSMEToLLVMPass
+ : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ arm_sme::ArmSMETypeConverter converter(&getContext(),
+ LowerToLLVMOptions(&getContext()));
+
+ patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
+ MoveVectorToTileSliceConversion, StoreTileSliceConversion,
+ OuterProductOpConversion, ZeroOpConversion>(converter);
+
+ LLVMConversionTarget target(getContext());
+ target.addLegalDialect<arith::ArithDialect>();
+ target.addLegalOp<UnrealizedConversionCastOp>();
+ target.addLegalOp<
+ scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
+ arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
+ arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+ arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+ arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+ arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+ arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+ arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+ arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+ arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+ arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+ arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+ arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+ arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
+ target.addLegalOp<arm_sme::GetTileID>();
+ target.addIllegalOp<vector::OuterProductOp>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
-void mlir::populateArmSMELegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<
- LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
- MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
- OuterProductOpConversion, ZeroOpConversion>(converter);
+std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
+ return std::make_unique<ConvertArmSMEToLLVMPass>();
}
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
new file mode 100644
index 000000000000000..9914f39e17a1a91
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArmSMEToLLVM
+ ArmSMEToLLVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMETransforms
+ MLIRArmSMEDialect
+ MLIRArmSMEUtils
+ MLIRTransforms
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 822ce5aca255510..c3a2481975040c9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
add_subdirectory(ArmSMEToSCF)
+add_subdirectory(ArmSMEToLLVM)
add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLibm)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4c6d0672d4108ef..ff8e78a668e0f10 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,9 +14,6 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -52,8 +49,6 @@ struct LowerVectorToLLVMPass
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
- if (armSME)
- registry.insert<arm_sme::ArmSMEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
@@ -96,7 +91,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
- arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
if (armNeon) {
// TODO: we may or may not want to include in-dialect lowering to
@@ -108,10 +102,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
- if (armSME) {
- configureArmSMELegalizeForExportTarget(target);
- populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
- }
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 8f485db4e8438b1..e2407d9f48f7061 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_dialect_library(MLIRArmSMETransforms
ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
- LegalizeForLLVMExport.cpp
TileAllocation.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
index 2c26c62ad42481e..65996e81c42d909 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s
// This test verifies the temporary casts that are emitted when lowering to
// intrinsics to preserve data flow are correct. Canonicalization will remove
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 8fdcf69958244f3..fa62332bc3f5b17 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
// Test conversion of ArmSME ops to LLVM intrinsics.
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index 0f31278eefd1550..ba650b031e6110b 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING
// CHECK-LABEL: @declaration
func.func private @declaration()
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 26cd91bd3e8956a..2378f4234aef1ef 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \
+// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
// RUN: -allocate-arm-sme-tiles -canonicalize \
// RUN: -allow-unregistered-dialect \
// RUN: | FileCheck %s
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 721ff8f2c3589d4..c288f786f89a947 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
//===----------------------------------------------------------------------===//
// vector.transfer_write
@@ -17,9 +17,8 @@
// CHEC...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of minor comments but otherwise LGTM, cheers
This gives more flexibility with when these lowerings are performed, without also lowering unrelated vector ops.
This is a NFC (other than adding a new
-convert-arm-sme-to-llvm
pass)