Skip to content

Commit d864f80

Browse files
committed
bind more stuff
1 parent cc99ede commit d864f80

File tree

7 files changed

+235
-18
lines changed

7 files changed

+235
-18
lines changed

mlir/include/mlir-c/Target/ExportSMTLIB.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@ extern "C" {
2121

2222
/// Emits SMTLIB for the specified module using the provided callback and user
2323
/// data
24-
MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule,
25-
MlirStringCallback,
26-
void *userData);
24+
MLIR_CAPI_EXPORTED MlirLogicalResult
25+
mlirTranslateModuleSMTLIB(MlirModule, MlirStringCallback, void *userData,
26+
bool inlineSingleUseValues, bool indentLetBody);
27+
28+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB(
29+
MlirOperation, MlirStringCallback, void *userData,
30+
bool inlineSingleUseValues, bool indentLetBody);
2731

2832
#ifdef __cplusplus
2933
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "NanobindUtils.h"
10+
11+
#include "mlir-c/Dialect/SMT.h"
12+
#include "mlir-c/IR.h"
13+
#include "mlir-c/Support.h"
14+
#include "mlir-c/Target/ExportSMTLIB.h"
15+
#include "mlir/Bindings/Python/Diagnostics.h"
16+
#include "mlir/Bindings/Python/Nanobind.h"
17+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
18+
19+
namespace nb = nanobind;
20+
21+
using namespace nanobind::literals;
22+
23+
using namespace mlir;
24+
using namespace mlir::python;
25+
using namespace mlir::python::nanobind_adaptors;
26+
27+
void populateDialectSMTSubmodule(nanobind::module_ &m) {
28+
29+
auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
30+
.def_classmethod(
31+
"get",
32+
[](const nb::object &, MlirContext context) {
33+
return mlirSMTTypeGetBool(context);
34+
},
35+
"cls"_a, "context"_a.none() = nb::none());
36+
auto smtBitVectorType =
37+
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
38+
.def_classmethod(
39+
"get",
40+
[](const nb::object &, int32_t width, MlirContext context) {
41+
return mlirSMTTypeGetBitVector(context, width);
42+
},
43+
"cls"_a, "width"_a, "context"_a.none() = nb::none());
44+
45+
auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
46+
bool indentLetBody) {
47+
mlir::python::CollectDiagnosticsToStringScope scope(
48+
mlirOperationGetContext(module));
49+
PyPrintAccumulator printAccum;
50+
MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
51+
module, printAccum.getCallback(), printAccum.getUserData(),
52+
inlineSingleUseValues, indentLetBody);
53+
if (mlirLogicalResultIsSuccess(result))
54+
return printAccum.join();
55+
throw nb::value_error(
56+
("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage())
57+
.c_str());
58+
};
59+
60+
m.def(
61+
"export_smtlib",
62+
[&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
63+
bool indentLetBody) {
64+
return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
65+
},
66+
"module"_a, "inline_single_use_values"_a = false,
67+
"indent_let_body"_a = false);
68+
m.def(
69+
"export_smtlib",
70+
[&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
71+
bool indentLetBody) {
72+
return exportSMTLIB(mlirModuleGetOperation(module),
73+
inlineSingleUseValues, indentLetBody);
74+
},
75+
"module"_a, "inline_single_use_values"_a = false,
76+
"indent_let_body"_a = false);
77+
}
78+
79+
NB_MODULE(_mlirDialectsSMT, m) {
80+
m.doc() = "MLIR SMT Dialect";
81+
82+
populateDialectSMTSubmodule(m);
83+
}

mlir/lib/CAPI/Target/ExportSMTLIB.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,24 @@
1919

2020
using namespace mlir;
2121

22-
MlirLogicalResult mlirExportSMTLIB(MlirModule module,
23-
MlirStringCallback callback,
24-
void *userData) {
22+
MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module,
23+
MlirStringCallback callback,
24+
void *userData,
25+
bool inlineSingleUseValues,
26+
bool indentLetBody) {
2527
mlir::detail::CallbackOstream stream(callback, userData);
28+
smt::SMTEmissionOptions options;
29+
options.inlineSingleUseValues = inlineSingleUseValues;
30+
options.indentLetBody = indentLetBody;
2631
return wrap(smt::exportSMTLIB(unwrap(module), stream));
2732
}
33+
34+
MlirLogicalResult mlirTranslateModuleSMTLIB(MlirModule module,
35+
MlirStringCallback callback,
36+
void *userData,
37+
bool inlineSingleUseValues,
38+
bool indentLetBody) {
39+
return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module),
40+
callback, userData,
41+
inlineSingleUseValues, indentLetBody);
42+
}

mlir/python/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,21 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
673673
MLIRCAPILinalg
674674
)
675675

676+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
677+
MODULE_NAME _mlirDialectsSMT
678+
ADD_TO_PARENT MLIRPythonSources.Dialects.smt
679+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
680+
PYTHON_BINDINGS_LIBRARY nanobind
681+
SOURCES
682+
DialectSMT.cpp
683+
PRIVATE_LINK_LIBS
684+
LLVMSupport
685+
EMBED_CAPI_LINK_LIBS
686+
MLIRCAPIIR
687+
MLIRCAPISMT
688+
MLIRCAPIExportSMTLIB
689+
)
690+
676691
declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
677692
MODULE_NAME _mlirSparseTensorPasses
678693
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor

mlir/python/mlir/dialects/smt.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,31 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
from ._smt_ops_gen import *
6+
7+
from .._mlir_libs._mlirDialectsSMT import *
8+
from ..extras.meta import region_op
9+
10+
11+
def bool_t():
12+
return BoolType.get()
13+
14+
15+
def bv_t(width):
16+
return BitVectorType.get(width)
17+
18+
19+
def _solver(
20+
inputs=None,
21+
results=None,
22+
loc=None,
23+
ip=None,
24+
):
25+
if inputs is None:
26+
inputs = []
27+
if results is None:
28+
results = []
29+
30+
return SolverOp(results, inputs, loc=loc, ip=ip)
31+
32+
33+
solver = region_op(_solver, terminator=YieldOp)

mlir/test/CAPI/smt.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ void testExportSMTLIB(MlirContext ctx) {
3434
MlirModule module =
3535
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(testSMT));
3636

37-
MlirLogicalResult result = mlirExportSMTLIB(module, dumpCallback, NULL);
37+
MlirLogicalResult result =
38+
mlirTranslateModuleSMTLIB(module, dumpCallback, NULL, false, false);
3839
(void)result;
3940
assert(mlirLogicalResultIsSuccess(result));
4041

mlir/test/python/dialects/smt.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,87 @@
1-
# REQUIRES: bindings_python
2-
# RUN: %PYTHON% %s | FileCheck %s
1+
# RUN: %PYTHON %s | FileCheck %s
32

4-
import mlir
3+
from mlir.dialects import smt, arith
4+
from mlir.ir import Context, Location, Module, InsertionPoint, F32Type
55

6-
from mlir.dialects import smt
7-
from mlir.ir import Context, Location, Module, InsertionPoint
86

9-
with Context() as ctx, Location.unknown():
10-
m = Module.create()
11-
with InsertionPoint(m.body):
12-
true = smt.constant(True)
13-
false = smt.constant(False)
7+
def run(f):
8+
print("\nTEST:", f.__name__)
9+
with Context(), Location.unknown():
10+
module = Module.create()
11+
with InsertionPoint(module.body):
12+
f(module)
13+
print(module)
14+
assert module.operation.verify()
15+
16+
17+
# CHECK-LABEL: TEST: test_smoke
18+
@run
19+
def test_smoke(_module):
20+
true = smt.constant(True)
21+
false = smt.constant(False)
1422
# CHECK: smt.constant true
1523
# CHECK: smt.constant false
16-
print(m)
24+
25+
26+
# CHECK-LABEL: TEST: test_types
27+
@run
28+
def test_types(_module):
29+
bool_t = smt.bool_t()
30+
bitvector_t = smt.bv_t(5)
31+
# CHECK: !smt.bool
32+
print(bool_t)
33+
# CHECK: !smt.bv<5>
34+
print(bitvector_t)
35+
36+
37+
# CHECK-LABEL: TEST: test_solver_op
38+
@run
39+
def test_solver_op(_module):
40+
@smt.solver
41+
def foo1():
42+
true = smt.constant(True)
43+
false = smt.constant(False)
44+
45+
# CHECK: smt.solver() : () -> () {
46+
# CHECK: %true = smt.constant true
47+
# CHECK: %false = smt.constant false
48+
# CHECK: }
49+
50+
f32 = F32Type.get()
51+
52+
@smt.solver(results=[f32])
53+
def foo2():
54+
return arith.ConstantOp(f32, 1.0)
55+
56+
# CHECK: %{{.*}} = smt.solver() : () -> f32 {
57+
# CHECK: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
58+
# CHECK: smt.yield %[[CST1]] : f32
59+
# CHECK: }
60+
61+
two = arith.ConstantOp(f32, 2.0)
62+
# CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
63+
print(two)
64+
65+
@smt.solver(inputs=[two], results=[f32])
66+
def foo3(z: f32):
67+
return z
68+
69+
# CHECK: %{{.*}} = smt.solver(%[[CST2]]) : (f32) -> f32 {
70+
# CHECK: ^bb0(%[[ARG0:.*]]: f32):
71+
# CHECK: smt.yield %[[ARG0]] : f32
72+
# CHECK: }
73+
74+
75+
# CHECK-LABEL: TEST: test_export_smtlib
76+
@run
77+
def test_export_smtlib(module):
78+
@smt.solver
79+
def foo1():
80+
true = smt.constant(True)
81+
smt.assert_(true)
82+
83+
query = smt.export_smtlib(module.operation)
84+
# CHECK: ; solver scope 0
85+
# CHECK: (assert true)
86+
# CHECK: (reset)
87+
print(query)

0 commit comments

Comments
 (0)