Skip to content

[mlir][python] python binding wrapper for the affine.AffineForOp #74408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions mlir/python/mlir/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,141 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._affine_ops_gen import *
from ._affine_ops_gen import _Dialect, AffineForOp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: why do we need to import AffineForOp if we import * above? (This is different for _Dialect that is private due to underscore in the name and is therefore not imported as part of *.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think explicitly importing names used in the file is not a bad thing. I also think this is the natural result of essentially treating these files as __init__.pys for modules that don't exist structurally (i.e., it should be affine/__init__.py) and as implementation files as well. Anyway I'll take it out.

from .arith import constant

try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, Sequence, Union


@_ods_cext.register_operation(_Dialect, replace=True)
class AffineForOp(AffineForOp):
"""Specialization for the Affine for op class"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: terminate Python comments with a full stop. Same below for some longer descriptions.


def __init__(
self,
lower_bound,
upper_bound,
step,
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
lower_bound_operands=[],
upper_bound_operands=[],
Comment on lines +28 to +34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These would really benefit from having type annotations.

loc=None,
ip=None,
):
"""Creates an Affine `for` operation.

- `lower_bound` is the affine map to use as lower bound of the loop.
- `upper_bound` is the affine map to use as upper bound of the loop.
- `step` is the value to use as loop step.
- `iter_args` is a list of additional loop-carried arguments or an operation
producing them as results.
- `lower_bound_operands` is the list of arguments to substitute the dimensions,
then symbols in the `lower_bound` affine map, in an increasing order
- `upper_bound_operands` is the list of arguments to substitute the dimensions,
then symbols in the `upper_bound` affine map, in an increasing order
"""

if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
if len(lower_bound_operands) != lower_bound.n_inputs:
raise ValueError(
f"Wrong number of lower bound operands passed to AffineForOp. "
+ "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}."
)

if len(upper_bound_operands) != upper_bound.n_inputs:
raise ValueError(
f"Wrong number of upper bound operands passed to AffineForOp. "
+ "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}."
)

results = [arg.type for arg in iter_args]
super().__init__(
results_=results,
lowerBoundOperands=_get_op_results_or_values(lower_bound_operands),
upperBoundOperands=_get_op_results_or_values(upper_bound_operands),
inits=list(iter_args),
lowerBoundMap=AffineMapAttr.get(lower_bound),
upperBoundMap=AffineMapAttr.get(upper_bound),
Comment on lines +72 to +73
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to support integers as lower/upper bound and have maps constructed for them on-the-fly.

step=IntegerAttr.get(IndexType.get(), step),
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(IndexType.get(), *results)

@property
def body(self):
"""Returns the body (block) of the loop."""
return self.regions[0].blocks[0]

@property
def induction_variable(self):
"""Returns the induction variable of the loop."""
return self.body.arguments[0]

@property
def inner_iter_args(self):
"""Returns the loop-carried arguments usable within the loop.

To obtain the loop-carried operands, use `iter_args`.
"""
return self.body.arguments[1:]


def for_(
start,
stop=None,
step=None,
iter_args: Optional[Sequence[Value]] = None,
*,
loc=None,
ip=None,
):
if step is None:
step = 1
if stop is None:
stop = start
start = 0
Comment on lines +108 to +112
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be doing this in the "main" builder, not in the wrapper.

params = [start, stop]
for i, p in enumerate(params):
if isinstance(p, int):
p = constant(IntegerAttr.get(IndexType.get(), p))
elif isinstance(p, float):
raise ValueError(f"{p=} must be int.")
params[i] = p

start, stop = params
s0 = AffineSymbolExpr.get(0)
lbmap = AffineMap.get(0, 1, [s0])
ubmap = AffineMap.get(0, 1, [s0])
Comment on lines +116 to +124
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to create a constant operation and feed as a symbol into an affine map. Affine maps can have constant expressions and we should use that. I also suspect canonicalization or constant folding will immediately remove those constants.

for_op = AffineForOp(
lbmap,
ubmap,
step,
iter_args=iter_args,
lower_bound_operands=[start],
upper_bound_operands=[stop],
loc=loc,
ip=ip,
)
iv = for_op.induction_variable
iter_args = tuple(for_op.inner_iter_args)
with InsertionPoint(for_op.body):
if len(iter_args) > 1:
yield iv, iter_args
elif len(iter_args) == 1:
yield iv, iter_args[0]
else:
yield iv
182 changes: 155 additions & 27 deletions mlir/test/python/dialects/affine.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,172 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.arith as arith
import mlir.dialects.affine as affine
import mlir.dialects.memref as memref
from mlir.dialects import func
from mlir.dialects import arith
from mlir.dialects import memref
from mlir.dialects import affine


def run(f):
def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
f()
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
Comment on lines +10 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

return f


# CHECK-LABEL: TEST: testAffineStoreOp
@run
@constructAndPrintInModule
def testAffineStoreOp():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f32 = F32Type.get()
index_type = IndexType.get()
memref_type_out = MemRefType.get([12, 12], f32)
f32 = F32Type.get()
index_type = IndexType.get()
memref_type_out = MemRefType.get([12, 12], f32)

# CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
@func.FuncOp.from_py_func(index_type)
def affine_store_test(arg0):
# CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
mem = memref.AllocOp(memref_type_out, [], []).result
# CHECK: func.func @affine_store_test(%[[ARG0:.*]]: index) -> memref<12x12xf32> {
@func.FuncOp.from_py_func(index_type)
def affine_store_test(arg0):
# CHECK: %[[O_VAR:.*]] = memref.alloc() : memref<12x12xf32>
mem = memref.AllocOp(memref_type_out, [], []).result

d0 = AffineDimExpr.get(0)
s0 = AffineSymbolExpr.get(0)
map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])
d0 = AffineDimExpr.get(0)
s0 = AffineSymbolExpr.get(0)
map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])

# CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
a1 = arith.ConstantOp(f32, 2.1)
# CHECK: %[[A1:.*]] = arith.constant 2.100000e+00 : f32
a1 = arith.ConstantOp(f32, 2.1)

# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)

return mem
return mem

print(module)

# CHECK-LABEL: TEST: testAffineLoadOp
@constructAndPrintInModule
def testAffineLoadOp():
f32 = F32Type.get()
index_type = IndexType.get()
memref_type_in = MemRefType.get([10, 10], f32)

# CHECK: func.func @affine_load_test(%[[I_VAR:.*]]: memref<10x10xf32>, %[[ARG0:.*]]: index) -> f32 {
@func.FuncOp.from_py_func(memref_type_in, index_type)
def affine_load_test(I, arg0):
d0 = AffineDimExpr.get(0)
s0 = AffineSymbolExpr.get(0)
map = AffineMap.get(1, 1, [s0 * 3, d0 + s0 + 1])

# CHECK: {{.*}} = affine.load %[[I_VAR]][symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<10x10xf32>
a1 = affine.AffineLoadOp(f32, I, indices=[arg0, arg0], map=map)

return a1


# CHECK-LABEL: TEST: testAffineForOp
@constructAndPrintInModule
def testAffineForOp():
f32 = F32Type.get()
index_type = IndexType.get()
memref_type = MemRefType.get([1024], f32)

# CHECK: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (0, d0 + s0)>
# CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 - 2, d1 * 32)>
# CHECK: func.func @affine_for_op_test(%[[BUFFER:.*]]: memref<1024xf32>) {
@func.FuncOp.from_py_func(memref_type)
def affine_for_op_test(buffer):
# CHECK: %[[C1:.*]] = arith.constant 1 : index
c1 = arith.ConstantOp(index_type, 1)
# CHECK: %[[C2:.*]] = arith.constant 2 : index
c2 = arith.ConstantOp(index_type, 2)
# CHECK: %[[C3:.*]] = arith.constant 3 : index
c3 = arith.ConstantOp(index_type, 3)
# CHECK: %[[C9:.*]] = arith.constant 9 : index
c9 = arith.ConstantOp(index_type, 9)
# CHECK: %[[AC0:.*]] = arith.constant 0.000000e+00 : f32
ac0 = AffineConstantExpr.get(0)

d0 = AffineDimExpr.get(0)
d1 = AffineDimExpr.get(1)
s0 = AffineSymbolExpr.get(0)
lb = AffineMap.get(1, 1, [ac0, d0 + s0])
ub = AffineMap.get(2, 0, [d0 - 2, 32 * d1])
sum_0 = arith.ConstantOp(f32, 0.0)

# CHECK: %0 = affine.for %[[INDVAR:.*]] = max #[[MAP0]](%[[C2]])[%[[C3]]] to min #[[MAP1]](%[[C9]], %[[C1]]) step 2 iter_args(%[[SUM0:.*]] = %[[AC0]]) -> (f32) {
sum = affine.AffineForOp(
lb,
ub,
2,
iter_args=[sum_0],
lower_bound_operands=[c2, c3],
upper_bound_operands=[c9, c1],
)

with InsertionPoint(sum.body):
# CHECK: %[[TMP:.*]] = memref.load %[[BUFFER]][%[[INDVAR]]] : memref<1024xf32>
tmp = memref.LoadOp(buffer, [sum.induction_variable])
sum_next = arith.AddFOp(sum.inner_iter_args[0], tmp)

affine.AffineYieldOp([sum_next])

return


@constructAndPrintInModule
def testForSugar():
index_type = IndexType.get()
memref_t = MemRefType.get([10], index_type)
range = affine.for_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks rather confusing/scary to me...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the range = affine.for_? In the scf test I have range_ (easy to have missed the trailing underscore...).


# CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
# CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
# CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step 2 {
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
# CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
def range_loop_1(lb, ub, step, memref_v):
for i in range(lb, 10, 2):
add = arith.addi(i, i)
s0 = AffineSymbolExpr.get(0)
map = AffineMap.get(0, 1, [s0])
affine.store(add, memref_v, [i], map=map)
affine.AffineYieldOp([])

# CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
# CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
# CHECK: affine.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] {
# CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
# CHECK: affine.store %[[VAL_8]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_7]]{{\)\]}} : memref<10xindex>
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
def range_loop_2(lb, ub, step, memref_v):
for i in range(0, 10, 1):
add = arith.addi(i, i)
s0 = AffineSymbolExpr.get(0)
map = AffineMap.get(0, 1, [s0])
affine.store(add, memref_v, [i], map=map)
affine.AffineYieldOp([])

# CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
# CHECK: affine.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] {
# CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
# CHECK: affine.store %[[VAL_7]], %[[VAL_3]]{{\[symbol\(}}%[[VAL_6]]{{\)\]}} : memref<10xindex>
# CHECK: }
# CHECK: return
# CHECK: }
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
def range_loop_3(lb, ub, step, memref_v):
for i in range(0, ub, 1):
add = arith.addi(i, i)
s0 = AffineSymbolExpr.get(0)
map = AffineMap.get(0, 1, [s0])
affine.store(add, memref_v, [i], map=map)
affine.AffineYieldOp([])