Skip to content

Commit f5c7c03

Browse files
committed
[mlir] Add C API for IntegerSet
Depends On D95357 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D95368
1 parent 27cc4a8 commit f5c7c03

File tree

6 files changed

+351
-1
lines changed

6 files changed

+351
-1
lines changed

mlir/include/mlir-c/IntegerSet.h

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
//===-- mlir-c/IntegerSet.h - C API for MLIR Affine maps ----------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_C_INTEGERSET_H
11+
#define MLIR_C_INTEGERSET_H
12+
13+
#include "mlir-c/AffineExpr.h"
14+
15+
#ifdef __cplusplus
16+
extern "C" {
17+
#endif
18+
19+
//===----------------------------------------------------------------------===//
20+
// Opaque type declarations.
21+
//
22+
// Types are exposed to C bindings as structs containing opaque pointers. They
23+
// are not supposed to be inspected from C. This allows the underlying
24+
// representation to change without affecting the API users. The use of structs
25+
// instead of typedefs enables some type safety as structs are not implicitly
26+
// convertible to each other.
27+
//
28+
// Instances of these types may or may not own the underlying object. The
29+
// ownership semantics is defined by how an instance of the type was obtained.
30+
//===----------------------------------------------------------------------===//
31+
32+
#define DEFINE_C_API_STRUCT(name, storage) \
33+
struct name { \
34+
storage *ptr; \
35+
}; \
36+
typedef struct name name
37+
38+
DEFINE_C_API_STRUCT(MlirIntegerSet, const void);
39+
40+
#undef DEFINE_C_API_STRUCT
41+
42+
/// Gets the context in which the given integer set lives.
43+
MLIR_CAPI_EXPORTED MlirContext mlirIntegerSetGetContext(MlirIntegerSet set);
44+
45+
/// Checks whether an integer set is a null object.
46+
static inline bool mlirIntegerSetIsNull(MlirIntegerSet set) { return !set.ptr; }
47+
48+
/// Checks if two integer set objects are equal. This is a "shallow" comparison
49+
/// of two objects. Only the sets with some small number of constraints are
50+
/// uniqued and compare equal here. Set objects that represent the same integer
51+
/// set with different constraints may be considered non-equal by this check.
52+
/// Set difference followed by an (expensive) emptiness check should be used to
53+
/// check equivalence of the underlying integer sets.
54+
MLIR_CAPI_EXPORTED bool mlirIntegerSetEqual(MlirIntegerSet s1,
55+
MlirIntegerSet s2);
56+
57+
/// Prints an integer set by sending chunks of the string representation and
58+
/// forwarding `userData to `callback`. Note that the callback may be called
59+
/// several times with consecutive chunks of the string.
60+
MLIR_CAPI_EXPORTED void mlirIntegerSetPrint(MlirIntegerSet set,
61+
MlirStringCallback callback,
62+
void *userData);
63+
64+
/// Prints an integer set to the standard error stream.
65+
MLIR_CAPI_EXPORTED void mlirIntegerSetDump(MlirIntegerSet set);
66+
67+
/// Gets or creates a new canonically empty integer set with the give number of
68+
/// dimensions and symbols in the given context.
69+
MLIR_CAPI_EXPORTED MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context,
70+
intptr_t numDims,
71+
intptr_t numSymbols);
72+
73+
/// Gets or creates a new integer set in the given context. The set is defined
74+
/// by a list of affine constraints, with the given number of input dimensions
75+
/// and symbols, which are treated as either equalities (eqFlags is 1) or
76+
/// inequalities (eqFlags is 0). Both `constraints` and `eqFlags` are expected
77+
/// to point to at least `numConstraint` consecutive values.
78+
MLIR_CAPI_EXPORTED MlirIntegerSet
79+
mlirIntegerSetGet(MlirContext context, intptr_t numDims, intptr_t numSymbols,
80+
intptr_t numConstraints, const MlirAffineExpr *constraints,
81+
const bool *eqFlags);
82+
83+
/// Gets or creates a new integer set in which the values and dimensions of the
84+
/// given set are replaced with the given affine expressions. `dimReplacements`
85+
/// and `symbolReplacements` are expected to point to at least as many
86+
/// consecutive expressions as the given set has dimensions and symbols,
87+
/// respectively. The new set will have `numResultDims` and `numResultSymbols`
88+
/// dimensions and symbols, respectively.
89+
MLIR_CAPI_EXPORTED MlirIntegerSet mlirIntegerSetReplaceGet(
90+
MlirIntegerSet set, const MlirAffineExpr *dimReplacements,
91+
const MlirAffineExpr *symbolReplacements, intptr_t numResultDims,
92+
intptr_t numResultSymbols);
93+
94+
/// Checks whether the given set is a canonical empty set, e.g., the set
95+
/// returned by mlirIntegerSetEmptyGet.
96+
MLIR_CAPI_EXPORTED bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set);
97+
98+
/// Returns the number of dimensions in the given set.
99+
MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set);
100+
101+
/// Returns the number of symbols in the given set.
102+
MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set);
103+
104+
/// Returns the number of inputs (dimensions + symbols) in the given set.
105+
MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set);
106+
107+
/// Returns the number of constraints (equalities + inequalities) in the given
108+
/// set.
109+
MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set);
110+
111+
/// Returns the number of equalities in the given set.
112+
MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set);
113+
114+
/// Returns the number of inequalities in the given set.
115+
MLIR_CAPI_EXPORTED intptr_t
116+
mlirIntegerSetGetNumInequalities(MlirIntegerSet set);
117+
118+
/// Returns `pos`-th constraint of the set.
119+
MLIR_CAPI_EXPORTED MlirAffineExpr
120+
mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos);
121+
122+
/// Returns `true` of the `pos`-th constraint of the set is an equality
123+
/// constraint, `false` otherwise.
124+
MLIR_CAPI_EXPORTED bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set,
125+
intptr_t pos);
126+
127+
#ifdef __cplusplus
128+
}
129+
#endif
130+
131+
#endif // MLIR_C_INTEGERSET_H

mlir/include/mlir/CAPI/IntegerSet.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- IntegerSet.h - C API Utils for Integer Sets --------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains declarations of implementation details of the C API for
10+
// MLIR IntegerSets. This file should not be included from C++ code other than C
11+
// API implementation nor from C code.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_CAPI_INTEGERSET_H
16+
#define MLIR_CAPI_INTEGERSET_H
17+
18+
#include "mlir-c/IntegerSet.h"
19+
#include "mlir/CAPI/Wrap.h"
20+
#include "mlir/IR/IntegerSet.h"
21+
22+
DEFINE_C_API_METHODS(MlirIntegerSet, mlir::IntegerSet);
23+
24+
#endif // MLIR_CAPI_INTEGERSET_H

mlir/include/mlir/IR/IntegerSet.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ class IntegerSet {
104104

105105
friend ::llvm::hash_code hash_value(IntegerSet arg);
106106

107+
/// Methods supporting C API.
108+
const void *getAsOpaquePointer() const {
109+
return static_cast<const void *>(set);
110+
}
111+
static IntegerSet getFromOpaquePointer(const void *pointer) {
112+
return IntegerSet(
113+
reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
114+
}
115+
107116
private:
108117
ImplType *set;
109118
/// Sets with constraints fewer than kUniquingThreshold are uniqued.

mlir/lib/CAPI/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_public_c_api_library(MLIRCAPIIR
55
BuiltinAttributes.cpp
66
BuiltinTypes.cpp
77
Diagnostics.cpp
8+
IntegerSet.cpp
89
IR.cpp
910
Pass.cpp
1011
Support.cpp

mlir/lib/CAPI/IR/IntegerSet.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//===- IntegerSet.cpp - C API for MLIR Integer Sets -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/IntegerSet.h"
10+
#include "mlir-c/AffineExpr.h"
11+
#include "mlir/CAPI/AffineExpr.h"
12+
#include "mlir/CAPI/IR.h"
13+
#include "mlir/CAPI/IntegerSet.h"
14+
#include "mlir/CAPI/Utils.h"
15+
#include "mlir/IR/IntegerSet.h"
16+
17+
using namespace mlir;
18+
19+
MlirContext mlirIntegerSetGetContext(MlirIntegerSet set) {
20+
return wrap(unwrap(set).getContext());
21+
}
22+
23+
bool mlirIntegerSetEqual(MlirIntegerSet s1, MlirIntegerSet s2) {
24+
return unwrap(s1) == unwrap(s2);
25+
}
26+
27+
void mlirIntegerSetPrint(MlirIntegerSet set, MlirStringCallback callback,
28+
void *userData) {
29+
mlir::detail::CallbackOstream stream(callback, userData);
30+
unwrap(set).print(stream);
31+
}
32+
33+
void mlirIntegerSetDump(MlirIntegerSet set) { unwrap(set).dump(); }
34+
35+
MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context, intptr_t numDims,
36+
intptr_t numSymbols) {
37+
return wrap(IntegerSet::getEmptySet(static_cast<unsigned>(numDims),
38+
static_cast<unsigned>(numSymbols),
39+
unwrap(context)));
40+
}
41+
42+
MlirIntegerSet mlirIntegerSetGet(MlirContext context, intptr_t numDims,
43+
intptr_t numSymbols, intptr_t numConstraints,
44+
const MlirAffineExpr *constraints,
45+
const bool *eqFlags) {
46+
SmallVector<AffineExpr> mlirConstraints;
47+
(void)unwrapList(static_cast<size_t>(numConstraints), constraints,
48+
mlirConstraints);
49+
return wrap(IntegerSet::get(
50+
static_cast<unsigned>(numDims), static_cast<unsigned>(numSymbols),
51+
mlirConstraints,
52+
llvm::makeArrayRef(eqFlags, static_cast<size_t>(numConstraints))));
53+
}
54+
55+
MlirIntegerSet
56+
mlirIntegerSetReplaceGet(MlirIntegerSet set,
57+
const MlirAffineExpr *dimReplacements,
58+
const MlirAffineExpr *symbolReplacements,
59+
intptr_t numResultDims, intptr_t numResultSymbols) {
60+
SmallVector<AffineExpr> mlirDims, mlirSymbols;
61+
(void)unwrapList(unwrap(set).getNumDims(), dimReplacements, mlirDims);
62+
(void)unwrapList(unwrap(set).getNumSymbols(), symbolReplacements,
63+
mlirSymbols);
64+
return wrap(unwrap(set).replaceDimsAndSymbols(
65+
mlirDims, mlirSymbols, static_cast<unsigned>(numResultDims),
66+
static_cast<unsigned>(numResultSymbols)));
67+
}
68+
69+
bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set) {
70+
return unwrap(set).isEmptyIntegerSet();
71+
}
72+
73+
intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set) {
74+
return static_cast<intptr_t>(unwrap(set).getNumDims());
75+
}
76+
77+
intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set) {
78+
return static_cast<intptr_t>(unwrap(set).getNumSymbols());
79+
}
80+
81+
intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set) {
82+
return static_cast<intptr_t>(unwrap(set).getNumInputs());
83+
}
84+
85+
intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set) {
86+
return static_cast<intptr_t>(unwrap(set).getNumConstraints());
87+
}
88+
89+
intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set) {
90+
return static_cast<intptr_t>(unwrap(set).getNumEqualities());
91+
}
92+
93+
intptr_t mlirIntegerSetGetNumInequalities(MlirIntegerSet set) {
94+
return static_cast<intptr_t>(unwrap(set).getNumInequalities());
95+
}
96+
97+
MlirAffineExpr mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos) {
98+
return wrap(unwrap(set).getConstraint(static_cast<unsigned>(pos)));
99+
}
100+
101+
bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set, intptr_t pos) {
102+
return unwrap(set).isEq(pos);
103+
}

mlir/test/CAPI/ir.c

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir-c/BuiltinTypes.h"
1818
#include "mlir-c/Diagnostics.h"
1919
#include "mlir-c/Dialect/Standard.h"
20+
#include "mlir-c/IntegerSet.h"
2021
#include "mlir-c/Registration.h"
2122

2223
#include <assert.h>
@@ -1325,6 +1326,85 @@ int affineMapFromExprs(MlirContext ctx) {
13251326
return 0;
13261327
}
13271328

1329+
int printIntegerSet(MlirContext ctx) {
1330+
MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1331+
1332+
// CHECK-LABEL: @printIntegerSet
1333+
fprintf(stderr, "@printIntegerSet");
1334+
1335+
// CHECK: (d0, d1)[s0] : (1 == 0)
1336+
mlirIntegerSetDump(emptySet);
1337+
1338+
if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
1339+
return 1;
1340+
1341+
MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1342+
if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
1343+
return 2;
1344+
1345+
// Construct a set constrained by:
1346+
// d0 - s0 == 0,
1347+
// d1 - 42 >= 0.
1348+
MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
1349+
MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
1350+
MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
1351+
MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
1352+
MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
1353+
MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
1354+
MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
1355+
MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
1356+
MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
1357+
bool flags[] = {true, false};
1358+
1359+
MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
1360+
// CHECK: (d0, d1)[s0] : (
1361+
// CHECK-DAG: d0 - s0 == 0
1362+
// CHECK-DAG: d1 - 42 >= 0
1363+
mlirIntegerSetDump(set);
1364+
1365+
// Transform d1 into s0.
1366+
MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
1367+
MlirAffineExpr repl[] = {d0, s1};
1368+
MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
1369+
// CHECK: (d0)[s0, s1] : (
1370+
// CHECK-DAG: d0 - s0 == 0
1371+
// CHECK-DAG: s1 - 42 >= 0
1372+
mlirIntegerSetDump(replaced);
1373+
1374+
if (mlirIntegerSetGetNumDims(set) != 2)
1375+
return 3;
1376+
if (mlirIntegerSetGetNumDims(replaced) != 1)
1377+
return 4;
1378+
1379+
if (mlirIntegerSetGetNumSymbols(set) != 1)
1380+
return 5;
1381+
if (mlirIntegerSetGetNumSymbols(replaced) != 2)
1382+
return 6;
1383+
1384+
if (mlirIntegerSetGetNumInputs(set) != 3)
1385+
return 7;
1386+
1387+
if (mlirIntegerSetGetNumConstraints(set) != 2)
1388+
return 8;
1389+
1390+
if (mlirIntegerSetGetNumEqualities(set) != 1)
1391+
return 9;
1392+
1393+
if (mlirIntegerSetGetNumInequalities(set) != 1)
1394+
return 10;
1395+
1396+
MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
1397+
MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
1398+
bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
1399+
bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
1400+
if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
1401+
return 11;
1402+
if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
1403+
return 12;
1404+
1405+
return 0;
1406+
}
1407+
13281408
int registerOnlyStd() {
13291409
MlirContext ctx = mlirContextCreate();
13301410
// The built-in dialect is always loaded.
@@ -1429,8 +1509,10 @@ int main() {
14291509
return 6;
14301510
if (affineMapFromExprs(ctx))
14311511
return 7;
1432-
if (registerOnlyStd())
1512+
if (printIntegerSet(ctx))
14331513
return 8;
1514+
if (registerOnlyStd())
1515+
return 9;
14341516

14351517
mlirContextDestroy(ctx);
14361518

0 commit comments

Comments
 (0)