Skip to content

Commit 36f675e

Browse files
committed
Fixups
1 parent f6e3b88 commit 36f675e

File tree

16 files changed

+60
-45
lines changed

16 files changed

+60
-45
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
#include "mlir/IR/OpDefinition.h"
2525
#include "mlir/Interfaces/SideEffectInterfaces.h"
2626

27+
namespace mlir::arm_sme {
2728
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
29+
}
2830

2931
#define GET_ATTRDEF_CLASSES
3032
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
4040
def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
4141
let description = [{
4242
An interface for operations that use or allocate Arm SME tiles. These
43-
operations need to be assigned a tile ID an i32 attribute, which specifies
43+
operations need to be assigned a tile ID, an i32 attribute, which specifies
4444
which virtual tile within the ZA storage to use. The number of tiles
4545
available depends on the type of the tile. This is summarized below:
4646

@@ -52,7 +52,7 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
5252
| `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
5353
| `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
5454

55-
Operations that allocate a new tiles (such as arm_sme.get_tile), are used as
55+
Operations that allocate a new tile (such as arm_sme.get_tile), are used as
5656
the roots for tile allocation, with all operations that (transitively)
5757
depend on a root being assigned the same tile ID.
5858
}];
@@ -71,7 +71,10 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
7171
}]
7272
>,
7373
InterfaceMethod<
74-
"Returns the (possibly null) tile ID assigned to this operation.",
74+
[{
75+
Returns the tile ID assigned to this operation. This will be null before
76+
tile allocation.
77+
}],
7578
/*returnType=*/"mlir::IntegerAttr",
7679
/*methodName=*/"getTileId",
7780
/*arguments=*/(ins),
@@ -82,13 +85,16 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
8285
}]
8386
>,
8487
InterfaceMethod<
85-
"The type of tile this operation allocates (or none)",
88+
[{
89+
The type of tile this operation allocates. Returns none (std::nullopt)
90+
if this operation does not allocate a tile.
91+
}],
8692
/*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
8793
/*methodName=*/"getAllocatedTileType",
8894
/*arguments=*/(ins),
8995
/*methodBody=*/[{}],
9096
/*defaultImpl=*/ [{
91-
// Do not allocate a new tile.
97+
// This operation does not allocate a tile.
9298
return std::nullopt;
9399
}]
94100
>
@@ -104,14 +110,16 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
104110
return op;
105111
}
106112

107-
// A helper to replace this operation and forward any tile ID.
113+
// A helper to replace this operation and forward its tile ID (if present).
108114
template<typename T, typename... Args>
109115
T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
110116
auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
111117
rewriter.replaceOp($_op, newOp);
112118
return newOp;
113119
}
114120
}];
121+
122+
let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
115123
}
116124

117125
//===----------------------------------------------------------------------===//
@@ -222,11 +230,11 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
222230
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
223231
Op<ArmSME_Dialect, mnemonic, traits> {}
224232

225-
def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
233+
def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
226234
let summary = "Returns a SME virtual tile";
227235
let description = [{
228236
Allocates a new SME "virtual tile" within a function. The contents of the
229-
tile returned from this operation undefined.
237+
tile returned from this operation are undefined.
230238

231239
Example 1:
232240

@@ -264,12 +272,12 @@ def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
264272
}];
265273
}
266274

267-
def MaterializeSSATile : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
275+
def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
268276
let summary = "SME tile placeholder";
269277
let description = [{
270278
A placeholder to preserve dataflow while lowering to SME intrinsics (which
271-
do not take or return tile values). This operation is intended to be DCE'd
272-
once all ArmSME operations have been lowered.
279+
do not take or return SME virtual tile values). This operation is intended
280+
to be DCE'd once all ArmSME operations have been lowered.
273281

274282
This operation is not intended to be used outside of the ArmSME -> LLVM
275283
conversion.

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ bool isValidSMETileElementType(Type type);
3535
/// otherwise.
3636
bool isValidSMETileVectorType(VectorType vType);
3737

38+
/// Returns the type of SME tile this vector type corresponds to or none.
3839
std::optional<ArmSMETileType> getSMETileType(VectorType);
3940

41+
/// Verifies the tile ID (if set) on this tile operation is valid.
42+
LogicalResult verifyOperationHasValidTileId(Operation *);
43+
4044
} // namespace mlir::arm_sme
4145

4246
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,21 @@ using namespace mlir;
3232

3333
namespace {
3434

35-
IntegerAttr getTileIdOrError(ArmSMETileOpInterface op) {
35+
IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
3636
auto tileId = op.getTileId();
3737
if (!tileId)
3838
op.emitOpError(
3939
"expected tile ID to be allocated before conversion to LLVM");
4040
return tileId;
4141
}
4242

43-
struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTile> {
44-
using ConvertOpToLLVMPattern<arm_sme::GetTile>::ConvertOpToLLVMPattern;
43+
struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTileOp> {
44+
using ConvertOpToLLVMPattern<arm_sme::GetTileOp>::ConvertOpToLLVMPattern;
4545

4646
LogicalResult
47-
matchAndRewrite(arm_sme::GetTile getTile, OpAdaptor,
47+
matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
4848
ConversionPatternRewriter &rewriter) const override {
49-
rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
49+
rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
5050
getTile, getTile.getTileType());
5151
return success();
5252
}
@@ -140,7 +140,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
140140
loc, rewriter.getI32IntegerAttr(zeroMask));
141141

142142
// Create a placeholder op to preserve dataflow.
143-
rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
143+
rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
144144
zero, zero.getVectorType());
145145

146146
return success();
@@ -558,7 +558,7 @@ struct ConvertArmSMEToLLVMPass
558558
void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
559559
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
560560
target.addLegalOp<
561-
arm_sme::MaterializeSSATile, arm_sme::aarch64_sme_zero,
561+
arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
562562
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
563563
arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
564564
arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,

mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ add_mlir_conversion_library(MLIRArmSMEToLLVM
1010
LINK_LIBS PUBLIC
1111
MLIRArmSMETransforms
1212
MLIRArmSMEDialect
13-
MLIRArmSMEUtils
1413
MLIRTransforms
1514
MLIRLLVMCommonConversion
1615
MLIRLLVMDialect)

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
8989
auto tileElementType = tileType.getElementType();
9090

9191
// Allocate a new SME tile.
92-
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
92+
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
9393
rewriter, loc, tileType);
9494

9595
// Create a loop that loads each ZA tile slice from memory.
@@ -299,7 +299,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
299299
loc, rewriter.getI32Type(), numCols);
300300

301301
// Allocate a new SME tile.
302-
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
302+
auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
303303
rewriter, loc, tileType);
304304

305305
// Create a loop that loads each ZA tile slice from memory.

mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,5 @@ add_mlir_conversion_library(MLIRArmSMEToSCF
99

1010
LINK_LIBS PUBLIC
1111
MLIRArmSMEDialect
12-
MLIRArmSMEUtils
1312
MLIRTransforms
1413
)

mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,5 @@ add_mlir_conversion_library(MLIRVectorToArmSME
1010

1111
LINK_LIBS PUBLIC
1212
MLIRArmSMEDialect
13-
MLIRArmSMEUtils
1413
MLIRLLVMCommonConversion
1514
)

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
253253
tileSliceType, denseAttr.getSplatValue<Attribute>());
254254
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
255255

256-
auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
256+
auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
257257

258258
auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
259259
auto tileSliceIndex = forOp.getInductionVar();
@@ -315,7 +315,7 @@ struct BroadcastOpToArmSMELowering
315315
else
316316
return failure();
317317

318-
auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
318+
auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
319319

320320
// Create a loop over ZA tile slices.
321321
auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
@@ -371,7 +371,7 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
371371
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
372372
loc, tileSliceType, splatOp.getInput());
373373

374-
auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
374+
auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
375375

376376
// Next, create a loop over ZA tile slices and "move" the generated 1-d
377377
// vector to each slice.
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3-
add_subdirectory(Utils)

mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRArmSMEDialect
22
ArmSME.cpp
3+
Utils.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME
@@ -15,5 +16,4 @@ add_mlir_dialect_library(MLIRArmSMEDialect
1516
MLIRSCFDialect
1617
MLIRSideEffectInterfaces
1718
MLIRVectorDialect
18-
MLIRArmSMEUtils
1919
)

mlir/lib/Dialect/ArmSME/Utils/Utils.cpp renamed to mlir/lib/Dialect/ArmSME/IR/Utils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
14+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
1415

1516
namespace mlir::arm_sme {
1617

@@ -59,4 +60,17 @@ std::optional<ArmSMETileType> getSMETileType(VectorType type) {
5960
}
6061
}
6162

63+
LogicalResult verifyOperationHasValidTileId(Operation *op) {
64+
auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
65+
if (!tileOp)
66+
return success(); // Not a tile op (no need to check).
67+
auto tileId = tileOp.getTileId();
68+
if (!tileId)
69+
return success(); // Not having a tile ID (yet) is okay.
70+
if (!tileId.getType().isSignlessInteger(32))
71+
return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
72+
// TODO: Verify value of tile ID is in range.
73+
return success();
74+
}
75+
6276
} // namespace mlir::arm_sme

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms
1111

1212
LINK_LIBS PUBLIC
1313
MLIRArmSMEDialect
14-
MLIRArmSMEUtils
1514
MLIRFuncDialect
1615
MLIRLLVMCommonConversion
1716
MLIRVectorDialect

mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
136136
}
137137
}
138138

139-
/// Allocates a tile to 'tileId' or returns an error if there are no tiles left.
140-
static FailureOr<unsigned> getTile(ArmSMETileType tileType,
141-
TileMask &tilesInUse) {
139+
/// Allocates and returns a tile ID. Returns an error if there are no tiles
140+
/// left.
141+
static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
142+
TileMask &tilesInUse) {
142143
auto masks = getMasks(tileType);
143144
for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
144145
if ((tilesInUse & tileMask) == TileMask::kNone) {
@@ -168,7 +169,7 @@ struct AssignTileIDsPattern
168169
else
169170
tilesInUse = TileMask::kNone;
170171

171-
auto tileId = getTile(*tileType, tilesInUse);
172+
auto tileId = allocateTileId(*tileType, tilesInUse);
172173
if (failed(tileId))
173174
return tileOp.emitError("ran out of SME virtual tiles!");
174175

mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt

Lines changed: 0 additions & 9 deletions
This file was deleted.

mlir/tools/mlir-query/mlir-query.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
using namespace mlir;
2222

2323
// This is needed because these matchers are defined as overloaded functions.
24-
using HasOpAttrName = mlir::detail::AttrOpMatcher(StringRef);
25-
using HasOpName = mlir::detail::NameOpMatcher(StringRef);
26-
using IsConstantOp = mlir::detail::constant_op_matcher();
24+
using HasOpAttrName = detail::AttrOpMatcher(StringRef);
25+
using HasOpName = detail::NameOpMatcher(StringRef);
26+
using IsConstantOp = detail::constant_op_matcher();
2727

2828
namespace test {
2929
#ifdef MLIR_INCLUDE_TESTS

0 commit comments

Comments
 (0)