Skip to content

[mlir][SMT] add python bindings #135674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 35 additions & 34 deletions mlir/include/mlir-c/Dialect/SMT.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,82 +26,83 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt);
//===----------------------------------------------------------------------===//

/// Checks if the given type is any non-func SMT value type.
MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type);

/// Checks if the given type is any SMT value type.
MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type);

/// Checks if the given type is a smt::ArrayType.
MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type);

/// Creates an array type with the given domain and range types.
MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx,
MlirType domainType,
MlirType rangeType);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx,
MlirType domainType,
MlirType rangeType);

/// Checks if the given type is a smt::BitVectorType.
MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type);

/// Creates a smt::BitVectorType with the given width.
MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx,
int32_t width);

/// Checks if the given type is a smt::BoolType.
MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type);

/// Creates a smt::BoolType.
MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx);

/// Checks if the given type is a smt::IntType.
MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type);

/// Creates a smt::IntType.
MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx);

/// Checks if the given type is a smt::FuncType.
MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type);

/// Creates a smt::FuncType with the given domain and range types.
MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx,
size_t numberOfDomainTypes,
const MlirType *domainTypes,
MlirType rangeType);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx,
size_t numberOfDomainTypes,
const MlirType *domainTypes,
MlirType rangeType);

/// Checks if the given type is a smt::SortType.
MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type);
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type);

/// Creates a smt::SortType with the given identifier and sort parameters.
MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx,
MlirIdentifier identifier,
size_t numberOfSortParams,
const MlirType *sortParams);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx,
MlirIdentifier identifier,
size_t numberOfSortParams,
const MlirType *sortParams);

//===----------------------------------------------------------------------===//
// Attribute API.
//===----------------------------------------------------------------------===//

/// Checks if the given string is a valid smt::BVCmpPredicate.
MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx,
MlirStringRef str);
MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx,
MlirStringRef str);

/// Checks if the given string is a valid smt::IntPredicate.
MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx,
MlirStringRef str);
MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx,
MlirStringRef str);

/// Checks if the given attribute is a smt::SMTAttribute.
MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr);
MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr);

/// Creates a smt::BitVectorAttr with the given value and width.
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx,
uint64_t value,
unsigned width);
MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx,
uint64_t value,
unsigned width);

/// Creates a smt::BVCmpPredicateAttr with the given string.
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx,
MlirStringRef str);
MLIR_CAPI_EXPORTED MlirAttribute
mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str);

/// Creates a smt::IntPredicateAttr with the given string.
MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx,
MlirStringRef str);
MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx,
MlirStringRef str);

#ifdef __cplusplus
}
Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir-c/Target/ExportSMTLIB.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ extern "C" {

/// Emits SMTLIB for the specified module using the provided callback and user
/// data
MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule,
MlirStringCallback,
void *userData);
MLIR_CAPI_EXPORTED MlirLogicalResult
mlirTranslateModuleToSMTLIB(MlirModule, MlirStringCallback, void *userData,
bool inlineSingleUseValues, bool indentLetBody);

MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB(
MlirOperation, MlirStringCallback, void *userData,
bool inlineSingleUseValues, bool indentLetBody);

#ifdef __cplusplus
}
Expand Down
83 changes: 83 additions & 0 deletions mlir/lib/Bindings/Python/DialectSMT.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "NanobindUtils.h"

#include "mlir-c/Dialect/SMT.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir-c/Target/ExportSMTLIB.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;

using namespace nanobind::literals;

using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;

void populateDialectSMTSubmodule(nanobind::module_ &m) {

auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
.def_classmethod(
"get",
[](const nb::object &, MlirContext context) {
return mlirSMTTypeGetBool(context);
},
"cls"_a, "context"_a.none() = nb::none());
auto smtBitVectorType =
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
.def_classmethod(
"get",
[](const nb::object &, int32_t width, MlirContext context) {
return mlirSMTTypeGetBitVector(context, width);
},
"cls"_a, "width"_a, "context"_a.none() = nb::none());

auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(module));
PyPrintAccumulator printAccum;
MlirLogicalResult result = mlirTranslateOperationToSMTLIB(
module, printAccum.getCallback(), printAccum.getUserData(),
inlineSingleUseValues, indentLetBody);
if (mlirLogicalResultIsSuccess(result))
return printAccum.join();
throw nb::value_error(
("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage())
.c_str());
};

m.def(
"export_smtlib",
[&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
return exportSMTLIB(module, inlineSingleUseValues, indentLetBody);
},
"module"_a, "inline_single_use_values"_a = false,
"indent_let_body"_a = false);
m.def(
"export_smtlib",
[&exportSMTLIB](MlirModule module, bool inlineSingleUseValues,
bool indentLetBody) {
return exportSMTLIB(mlirModuleGetOperation(module),
inlineSingleUseValues, indentLetBody);
},
"module"_a, "inline_single_use_values"_a = false,
"indent_let_body"_a = false);
}

NB_MODULE(_mlirDialectsSMT, m) {
m.doc() = "MLIR SMT Dialect";

populateDialectSMTSubmodule(m);
}
52 changes: 28 additions & 24 deletions mlir/lib/CAPI/Dialect/SMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,46 +25,49 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect)
// Type API.
//===----------------------------------------------------------------------===//

bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) {
bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) {
return isAnyNonFuncSMTValueType(unwrap(type));
}

bool smtTypeIsAnySMTValueType(MlirType type) {
bool mlirSMTTypeIsAnySMTValueType(MlirType type) {
return isAnySMTValueType(unwrap(type));
}

bool smtTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }
bool mlirSMTTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); }

MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType,
MlirType rangeType) {
MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType,
MlirType rangeType) {
return wrap(
ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType)));
}

bool smtTypeIsABitVector(MlirType type) {
bool mlirSMTTypeIsABitVector(MlirType type) {
return isa<BitVectorType>(unwrap(type));
}

MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) {
MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) {
return wrap(BitVectorType::get(unwrap(ctx), width));
}

bool smtTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
bool mlirSMTTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }

MlirType smtTypeGetBool(MlirContext ctx) {
MlirType mlirSMTTypeGetBool(MlirContext ctx) {
return wrap(BoolType::get(unwrap(ctx)));
}

bool smtTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
bool mlirSMTTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }

MlirType smtTypeGetInt(MlirContext ctx) {
MlirType mlirSMTTypeGetInt(MlirContext ctx) {
return wrap(IntType::get(unwrap(ctx)));
}

bool smtTypeIsASMTFunc(MlirType type) { return isa<SMTFuncType>(unwrap(type)); }
bool mlirSMTTypeIsASMTFunc(MlirType type) {
return isa<SMTFuncType>(unwrap(type));
}

MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
const MlirType *domainTypes, MlirType rangeType) {
MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
const MlirType *domainTypes,
MlirType rangeType) {
SmallVector<Type> domainTypesVec;
domainTypesVec.reserve(numberOfDomainTypes);

Expand All @@ -74,10 +77,11 @@ MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes,
return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType)));
}

bool smtTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }
bool mlirSMTTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); }

MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
size_t numberOfSortParams, const MlirType *sortParams) {
MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
size_t numberOfSortParams,
const MlirType *sortParams) {
SmallVector<Type> sortParamsVec;
sortParamsVec.reserve(numberOfSortParams);

Expand All @@ -91,31 +95,31 @@ MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier,
// Attribute API.
//===----------------------------------------------------------------------===//

bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
return symbolizeBVCmpPredicate(unwrap(str)).has_value();
}

bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) {
return symbolizeIntPredicate(unwrap(str)).has_value();
}

bool smtAttrIsASMTAttribute(MlirAttribute attr) {
bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) {
return isa<BitVectorAttr, BVCmpPredicateAttr, IntPredicateAttr>(unwrap(attr));
}

MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value,
unsigned width) {
MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value,
unsigned width) {
return wrap(BitVectorAttr::get(unwrap(ctx), value, width));
}

MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) {
auto predicate = symbolizeBVCmpPredicate(unwrap(str));
assert(predicate.has_value() && "invalid predicate");

return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value()));
}

MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) {
auto predicate = symbolizeIntPredicate(unwrap(str));
assert(predicate.has_value() && "invalid predicate");

Expand Down
21 changes: 18 additions & 3 deletions mlir/lib/CAPI/Target/ExportSMTLIB.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,24 @@

using namespace mlir;

MlirLogicalResult mlirExportSMTLIB(MlirModule module,
MlirStringCallback callback,
void *userData) {
MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module,
MlirStringCallback callback,
void *userData,
bool inlineSingleUseValues,
bool indentLetBody) {
mlir::detail::CallbackOstream stream(callback, userData);
smt::SMTEmissionOptions options;
options.inlineSingleUseValues = inlineSingleUseValues;
options.indentLetBody = indentLetBody;
return wrap(smt::exportSMTLIB(unwrap(module), stream));
}

MlirLogicalResult mlirTranslateModuleToSMTLIB(MlirModule module,
MlirStringCallback callback,
void *userData,
bool inlineSingleUseValues,
bool indentLetBody) {
return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module),
callback, userData,
inlineSingleUseValues, indentLetBody);
}
Loading
Loading