-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] fix up affine for #74495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
14fa324
to
8786c94
Compare
8786c94
to
c333f86
Compare
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesFix up #74408. Full diff: https://github.com/llvm/llvm-project/pull/74495.diff 3 Files Affected:
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 60ce83c09f171..20ec08400d081 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -134,3 +134,6 @@ def get_op_result_or_op_results(
# see the typing.Type doc string.
_U = _TypeVar("_U", bound=_cext.ir.Value)
SubClassValueT = _Type[_U]
+
+ResultValueT = _Union[_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value]
+VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 26e827009bc04..834a8cccc7c71 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,8 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect, AffineForOp
-from .arith import constant
+from ._affine_ops_gen import _Dialect
try:
from ..ir import *
@@ -12,6 +11,8 @@
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
+ ResultValueT as _ResultValueT,
+ VariadicResultValueT as _VariadicResultValueT,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -21,17 +22,17 @@
@_ods_cext.register_operation(_Dialect, replace=True)
class AffineForOp(AffineForOp):
- """Specialization for the Affine for op class"""
+ """Specialization for the Affine for op class."""
def __init__(
self,
- lower_bound,
- upper_bound,
- step,
- iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ lower_bound: Union[int, _ResultValueT, AffineMap],
+ upper_bound: Optional[Union[int, _ResultValueT, AffineMap]] = None,
+ step: Optional[Union[int, _ResultValueT]] = None,
+ iter_args: Optional[_ResultValueT] = None,
*,
- lower_bound_operands=[],
- upper_bound_operands=[],
+ lower_bound_operands: Optional[_VariadicResultValueT] = None,
+ upper_bound_operands: Optional[_VariadicResultValueT] = None,
loc=None,
ip=None,
):
@@ -45,23 +46,40 @@ def __init__(
- `lower_bound_operands` is the list of arguments to substitute the dimensions,
then symbols in the `lower_bound` affine map, in an increasing order
- `upper_bound_operands` is the list of arguments to substitute the dimensions,
- then symbols in the `upper_bound` affine map, in an increasing order
+ then symbols in the `upper_bound` affine map, in an increasing order.
"""
+ if lower_bound_operands is None:
+ lower_bound_operands = []
+ if upper_bound_operands is None:
+ upper_bound_operands = []
+
+ if step is None:
+ step = 1
+ if upper_bound is None:
+ upper_bound, lower_bound = lower_bound, 0
+
+ if isinstance(lower_bound, int):
+ lower_bound = AffineMap.get_constant(lower_bound)
+ elif isinstance(lower_bound, _ResultValueT):
+ lower_bound_operands.append(lower_bound)
+ lower_bound = AffineMap.get_constant(1)
+
+ if not isinstance(lower_bound, AffineMap):
+ raise ValueError(f"{lower_bound=} must be int | ResultValueT | AffineMap")
+
+ if isinstance(upper_bound, int):
+ upper_bound = AffineMap.get_constant(upper_bound)
+ elif isinstance(upper_bound, _ResultValueT):
+ upper_bound_operands.append(upper_bound)
+ upper_bound = AffineMap.get_constant(1)
+
+ if not isinstance(upper_bound, AffineMap):
+ raise ValueError(f"{upper_bound=} must be int | ResultValueT | AffineMap")
+
if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
- if len(lower_bound_operands) != lower_bound.n_inputs:
- raise ValueError(
- f"Wrong number of lower bound operands passed to AffineForOp. "
- + "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
- )
-
- if len(upper_bound_operands) != upper_bound.n_inputs:
- raise ValueError(
- f"Wrong number of upper bound operands passed to AffineForOp. "
- + "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
- )
results = [arg.type for arg in iter_args]
super().__init__(
@@ -71,7 +89,7 @@ def __init__(
inits=list(iter_args),
lowerBoundMap=AffineMapAttr.get(lower_bound),
upperBoundMap=AffineMapAttr.get(upper_bound),
- step=IntegerAttr.get(IndexType.get(), step),
+ step=step,
loc=loc,
ip=ip,
)
@@ -105,30 +123,11 @@ def for_(
loc=None,
ip=None,
):
- if step is None:
- step = 1
- if stop is None:
- stop = start
- start = 0
- params = [start, stop]
- for i, p in enumerate(params):
- if isinstance(p, int):
- p = constant(IntegerAttr.get(IndexType.get(), p))
- elif isinstance(p, float):
- raise ValueError(f"{p=} must be int.")
- params[i] = p
-
- start, stop = params
- s0 = AffineSymbolExpr.get(0)
- lbmap = AffineMap.get(0, 1, [s0])
- ubmap = AffineMap.get(0, 1, [s0])
for_op = AffineForOp(
- lbmap,
- ubmap,
+ start,
+ stop,
step,
iter_args=iter_args,
- lower_bound_operands=[start],
- upper_bound_operands=[stop],
loc=loc,
ip=ip,
)
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index df42f8fcf1a57..737044b293f8c 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -5,6 +5,7 @@
from mlir.dialects import arith
from mlir.dialects import memref
from mlir.dialects import affine
+import mlir.extras.types as T
def constructAndPrintInModule(f):
@@ -115,58 +116,130 @@ def affine_for_op_test(buffer):
@constructAndPrintInModule
def testForSugar():
- index_type = IndexType.get()
- memref_t = MemRefType.get([10], index_type)
+ memref_t = T.memref(10, T.index())
range = affine.for_
- # CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
- # CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
- # CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
- # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
- # CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
- # CHECK: }
- # CHECK: return
- # CHECK: }
- @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
- def range_loop_1(lb, ub, step, memref_v):
- for i in range(lb, 10, 2):
+ # CHECK-LABEL: func.func @range_loop_1(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 1 to 1 iter_args() -> () {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: affine.yield
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_1(lb, ub, memref_v):
+ for i in range(lb, ub, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_2(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 1 to 10 iter_args() -> () {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: affine.yield
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_2(lb, ub, memref_v):
+ for i in range(lb, 10, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_3(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 1 iter_args() -> () {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: affine.yield
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_3(lb, ub, memref_v):
+ for i in range(0, ub, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_4(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_4(lb, ub, memref_v):
+ for i in range(0, 10, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_5(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_5(lb, ub, memref_v):
+ for i in range(0, 10, step=1):
+ add = arith.addi(i, i)
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_6(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_6(lb, ub, memref_v):
+ for i in range(0, 10):
add = arith.addi(i, i)
- s0 = AffineSymbolExpr.get(0)
- map = AffineMap.get(0, 1, [s0])
- affine.store(add, memref_v, [i], map=map)
- affine.AffineYieldOp([])
-
- # CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
- # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
- # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
- # CHECK: affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
- # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
- # CHECK: affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
- # CHECK: }
- # CHECK: return
- # CHECK: }
- @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
- def range_loop_2(lb, ub, step, memref_v):
- for i in range(0, 10, 1):
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_7(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: affine.for %[[VAL_3:.*]] = 0 to 10 {
+ # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
+ # CHECK: memref.store %[[VAL_4]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_7(lb, ub, memref_v):
+ for i in range(10):
add = arith.addi(i, i)
- s0 = AffineSymbolExpr.get(0)
- map = AffineMap.get(0, 1, [s0])
- affine.store(add, memref_v, [i], map=map)
- affine.AffineYieldOp([])
-
- # CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
- # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
- # CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
- # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
- # CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
- # CHECK: }
- # CHECK: return
- # CHECK: }
- @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
- def range_loop_3(lb, ub, step, memref_v):
- for i in range(0, ub, 1):
+ memref.store(add, memref_v, [i])
+ affine.yield_([])
+
+ # CHECK-LABEL: func.func @range_loop_8(
+ # CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: memref<10xindex>) {
+ # CHECK: %[[VAL_3:.*]] = affine.for %[[VAL_4:.*]] = 0 to 10 iter_args(%[[VAL_5:.*]] = %[[VAL_2]]) -> (memref<10xindex>) {
+ # CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
+ # CHECK: memref.store %[[VAL_6]], %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<10xindex>
+ # CHECK: affine.yield %[[VAL_5]] : memref<10xindex>
+ # CHECK: }
+ # CHECK: return
+ # CHECK: }
+ @func.FuncOp.from_py_func(T.index(), T.index(), memref_t)
+ def range_loop_8(lb, ub, memref_v):
+ for i, it in range(10, iter_args=[memref_v]):
add = arith.addi(i, i)
- s0 = AffineSymbolExpr.get(0)
- map = AffineMap.get(0, 1, [s0])
- affine.store(add, memref_v, [i], map=map)
- affine.AffineYieldOp([])
+ memref.store(add, it, [i])
+ affine.yield_([it])
|
6169f30
to
83be26d
Compare
83be26d
to
ba00415
Compare
ba00415
to
829df04
Compare
mlir/python/mlir/dialects/affine.py
Outdated
|
||
if isinstance(lower_bound, int): | ||
lower_bound = AffineMap.get_constant(lower_bound) | ||
elif isinstance(lower_bound, (Operation, OpView, Value)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I like _ResultValueT
in commons, some other dialects have local equivalents. Would it be possible to have ResultValueTypeTuple = (_cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value)
, ResultValueT = _Union[*ResultValueTypeTuple]
to avoid repetition here? Not sure if that works as I think it does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You know what's a better solution? To bump the minimum version of python to 3.10 :) - I had isinstance(bound, _VariadicResultValueT)
but the windows bot failed where the Linux bot passed (because I guess the windows python version is lower than 3.10 but the linux version isn't). Anyway let me try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Splat doesn't work for <3.10 but lucky for us Union
has a tuple "constructor" so just _Union[ResultValueTypeTuple]
works.
IIRC this is why I never upstreamed this one myself (the error checking is such a pain............). And I probably still missed some paths :( |
168ef88
to
2f12f48
Compare
Fix up #74408.
I believe I've satisfied all of the desiderata but let me know if I missed anything.