Skip to content

Commit 4fba3b9

Browse files
committed
[mlir][SMT] C APIs
1 parent acf964b commit 4fba3b9

File tree

8 files changed

+491
-0
lines changed

8 files changed

+491
-0
lines changed

mlir/include/mlir-c/Dialect/SMT.h

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//===- SMT.h - C interface for the SMT dialect --------------------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_C_DIALECT_SMT_H
10+
#define MLIR_C_DIALECT_SMT_H
11+
12+
#include "mlir-c/IR.h"
13+
14+
#ifdef __cplusplus
15+
extern "C" {
16+
#endif
17+
18+
//===----------------------------------------------------------------------===//
19+
// Dialect API.
20+
//===----------------------------------------------------------------------===//
21+
22+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt);
23+
24+
//===----------------------------------------------------------------------===//
25+
// Type API.
26+
//===----------------------------------------------------------------------===//
27+
28+
/// Checks if the given type is any non-func SMT value type.
29+
MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type);
30+
31+
/// Checks if the given type is any SMT value type.
32+
MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type);
33+
34+
/// Checks if the given type is a smt::ArrayType.
35+
MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type);
36+
37+
/// Creates an array type with the given domain and range types.
38+
MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx,
39+
MlirType domainType,
40+
MlirType rangeType);
41+
42+
/// Checks if the given type is a smt::BitVectorType.
43+
MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type);
44+
45+
/// Creates a smt::BitVectorType with the given width.
46+
MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width);
47+
48+
/// Checks if the given type is a smt::BoolType.
49+
MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type);
50+
51+
/// Creates a smt::BoolType.
52+
MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx);
53+
54+
/// Checks if the given type is a smt::IntType.
55+
MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type);
56+
57+
/// Creates a smt::IntType.
58+
MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx);
59+
60+
/// Checks if the given type is a smt::FuncType.
61+
MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type);
62+
63+
/// Creates a smt::FuncType with the given domain and range types.
64+
MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx,
65+
size_t numberOfDomainTypes,
66+
const MlirType *domainTypes,
67+
MlirType rangeType);
68+
69+
/// Checks if the given type is a smt::SortType.
70+
MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type);
71+
72+
/// Creates a smt::SortType with the given identifier and sort parameters.
73+
MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx,
74+
MlirIdentifier identifier,
75+
size_t numberOfSortParams,
76+
const MlirType *sortParams);
77+
78+
//===----------------------------------------------------------------------===//
79+
// Attribute API.
80+
//===----------------------------------------------------------------------===//
81+
82+
/// Checks if the given string is a valid smt::BVCmpPredicate.
83+
MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx,
84+
MlirStringRef str);
85+
86+
/// Checks if the given string is a valid smt::IntPredicate.
87+
MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx,
88+
MlirStringRef str);
89+
90+
/// Checks if the given attribute is a smt::SMTAttribute.
91+
MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr);
92+
93+
/// Creates a smt::BitVectorAttr with the given value and width.
94+
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx,
95+
uint64_t value,
96+
unsigned width);
97+
98+
/// Creates a smt::BVCmpPredicateAttr with the given string.
99+
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx,
100+
MlirStringRef str);
101+
102+
/// Creates a smt::IntPredicateAttr with the given string.
103+
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx,
104+
MlirStringRef str);
105+
106+
#ifdef __cplusplus
107+
}
108+
#endif
109+
110+
#endif // MLIR_C_DIALECT_SMT_H
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- MLIR-c/ExportSMTLIB.h - C API for emitting SMTLIB ---------*- C -*-===//
2+
//
3+
// This header declares the C interface for emitting SMTLIB from a MLIR MLIR
4+
// module.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#ifndef MLIR_C_EXPORTSMTLIB_H
9+
#define MLIR_C_EXPORTSMTLIB_H
10+
11+
#include "mlir-c/IR.h"
12+
13+
#ifdef __cplusplus
14+
extern "C" {
15+
#endif
16+
17+
/// Emits SMTLIB for the specified module using the provided callback and user
18+
/// data
19+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule,
20+
MlirStringCallback,
21+
void *userData);
22+
23+
#ifdef __cplusplus
24+
}
25+
#endif
26+
27+
#endif // MLIR_C_EXPORTSMTLIB_H

mlir/lib/CAPI/Dialect/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,11 @@ add_mlir_upstream_c_api_library(MLIRCAPIVector
269269
MLIRCAPIIR
270270
MLIRVectorDialect
271271
)
272+
273+
add_mlir_upstream_c_api_library(MLIRCAPISMT
274+
SMT.cpp
275+
276+
LINK_LIBS PUBLIC
277+
MLIRCAPIIR
278+
MLIRSMT
279+
)

mlir/lib/CAPI/Dialect/SMT.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//===- SMT.cpp - C interface for the SMT dialect --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/SMT.h"
10+
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
11+
#include "mlir/Dialect/SMT/IR/SMTOps.h"
12+
#include "mlir/CAPI/Registration.h"
13+
14+
using namespace mlir;
15+
using namespace smt;
16+
17+
//===----------------------------------------------------------------------===//
18+
// Dialect API.
19+
//===----------------------------------------------------------------------===//
20+
21+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect)
22+
23+
//===----------------------------------------------------------------------===//
24+
// Type API.
25+
//===----------------------------------------------------------------------===//
26+
27+
bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) {
28+
return isAnyNonFuncSMTValueType(unwrap(type));
29+
}
30+
31+
bool smtTypeIsAnySMTValueType(MlirType type) {
32+
return isAnySMTValueType(unwrap(type));
33+
}
34+
35+
bool smtTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
36+
37+
MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType,
38+
MlirType rangeType) {
39+
return wrap(
40+
ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType)));
41+
}
42+
43+
bool smtTypeIsABitVector(MlirType type) {
44+
return isa<BitVectorType>(unwrap(type));
45+
}
46+
47+
MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) {
48+
return wrap(BitVectorType::get(unwrap(ctx), width));
49+
}
50+
51+
bool smtTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
52+
53+
MlirType smtTypeGetBool(MlirContext ctx) {
54+
return wrap(BoolType::get(unwrap(ctx)));
55+
}
56+
57+
bool smtTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
58+
59+
MlirType smtTypeGetInt(MlirContext ctx) {
60+
return wrap(IntType::get(unwrap(ctx)));
61+
}
62+
63+
bool smtTypeIsASMTFunc(MlirType type) { return isa<SMTFuncType>(unwrap(type)); }
64+
65+
MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
66+
const MlirType *domainTypes, MlirType rangeType) {
67+
SmallVector<Type> domainTypesVec;
68+
domainTypesVec.reserve(numberOfDomainTypes);
69+
70+
for (size_t i = 0; i < numberOfDomainTypes; i++)
71+
domainTypesVec.push_back(unwrap(domainTypes[i]));
72+
73+
return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType)));
74+
}
75+
76+
bool smtTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
77+
78+
MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
79+
size_t numberOfSortParams, const MlirType *sortParams) {
80+
SmallVector<Type> sortParamsVec;
81+
sortParamsVec.reserve(numberOfSortParams);
82+
83+
for (size_t i = 0; i < numberOfSortParams; i++)
84+
sortParamsVec.push_back(unwrap(sortParams[i]));
85+
86+
return wrap(SortType::get(unwrap(ctx), unwrap(identifier), sortParamsVec));
87+
}
88+
89+
//===----------------------------------------------------------------------===//
90+
// Attribute API.
91+
//===----------------------------------------------------------------------===//
92+
93+
bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
94+
return symbolizeBVCmpPredicate(unwrap(str)).has_value();
95+
}
96+
97+
bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
98+
return symbolizeIntPredicate(unwrap(str)).has_value();
99+
}
100+
101+
bool smtAttrIsASMTAttribute(MlirAttribute attr) {
102+
return isa<BitVectorAttr, BVCmpPredicateAttr, IntPredicateAttr>(unwrap(attr));
103+
}
104+
105+
MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value,
106+
unsigned width) {
107+
return wrap(BitVectorAttr::get(unwrap(ctx), value, width));
108+
}
109+
110+
MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
111+
auto predicate = symbolizeBVCmpPredicate(unwrap(str));
112+
assert(predicate.has_value() && "invalid predicate");
113+
114+
return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value()));
115+
}
116+
117+
MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
118+
auto predicate = symbolizeIntPredicate(unwrap(str));
119+
assert(predicate.has_value() && "invalid predicate");
120+
121+
return wrap(IntPredicateAttr::get(unwrap(ctx), predicate.value()));
122+
}

mlir/lib/CAPI/Target/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
add_mlir_upstream_c_api_library(MLIRCAPITarget
22
LLVMIR.cpp
33

4+
PARTIAL_SOURCES_INTENDED
5+
46
LINK_COMPONENTS
57
Core
68

@@ -11,3 +13,13 @@ add_mlir_upstream_c_api_library(MLIRCAPITarget
1113
MLIRLLVMIRToLLVMTranslation
1214
MLIRSupport
1315
)
16+
17+
add_mlir_upstream_c_api_library(MLIRCAPIExportSMTLIB
18+
ExportSMTLIB.cpp
19+
20+
PARTIAL_SOURCES_INTENDED
21+
22+
LINK_LIBS PUBLIC
23+
MLIRCAPIIR
24+
MLIRExportSMTLIB
25+
)

mlir/lib/CAPI/Target/ExportSMTLIB.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- ExportSMTLIB.cpp - C Interface to ExportSMTLIB ---------------------===//
2+
//
3+
// Implements a C Interface for export SMTLIB.
4+
//
5+
//===----------------------------------------------------------------------===//
6+
7+
#include "mlir-c/Target/ExportSMTLIB.h"
8+
#include "mlir/CAPI/IR.h"
9+
#include "mlir/CAPI/Support.h"
10+
#include "mlir/CAPI/Utils.h"
11+
#include "mlir/Target/SMTLIB/ExportSMTLIB.h"
12+
13+
using namespace mlir;
14+
15+
MlirLogicalResult mlirExportSMTLIB(MlirModule module,
16+
MlirStringCallback callback,
17+
void *userData) {
18+
mlir::detail::CallbackOstream stream(callback, userData);
19+
return wrap(smt::exportSMTLIB(unwrap(module), stream));
20+
}

mlir/test/CAPI/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,13 @@ _add_capi_test_executable(mlir-capi-translation-test
123123
MLIRCAPIRegisterEverything
124124
MLIRCAPITarget
125125
)
126+
127+
_add_capi_test_executable(mlir-capi-smt-test
128+
smt.c
129+
130+
LINK_LIBS PRIVATE
131+
MLIRCAPIIR
132+
MLIRCAPIFunc
133+
MLIRCAPISMT
134+
MLIRCAPIExportSMTLIB
135+
)

0 commit comments

Comments
 (0)