Skip to content

[mlir] Expose AffineExpr.shift_dims/shift_symbols through C and Python bindings #131521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir-c/AffineExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose(
MlirAffineExpr affineExpr, struct MlirAffineMap affineMap);

/// Replace dims[offset ... numDims)
/// by dims[offset + shift ... shift + numDims).
MLIR_CAPI_EXPORTED MlirAffineExpr
mlirAffineExprShiftDims(MlirAffineExpr affineExpr, uint32_t numDims,
uint32_t shift, uint32_t offset);

/// Replace symbols[offset ... numSymbols)
/// by symbols[offset + shift ... shift + numSymbols).
MLIR_CAPI_EXPORTED MlirAffineExpr
mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
uint32_t shift, uint32_t offset);

//===----------------------------------------------------------------------===//
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Bindings/Python/IRAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,25 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
return PyAffineExpr(self.getContext(),
mlirAffineExprCompose(self, other));
})
.def(
"shift_dims",
[](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
uint32_t offset) {
return PyAffineExpr(
self.getContext(),
mlirAffineExprShiftDims(self, numDims, shift, offset));
},
nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0)
.def(
"shift_symbols",
[](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift,
uint32_t offset) {
return PyAffineExpr(
self.getContext(),
mlirAffineExprShiftSymbols(self, numSymbols, shift, offset));
},
nb::arg("num_symbols"), nb::arg("shift"),
nb::arg("offset").none() = 0)
.def_static(
"get_add", &PyAffineAddExpr::get,
"Gets an affine expression containing a sum of two expressions.")
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/CAPI/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr,
return wrap(unwrap(affineExpr).compose(unwrap(affineMap)));
}

MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr,
uint32_t numDims, uint32_t shift,
uint32_t offset) {
return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset));
}

MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
uint32_t numSymbols, uint32_t shift,
uint32_t offset) {
return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
}

//===----------------------------------------------------------------------===//
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/python/ir/affine_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,14 @@ def testHash():
dictionary[s1] = 1
assert d0 in dictionary
assert s1 in dictionary


# CHECK-LABEL: TEST: testAffineExprShift
@run
def testAffineExprShift():
with Context() as ctx:
dims = [AffineExpr.get_dim(i) for i in range(4)]
syms = [AffineExpr.get_symbol(i) for i in range(4)]

assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)