Skip to content

[mlir][python] python binding wrapper for the affine.AffineForOp #74408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 5, 2023

Conversation

kaitingwang
Copy link
Contributor

This PR creates the wrapper class AffineForOp and adds a testcase for it. A testcase for the AffineLoadOp is also added.

@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Dec 5, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2023

@llvm/pr-subscribers-mlir

Author: Amy Wang (kaitingwang)

Changes

This PR creates the wrapper class AffineForOp and adds a testcase for it. A testcase for the AffineLoadOp is also added.


Full diff: https://github.com/llvm/llvm-project/pull/74408.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/affine.py (+91)
  • (modified) mlir/test/python/dialects/affine.py (+83-4)
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 80d3873e19a05..71b44e5492716 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,3 +3,94 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._affine_ops_gen import *
+from ._affine_ops_gen import _Dialect, AffineForOp
+
+try:
+    from ..ir import *
+    from ._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        get_op_results_or_values as _get_op_results_or_values,
+        _cext as _ods_cext,
+    )
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class AffineForOp(AffineForOp):
+    """Specialization for the Affine for op class"""
+
+    def __init__(
+        self,
+        lower_bound,
+        upper_bound,
+        step,
+        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        lower_bound_operands=[],
+        upper_bound_operands=[],
+        loc=None,
+        ip=None
+    ):
+        """Creates an Affine `for` operation.
+
+        - `lower_bound` is the affine map to use as lower bound of the loop.
+        - `upper_bound` is the affine map to use as upper bound of the loop.
+        - `step` is the value to use as loop step.
+        - `iter_args` is a list of additional loop-carried arguments or an operation
+          producing them as results.
+        - `lower_bound_operands` is the list of arguments to substitute the dimensions, 
+          then symbols in the `lower_bound` affine map, in an increasing order
+        - `upper_bound_operands` is the list of arguments to substitute the dimensions, 
+          then symbols in the `upper_bound` affine map, in an increasing order
+        """
+
+        if iter_args is None:
+            iter_args = []
+        iter_args = _get_op_results_or_values(iter_args)
+
+        if len(lower_bound_operands) != lower_bound.n_inputs:
+            raise ValueError(
+                f"Wrong number of lower bound operands passed to AffineForOp. " +
+                "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}.")
+
+        if len(upper_bound_operands) != upper_bound.n_inputs:
+            raise ValueError(
+                f"Wrong number of upper bound operands passed to AffineForOp. " +
+                "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}.")
+
+        results = [arg.type for arg in iter_args]
+        super().__init__(
+            results_=results,
+            lowerBoundOperands=[_get_op_result_or_value(
+                o) for o in lower_bound_operands],
+            upperBoundOperands=[_get_op_result_or_value(
+                o) for o in upper_bound_operands],
+            inits=list(iter_args),
+            lowerBoundMap=AffineMapAttr.get(lower_bound),
+            upperBoundMap=AffineMapAttr.get(upper_bound),
+            step=IntegerAttr.get(IndexType.get(), step),
+            loc=loc,
+            ip=ip
+        )
+        self.regions[0].blocks.append(IndexType.get(), *results)
+
+    @property
+    def body(self):
+        """Returns the body (block) of the loop."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def induction_variable(self):
+        """Returns the induction variable of the loop."""
+        return self.body.arguments[0]
+
+    @property
+    def inner_iter_args(self):
+        """Returns the loop-carried arguments usable within the loop.
+
+        To obtain the loop-carried operands, use `iter_args`.
+        """
+        return self.body.arguments[1:]
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index c5ec85457493b..c780bf01beac2 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -1,10 +1,10 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-import mlir.dialects.func as func
-import mlir.dialects.arith as arith
-import mlir.dialects.affine as affine
-import mlir.dialects.memref as memref
+from mlir.dialects import func
+from mlir.dialects import arith
+from mlir.dialects import memref
+from mlir.dialects import affine
 
 
 def run(f):
@@ -42,3 +42,82 @@ def affine_store_test(arg0):
                 return mem
 
         print(module)
+
+
+# CHECK-LABEL: TEST: testAffineLoadOp
+@run
+def testAffineLoadOp():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f32 = F32Type.get()
+            index_type = IndexType.get()
+            memref_type_in = MemRefType.get([10, 10], f32)
+
+            # CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 {
+            @func.FuncOp.from_py_func(memref_type_in, index_type)
+            def affine_load_test(I, arg0):
+
+                d0 = AffineDimExpr.get(0)
+                s0 = AffineSymbolExpr.get(0)
+                map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
+
+                # CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32>
+                a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map)
+
+                return a1
+
+        print(module)
+
+
+# CHECK-LABEL: TEST: testAffineForOp
+@run
+def testAffineForOp():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f32 = F32Type.get()
+            index_type = IndexType.get()
+            memref_type = MemRefType.get([1024], f32)
+
+            # CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)>
+            # CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)>
+            # CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) {
+            @func.FuncOp.from_py_func(memref_type)
+            def affine_for_op_test(buffer):
+                # CHECK: %[[C1:.*]] = arith.constant 1 : index
+                c1 = arith.ConstantOp(index_type, 1)
+                # CHECK: %[[C2:.*]] = arith.constant 2 : index
+                c2 = arith.ConstantOp(index_type, 2)
+                # CHECK: %[[C3:.*]] = arith.constant 3 : index
+                c3 = arith.ConstantOp(index_type, 3)
+                # CHECK: %[[C9:.*]] = arith.constant 9 : index
+                c9 = arith.ConstantOp(index_type, 9)
+                # CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32
+                ac0 = AffineConstantExpr.get(0)
+
+                d0 = AffineDimExpr.get(0)
+                d1 = AffineDimExpr.get(1)
+                s0 = AffineSymbolExpr.get(0)
+                lb = AffineMap.get(1, 1, [ac0, d0 + s0])
+                ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1])
+                sum_0 = arith.ConstantOp(f32, 0.0)
+
+                # CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) {
+                sum = affine.AffineForOp(
+                    lb, ub, 2,
+                    iter_args=[sum_0],
+                    lower_bound_operands=[c2, c3],
+                    upper_bound_operands=[c9, c1]
+                )
+
+                with InsertionPoint(sum.body):
+
+                    # CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
+                    tmp = memref.LoadOp(buffer, [sum.induction_variable])
+                    sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
+
+                    affine.AffineYieldOp([sum_next])
+
+                return
+        print(module)

Copy link

github-actions bot commented Dec 5, 2023

✅ With the latest revision this PR passed the Python code formatter.

@makslevental
Copy link
Contributor

makslevental commented Dec 5, 2023

Hey thanks @kaitingwang. I'll take a look soon but can you take a look here https://github.com/llvm/llvm-project/blob/baf42cdea2f243a1c9d2f89c456589e351c50521/mlir/python/mlir/dialects/scf.py#L106 and see if you can add the same thing for affine?

@kaitingwang kaitingwang force-pushed the affine-python-binding-2 branch from baf42cd to 79e7139 Compare December 5, 2023 06:34
@kaitingwang
Copy link
Contributor Author

Hey thanks @kaitingwang. I'll take a look soon but can you take a look here

https://github.com/llvm/llvm-project/blob/baf42cdea2f243a1c9d2f89c456589e351c50521/mlir/python/mlir/dialects/scf.py#L106
and see if you can add the same thing for affine?

Thanks for the suggestion! I added the syntatic suger form of affine.for_ as well as some testcases (mostly adapted from scf.for)

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo that one suggestion (not necessary ofc). Well organized code too. Thanks!

Comment on lines 69 to 71
lowerBoundOperands=[
_get_op_result_or_value(o) for o in lower_bound_operands
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lowerBoundOperands=[
_get_op_result_or_value(o) for o in lower_bound_operands
],
lowerBoundOperands=_get_op_results_or_values(lower_bound_operands),

I think this should work - same for upperBoundOperands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yes, it does work. I changed it and pushed the change.

Comment on lines +10 to +16
def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
f()
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@kaitingwang kaitingwang force-pushed the affine-python-binding-2 branch from 79e7139 to c2bef0a Compare December 5, 2023 06:45
This PR creates the wrapper class AffineForOp and adds a testcase
for it. A testcase for AffineLoadOp is also added as well
as some syntatic suger tests.
@kaitingwang kaitingwang force-pushed the affine-python-binding-2 branch from c2bef0a to 21e9a60 Compare December 5, 2023 06:51
@kaitingwang kaitingwang merged commit 543589a into llvm:main Dec 5, 2023
@kaitingwang
Copy link
Contributor Author

CI was all green so I merged the PR. Thanks @makslevental for the review!

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have requested changes on this one. Please address promptly in a follow-up PR.

@@ -3,3 +3,141 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._affine_ops_gen import *
from ._affine_ops_gen import _Dialect, AffineForOp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: why do we need to import AffineForOp if we import * above? (This is different for _Dialect that is private due to underscore in the name and is therefore not imported as part of *.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think explicitly importing names used in the file is not a bad thing. I also think this is the natural result of essentially treating these files as __init__.pys for modules that don't exist structurally (i.e., it should be affine/__init__.py) and as implementation files as well. Anyway I'll take it out.


@_ods_cext.register_operation(_Dialect, replace=True)
class AffineForOp(AffineForOp):
"""Specialization for the Affine for op class"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: terminate Python comments with a full stop. Same below for some longer descriptions.

Comment on lines +28 to +34
lower_bound,
upper_bound,
step,
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
lower_bound_operands=[],
upper_bound_operands=[],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These would really benefit from having type annotations.

Comment on lines +72 to +73
lowerBoundMap=AffineMapAttr.get(lower_bound),
upperBoundMap=AffineMapAttr.get(upper_bound),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to support integers as lower/upper bound and have maps constructed for them on-the-fly.

Comment on lines +108 to +112
if step is None:
step = 1
if stop is None:
stop = start
start = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be doing this in the "main" builder, not in the wrapper.

Comment on lines +116 to +124
p = constant(IntegerAttr.get(IndexType.get(), p))
elif isinstance(p, float):
raise ValueError(f"{p=} must be int.")
params[i] = p

start, stop = params
s0 = AffineSymbolExpr.get(0)
lbmap = AffineMap.get(0, 1, [s0])
ubmap = AffineMap.get(0, 1, [s0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to create a constant operation and feed as a symbol into an affine map. Affine maps can have constant expressions and we should use that. I also suspect canonicalization or constant folding will immediately remove those constants.

def testForSugar():
index_type = IndexType.get()
memref_t = MemRefType.get([10], index_type)
range = affine.for_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks rather confusing/scary to me...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the range = affine.for_? In the scf test I have range_ (easy to have missed the trailing underscore...).

@makslevental
Copy link
Contributor

makslevental commented Dec 5, 2023

WIP #74495.

@kaitingwang
Copy link
Contributor Author

@makslevental

WIP #74495.

Really appreciate your help to address @ftynse 's comments (next time I'll wait for your approval too)! Thank you for your reviews. I'll go over the left-over situation after work today and will be happy to address any remaining issues needed. (At @makslevental's fast speed, there may not be anything left for me to address!) :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants