1
- // ===- LegalizeForLLVMExport .cpp - Prepare ArmSME for LLVM translation ----===//
1
+ // ===- ArmSMEToLLVM .cpp - Convert ArmSME to LLVM dialect -------------- ----===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
+ //
9
+ // This file implements lowering of ArmSME operations to LLVM intrinsics.
10
+ //
11
+ // ===----------------------------------------------------------------------===//
12
+
13
+ #include " mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
8
14
9
15
#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
10
16
#include " mlir/Conversion/LLVMCommon/Pattern.h"
11
17
#include " mlir/Dialect/Arith/IR/Arith.h"
12
18
#include " mlir/Dialect/ArmSME/IR/ArmSME.h"
13
- #include " mlir/Dialect/ArmSME/Transforms/Transforms.h"
14
19
#include " mlir/Dialect/ArmSME/Utils/Utils.h"
15
20
#include " mlir/Dialect/Func/IR/FuncOps.h"
16
21
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
17
- #include " mlir/Dialect/SCF/IR/SCF.h"
18
22
#include " mlir/Dialect/Vector/IR/VectorOps.h"
23
+ #include " mlir/Pass/Pass.h"
24
+ #include " mlir/Transforms/DialectConversion.h"
25
+
26
+ namespace mlir {
27
+ #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
28
+ #include " mlir/Conversion/Passes.h.inc"
29
+ } // namespace mlir
19
30
20
31
using namespace mlir ;
21
- using namespace mlir ::arm_sme;
22
32
23
33
namespace {
24
34
@@ -40,11 +50,11 @@ namespace {
40
50
// / The 'arm_sme.cast_tile_to_vector' (which models the return) and the
41
51
// / 'arith.shli' (which generates the mask) will be folded away after tile
42
52
// / allocation and canonization.
43
- struct ZeroOpConversion : public ConvertOpToLLVMPattern <ZeroOp> {
44
- using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
53
+ struct ZeroOpConversion : public ConvertOpToLLVMPattern <arm_sme:: ZeroOp> {
54
+ using ConvertOpToLLVMPattern<arm_sme:: ZeroOp>::ConvertOpToLLVMPattern;
45
55
46
56
LogicalResult
47
- matchAndRewrite (ZeroOp zero, OpAdaptor adaptor,
57
+ matchAndRewrite (arm_sme:: ZeroOp zero, OpAdaptor adaptor,
48
58
ConversionPatternRewriter &rewriter) const override {
49
59
auto loc = zero.getLoc ();
50
60
@@ -121,7 +131,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
121
131
};
122
132
123
133
// / Lower `arm_sme.load_tile_slice` to SME intrinsics.
124
- struct LoadTileSliceToArmSMELowering
134
+ struct LoadTileSliceConversion
125
135
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
126
136
using ConvertOpToLLVMPattern<
127
137
arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
@@ -220,7 +230,7 @@ struct LoadTileSliceToArmSMELowering
220
230
};
221
231
222
232
// / Lower for `arm_sme.store_tile_slice` to SME intrinsics.
223
- struct StoreTileSliceToArmSMELowering
233
+ struct StoreTileSliceConversion
224
234
: public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
225
235
using ConvertOpToLLVMPattern<
226
236
arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
@@ -313,7 +323,7 @@ struct StoreTileSliceToArmSMELowering
313
323
};
314
324
315
325
// / Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
316
- struct MoveVectorToTileSliceToArmSMELowering
326
+ struct MoveVectorToTileSliceConversion
317
327
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
318
328
using ConvertOpToLLVMPattern<
319
329
arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
@@ -373,7 +383,7 @@ struct MoveVectorToTileSliceToArmSMELowering
373
383
};
374
384
375
385
// / Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
376
- struct MoveTileSliceToVectorArmSMELowering
386
+ struct MoveTileSliceToVectorConversion
377
387
: public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
378
388
using ConvertOpToLLVMPattern<
379
389
arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
@@ -456,7 +466,8 @@ struct OuterProductOpConversion
456
466
// * half-precision - +sme2p1,+b16b16
457
467
//
458
468
// It should be possible to control lowering based on target features.
459
- // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
469
+ // [1]
470
+ // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
460
471
if ((vectorType.getRank () != 2 ) || !vectorType.allDimsScalable ())
461
472
return false ;
462
473
@@ -475,7 +486,7 @@ struct OuterProductOpConversion
475
486
};
476
487
477
488
// TODO: Support CombiningKind::Sub for outer products.
478
- if (outerProductOp.getKind () != CombiningKind::Add)
489
+ if (outerProductOp.getKind () != arm_sme:: CombiningKind::Add)
479
490
return outerProductOp.emitError (" unsupported kind" );
480
491
481
492
auto resultVectorType = outerProductOp.getResultType ();
@@ -522,32 +533,56 @@ struct OuterProductOpConversion
522
533
523
534
} // namespace
524
535
525
- void mlir::configureArmSMELegalizeForExportTarget (
526
- LLVMConversionTarget &target) {
536
+ namespace {
537
+
538
+ struct ConvertArmSMEToLLVMPass
539
+ : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
540
+ void runOnOperation () override {
541
+ LLVMConversionTarget target (getContext ());
542
+ RewritePatternSet patterns (&getContext ());
543
+ ArmSMETypeConverter converter (&getContext (),
544
+ LowerToLLVMOptions (&getContext ()));
545
+
546
+ configureArmSMEToLLVMConversionLegality (target);
547
+ populateArmSMEToLLVMConversionPatterns (converter, patterns);
548
+
549
+ if (failed (applyPartialConversion (getOperation (), target,
550
+ std::move (patterns))))
551
+ signalPassFailure ();
552
+ }
553
+ };
554
+
555
+ } // namespace
556
+
557
+ void mlir::configureArmSMEToLLVMConversionLegality (ConversionTarget &target) {
558
+ target.addIllegalDialect <arm_sme::ArmSMEDialect>();
527
559
target.addLegalOp <
528
- scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
529
- arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
530
- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
531
- arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
532
- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
533
- arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
534
- arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
535
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
536
- arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
537
- arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
538
- arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
539
- arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
540
- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
541
- arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
542
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
543
- target.addLegalOp <GetTileID>();
544
- target.addIllegalOp <vector::OuterProductOp>();
560
+ arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
561
+ arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
562
+ arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
563
+ arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
564
+ arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
565
+ arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
566
+ arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
567
+ arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
568
+ arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
569
+ arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
570
+ arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
571
+ arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
572
+ arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
573
+ arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
574
+ arm_sme::aarch64_sme_mopa>();
575
+ target.addLegalDialect <arith::ArithDialect>();
576
+ target.addLegalOp <UnrealizedConversionCastOp>();
577
+ }
578
+
579
+ void mlir::populateArmSMEToLLVMConversionPatterns (
580
+ ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
581
+ patterns.add <LoadTileSliceConversion, MoveTileSliceToVectorConversion,
582
+ MoveVectorToTileSliceConversion, StoreTileSliceConversion,
583
+ OuterProductOpConversion, ZeroOpConversion>(converter);
545
584
}
546
585
547
- void mlir::populateArmSMELegalizeForLLVMExportPatterns (
548
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
549
- patterns.add <
550
- LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
551
- MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
552
- OuterProductOpConversion, ZeroOpConversion>(converter);
586
+ std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass () {
587
+ return std::make_unique<ConvertArmSMEToLLVMPass>();
553
588
}
0 commit comments