Skip to content

Commit 7c98cdd

Browse files
authored
[mlir] Expose AffineExpr.shift_dims/shift_symbols through C and Python bindings (#131521)
1 parent 93ce345 commit 7c98cdd

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

mlir/include/mlir-c/AffineExpr.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
9292
MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose(
9393
MlirAffineExpr affineExpr, struct MlirAffineMap affineMap);
9494

95+
/// Replace dims[offset ... numDims)
96+
/// by dims[offset + shift ... shift + numDims).
97+
MLIR_CAPI_EXPORTED MlirAffineExpr
98+
mlirAffineExprShiftDims(MlirAffineExpr affineExpr, uint32_t numDims,
99+
uint32_t shift, uint32_t offset);
100+
101+
/// Replace symbols[offset ... numSymbols)
102+
/// by symbols[offset + shift ... shift + numSymbols).
103+
MLIR_CAPI_EXPORTED MlirAffineExpr
104+
mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
105+
uint32_t shift, uint32_t offset);
106+
95107
//===----------------------------------------------------------------------===//
96108
// Affine Dimension Expression.
97109
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRAffine.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,25 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
580580
return PyAffineExpr(self.getContext(),
581581
mlirAffineExprCompose(self, other));
582582
})
583+
.def(
584+
"shift_dims",
585+
[](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
586+
uint32_t offset) {
587+
return PyAffineExpr(
588+
self.getContext(),
589+
mlirAffineExprShiftDims(self, numDims, shift, offset));
590+
},
591+
nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0)
592+
.def(
593+
"shift_symbols",
594+
[](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift,
595+
uint32_t offset) {
596+
return PyAffineExpr(
597+
self.getContext(),
598+
mlirAffineExprShiftSymbols(self, numSymbols, shift, offset));
599+
},
600+
nb::arg("num_symbols"), nb::arg("shift"),
601+
nb::arg("offset").none() = 0)
583602
.def_static(
584603
"get_add", &PyAffineAddExpr::get,
585604
"Gets an affine expression containing a sum of two expressions.")

mlir/lib/CAPI/IR/AffineExpr.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr,
6161
return wrap(unwrap(affineExpr).compose(unwrap(affineMap)));
6262
}
6363

64+
MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr,
65+
uint32_t numDims, uint32_t shift,
66+
uint32_t offset) {
67+
return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset));
68+
}
69+
70+
MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
71+
uint32_t numSymbols, uint32_t shift,
72+
uint32_t offset) {
73+
return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
74+
}
75+
6476
//===----------------------------------------------------------------------===//
6577
// Affine Dimension Expression.
6678
//===----------------------------------------------------------------------===//

mlir/test/python/ir/affine_expr.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,14 @@ def testHash():
405405
dictionary[s1] = 1
406406
assert d0 in dictionary
407407
assert s1 in dictionary
408+
409+
410+
# CHECK-LABEL: TEST: testAffineExprShift
411+
@run
412+
def testAffineExprShift():
413+
with Context() as ctx:
414+
dims = [AffineExpr.get_dim(i) for i in range(4)]
415+
syms = [AffineExpr.get_symbol(i) for i in range(4)]
416+
417+
assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
418+
assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)

0 commit comments

Comments
 (0)