Skip to content

Commit 04686fc

Browse files
committed
[CIR] Refactor IntType constraints
- Adds `CIR_` prefixes to integer type constraints types to disambiguate their names from other dialects. - Renames `PrimitiveInt` to `CIR_AnyFundamentalIntType` to align more with constrian conventions. - Adds bunch of helper constraint classes to be able to define base types to reduce clutter of necessary type casts. - Reworks constraints to use `CIR_ConfinedType` to avoid repeating validation checks. - Adds `IntOfWidths` variadic bitwidth constraint to reduce boilerplate code needed to handle multi-bitwidth parameters. - Constraints are moved into a separate file, which starts decoupling of constraints and types to remove the cyclic dependency between types and attributes and will eventually help fix several outstanding TODOs. This mirrors incubator changes from llvm/clangir#1593
1 parent 9693bf4 commit 04686fc

File tree

7 files changed

+154
-66
lines changed

7 files changed

+154
-66
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,11 @@ def PtrStrideOp : CIR_Op<"ptr_stride",
254254
```
255255
}];
256256

257-
let arguments = (ins CIR_PointerType:$base, PrimitiveInt:$stride);
257+
let arguments = (ins
258+
CIR_PointerType:$base,
259+
CIR_AnyFundamentalIntType:$stride
260+
);
261+
258262
let results = (outs CIR_PointerType:$result);
259263

260264
let assemblyFormat = [{
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines the CIR dialect type constraints.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
14+
#define CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
15+
16+
include "mlir/IR/Constraints.td"
17+
include "mlir/IR/CommonTypeConstraints.td"
18+
19+
class CIR_IsTypePred<code type> : CPred<"::mlir::isa<" # type # ">($_self)">;
20+
21+
class CIR_TypeBase<code type, string summary = "">
22+
: Type<CIR_IsTypePred<type>, summary, type>;
23+
24+
class CIR_CastSelfToType<code type, Pred pred>
25+
: SubstLeaves<"$_self", "::mlir::cast<" # type # ">($_self)", pred>;
26+
27+
class CIR_CastedSelfsToType<code type, list<Pred> preds>
28+
: And<!foreach(pred, preds, CIR_CastSelfToType<type, pred>)>;
29+
30+
class CIR_ConfinedType<Type type, list<Pred> preds, string summary = "">
31+
: Type<And<[type.predicate, CIR_CastedSelfsToType<type.cppType, preds>]>,
32+
summary, type.cppType>;
33+
34+
//===----------------------------------------------------------------------===//
35+
// IntType predicates
36+
//===----------------------------------------------------------------------===//
37+
38+
def CIR_AnyIntType : CIR_TypeBase<"::cir::IntType", "integer type">;
39+
40+
def CIR_AnyUIntType : CIR_ConfinedType<CIR_AnyIntType, [
41+
CPred<"$_self.isUnsigned()">], "unsigned integer type">;
42+
43+
def CIR_AnySIntType : CIR_ConfinedType<CIR_AnyIntType, [
44+
CPred<"$_self.isSigned()">], "signed integer type">;
45+
46+
class CIR_HasWidthPred<int width> : CPred<"$_self.getWidth() == " # width>;
47+
48+
def CIR_HasFundamentalIntWidthPred
49+
: CPred<"::cir::isValidFundamentalIntWidth($_self.getWidth())">;
50+
51+
class CIR_IntOfWidthsPred<list<int> widths>
52+
: Or<!foreach(width, widths, CIR_HasWidthPred<width>)>;
53+
54+
class CIR_IntOfWidths<list<int> widths>
55+
: CIR_ConfinedType<CIR_AnyIntType, [CIR_IntOfWidthsPred<widths>],
56+
"integer type of widths " # !interleave(widths, "/")>;
57+
58+
class CIR_SIntOfWidths<list<int> widths>
59+
: CIR_ConfinedType<CIR_AnySIntType, [CIR_IntOfWidthsPred<widths>],
60+
"signed integer type of widths " # !interleave(widths, "/")>;
61+
62+
class CIR_UIntOfWidths<list<int> widths>
63+
: CIR_ConfinedType<CIR_AnyUIntType, [CIR_IntOfWidthsPred<widths>],
64+
"unsigned integer type of widths " # !interleave(widths, "/")>;
65+
66+
class CIR_UInt<int width>
67+
: CIR_ConfinedType<CIR_AnyUIntType, [CIR_HasWidthPred<width>],
68+
width # "-bit unsigned integer">,
69+
BuildableType<"$_builder.getType<" # cppType # ">(" #
70+
width # ", /*isSigned=*/false)">;
71+
72+
def CIR_UInt1 : CIR_UInt<1>;
73+
def CIR_UInt8 : CIR_UInt<8>;
74+
def CIR_UInt16 : CIR_UInt<16>;
75+
def CIR_UInt32 : CIR_UInt<32>;
76+
def CIR_UInt64 : CIR_UInt<64>;
77+
def CIR_UInt128 : CIR_UInt<128>;
78+
79+
class CIR_SInt<int width>
80+
: CIR_ConfinedType<CIR_AnySIntType, [CIR_HasWidthPred<width>],
81+
width # "-bit signed integer">,
82+
BuildableType<"$_builder.getType<" # cppType # ">(" #
83+
width # ", /*isSigned=*/true)">;
84+
85+
def CIR_SInt1 : CIR_SInt<1>;
86+
def CIR_SInt8 : CIR_SInt<8>;
87+
def CIR_SInt16 : CIR_SInt<16>;
88+
def CIR_SInt32 : CIR_SInt<32>;
89+
def CIR_SInt64 : CIR_SInt<64>;
90+
def CIR_SInt128 : CIR_SInt<128>;
91+
92+
// Fundamental integer types represent standard source-level integer types that
93+
// have a specified set of admissible bitwidths (8, 16, 32, 64).
94+
95+
def CIR_AnyFundamentalIntType
96+
: CIR_ConfinedType<CIR_AnyIntType, [CIR_HasFundamentalIntWidthPred],
97+
"fundamental integer type"> {
98+
let cppFunctionName = "isFundamentalIntType";
99+
}
100+
101+
def CIR_AnyFundamentalUIntType
102+
: CIR_ConfinedType<CIR_AnyUIntType, [CIR_HasFundamentalIntWidthPred],
103+
"fundamental unsigned integer type"> {
104+
let cppFunctionName = "isFundamentalUIntType";
105+
}
106+
107+
def CIR_AnyFundamentalSIntType
108+
: CIR_ConfinedType<CIR_AnySIntType, [CIR_HasFundamentalIntWidthPred],
109+
"fundamental signed integer type"> {
110+
let cppFunctionName = "isFundamentalSIntType";
111+
}
112+
113+
#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD

clang/include/clang/CIR/Dialect/IR/CIRTypes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ namespace detail {
2424
struct RecordTypeStorage;
2525
} // namespace detail
2626

27+
bool isValidFundamentalIntWidth(unsigned width);
28+
2729
bool isAnyFloatingPointType(mlir::Type t);
2830
bool isFPOrFPVectorTy(mlir::Type);
2931

@@ -33,6 +35,12 @@ bool isFPOrFPVectorTy(mlir::Type);
3335
// CIR Dialect Tablegen'd Types
3436
//===----------------------------------------------------------------------===//
3537

38+
namespace cir {
39+
40+
#include "clang/CIR/Dialect/IR/CIRTypeConstraints.h.inc"
41+
42+
} // namespace cir
43+
3644
#define GET_TYPEDEF_CLASSES
3745
#include "clang/CIR/Dialect/IR/CIROpsTypes.h.inc"
3846

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 11 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_CIR_DIALECT_CIR_TYPES
1515

1616
include "clang/CIR/Dialect/IR/CIRDialect.td"
17+
include "clang/CIR/Dialect/IR/CIRTypeConstraints.td"
1718
include "clang/CIR/Interfaces/CIRFPTypeInterface.td"
1819
include "mlir/Interfaces/DataLayoutInterfaces.td"
1920
include "mlir/IR/AttrTypeBase.td"
@@ -41,7 +42,7 @@ def CIR_IntType : CIR_Type<"Int", "int",
4142
such as `__int128`, and arbitrary width types such as `_BitInt(n)`.
4243

4344
Those integer types that are directly available in C/C++ standard are called
44-
primitive integer types. Said types are: `signed char`, `short`, `int`,
45+
fundamental integer types. Said types are: `signed char`, `short`, `int`,
4546
`long`, `long long`, and their unsigned variations.
4647
}];
4748
let parameters = (ins "unsigned":$width, "bool":$isSigned);
@@ -55,81 +56,26 @@ def CIR_IntType : CIR_Type<"Int", "int",
5556
std::string getAlias() const {
5657
return (isSigned() ? 's' : 'u') + std::to_string(getWidth()) + 'i';
5758
}
58-
/// Return true if this is a primitive integer type (i.e. signed or unsigned
59-
/// integer types whose bit width is 8, 16, 32, or 64).
60-
bool isPrimitive() const {
61-
return isValidPrimitiveIntBitwidth(getWidth());
59+
/// Return true if this is a fundamental integer type (i.e. signed or
60+
/// unsigned integer types whose bit width is 8, 16, 32, or 64).
61+
bool isFundamental() const {
62+
return isFundamentalIntType(*this);
6263
}
63-
bool isSignedPrimitive() const {
64-
return isPrimitive() && isSigned();
64+
bool isSignedFundamental() const {
65+
return isFundamentalSIntType(*this);
66+
}
67+
bool isUnsignedFundamental() const {
68+
return isFundamentalUIntType(*this);
6569
}
6670

6771
/// Returns a minimum bitwidth of cir::IntType
6872
static unsigned minBitwidth() { return 1; }
6973
/// Returns a maximum bitwidth of cir::IntType
7074
static unsigned maxBitwidth() { return 128; }
71-
72-
/// Returns true if cir::IntType that represents a primitive integer type
73-
/// can be constructed from the provided bitwidth.
74-
static bool isValidPrimitiveIntBitwidth(unsigned width) {
75-
return width == 8 || width == 16 || width == 32 || width == 64;
76-
}
7775
}];
7876
let genVerifyDecl = 1;
7977
}
8078

81-
// Constraints
82-
83-
// Unsigned integer type of a specific width.
84-
class UInt<int width>
85-
: Type<And<[
86-
CPred<"::mlir::isa<::cir::IntType>($_self)">,
87-
CPred<"::mlir::cast<::cir::IntType>($_self).isUnsigned()">,
88-
CPred<"::mlir::cast<::cir::IntType>($_self).getWidth() == " # width>
89-
]>, width # "-bit unsigned integer", "::cir::IntType">,
90-
BuildableType<
91-
"cir::IntType::get($_builder.getContext(), "
92-
# width # ", /*isSigned=*/false)"> {
93-
int bitwidth = width;
94-
}
95-
96-
def UInt1 : UInt<1>;
97-
def UInt8 : UInt<8>;
98-
def UInt16 : UInt<16>;
99-
def UInt32 : UInt<32>;
100-
def UInt64 : UInt<64>;
101-
102-
// Signed integer type of a specific width.
103-
class SInt<int width>
104-
: Type<And<[
105-
CPred<"::mlir::isa<::cir::IntType>($_self)">,
106-
CPred<"::mlir::cast<::cir::IntType>($_self).isSigned()">,
107-
CPred<"::mlir::cast<::cir::IntType>($_self).getWidth() == " # width>
108-
]>, width # "-bit signed integer", "::cir::IntType">,
109-
BuildableType<
110-
"cir::IntType::get($_builder.getContext(), "
111-
# width # ", /*isSigned=*/true)"> {
112-
int bitwidth = width;
113-
}
114-
115-
def SInt1 : SInt<1>;
116-
def SInt8 : SInt<8>;
117-
def SInt16 : SInt<16>;
118-
def SInt32 : SInt<32>;
119-
def SInt64 : SInt<64>;
120-
121-
def PrimitiveUInt
122-
: AnyTypeOf<[UInt8, UInt16, UInt32, UInt64], "primitive unsigned int",
123-
"::cir::IntType">;
124-
125-
def PrimitiveSInt
126-
: AnyTypeOf<[SInt8, SInt16, SInt32, SInt64], "primitive signed int",
127-
"::cir::IntType">;
128-
129-
def PrimitiveInt
130-
: AnyTypeOf<[UInt8, UInt16, UInt32, UInt64, SInt8, SInt16, SInt32, SInt64],
131-
"primitive int", "::cir::IntType">;
132-
13379
//===----------------------------------------------------------------------===//
13480
// FloatType
13581
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@ mlir_tablegen(CIROpsAttributes.cpp.inc -gen-attrdef-defs)
1919
mlir_tablegen(CIROpsEnums.h.inc -gen-enum-decls)
2020
mlir_tablegen(CIROpsEnums.cpp.inc -gen-enum-defs)
2121
add_public_tablegen_target(MLIRCIREnumsGen)
22+
23+
set(LLVM_TARGET_DEFINITIONS CIRTypeConstraints.td)
24+
mlir_tablegen(CIRTypeConstraints.h.inc -gen-type-constraint-decls)
25+
mlir_tablegen(CIRTypeConstraints.cpp.inc -gen-type-constraint-defs)
26+
add_public_tablegen_target(MLIRCIRTypeConstraintsIncGen)
27+
add_dependencies(mlir-headers MLIRCIRTypeConstraintsIncGen)

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ static void printFuncTypeParams(mlir::AsmPrinter &p,
3333
// Get autogenerated stuff
3434
//===----------------------------------------------------------------------===//
3535

36+
namespace cir {
37+
38+
#include "clang/CIR/Dialect/IR/CIRTypeConstraints.cpp.inc"
39+
40+
} // namespace cir
41+
3642
#define GET_TYPEDEF_CLASSES
3743
#include "clang/CIR/Dialect/IR/CIROpsTypes.cpp.inc"
3844

@@ -424,6 +430,10 @@ IntType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
424430
return mlir::success();
425431
}
426432

433+
bool cir::isValidFundamentalIntWidth(unsigned width) {
434+
return width == 8 || width == 16 || width == 32 || width == 64;
435+
}
436+
427437
//===----------------------------------------------------------------------===//
428438
// Floating-point type definitions
429439
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_clang_library(MLIRCIR
66

77
DEPENDS
88
MLIRCIROpsIncGen
9+
MLIRCIRTypeConstraintsIncGen
910
MLIRCIREnumsGen
1011
MLIRCIROpInterfacesIncGen
1112
MLIRCIRLoopOpInterfaceIncGen

0 commit comments

Comments
 (0)