Skip to content

Commit 9f7fff7

Browse files
authored
[mlir][ArmSME] Add arith-to-arm-sme conversion pass (#78197)
Existing 'arith::ConstantOp' conversion and tests are moved from VectorToArmSME. There's currently only a single op that's converted at the moment, but this will grow in the future as things like in-tile add are implemented. Also, 'createLoopOverTileSlices' is moved to ArmSME utils since it's relevant for both conversions.
1 parent 11c0dc3 commit 9f7fff7

File tree

18 files changed

+257
-114
lines changed

18 files changed

+257
-114
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- ArithToArmSME.h - Arith to ArmSME dialect conversion -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
10+
#define MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
16+
class RewritePatternSet;
17+
class Pass;
18+
19+
#define GEN_PASS_DECL_ARITHTOARMSMECONVERSIONPASS
20+
#include "mlir/Conversion/Passes.h.inc"
21+
22+
namespace arith {
23+
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns);
24+
} // namespace arith
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_ARITHTOARMSME_ARITHTOARMSME_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
15+
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
1516
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1617
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
1718
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,15 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
164164
];
165165
}
166166

167+
//===----------------------------------------------------------------------===//
168+
// ArithToArmSME
169+
//===----------------------------------------------------------------------===//
170+
171+
def ArithToArmSMEConversionPass : Pass<"convert-arith-to-arm-sme"> {
172+
let summary = "Convert Arith dialect to ArmSME dialect";
173+
let dependentDialects = ["arm_sme::ArmSMEDialect"];
174+
}
175+
167176
//===----------------------------------------------------------------------===//
168177
// ArmNeon2dToIntr
169178
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,16 @@
1616
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
1717

1818
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
19+
#include "mlir/Dialect/SCF/IR/SCF.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include <optional>
2122

23+
namespace mlir {
24+
class Location;
25+
class PatternRewriter;
26+
class Value;
27+
} // namespace mlir
28+
2229
namespace mlir::arm_sme {
2330

2431
constexpr unsigned MinStreamingVectorLengthInBits = 128;
@@ -42,6 +49,13 @@ std::optional<ArmSMETileType> getSMETileType(VectorType);
4249
/// Verifies the tile ID (if set) on this tile operation is valid.
4350
LogicalResult verifyOperationHasValidTileId(Operation *);
4451

52+
/// Generates a for loop over ZA tile slices where the induction variable is
53+
/// the tile slice index and each iteration yields a new tile. Loop body is
54+
/// built via `makeLoopBody`, which returns the next tile value.
55+
scf::ForOp createLoopOverTileSlices(
56+
PatternRewriter &rewriter, Location loc, Value initTile,
57+
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
58+
4559
} // namespace mlir::arm_sme
4660

4761
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
13+
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
17+
namespace mlir {
18+
#define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
} // namespace mlir
21+
22+
#define DEBUG_TYPE "arith-to-arm-sme"
23+
24+
using namespace mlir;
25+
26+
//===----------------------------------------------------------------------===//
27+
// Conversion helpers
28+
//===----------------------------------------------------------------------===//
29+
30+
/// Returns true if 'val' is a splat of zero, false otherwise.
31+
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
32+
if (llvm::isa<FloatType>(elemType))
33+
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
34+
if (llvm::isa<IntegerType>(elemType))
35+
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
36+
return false;
37+
}
38+
39+
namespace {
40+
41+
//===----------------------------------------------------------------------===//
42+
// ConstantOp
43+
//===----------------------------------------------------------------------===//
44+
45+
/// Conversion pattern for dense arith.constant.
46+
struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
47+
using OpRewritePattern<arith::ConstantOp>::OpRewritePattern;
48+
49+
LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50+
PatternRewriter &rewriter) const final {
51+
auto tileType = dyn_cast<VectorType>(constantOp.getType());
52+
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
53+
return failure();
54+
55+
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56+
if (!denseAttr || !denseAttr.isSplat())
57+
return failure();
58+
59+
auto tileElementType = tileType.getElementType();
60+
61+
// Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62+
if (isSplatZero(tileElementType, denseAttr)) {
63+
rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
64+
return success();
65+
}
66+
67+
// Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
68+
// ops that broadcast the constant to each tile slice.
69+
auto loc = constantOp.getLoc();
70+
71+
// To fill a tile with a constant, we create a 1-D splat of the constant,
72+
// then move that into each tile slice (the largest unit we can set at once,
73+
// outside of operations like the outerproduct).
74+
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
75+
auto denseAttr1D = DenseElementsAttr::get(
76+
tileSliceType, denseAttr.getSplatValue<Attribute>());
77+
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
78+
79+
auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
80+
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
81+
Value currentTile) {
82+
// Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
83+
// slice.
84+
auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
85+
loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86+
return nextTile.getResult();
87+
};
88+
auto forOp = mlir::arm_sme::createLoopOverTileSlices(
89+
rewriter, loc, initTile, makeLoopBody);
90+
rewriter.replaceOp(constantOp, forOp.getResult(0));
91+
92+
return success();
93+
}
94+
};
95+
96+
} // namespace
97+
98+
//===----------------------------------------------------------------------===//
99+
// Pattern population
100+
//===----------------------------------------------------------------------===//
101+
102+
void mlir::arith::populateArithToArmSMEConversionPatterns(
103+
RewritePatternSet &patterns) {
104+
patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
105+
}
106+
107+
//===----------------------------------------------------------------------===//
108+
// Pass definition
109+
//===----------------------------------------------------------------------===//
110+
111+
namespace {
112+
struct ArithToArmSMEConversionPass final
113+
: impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
114+
using impl::ArithToArmSMEConversionPassBase<
115+
ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
116+
117+
void runOnOperation() override {
118+
RewritePatternSet patterns(&getContext());
119+
arith::populateArithToArmSMEConversionPatterns(patterns);
120+
if (failed(
121+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
122+
return signalPassFailure();
123+
}
124+
};
125+
} // namespace
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_conversion_library(MLIRArithToArmSME
2+
ArithToArmSME.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToArmSME
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRArmSMEDialect
15+
MLIRArithDialect
16+
MLIRPass
17+
MLIRTransforms
18+
)

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
22
add_subdirectory(AMDGPUToROCDL)
33
add_subdirectory(ArithCommon)
44
add_subdirectory(ArithToAMDGPU)
5+
add_subdirectory(ArithToArmSME)
56
add_subdirectory(ArithToLLVM)
67
add_subdirectory(ArithToSPIRV)
78
add_subdirectory(ArmNeon2dToIntr)

0 commit comments

Comments
 (0)