Skip to content

[mlir] Expose AffineExpr.shift_dims/shift_symbols through C and Python bindings #131521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 16, 2025

Conversation

Hardcode84
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Mar 16, 2025

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/131521.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/AffineExpr.h (+12)
  • (modified) mlir/lib/Bindings/Python/IRAffine.cpp (+19)
  • (modified) mlir/lib/CAPI/IR/AffineExpr.cpp (+12)
  • (modified) mlir/test/python/ir/affine_expr.py (+11)
diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h
index 14e951ddee9ad..ab768eb2ec870 100644
--- a/mlir/include/mlir-c/AffineExpr.h
+++ b/mlir/include/mlir-c/AffineExpr.h
@@ -92,6 +92,18 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
 MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose(
     MlirAffineExpr affineExpr, struct MlirAffineMap affineMap);
 
+/// Replace dims[offset ... numDims)
+/// by dims[offset + shift ... shift + numDims).
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirAffineExprShiftDims(MlirAffineExpr affineExpr, uint32_t numDims,
+                        uint32_t shift, uint32_t offset);
+
+/// Replace symbols[offset ... numSymbols)
+/// by symbols[offset + shift ... shift + numSymbols).
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
+                           uint32_t shift, uint32_t offset);
+
 //===----------------------------------------------------------------------===//
 // Affine Dimension Expression.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index a2df824f59a53..3c95d29c4bcca 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -580,6 +580,25 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
              return PyAffineExpr(self.getContext(),
                                  mlirAffineExprCompose(self, other));
            })
+      .def(
+          "shift_dims",
+          [](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
+             uint32_t offset) {
+            return PyAffineExpr(
+                self.getContext(),
+                mlirAffineExprShiftDims(self, numDims, shift, offset));
+          },
+          nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0)
+      .def(
+          "shift_symbols",
+          [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift,
+             uint32_t offset) {
+            return PyAffineExpr(
+                self.getContext(),
+                mlirAffineExprShiftSymbols(self, numSymbols, shift, offset));
+          },
+          nb::arg("num_symbols"), nb::arg("shift"),
+          nb::arg("offset").none() = 0)
       .def_static(
           "get_add", &PyAffineAddExpr::get,
           "Gets an affine expression containing a sum of two expressions.")
diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp
index 6e3328b65cb08..bc3dcd4174736 100644
--- a/mlir/lib/CAPI/IR/AffineExpr.cpp
+++ b/mlir/lib/CAPI/IR/AffineExpr.cpp
@@ -61,6 +61,18 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr,
   return wrap(unwrap(affineExpr).compose(unwrap(affineMap)));
 }
 
+MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr,
+                                       uint32_t numDims, uint32_t shift,
+                                       uint32_t offset) {
+  return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset));
+}
+
+MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
+                                          uint32_t numSymbols, uint32_t shift,
+                                          uint32_t offset) {
+  return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Dimension Expression.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index c7861c1acfe12..2f64aff143420 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -405,3 +405,14 @@ def testHash():
         dictionary[s1] = 1
         assert d0 in dictionary
         assert s1 in dictionary
+
+
+# CHECK-LABEL: TEST: testAffineExprShift
+@run
+def testAffineExprShift():
+    with Context() as ctx:
+        dims = [AffineExpr.get_dim(i) for i in range(4)]
+        syms = [AffineExpr.get_symbol(i) for i in range(4)]
+
+        assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
+        assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)

@Hardcode84 Hardcode84 merged commit 7c98cdd into llvm:main Mar 16, 2025
13 checks passed
@Hardcode84 Hardcode84 deleted the pyaffine-shift branch March 16, 2025 16:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants