Skip to content

[mlir][python] fix up affine for #74495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 7, 2023
Merged

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Dec 5, 2023

Fix up #74408.

I believe I've satisfied all of the desiderata but let me know if I missed anything.

@makslevental makslevental marked this pull request as ready for review December 5, 2023 19:41
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Dec 5, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2023

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Fix up #74408.


Full diff: https://github.com/llvm/llvm-project/pull/74495.diff

3 Files Affected:

  • (modified) mlir/python/mlir/dialects/_ods_common.py (+3)
  • (modified) mlir/python/mlir/dialects/affine.py (+42-43)
  • (modified) mlir/test/python/dialects/affine.py (+123-50)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 60ce83c09f171..20ec08400d081 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -134,3 +134,6 @@ def get_op_result_or_op_results(
 # see the typing.Type doc string.
 _U = _TypeVar("_U", bound=_cext.ir.Value)
 SubClassValueT = _Type[_U]
+
+ResultValueT = _Union[_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value]
+VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 26e827009bc04..834a8cccc7c71 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,8 +3,7 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect, AffineForOp
-from .arith import constant
+from ._affine_ops_gen import _Dialect
 
 try:
     from ..ir import *
@@ -12,6 +11,8 @@
         get_op_result_or_value as _get_op_result_or_value,
         get_op_results_or_values as _get_op_results_or_values,
         _cext as _ods_cext,
+        ResultValueT as _ResultValueT,
+        VariadicResultValueT as _VariadicResultValueT,
     )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
@@ -21,17 +22,17 @@
 
 @_ods_cext.register_operation(_Dialect, replace=True)
 class AffineForOp(AffineForOp):
-    """Specialization for the Affine for op class"""
+    """Specialization for the Affine for op class."""
 
     def __init__(
         self,
-        lower_bound,
-        upper_bound,
-        step,
-        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        lower_bound: Union[int, _ResultValueT, AffineMap],
+        upper_bound: Optional[Union[int, _ResultValueT, AffineMap]] = None,
+        step: Optional[Union[int, _ResultValueT]] = None,
+        iter_args: Optional[_ResultValueT] = None,
         *,
-        lower_bound_operands=[],
-        upper_bound_operands=[],
+        lower_bound_operands: Optional[_VariadicResultValueT] = None,
+        upper_bound_operands: Optional[_VariadicResultValueT] = None,
         loc=None,
         ip=None,
     ):
@@ -45,23 +46,40 @@ def __init__(
         - `lower_bound_operands` is the list of arguments to substitute the dimensions,
           then symbols in the `lower_bound` affine map, in an increasing order
         - `upper_bound_operands` is the list of arguments to substitute the dimensions,
-          then symbols in the `upper_bound` affine map, in an increasing order
+          then symbols in the `upper_bound` affine map, in an increasing order.
         """
 
+        if lower_bound_operands is None:
+            lower_bound_operands = []
+        if upper_bound_operands is None:
+            upper_bound_operands = []
+
+        if step is None:
+            step = 1
+        if upper_bound is None:
+            upper_bound, lower_bound = lower_bound, 0
+
+        if isinstance(lower_bound, int):
+            lower_bound = AffineMap.get_constant(lower_bound)
+        elif isinstance(lower_bound, _ResultValueT):
+            lower_bound_operands.append(lower_bound)
+            lower_bound = AffineMap.get_constant(1)
+
+        if not isinstance(lower_bound, AffineMap):
+            raise ValueError(f"{lower_bound=} must be int | ResultValueT | AffineMap")
+
+        if isinstance(upper_bound, int):
+            upper_bound = AffineMap.get_constant(upper_bound)
+        elif isinstance(upper_bound, _ResultValueT):
+            upper_bound_operands.append(upper_bound)
+            upper_bound = AffineMap.get_constant(1)
+
+        if not isinstance(upper_bound, AffineMap):
+            raise ValueError(f"{upper_bound=} must be int | ResultValueT | AffineMap")
+
         if iter_args is None:
             iter_args = []
         iter_args = _get_op_results_or_values(iter_args)
-        if len(lower_bound_operands) != lower_bound.n_inputs:
-            raise ValueError(
-                f"Wrong number of lower bound operands passed to AffineForOp. "
-                + "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
-            )
-
-        if len(upper_bound_operands) != upper_bound.n_inputs:
-            raise ValueError(
-                f"Wrong number of upper bound operands passed to AffineForOp. "
-                + "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
-            )
 
         results = [arg.type for arg in iter_args]
         super().__init__(
@@ -71,7 +89,7 @@ def __init__(
             inits=list(iter_args),
             lowerBoundMap=AffineMapAttr.get(lower_bound),
             upperBoundMap=AffineMapAttr.get(upper_bound),
-            step=IntegerAttr.get(IndexType.get(), step),
+            step=step,
             loc=loc,
             ip=ip,
         )
@@ -105,30 +123,11 @@ def for_(
     loc=None,
     ip=None,
 ):
-    if step is None:
-        step = 1
-    if stop is None:
-        stop = start
-        start = 0
-    params = [start, stop]
-    for i, p in enumerate(params):
-        if isinstance(p, int):
-            p = constant(IntegerAttr.get(IndexType.get(), p))
-        elif isinstance(p, float):
-            raise ValueError(f"{p=} must be int.")
-        params[i] = p
-
-    start, stop = params
-    s0 = AffineSymbolExpr.get(0)
-    lbmap = AffineMap.get(0, 1, [s0])
-    ubmap = AffineMap.get(0, 1, [s0])
     for_op = AffineForOp(
-        lbmap,
-        ubmap,
+        start,
+        stop,
         step,
         iter_args=iter_args,
-        lower_bound_operands=[start],
-        upper_bound_operands=[stop],
         loc=loc,
         ip=ip,
     )
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index df42f8fcf1a57..737044b293f8c 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -5,6 +5,7 @@
 from mlir.dialects import arith
 from mlir.dialects import memref
 from mlir.dialects import affine
+import mlir.extras.types as T
 
 
 def constructAndPrintInModule(f):
@@ -115,58 +116,130 @@ def affine_for_op_test(buffer):
 
 @constructAndPrintInModule
 def testForSugar():
-    index_type = IndexType.get()
-    memref_t = MemRefType.get([10], index_type)
+    memref_t = T.memref(10, T.index())
     range = affine.for_
 
-    # CHECK:  func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
-    # CHECK:    %[[VAL_4:.*]] = arith.constant 10 : index
-    # CHECK:    affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
-    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
-    # CHECK:      affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
-    # CHECK:    }
-    # CHECK:    return
-    # CHECK:  }
-    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
-    def range_loop_1(lb, ub, step, memref_v):
-        for i in range(lb, 10, 2):
+    # CHECK-LABEL:   func.func @range_loop_1(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 1 to 1 iter_args() -> () {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:             affine.yield
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_1(lb, ub, memref_v):
+        for i in range(lb, ub, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_2(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 1 to 10 iter_args() -> () {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:             affine.yield
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_2(lb, ub, memref_v):
+        for i in range(lb, 10, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_3(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 1 iter_args() -> () {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:             affine.yield
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_3(lb, ub, memref_v):
+        for i in range(0, ub, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_4(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_4(lb, ub, memref_v):
+        for i in range(0, 10, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_5(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_5(lb, ub, memref_v):
+        for i in range(0, 10, step=1):
+            add = arith.addi(i, i)
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_6(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_6(lb, ub, memref_v):
+        for i in range(0, 10):
             add = arith.addi(i, i)
-            s0 = AffineSymbolExpr.get(0)
-            map = AffineMap.get(0, 1, [s0])
-            affine.store(add, memref_v, [i], map=map)
-            affine.AffineYieldOp([])
-
-    # CHECK:  func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
-    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
-    # CHECK:    %[[VAL_5:.*]] = arith.constant 10 : index
-    # CHECK:    affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
-    # CHECK:      %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
-    # CHECK:      affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
-    # CHECK:    }
-    # CHECK:    return
-    # CHECK:  }
-    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
-    def range_loop_2(lb, ub, step, memref_v):
-        for i in range(0, 10, 1):
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_7(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           affine.for %[[VAL_3:.*]] = 0 to 10 {
+    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+    # CHECK:             memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_7(lb, ub, memref_v):
+        for i in range(10):
             add = arith.addi(i, i)
-            s0 = AffineSymbolExpr.get(0)
-            map = AffineMap.get(0, 1, [s0])
-            affine.store(add, memref_v, [i], map=map)
-            affine.AffineYieldOp([])
-
-    # CHECK:  func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
-    # CHECK:    %[[VAL_4:.*]] = arith.constant 0 : index
-    # CHECK:    affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
-    # CHECK:      %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
-    # CHECK:      affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
-    # CHECK:    }
-    # CHECK:    return
-    # CHECK:  }
-    @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
-    def range_loop_3(lb, ub, step, memref_v):
-        for i in range(0, ub, 1):
+            memref.store(add, memref_v, [i])
+            affine.yield_([])
+
+    # CHECK-LABEL:   func.func @range_loop_8(
+    # CHECK-SAME:                            %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+    # CHECK:           %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
+    # CHECK:             %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
+    # CHECK:             memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
+    # CHECK:             affine.yield %[[VAL_5]] : memref<10xindex>
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }
+    @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+    def range_loop_8(lb, ub, memref_v):
+        for i, it in range(10, iter_args=[memref_v]):
             add = arith.addi(i, i)
-            s0 = AffineSymbolExpr.get(0)
-            map = AffineMap.get(0, 1, [s0])
-            affine.store(add, memref_v, [i], map=map)
-            affine.AffineYieldOp([])
+            memref.store(add, it, [i])
+            affine.yield_([it])

@makslevental makslevental marked this pull request as draft December 5, 2023 20:28
@makslevental makslevental marked this pull request as ready for review December 5, 2023 22:06

if isinstance(lower_bound, int):
lower_bound = AffineMap.get_constant(lower_bound)
elif isinstance(lower_bound, (Operation, OpView, Value)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I like _ResultValueT in commons, some other dialects have local equivalents. Would it be possible to have ResultValueTypeTuple = (_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value), ResultValueT = _Union[*ResultValueTypeTuple] to avoid repetition here? Not sure if that works as I think it does.

Copy link
Contributor Author

@makslevental makslevental Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know what's a better solution? To bump the minimum version of python to 3.10 :) - I had isinstance(bound, _VariadicResultValueT) but the windows bot failed where the Linux bot passed (because I guess the windows python version is lower than 3.10 but the linux version isn't). Anyway let me try.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splat doesn't work for <3.10 but lucky for us Union has a tuple "constructor" so just _Union[ResultValueTypeTuple] works.

@makslevental
Copy link
Contributor Author

makslevental commented Dec 6, 2023

IIRC this is why I never upstreamed this one myself (the error checking is such a pain............). And I probably still missed some paths :(

@makslevental makslevental requested a review from ftynse December 6, 2023 18:43
@makslevental makslevental merged commit db3bc49 into llvm:main Dec 7, 2023
@makslevental makslevental deleted the fix_affine_for branch December 7, 2023 16:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants