Skip to content

Commit 429d792

Browse files
committed
[mlir] Add support for generating dialect declarations via tablegen.
Summary: This generates the class declarations for dialects using the existing 'Dialect' tablegen classes. Differential Revision: https://reviews.llvm.org/D76185
1 parent 27f3039 commit 429d792

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+388
-281
lines changed

mlir/cmake/modules/AddMLIR.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ function(whole_archive_link target)
2828
endfunction(whole_archive_link)
2929

3030
# Declare a dialect in the include directory
31-
function(add_mlir_dialect dialect dialect_doc_filename)
31+
function(add_mlir_dialect dialect dialect_namespace dialect_doc_filename)
3232
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
3333
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
3434
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
35+
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
3536
add_public_tablegen_target(MLIR${dialect}IncGen)
3637
add_dependencies(mlir-headers MLIR${dialect}IncGen)
3738

mlir/docs/CreatingADialect.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ is declared using add_mlir_dialect().
3939

4040
```cmake
4141
42-
add_mlir_dialect(FooOps FooOps)
42+
add_mlir_dialect(FooOps foo FooOps)
4343
4444
```
4545

mlir/include/mlir/Dialect/AffineOps/AffineOps.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,6 @@ class OpBuilder;
3636
/// symbol.
3737
bool isTopLevelValue(Value value);
3838

39-
class AffineOpsDialect : public Dialect {
40-
public:
41-
AffineOpsDialect(MLIRContext *context);
42-
static StringRef getDialectNamespace() { return "affine"; }
43-
44-
/// Materialize a single constant operation from a given attribute value with
45-
/// the desired resultant type.
46-
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
47-
Location loc) override;
48-
};
49-
5039
/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
5140
/// from a source memref to a destination memref. The source and destination
5241
/// memref need not be of the same dimensionality, but need to have the same
@@ -504,6 +493,8 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
504493
void fullyComposeAffineMapAndOperands(AffineMap *map,
505494
SmallVectorImpl<Value> *operands);
506495

496+
#include "mlir/Dialect/AffineOps/AffineOpsDialect.h.inc"
497+
507498
#define GET_OP_CLASSES
508499
#include "mlir/Dialect/AffineOps/AffineOps.h.inc"
509500

mlir/include/mlir/Dialect/AffineOps/AffineOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ include "mlir/Dialect/AffineOps/AffineOpsBase.td"
1717
include "mlir/Interfaces/LoopLikeInterface.td"
1818
include "mlir/Interfaces/SideEffects.td"
1919

20-
def Affine_Dialect : Dialect {
20+
def AffineOps_Dialect : Dialect {
2121
let name = "affine";
2222
let cppNamespace = "";
23+
let hasConstantMaterializer = 1;
2324
}
2425

2526
// Base class for Affine dialect ops.
2627
class Affine_Op<string mnemonic, list<OpTrait> traits = []> :
27-
Op<Affine_Dialect, mnemonic, traits> {
28+
Op<AffineOps_Dialect, mnemonic, traits> {
2829
// For every affine op, there needs to be a:
2930
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
3031
// * LogicalResult verify(${C++ class of Op} op)
@@ -290,7 +291,7 @@ def AffineIfOp : Affine_Op<"if",
290291
}
291292

292293
class AffineMinMaxOpBase<string mnemonic, list<OpTrait> traits = []> :
293-
Op<Affine_Dialect, mnemonic, traits> {
294+
Op<AffineOps_Dialect, mnemonic, traits> {
294295
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
295296
let results = (outs Index);
296297

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
add_mlir_dialect(AffineOps AffineOps)
1+
add_mlir_dialect(AffineOps affine AffineOps)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
add_mlir_dialect(FxpMathOps FxpMathOps)
1+
add_mlir_dialect(FxpMathOps fxpmath FxpMathOps)

mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717
namespace mlir {
1818
namespace fxpmath {
1919

20-
/// Defines the 'FxpMathOps' dialect.
21-
class FxpMathOpsDialect : public Dialect {
22-
public:
23-
FxpMathOpsDialect(MLIRContext *context);
24-
};
20+
#include "mlir/Dialect/FxpMathOps/FxpMathOpsDialect.h.inc"
2521

2622
#define GET_OP_CLASSES
2723
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc"

mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
#define DIALECT_FXPMATHOPS_FXPMATH_OPS_
1616

1717
include "mlir/IR/OpBase.td"
18-
include "mlir/Dialect/QuantOps/QuantPredicates.td"
18+
include "mlir/Dialect/QuantOps/QuantOpsBase.td"
1919
include "mlir/Interfaces/SideEffects.td"
2020

21-
def fxpmath_Dialect : Dialect {
21+
def FxpMathOps_Dialect : Dialect {
2222
let name = "fxpmath";
2323
}
2424

@@ -78,7 +78,7 @@ def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn",
7878
//===----------------------------------------------------------------------===//
7979

8080
class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
81-
Op<fxpmath_Dialect, mnemonic, traits>;
81+
Op<FxpMathOps_Dialect, mnemonic, traits>;
8282

8383
//===----------------------------------------------------------------------===//
8484
// Fixed-point (fxp) arithmetic ops used by kernels.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
add_mlir_dialect(GPUOps GPUOps)
1+
add_mlir_dialect(GPUOps gpu GPUOps)

mlir/include/mlir/Dialect/GPU/GPUDialect.h

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -26,51 +26,6 @@ class FuncOp;
2626

2727
namespace gpu {
2828

29-
/// The dialect containing GPU kernel launching operations and related
30-
/// facilities.
31-
class GPUDialect : public Dialect {
32-
public:
33-
/// Create the dialect in the given `context`.
34-
explicit GPUDialect(MLIRContext *context);
35-
/// Get dialect namespace.
36-
static StringRef getDialectNamespace() { return "gpu"; }
37-
38-
/// Get the name of the attribute used to annotate the modules that contain
39-
/// kernel modules.
40-
static StringRef getContainerModuleAttrName() {
41-
return "gpu.container_module";
42-
}
43-
44-
/// Get the canonical string name of the dialect.
45-
static StringRef getDialectName();
46-
47-
/// Get the name of the attribute used to annotate external kernel functions.
48-
static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
49-
50-
/// Get the name of the attribute used to annotate kernel modules.
51-
static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
52-
53-
/// Returns whether the given function is a kernel function, i.e., has the
54-
/// 'gpu.kernel' attribute.
55-
static bool isKernel(Operation *op);
56-
57-
/// Returns the number of workgroup (thread, block) dimensions supported in
58-
/// the GPU dialect.
59-
// TODO(zinenko,herhut): consider generalizing this.
60-
static unsigned getNumWorkgroupDimensions() { return 3; }
61-
62-
/// Returns the numeric value used to identify the workgroup memory address
63-
/// space.
64-
static unsigned getWorkgroupAddressSpace() { return 3; }
65-
66-
/// Returns the numeric value used to identify the private memory address
67-
/// space.
68-
static unsigned getPrivateAddressSpace() { return 5; }
69-
70-
LogicalResult verifyOperationAttribute(Operation *op,
71-
NamedAttribute attr) override;
72-
};
73-
7429
/// Utility class for the GPU dialect to represent triples of `Value`s
7530
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
7631
struct KernelDim3 {
@@ -79,6 +34,8 @@ struct KernelDim3 {
7934
Value z;
8035
};
8136

37+
#include "mlir/Dialect/GPU/GPUOpsDialect.h.inc"
38+
8239
#define GET_OP_CLASSES
8340
#include "mlir/Dialect/GPU/GPUOps.h.inc"
8441

mlir/include/mlir/Dialect/GPU/GPUOps.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,39 @@ def IntLikeOrLLVMInt : TypeConstraint<
2828

2929
def GPU_Dialect : Dialect {
3030
let name = "gpu";
31+
let extraClassDeclaration = [{
32+
/// Get the name of the attribute used to annotate the modules that contain
33+
/// kernel modules.
34+
static StringRef getContainerModuleAttrName() {
35+
return "gpu.container_module";
36+
}
37+
/// Get the name of the attribute used to annotate external kernel
38+
/// functions.
39+
static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
40+
41+
/// Get the name of the attribute used to annotate kernel modules.
42+
static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
43+
44+
/// Returns whether the given function is a kernel function, i.e., has the
45+
/// 'gpu.kernel' attribute.
46+
static bool isKernel(Operation *op);
47+
48+
/// Returns the number of workgroup (thread, block) dimensions supported in
49+
/// the GPU dialect.
50+
// TODO(zinenko,herhut): consider generalizing this.
51+
static unsigned getNumWorkgroupDimensions() { return 3; }
52+
53+
/// Returns the numeric value used to identify the workgroup memory address
54+
/// space.
55+
static unsigned getWorkgroupAddressSpace() { return 3; }
56+
57+
/// Returns the numeric value used to identify the private memory address
58+
/// space.
59+
static unsigned getPrivateAddressSpace() { return 5; }
60+
61+
LogicalResult verifyOperationAttribute(Operation *op,
62+
NamedAttribute attr) override;
63+
}];
3164
}
3265

3366
class GPU_Op<string mnemonic, list<OpTrait> traits = []> :

mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
22
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
33
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
4+
mlir_tablegen(LLVMOpsDialect.h.inc -gen-dialect-decls)
45
mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
56
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
67
add_public_tablegen_target(MLIRLLVMOpsIncGen)
78

8-
add_mlir_dialect(NVVMOps NVVMOps)
9-
add_mlir_dialect(ROCDLOps ROCDLOps)
9+
add_mlir_dialect(NVVMOps nvvm NVVMOps)
10+
add_mlir_dialect(ROCDLOps rocdl ROCDLOps)
1011

1112
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
1213
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -201,32 +201,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
201201
#define GET_OP_CLASSES
202202
#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc"
203203

204-
class LLVMDialect : public Dialect {
205-
public:
206-
explicit LLVMDialect(MLIRContext *context);
207-
~LLVMDialect();
208-
static StringRef getDialectNamespace() { return "llvm"; }
209-
210-
llvm::LLVMContext &getLLVMContext();
211-
llvm::Module &getLLVMModule();
212-
213-
/// Parse a type registered to this dialect.
214-
Type parseType(DialectAsmParser &parser) const override;
215-
216-
/// Print a type registered to this dialect.
217-
void printType(Type type, DialectAsmPrinter &os) const override;
218-
219-
/// Verify a region argument attribute registered to this dialect.
220-
/// Returns failure if the verification failed, success otherwise.
221-
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
222-
unsigned argIdx,
223-
NamedAttribute argAttr) override;
224-
225-
private:
226-
friend LLVMType;
227-
228-
std::unique_ptr<detail::LLVMDialectImpl> impl;
229-
};
204+
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.h.inc"
230205

231206
/// Create an LLVM global containing the string "value" at the module containing
232207
/// surrounding the insertion point of builder. Obtain the address of that

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,28 @@ include "mlir/IR/OpBase.td"
1919
def LLVM_Dialect : Dialect {
2020
let name = "llvm";
2121
let cppNamespace = "LLVM";
22+
let extraClassDeclaration = [{
23+
~LLVMDialect();
24+
llvm::LLVMContext &getLLVMContext();
25+
llvm::Module &getLLVMModule();
26+
27+
/// Verify a region argument attribute registered to this dialect.
28+
/// Returns failure if the verification failed, success otherwise.
29+
LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
30+
unsigned argIdx,
31+
NamedAttribute argAttr) override;
32+
33+
private:
34+
friend LLVMType;
35+
36+
std::unique_ptr<detail::LLVMDialectImpl> impl;
37+
}];
2238
}
2339

2440
// LLVM IR type wrapped in MLIR.
25-
def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
26-
"LLVM dialect type">;
41+
def LLVM_Type : DialectType<LLVM_Dialect,
42+
CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
43+
"LLVM dialect type">;
2744

2845
// Type constraint accepting only wrapped LLVM integer types.
2946
def LLVMInt : TypeConstraint<

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,7 @@ namespace NVVM {
2525
#define GET_OP_CLASSES
2626
#include "mlir/Dialect/LLVMIR/NVVMOps.h.inc"
2727

28-
class NVVMDialect : public Dialect {
29-
public:
30-
explicit NVVMDialect(MLIRContext *context);
31-
32-
static StringRef getDialectNamespace() { return "nvvm"; }
33-
};
28+
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.h.inc"
3429

3530
} // namespace NVVM
3631
} // namespace mlir

mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,7 @@ namespace ROCDL {
3333
#define GET_OP_CLASSES
3434
#include "mlir/Dialect/LLVMIR/ROCDLOps.h.inc"
3535

36-
class ROCDLDialect : public Dialect {
37-
public:
38-
explicit ROCDLDialect(MLIRContext *context);
39-
40-
static StringRef getDialectNamespace() { return "rocdl"; }
41-
};
36+
#include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.h.inc"
4237

4338
} // namespace ROCDL
4439
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
add_mlir_dialect(LinalgOps LinalgDoc)
1+
add_mlir_dialect(LinalgOps linalg LinalgDoc)
22
set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td)
33
mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls)
44
mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs)

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ def Linalg_Dialect : Dialect {
3434

3535
// Whether a type is a RangeType.
3636
def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
37-
def Range : Type<LinalgIsRangeTypePred, "range">;
37+
def Range : DialectType<Linalg_Dialect, LinalgIsRangeTypePred, "range">;
3838

3939
#endif // LINALG_BASE

mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,7 @@ enum LinalgTypes {
2121
LAST_USED_LINALG_TYPE = Range,
2222
};
2323

24-
class LinalgDialect : public Dialect {
25-
public:
26-
explicit LinalgDialect(MLIRContext *context);
27-
static StringRef getDialectNamespace() { return "linalg"; }
28-
29-
/// Parse a type registered to this dialect.
30-
Type parseType(DialectAsmParser &parser) const override;
31-
32-
/// Print a type registered to this dialect.
33-
void printType(Type type, DialectAsmPrinter &os) const override;
34-
};
24+
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
3525

3626
/// A RangeType represents a minimal range abstraction (min, max, step).
3727
/// It is constructed by calling the linalg.range op with three values index of
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
add_mlir_dialect(LoopOps LoopOps)
1+
add_mlir_dialect(LoopOps loop LoopOps)

mlir/include/mlir/Dialect/LoopOps/LoopOps.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,7 @@ namespace loop {
2525

2626
class TerminatorOp;
2727

28-
class LoopOpsDialect : public Dialect {
29-
public:
30-
LoopOpsDialect(MLIRContext *context);
31-
static StringRef getDialectNamespace() { return "loop"; }
32-
};
28+
#include "mlir/Dialect/LoopOps/LoopOpsDialect.h.inc"
3329

3430
#define GET_OP_CLASSES
3531
#include "mlir/Dialect/LoopOps/LoopOps.h.inc"

mlir/include/mlir/Dialect/LoopOps/LoopOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
include "mlir/Interfaces/LoopLikeInterface.td"
1717
include "mlir/Interfaces/SideEffects.td"
1818

19-
def Loop_Dialect : Dialect {
19+
def LoopOps_Dialect : Dialect {
2020
let name = "loop";
2121
let cppNamespace = "";
2222
}
2323

2424
// Base class for Loop dialect ops.
2525
class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
26-
Op<Loop_Dialect, mnemonic, traits> {
26+
Op<LoopOps_Dialect, mnemonic, traits> {
2727
// For every standard op, there needs to be a:
2828
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
2929
// * LogicalResult verify(${C++ class of Op} op)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
add_mlir_dialect(OpenMPOps OpenMPOps)
1+
add_mlir_dialect(OpenMPOps omp OpenMPOps)

0 commit comments

Comments
 (0)