Skip to content

[mlir][ArmSME] Move ArmSME -> intrinsics lowerings to convert-arm-sme-to-llvm pass #72890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- ArmSMEToLLVM.h - Convert ArmSME to LLVM dialect ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
#define MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_

#include <memory>

#include "mlir/Dialect/ArmSME/Transforms/Passes.h"

namespace mlir {
class Pass;
class RewritePatternSet;

#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"

using arm_sme::ArmSMETypeConverter;

/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();

/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);

/// Populate the given list with patterns that convert from the ArmSME dialect
/// to LLVM intrinsics.
void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
RewritePatternSet &patterns);

} // namespace mlir

#endif // MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
Expand Down
18 changes: 14 additions & 4 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,20 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
];
}

//===----------------------------------------------------------------------===//
// ArmSMEToLLVM
//===----------------------------------------------------------------------===//

def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
let summary = "Lower the operations from the ArmSME dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertArmSMEToLLVMPass()";
let dependentDialects = [
"arm_sme::ArmSMEDialect",
"LLVM::LLVMDialect"
];
}

//===----------------------------------------------------------------------===//
// VectorToLLVM
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1280,10 +1294,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
Option<"armSME", "enable-arm-sme",
"bool", /*default=*/"false",
"Enables the use of ArmSME dialect while lowering the vector "
"dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
Expand Down
9 changes: 0 additions & 9 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@ void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace arm_sme

/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
/// intrinsics.
void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);

/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
/// intrinsics.
void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);

} // namespace mlir

#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering of ArmSME operations to LLVM intrinsics.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"

#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::arm_sme;

namespace {

Expand All @@ -40,11 +50,11 @@ namespace {
/// The 'arm_sme.cast_tile_to_vector' (which models the return) and the
/// 'arith.shli' (which generates the mask) will be folded away after tile
/// allocation and canonization.
struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
using ConvertOpToLLVMPattern<arm_sme::ZeroOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(ZeroOp zero, OpAdaptor adaptor,
matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();

Expand Down Expand Up @@ -121,7 +131,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
};

/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
struct LoadTileSliceToArmSMELowering
struct LoadTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -220,7 +230,7 @@ struct LoadTileSliceToArmSMELowering
};

/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
struct StoreTileSliceToArmSMELowering
struct StoreTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -313,7 +323,7 @@ struct StoreTileSliceToArmSMELowering
};

/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
struct MoveVectorToTileSliceToArmSMELowering
struct MoveVectorToTileSliceConversion
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -373,7 +383,7 @@ struct MoveVectorToTileSliceToArmSMELowering
};

/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
struct MoveTileSliceToVectorArmSMELowering
struct MoveTileSliceToVectorConversion
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -456,7 +466,8 @@ struct OuterProductOpConversion
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
// [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
// [1]
// https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;

Expand All @@ -475,7 +486,7 @@ struct OuterProductOpConversion
};

// TODO: Support CombiningKind::Sub for outer products.
if (outerProductOp.getKind() != CombiningKind::Add)
if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
return outerProductOp.emitError("unsupported kind");

auto resultVectorType = outerProductOp.getResultType();
Expand Down Expand Up @@ -522,32 +533,56 @@ struct OuterProductOpConversion

} // namespace

void mlir::configureArmSMELegalizeForExportTarget(
LLVMConversionTarget &target) {
namespace {

struct ConvertArmSMEToLLVMPass
: public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
ArmSMETypeConverter converter(&getContext(),
LowerToLLVMOptions(&getContext()));

configureArmSMEToLLVMConversionLegality(target);
populateArmSMEToLLVMConversionPatterns(converter, patterns);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};

} // namespace

void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();
arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
arm_sme::aarch64_sme_mopa>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
}

void mlir::populateArmSMEToLLVMConversionPatterns(
ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion>(converter);
}

void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
OuterProductOpConversion, ZeroOpConversion>(converter);
std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
return std::make_unique<ConvertArmSMEToLLVMPass>();
}
16 changes: 16 additions & 0 deletions mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
add_mlir_conversion_library(MLIRArmSMEToLLVM
ArmSMEToLLVM.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToLLVM

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRArmSMETransforms
MLIRArmSMEDialect
MLIRArmSMEUtils
MLIRTransforms
MLIRLLVMCommonConversion
MLIRLLVMDialect)
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
add_subdirectory(ArmSMEToSCF)
add_subdirectory(ArmSMEToLLVM)
add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexToLibm)
Expand Down
10 changes: 0 additions & 10 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -52,8 +49,6 @@ struct LowerVectorToLLVMPass
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
if (armSME)
registry.insert<arm_sme::ArmSMEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
Expand Down Expand Up @@ -96,7 +91,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);

if (armNeon) {
// TODO: we may or may not want to include in-dialect lowering to
Expand All @@ -108,10 +102,6 @@ void LowerVectorToLLVMPass::runOnOperation() {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (armSME) {
configureArmSMELegalizeForExportTarget(target);
populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
}
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
add_mlir_dialect_library(MLIRArmSMETransforms
ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
LegalizeForLLVMExport.cpp
TileAllocation.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s

// This test verifies the temporary casts that are emitted when lowering to
// intrinsics to preserve data flow are correct. Canonicalization will remove
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s

// Test conversion of ArmSME ops to LLVM intrinsics.

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/ArmSME/enable-arm-za.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING

// CHECK-LABEL: @declaration
func.func private @declaration()
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \
// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
// RUN: -allocate-arm-sme-tiles -canonicalize \
// RUN: -allow-unregistered-dialect \
// RUN: | FileCheck %s
Expand Down
Loading