-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] python binding wrapper for the affine.AffineForOp #74408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Amy Wang (kaitingwang) ChangesThis PR creates the wrapper class AffineForOp and adds a testcase for it. A testcase for the AffineLoadOp is also added. Full diff: https://github.com/llvm/llvm-project/pull/74408.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 80d3873e19a05..71b44e5492716 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,3 +3,94 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
+from ._affine_ops_gen import _Dialect, AffineForOp
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class AffineForOp(AffineForOp):
+ """Specialization for the Affine for op class"""
+
+ def __init__(
+ self,
+ lower_bound,
+ upper_bound,
+ step,
+ iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ lower_bound_operands=[],
+ upper_bound_operands=[],
+ loc=None,
+ ip=None
+ ):
+ """Creates an Affine `for` operation.
+
+ - `lower_bound` is the affine map to use as lower bound of the loop.
+ - `upper_bound` is the affine map to use as upper bound of the loop.
+ - `step` is the value to use as loop step.
+ - `iter_args` is a list of additional loop-carried arguments or an operation
+ producing them as results.
+ - `lower_bound_operands` is the list of arguments to substitute the dimensions,
+ then symbols in the `lower_bound` affine map, in an increasing order
+ - `upper_bound_operands` is the list of arguments to substitute the dimensions,
+ then symbols in the `upper_bound` affine map, in an increasing order
+ """
+
+ if iter_args is None:
+ iter_args = []
+ iter_args = _get_op_results_or_values(iter_args)
+
+ if len(lower_bound_operands) != lower_bound.n_inputs:
+ raise ValueError(
+ f"Wrong number of lower bound operands passed to AffineForOp. " +
+ "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}.")
+
+ if len(upper_bound_operands) != upper_bound.n_inputs:
+ raise ValueError(
+ f"Wrong number of upper bound operands passed to AffineForOp. " +
+ "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}.")
+
+ results = [arg.type for arg in iter_args]
+ super().__init__(
+ results_=results,
+ lowerBoundOperands=[_get_op_result_or_value(
+ o) for o in lower_bound_operands],
+ upperBoundOperands=[_get_op_result_or_value(
+ o) for o in upper_bound_operands],
+ inits=list(iter_args),
+ lowerBoundMap=AffineMapAttr.get(lower_bound),
+ upperBoundMap=AffineMapAttr.get(upper_bound),
+ step=IntegerAttr.get(IndexType.get(), step),
+ loc=loc,
+ ip=ip
+ )
+ self.regions[0].blocks.append(IndexType.get(), *results)
+
+ @property
+ def body(self):
+ """Returns the body (block) of the loop."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def induction_variable(self):
+ """Returns the induction variable of the loop."""
+ return self.body.arguments[0]
+
+ @property
+ def inner_iter_args(self):
+ """Returns the loop-carried arguments usable within the loop.
+
+ To obtain the loop-carried operands, use `iter_args`.
+ """
+ return self.body.arguments[1:]
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index c5ec85457493b..c780bf01beac2 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -1,10 +1,10 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
-import mlir.dialects.func as func
-import mlir.dialects.arith as arith
-import mlir.dialects.affine as affine
-import mlir.dialects.memref as memref
+from mlir.dialects import func
+from mlir.dialects import arith
+from mlir.dialects import memref
+from mlir.dialects import affine
def run(f):
@@ -42,3 +42,82 @@ def affine_store_test(arg0):
return mem
print(module)
+
+
+# CHECK-LABEL: TEST: testAffineLoadOp
+@run
+def testAffineLoadOp():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f32 = F32Type.get()
+ index_type = IndexType.get()
+ memref_type_in = MemRefType.get([10, 10], f32)
+
+ # CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 {
+ @func.FuncOp.from_py_func(memref_type_in, index_type)
+ def affine_load_test(I, arg0):
+
+ d0 = AffineDimExpr.get(0)
+ s0 = AffineSymbolExpr.get(0)
+ map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
+
+ # CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32>
+ a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map)
+
+ return a1
+
+ print(module)
+
+
+# CHECK-LABEL: TEST: testAffineForOp
+@run
+def testAffineForOp():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f32 = F32Type.get()
+ index_type = IndexType.get()
+ memref_type = MemRefType.get([1024], f32)
+
+ # CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)>
+ # CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)>
+ # CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) {
+ @func.FuncOp.from_py_func(memref_type)
+ def affine_for_op_test(buffer):
+ # CHECK: %[[C1:.*]] = arith.constant 1 : index
+ c1 = arith.ConstantOp(index_type, 1)
+ # CHECK: %[[C2:.*]] = arith.constant 2 : index
+ c2 = arith.ConstantOp(index_type, 2)
+ # CHECK: %[[C3:.*]] = arith.constant 3 : index
+ c3 = arith.ConstantOp(index_type, 3)
+ # CHECK: %[[C9:.*]] = arith.constant 9 : index
+ c9 = arith.ConstantOp(index_type, 9)
+ # CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32
+ ac0 = AffineConstantExpr.get(0)
+
+ d0 = AffineDimExpr.get(0)
+ d1 = AffineDimExpr.get(1)
+ s0 = AffineSymbolExpr.get(0)
+ lb = AffineMap.get(1, 1, [ac0, d0 + s0])
+ ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1])
+ sum_0 = arith.ConstantOp(f32, 0.0)
+
+ # CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) {
+ sum = affine.AffineForOp(
+ lb, ub, 2,
+ iter_args=[sum_0],
+ lower_bound_operands=[c2, c3],
+ upper_bound_operands=[c9, c1]
+ )
+
+ with InsertionPoint(sum.body):
+
+ # CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
+ tmp = memref.LoadOp(buffer, [sum.induction_variable])
+ sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)
+
+ affine.AffineYieldOp([sum_next])
+
+ return
+ print(module)
|
✅ With the latest revision this PR passed the Python code formatter. |
Hey thanks @kaitingwang. I'll take a look soon but can you take a look here https://github.com/llvm/llvm-project/blob/baf42cdea2f243a1c9d2f89c456589e351c50521/mlir/python/mlir/dialects/scf.py#L106 and see if you can add the same thing for affine? |
baf42cd
to
79e7139
Compare
Thanks for the suggestion! I added the syntatic suger form of affine.for_ as well as some testcases (mostly adapted from scf.for) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo that one suggestion (not necessary ofc). Well organized code too. Thanks!
mlir/python/mlir/dialects/affine.py
Outdated
lowerBoundOperands=[ | ||
_get_op_result_or_value(o) for o in lower_bound_operands | ||
], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lowerBoundOperands=[ | |
_get_op_result_or_value(o) for o in lower_bound_operands | |
], | |
lowerBoundOperands=_get_op_results_or_values(lower_bound_operands), |
I think this should work - same for upperBoundOperands
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Yes, it does work. I changed it and pushed the change.
def constructAndPrintInModule(f): | ||
print("\nTEST:", f.__name__) | ||
f() | ||
with Context(), Location.unknown(): | ||
module = Module.create() | ||
with InsertionPoint(module.body): | ||
f() | ||
print(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
79e7139
to
c2bef0a
Compare
This PR creates the wrapper class AffineForOp and adds a testcase for it. A testcase for AffineLoadOp is also added as well as some syntatic suger tests.
c2bef0a
to
21e9a60
Compare
CI was all green so I merged the PR. Thanks @makslevental for the review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have requested changes on this one. Please address promptly in a follow-up PR.
@@ -3,3 +3,141 @@ | |||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
|
|||
from ._affine_ops_gen import * | |||
from ._affine_ops_gen import _Dialect, AffineForOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: why do we need to import AffineForOp
if we import *
above? (This is different for _Dialect
that is private due to underscore in the name and is therefore not imported as part of *
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think explicitly importing names used in the file is not a bad thing. I also think this is the natural result of essentially treating these files as __init__.py
s for modules that don't exist structurally (i.e., it should be affine/__init__.py
) and as implementation files as well. Anyway I'll take it out.
|
||
@_ods_cext.register_operation(_Dialect, replace=True) | ||
class AffineForOp(AffineForOp): | ||
"""Specialization for the Affine for op class""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: terminate Python comments with a full stop. Same below for some longer descriptions.
lower_bound, | ||
upper_bound, | ||
step, | ||
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, | ||
*, | ||
lower_bound_operands=[], | ||
upper_bound_operands=[], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These would really benefit from having type annotations.
lowerBoundMap=AffineMapAttr.get(lower_bound), | ||
upperBoundMap=AffineMapAttr.get(upper_bound), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to support integers as lower/upper bound and have maps constructed for them on-the-fly.
if step is None: | ||
step = 1 | ||
if stop is None: | ||
stop = start | ||
start = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be doing this in the "main" builder, not in the wrapper.
p = constant(IntegerAttr.get(IndexType.get(), p)) | ||
elif isinstance(p, float): | ||
raise ValueError(f"{p=} must be int.") | ||
params[i] = p | ||
|
||
start, stop = params | ||
s0 = AffineSymbolExpr.get(0) | ||
lbmap = AffineMap.get(0, 1, [s0]) | ||
ubmap = AffineMap.get(0, 1, [s0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to create a constant operation and feed as a symbol into an affine map. Affine maps can have constant expressions and we should use that. I also suspect canonicalization or constant folding will immediately remove those constants.
def testForSugar(): | ||
index_type = IndexType.get() | ||
memref_t = MemRefType.get([10], index_type) | ||
range = affine.for_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks rather confusing/scary to me...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the range = affine.for_
? In the scf
test I have range_
(easy to have missed the trailing underscore...).
WIP #74495. |
Really appreciate your help to address @ftynse 's comments (next time I'll wait for your approval too)! Thank you for your reviews. I'll go over the left-over situation after work today and will be happy to address any remaining issues needed. (At @makslevental's fast speed, there may not be anything left for me to address!) :) |
This PR creates the wrapper class AffineForOp and adds a testcase for it. A testcase for the AffineLoadOp is also added.