Skip to content

Commit 1a21196

Browse files
authored
[MLIR] reverse int8 type's printing logic (llvm#69361)
Specializing for 8-bit integers to ensure values are printed as integers Fixes llvm#69310
1 parent a587f42 commit 1a21196

File tree

4 files changed

+12
-24
lines changed

4 files changed

+12
-24
lines changed

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
5858

5959
let parameters = (ins
6060
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
61-
ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
62-
OptionalArrayRefParameter<"int8_t">:$partial_axes,
61+
ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
62+
OptionalArrayRefParameter<"int32_t">:$partial_axes,
6363
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
6464
);
6565

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
7070
}];
7171
let arguments = (ins
7272
SymbolNameAttr:$sym_name,
73-
I8Attr:$rank,
73+
I64Attr:$rank,
7474
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
7575
);
7676
let assemblyFormat = [{

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,7 @@ template <typename AsmPrinterT, typename T,
350350
!std::is_convertible<T &, Attribute &>::value &&
351351
!std::is_convertible<T &, ValueRange>::value &&
352352
!std::is_convertible<T &, APFloat &>::value &&
353-
!llvm::is_one_of<T, bool, int8_t, uint8_t, float,
354-
double>::value,
353+
!llvm::is_one_of<T, bool, float, double>::value,
355354
T> * = nullptr>
356355
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
357356
AsmPrinterT &>
@@ -367,17 +366,6 @@ operator<<(AsmPrinterT &p, bool value) {
367366
return p << (value ? StringRef("true") : "false");
368367
}
369368

370-
/// Specialization for 8-bit integers to ensure values are printed as integers
371-
// and not characters.
372-
template <
373-
typename AsmPrinterT, typename T,
374-
std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
375-
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
376-
AsmPrinterT &>
377-
operator<<(AsmPrinterT &p, T value) {
378-
return p << static_cast<int16_t>(value);
379-
}
380-
381369
template <typename AsmPrinterT, typename ValueRangeT>
382370
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
383371
AsmPrinterT &>

mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
4747

4848
LogicalResult ClusterOp::verify() {
4949
ArrayRef<int64_t> dimSizes = getDimSizes();
50-
uint8_t rank = getRank();
50+
uint64_t rank = getRank();
5151

5252
if (rank == 0)
5353
return emitOpError("rank of cluster is expected to be a positive integer");
@@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() {
7171

7272
LogicalResult
7373
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
74-
SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
75-
ArrayRef<int8_t> partialAxes, Partial) {
74+
SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
75+
ArrayRef<int32_t> partialAxes, Partial) {
7676
// TODO: At present cluster symbol ref is not verified. This is due to the
7777
// difficulty in fetching the corresponding symbol op based on an attribute.
7878

79-
llvm::SmallSet<int8_t, 4> visitedAxes;
79+
llvm::SmallSet<int32_t, 4> visitedAxes;
8080

81-
auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
82-
for (int8_t axis : axesArray) {
81+
auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
82+
for (int32_t axis : axesArray) {
8383
if (axis < 0)
8484
return emitError() << "mesh axis is expected to be non-negative";
8585
if (!visitedAxes.insert(axis).second)
@@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
8888
return success();
8989
};
9090

91-
for (DenseI8ArrayAttr subAxes : splitAxes) {
92-
ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
91+
for (DenseI32ArrayAttr subAxes : splitAxes) {
92+
ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
9393
if (failed(checkMeshAxis(subAxesArray)))
9494
return failure();
9595
}

0 commit comments

Comments
 (0)