Skip to content

Commit 697aa99

Browse files
authored
[mlir][SMT] add python bindings (#135674)
This PR adds "rich" python bindings to SMT dialect.
1 parent 7623501 commit 697aa99

File tree

10 files changed

+378
-112
lines changed

10 files changed

+378
-112
lines changed

mlir/include/mlir-c/Dialect/SMT.h

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,82 +26,83 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt);
2626
//===----------------------------------------------------------------------===//
2727

2828
/// Checks if the given type is any non-func SMT value type.
29-
MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type);
29+
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type);
3030

3131
/// Checks if the given type is any SMT value type.
32-
MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type);
32+
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type);
3333

3434
/// Checks if the given type is a smt::ArrayType.
35-
MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type);
35+
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type);
3636

3737
/// Creates an array type with the given domain and range types.
38-
MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx,
39-
MlirType domainType,
40-
MlirType rangeType);
38+
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx,
39+
MlirType domainType,
40+
MlirType rangeType);
4141

4242
/// Checks if the given type is a smt::BitVectorType.
43-
MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type);
43+
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type);
4444

4545
/// Creates a smt::BitVectorType with the given width.
46-
MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width);
46+
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx,
47+
int32_t width);
4748

4849
/// Checks if the given type is a smt::BoolType.
49-
MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type);
50+
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type);
5051

5152
/// Creates a smt::BoolType.
52-
MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx);
53+
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx);
5354

5455
/// Checks if the given type is a smt::IntType.
55-
MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type);
56+
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type);
5657

5758
/// Creates a smt::IntType.
58-
MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx);
59+
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx);
5960

6061
/// Checks if the given type is a smt::FuncType.
61-
MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type);
62+
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type);
6263

6364
/// Creates a smt::FuncType with the given domain and range types.
64-
MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx,
65-
size_t numberOfDomainTypes,
66-
const MlirType *domainTypes,
67-
MlirType rangeType);
65+
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx,
66+
size_t numberOfDomainTypes,
67+
const MlirType *domainTypes,
68+
MlirType rangeType);
6869

6970
/// Checks if the given type is a smt::SortType.
70-
MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type);
71+
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type);
7172

7273
/// Creates a smt::SortType with the given identifier and sort parameters.
73-
MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx,
74-
MlirIdentifier identifier,
75-
size_t numberOfSortParams,
76-
const MlirType *sortParams);
74+
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx,
75+
MlirIdentifier identifier,
76+
size_t numberOfSortParams,
77+
const MlirType *sortParams);
7778

7879
//===----------------------------------------------------------------------===//
7980
// Attribute API.
8081
//===----------------------------------------------------------------------===//
8182

8283
/// Checks if the given string is a valid smt::BVCmpPredicate.
83-
MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx,
84-
MlirStringRef str);
84+
MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx,
85+
MlirStringRef str);
8586

8687
/// Checks if the given string is a valid smt::IntPredicate.
87-
MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx,
88-
MlirStringRef str);
88+
MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx,
89+
MlirStringRef str);
8990

9091
/// Checks if the given attribute is a smt::SMTAttribute.
91-
MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr);
92+
MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr);
9293

9394
/// Creates a smt::BitVectorAttr with the given value and width.
94-
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx,
95-
uint64_t value,
96-
unsigned width);
95+
MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx,
96+
uint64_t value,
97+
unsigned width);
9798

9899
/// Creates a smt::BVCmpPredicateAttr with the given string.
99-
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx,
100-
MlirStringRef str);
100+
MLIR_CAPI_EXPORTED MlirAttribute
101+
mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str);
101102

102103
/// Creates a smt::IntPredicateAttr with the given string.
103-
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx,
104-
MlirStringRef str);
104+
MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx,
105+
MlirStringRef str);
105106

106107
#ifdef __cplusplus
107108
}

mlir/include/mlir-c/Target/ExportSMTLIB.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ extern "C" {
2121

2222
/// Emits SMTLIB for the specified module using the provided callback and user
2323
/// data
24-
MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule,
25-
MlirStringCallback,
26-
void *userData);
24+
MLIR_CAPI_EXPORTED MlirLogicalResult
25+
mlirTranslateModuleToSMTLIB(MlirModule, MlirStringCallback, void *userData,
26+
bool inlineSingleUseValues, bool indentLetBody);
27+
28+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB(
29+
MlirOperation, MlirStringCallback, void *userData,
30+
bool inlineSingleUseValues, bool indentLetBody);
2731

2832
#ifdef __cplusplus
2933
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "NanobindUtils.h"
10+
11+
#include "mlir-c/Dialect/SMT.h"
12+
#include "mlir-c/IR.h"
13+
#include "mlir-c/Support.h"
14+
#include "mlir-c/Target/ExportSMTLIB.h"
15+
#include "mlir/Bindings/Python/Diagnostics.h"
16+
#include "mlir/Bindings/Python/Nanobind.h"
17+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
18+
19+
namespace nb = nanobind;
20+
21+
using namespace nanobind::literals;
22+
23+
using namespace mlir;
24+
using namespace mlir::python;
25+
using namespace mlir::python::nanobind_adaptors;
26+
27+
void populateDialectSMTSubmodule(nanobind::module_ &m) {
28+
29+
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
30+
.def_classmethod(
31+
"get",
32+
[](const nb::object &, MlirContext context) {
33+
return mlirSMTTypeGetBool(context);
34+
},
35+
"cls"_a, "context"_a.none() = nb::none());
36+
auto smtBitVectorType =
37+
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
38+
.def_classmethod(
39+
"get",
40+
[](const nb::object &, int32_t width, MlirContext context) {
41+
return mlirSMTTypeGetBitVector(context, width);
42+
},
43+
"cls"_a, "width"_a, "context"_a.none() = nb::none());
44+
45+
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
46+
bool indentLetBody) {
47+
mlir::python::CollectDiagnosticsToStringScope scope(
48+
mlirOperationGetContext(module));
49+
PyPrintAccumulator printAccum;
50+
MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
51+
module, printAccum.getCallback(), printAccum.getUserData(),
52+
inlineSingleUseValues, indentLetBody);
53+
if (mlirLogicalResultIsSuccess(result))
54+
return printAccum.join();
55+
throw nb::value_error(
56+
("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage())
57+
.c_str());
58+
};
59+
60+
m.def(
61+
"export_smtlib",
62+
[&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
63+
bool indentLetBody) {
64+
return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
65+
},
66+
"module"_a, "inline_single_use_values"_a = false,
67+
"indent_let_body"_a = false);
68+
m.def(
69+
"export_smtlib",
70+
[&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
71+
bool indentLetBody) {
72+
return exportSMTLIB(mlirModuleGetOperation(module),
73+
inlineSingleUseValues, indentLetBody);
74+
},
75+
"module"_a, "inline_single_use_values"_a = false,
76+
"indent_let_body"_a = false);
77+
}
78+
79+
NB_MODULE(_mlirDialectsSMT, m) {
80+
m.doc() = "MLIR SMT Dialect";
81+
82+
populateDialectSMTSubmodule(m);
83+
}

mlir/lib/CAPI/Dialect/SMT.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,46 +25,49 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect)
2525
// Type API.
2626
//===----------------------------------------------------------------------===//
2727

28-
bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) {
28+
bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) {
2929
return isAnyNonFuncSMTValueType(unwrap(type));
3030
}
3131

32-
bool smtTypeIsAnySMTValueType(MlirType type) {
32+
bool mlirSMTTypeIsAnySMTValueType(MlirType type) {
3333
return isAnySMTValueType(unwrap(type));
3434
}
3535

36-
bool smtTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
36+
bool mlirSMTTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
3737

38-
MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType,
39-
MlirType rangeType) {
38+
MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType,
39+
MlirType rangeType) {
4040
return wrap(
4141
ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType)));
4242
}
4343

44-
bool smtTypeIsABitVector(MlirType type) {
44+
bool mlirSMTTypeIsABitVector(MlirType type) {
4545
return isa<BitVectorType>(unwrap(type));
4646
}
4747

48-
MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) {
48+
MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) {
4949
return wrap(BitVectorType::get(unwrap(ctx), width));
5050
}
5151

52-
bool smtTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
52+
bool mlirSMTTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
5353

54-
MlirType smtTypeGetBool(MlirContext ctx) {
54+
MlirType mlirSMTTypeGetBool(MlirContext ctx) {
5555
return wrap(BoolType::get(unwrap(ctx)));
5656
}
5757

58-
bool smtTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
58+
bool mlirSMTTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
5959

60-
MlirType smtTypeGetInt(MlirContext ctx) {
60+
MlirType mlirSMTTypeGetInt(MlirContext ctx) {
6161
return wrap(IntType::get(unwrap(ctx)));
6262
}
6363

64-
bool smtTypeIsASMTFunc(MlirType type) { return isa<SMTFuncType>(unwrap(type)); }
64+
bool mlirSMTTypeIsASMTFunc(MlirType type) {
65+
return isa<SMTFuncType>(unwrap(type));
66+
}
6567

66-
MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
67-
const MlirType *domainTypes, MlirType rangeType) {
68+
MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
69+
const MlirType *domainTypes,
70+
MlirType rangeType) {
6871
SmallVector<Type> domainTypesVec;
6972
domainTypesVec.reserve(numberOfDomainTypes);
7073

@@ -74,10 +77,11 @@ MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
7477
return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType)));
7578
}
7679

77-
bool smtTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
80+
bool mlirSMTTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
7881

79-
MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
80-
size_t numberOfSortParams, const MlirType *sortParams) {
82+
MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
83+
size_t numberOfSortParams,
84+
const MlirType *sortParams) {
8185
SmallVector<Type> sortParamsVec;
8286
sortParamsVec.reserve(numberOfSortParams);
8387

@@ -91,31 +95,31 @@ MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
9195
// Attribute API.
9296
//===----------------------------------------------------------------------===//
9397

94-
bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
98+
bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
9599
return symbolizeBVCmpPredicate(unwrap(str)).has_value();
96100
}
97101

98-
bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
102+
bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
99103
return symbolizeIntPredicate(unwrap(str)).has_value();
100104
}
101105

102-
bool smtAttrIsASMTAttribute(MlirAttribute attr) {
106+
bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) {
103107
return isa<BitVectorAttr, BVCmpPredicateAttr, IntPredicateAttr>(unwrap(attr));
104108
}
105109

106-
MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value,
107-
unsigned width) {
110+
MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value,
111+
unsigned width) {
108112
return wrap(BitVectorAttr::get(unwrap(ctx), value, width));
109113
}
110114

111-
MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
115+
MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
112116
auto predicate = symbolizeBVCmpPredicate(unwrap(str));
113117
assert(predicate.has_value() && "invalid predicate");
114118

115119
return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value()));
116120
}
117121

118-
MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
122+
MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
119123
auto predicate = symbolizeIntPredicate(unwrap(str));
120124
assert(predicate.has_value() && "invalid predicate");
121125

mlir/lib/CAPI/Target/ExportSMTLIB.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,24 @@
1919

2020
using namespace mlir;
2121

22-
MlirLogicalResult mlirExportSMTLIB(MlirModule module,
23-
MlirStringCallback callback,
24-
void *userData) {
22+
MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module,
23+
MlirStringCallback callback,
24+
void *userData,
25+
bool inlineSingleUseValues,
26+
bool indentLetBody) {
2527
mlir::detail::CallbackOstream stream(callback, userData);
28+
smt::SMTEmissionOptions options;
29+
options.inlineSingleUseValues = inlineSingleUseValues;
30+
options.indentLetBody = indentLetBody;
2631
return wrap(smt::exportSMTLIB(unwrap(module), stream));
2732
}
33+
34+
MlirLogicalResult mlirTranslateModuleToSMTLIB(MlirModule module,
35+
MlirStringCallback callback,
36+
void *userData,
37+
bool inlineSingleUseValues,
38+
bool indentLetBody) {
39+
return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module),
40+
callback, userData,
41+
inlineSingleUseValues, indentLetBody);
42+
}

0 commit comments

Comments
 (0)