Skip to content

Commit eaff02f

Browse files
authored
[mlir][ArmSME] Switch to an attribute-based tile allocation scheme (#73253)
This reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach: * Tile allocation can now be done ASAP (i.e. immediately after `-convert-vector-to-arm-sme`) * SSA form for control flow is now supported (e.g.`scf.for` loops that yield tiles) * ArmSME ops can be converted to intrinsics very late (i.e. after lowering to control flow) * Tests are simplified by removing constants and casts * Avoids correctness issues with representing LLVM `immargs` as MLIR values - The tile ID on the SME intrinsics is an `immarg` (so is required to be a compile-time constant), `immargs` should be mapped to MLIR attributes (this is already the case for intrinsics in the LLVM dialect) - Using MLIR values for `immargs` can lead to invalid LLVM IR being generated (and passes such as -cse making incorrect optimizations) As part of this patch we bid farewell to the following operations: ```mlir arm_sme.get_tile_id : i32 arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32> arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32 ``` These are now replaced with: ```mlir // Allocates a new tile with (indeterminate) state: arm_sme.get_tile : vector<[4]x[4]xi32> // A placeholder operation for lowering ArmSME ops to intrinsics: arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32> ``` The new tile allocation works by operations implementing the `ArmSMETileOpInterface`. This interface says that an operation needs to be assigned a tile ID, and may conditionally allocate a new SME tile. Operations allocate a new tile by implementing... ```c++ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() ``` ...and returning what type of tile the op allocates (ZAB, ZAH, etc). Operations that don't allocate a tile return `std::nullopt` (which is the default behaviour). Currently the following ops are defined as allocating: ```mlir arm_sme.get_tile arm_sme.zero arm_sme.tile_load arm_sme.outerproduct // (if no accumulator is specified) ``` Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases. Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly created operations.
1 parent b04a419 commit eaff02f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1406
-1525
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define MLIR_DIALECT_ARMSME_IR_ARMSME_H
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
18+
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
1719
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1820
#include "mlir/Dialect/SCF/IR/SCF.h"
1921
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -22,7 +24,9 @@
2224
#include "mlir/IR/OpDefinition.h"
2325
#include "mlir/Interfaces/SideEffectInterfaces.h"
2426

25-
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
27+
namespace mlir::arm_sme {
28+
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
29+
}
2630

2731
#define GET_ATTRDEF_CLASSES
2832
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//===- ArmSMEEnums.h - Arm SME Dialect Enums --------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARMSME_ENUMS_H
10+
#define MLIR_DIALECT_ARMSME_ENUMS_H
11+
12+
#include "mlir/IR/Dialect.h"
13+
14+
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
15+
16+
#endif

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
5454
}];
5555
}
5656

57-
class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
57+
class ArmSME_IntrOp<string mnemonic,
58+
list<int> immArgPositions = [],
59+
list<string> immArgAttrNames = [],
60+
list<int> overloadedOperands = [],
5861
list<Trait> traits = [], int numResults = 0,
5962
list<int> overloadedResults = []>
6063
: LLVM_IntrOpBase<
@@ -64,16 +67,27 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
6467
/*list<int> overloadedResults=*/overloadedResults,
6568
/*list<int> overloadedOperands=*/overloadedOperands,
6669
/*list<Trait> traits=*/traits,
67-
/*int numResults=*/numResults>;
70+
/*int numResults=*/numResults,
71+
/*bit requiresAccessGroup=*/0,
72+
/*bit requiresAliasAnalysis=*/0,
73+
/*bit requiresFastmath=*/0,
74+
/*list<int> immArgPositions=*/immArgPositions,
75+
/*list<string> immArgAttrNames=*/immArgAttrNames>;
6876

6977
// Zero
70-
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
71-
Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
78+
def LLVM_aarch64_sme_zero
79+
: ArmSME_IntrOp<"zero",
80+
/*immArgPositions=*/[0],
81+
/*immArgAttrNames=*/["tile_mask"]>,
82+
Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
7283

7384
// MOP's
7485
class ArmSME_IntrMopOverloadedOp<string mnemonic>
75-
: ArmSME_IntrOp<mnemonic, [4]>,
76-
Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
86+
: ArmSME_IntrOp<mnemonic,
87+
/*immArgPositions=*/[0],
88+
/*immArgAttrNames=*/["tile_id"],
89+
/*overloadedOperands=*/[4]>,
90+
Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
7791
Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
7892
Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
7993
Arg<MOPVector, "LHS vector operand">:$lhs_vector,
@@ -92,12 +106,17 @@ def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
92106
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
93107
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
94108

109+
class ArmSME_IntrLoadStoreOp<string mnemonic>
110+
: ArmSME_IntrOp<mnemonic,
111+
/*immArgPositions=*/[2],
112+
/*immArgAttrNames=*/["tile_id"]>;
113+
95114
// Loads
96115
class ArmSME_IntrLoadOp<string mnemonic>
97-
: ArmSME_IntrOp<mnemonic>,
116+
: ArmSME_IntrLoadStoreOp<mnemonic>,
98117
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
99118
Arg<LLVM_AnyPointer, "Load address">:$load_address,
100-
Arg<I32, "Virtual tile ID">:$tile_id,
119+
Arg<I32Attr, "Virtual tile ID">:$tile_id,
101120
Arg<I32, "Tile slice">:$tile_slice_index)>;
102121

103122
def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
@@ -113,10 +132,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
113132

114133
// Stores
115134
class ArmSME_IntrStoreOp<string mnemonic>
116-
: ArmSME_IntrOp<mnemonic>,
135+
: ArmSME_IntrLoadStoreOp<mnemonic>,
117136
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
118137
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
119-
Arg<I32, "Virtual tile ID">:$tile_id,
138+
Arg<I32Attr, "Virtual tile ID">:$tile_id,
120139
Arg<I32, "Tile slice">:$tile_slice_index)>;
121140

122141
def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
@@ -138,22 +157,28 @@ def LLVM_aarch64_sme_str
138157

139158
// Vector to tile slice
140159
class LLVM_aarch64_sme_write<string direction>
141-
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
160+
: ArmSME_IntrOp<"write." # direction,
161+
/*immArgPositions=*/[0],
162+
/*immArgAttrNames=*/["tile_id"],
163+
/*overloadedOperands=*/[3],
142164
[AllShapesMatch<["predicate", "vector"]>]>,
143-
Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
165+
Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
144166
Arg<I32, "Tile slice">:$tile_slice_index,
145167
Arg<SVEPredicate, "Vector predicate">:$predicate,
146168
Arg<SVEVector, "Vector operand">:$vector)>;
147169

148170
// Tile slice to vector
149171
class LLVM_aarch64_sme_read<string direction>
150-
: ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
172+
: ArmSME_IntrOp<"read." # direction,
173+
/*immArgPositions=*/[2],
174+
/*immArgAttrNames=*/["tile_id"],
175+
/*overloadedOperands=*/[],
151176
[AllShapesMatch<["vector", "predicate", "res"]>,
152177
AllElementTypesMatch<["vector", "res"]>],
153178
/*numResults=*/1, /*overloadedResults=*/[0]>,
154179
Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
155180
Arg<SVEPredicate, "Vector predicate">:$predicate,
156-
Arg<I32, "Virtual tile ID">:$tile_id,
181+
Arg<I32Attr, "Virtual tile ID">:$tile_id,
157182
Arg<I32, "Tile slice">:$tile_slice_index)>;
158183

159184
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;

0 commit comments

Comments
 (0)