Skip to content

Commit 1f194ff

Browse files
authored
[mlir] Expose simplifyAffineExpr through python api (#133926)
1 parent 7e25b24 commit 1f194ff

File tree

4 files changed

+33
-0
lines changed

4 files changed

+33
-0
lines changed

mlir/include/mlir-c/AffineExpr.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,16 @@ MLIR_CAPI_EXPORTED MlirAffineExpr
104104
mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
105105
uint32_t shift, uint32_t offset);
106106

107+
/// Simplify an affine expression by flattening and some amount of simple
108+
/// analysis. This has complexity linear in the number of nodes in 'expr'.
109+
/// Returns the simplified expression, which is the same as the input expression
110+
/// if it can't be simplified. When `expr` is semi-affine, a simplified
111+
/// semi-affine expression is constructed in the sorted order of dimension and
112+
/// symbol positions.
113+
MLIR_CAPI_EXPORTED MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr,
114+
uint32_t numDims,
115+
uint32_t numSymbols);
116+
107117
//===----------------------------------------------------------------------===//
108118
// Affine Dimension Expression.
109119
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRAffine.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,16 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
599599
},
600600
nb::arg("num_symbols"), nb::arg("shift"),
601601
nb::arg("offset").none() = 0)
602+
.def_static(
603+
"simplify_affine_expr",
604+
[](PyAffineExpr &self, uint32_t numDims, uint32_t numSymbols) {
605+
return PyAffineExpr(
606+
self.getContext(),
607+
mlirSimplifyAffineExpr(self, numDims, numSymbols));
608+
},
609+
nb::arg("expr"), nb::arg("num_dims"), nb::arg("num_symbols"),
610+
"Simplify an affine expression by flattening and some amount of "
611+
"simple analysis.")
602612
.def_static(
603613
"get_add", &PyAffineAddExpr::get,
604614
"Gets an affine expression containing a sum of two expressions.")

mlir/lib/CAPI/IR/AffineExpr.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
7373
return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
7474
}
7575

76+
MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr, uint32_t numDims,
77+
uint32_t numSymbols) {
78+
return wrap(simplifyAffineExpr(unwrap(expr), numDims, numSymbols));
79+
}
80+
7681
//===----------------------------------------------------------------------===//
7782
// Affine Dimension Expression.
7883
//===----------------------------------------------------------------------===//

mlir/test/python/ir/affine_expr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,11 @@ def testAffineExprShift():
416416

417417
assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
418418
assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)
419+
420+
421+
# CHECK-LABEL: TEST: testAffineExprSimplify
422+
@run
423+
def testAffineExprSimplify():
424+
with Context() as ctx:
425+
expr = AffineExpr.get_dim(0) + AffineExpr.get_symbol(0)
426+
assert expr == AffineExpr.simplify_affine_expr(expr, 1, 1)

0 commit comments

Comments
 (0)