Skip to content

[mlir][python] Fix how the mlir variadic Python accessor _ods_equally_sized_accessor is used (#101132) #106003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/python/mlir/dialects/_ods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,22 @@ def segmented_accessor(elements, raw_segments, idx):


def equally_sized_accessor(
elements, n_variadic, n_preceding_simple, n_preceding_variadic
elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic
):
"""
Returns a starting position and a number of elements per variadic group
assuming equally-sized groups and the given numbers of preceding groups.

elements: a sequential container.
n_simple: the number of non-variadic groups in the container.
n_variadic: the number of variadic groups in the container.
n_preceding_simple: the number of non-variadic groups preceding the current
group.
n_preceding_variadic: the number of variadic groups preceding the current
group.
"""

total_variadic_length = len(elements) - n_variadic + 1
total_variadic_length = len(elements) - n_simple
# This should be enforced by the C++-side trait verifier.
assert total_variadic_length % n_variadic == 0

Expand Down
20 changes: 10 additions & 10 deletions mlir/test/mlir-tblgen/op-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,18 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
[SameVariadicOperandSize]> {
// CHECK: @builtins.property
// CHECK: def variadic1(self):
// CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 0)
// CHECK: return self.operation.operands[start:start + pg]
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
// CHECK: return self.operation.operands[start:start + elements_per_group]
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
// CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 1)
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
// CHECK: return self.operation.operands[start]
//
// CHECK: @builtins.property
// CHECK: def variadic2(self):
// CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 1, 1)
// CHECK: return self.operation.operands[start:start + pg]
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
// CHECK: return self.operation.operands[start:start + elements_per_group]
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
}
Expand All @@ -506,18 +506,18 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
[SameVariadicResultSize]> {
// CHECK: @builtins.property
// CHECK: def variadic1(self):
// CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 0)
// CHECK: return self.operation.results[start:start + pg]
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
// CHECK: return self.operation.results[start:start + elements_per_group]
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
// CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 1)
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
// CHECK: return self.operation.results[start]
//
// CHECK: @builtins.property
// CHECK: def variadic2(self):
// CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 1, 1)
// CHECK: return self.operation.results[start:start + pg]
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
// CHECK: return self.operation.results[start:start + elements_per_group]
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
}
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/python/dialects/ods_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc

from mlir.ir import *
from mlir.dialects._ods_common import equally_sized_accessor


def run(f):
Expand Down Expand Up @@ -208,3 +209,70 @@ class TestOp(OpView):


run(testOdsBuildDefaultCastError)


def testOdsEquallySizedAccessor():
class TestOpMultiResultSegments(OpView):
OPERATION_NAME = "custom.test_op"
_ODS_REGIONS = (1, True)

with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
m = Module.create()
with InsertionPoint(m.body):
v = add_dummy_value()
ts = [IntegerType.get_signless(i * 8) for i in range(4)]

op = TestOpMultiResultSegments.build_generic(
results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
)
start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0)
# CHECK: start: 1, elements_per_group: 1
print(f"start: {start}, elements_per_group: {elements_per_group}")
# CHECK: i8
print(op.results[start].type)

start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1)
# CHECK: start: 2, elements_per_group: 1
print(f"start: {start}, elements_per_group: {elements_per_group}")
# CHECK: i16
print(op.results[start].type)


run(testOdsEquallySizedAccessor)


def testOdsEquallySizedAccessorMultipleSegments():
class TestOpMultiResultSegments(OpView):
OPERATION_NAME = "custom.test_op"
_ODS_REGIONS = (1, True)
_ODS_RESULT_SEGMENTS = [0, -1, -1]

def types(lst):
return [e.type for e in lst]

with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
m = Module.create()
with InsertionPoint(m.body):
v = add_dummy_value()
ts = [IntegerType.get_signless(i * 8) for i in range(7)]

op = TestOpMultiResultSegments.build_generic(
results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
operands=[v],
)
start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0)
# CHECK: start: 1, elements_per_group: 3
print(f"start: {start}, elements_per_group: {elements_per_group}")
# CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
print(types(op.results[start : start + elements_per_group]))

start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1)
# CHECK: start: 4, elements_per_group: 3
print(f"start: {start}, elements_per_group: {elements_per_group}")
# CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
print(types(op.results[start : start + elements_per_group]))


run(testOdsEquallySizedAccessorMultipleSegments)
120 changes: 120 additions & 0 deletions mlir/test/python/dialects/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,123 @@ def testInferTypeOpInterface():
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
# CHECK: f32
print(two_operands.result.type)


# CHECK-LABEL: TEST: testVariadicOperandAccess
@run
def testVariadicOperandAccess():
def values(lst):
return [str(e) for e in lst]

with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
i32 = IntegerType.get_signless(32)
zero = arith.ConstantOp(i32, 0)
one = arith.ConstantOp(i32, 1)
two = arith.ConstantOp(i32, 2)
three = arith.ConstantOp(i32, 3)
four = arith.ConstantOp(i32, 4)

variadic_operands = test.SameVariadicOperandSizeOp(
[zero, one], two, [three, four]
)
# CHECK: Value(%{{.*}} = arith.constant 2 : i32)
print(variadic_operands.non_variadic)
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
print(values(variadic_operands.variadic1))
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
print(values(variadic_operands.variadic2))


# CHECK-LABEL: TEST: testVariadicResultAccess
@run
def testVariadicResultAccess():
def types(lst):
return [e.type for e in lst]

with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
i = [IntegerType.get_signless(k) for k in range(7)]

# Test Variadic-Fixed-Variadic
op = test.SameVariadicResultSizeOpVFV([i[0], i[1]], i[2], [i[3], i[4]])
# CHECK: i2
print(op.non_variadic.type)
# CHECK: [IntegerType(i0), IntegerType(i1)]
print(types(op.variadic1))
# CHECK: [IntegerType(i3), IntegerType(i4)]
print(types(op.variadic2))

# Test Variadic-Variadic-Variadic
op = test.SameVariadicResultSizeOpVVV(
[i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
)
# CHECK: [IntegerType(i0), IntegerType(i1)]
print(types(op.variadic1))
# CHECK: [IntegerType(i2), IntegerType(i3)]
print(types(op.variadic2))
# CHECK: [IntegerType(i4), IntegerType(i5)]
print(types(op.variadic3))

# Test Fixed-Fixed-Variadic
op = test.SameVariadicResultSizeOpFFV(i[0], i[1], [i[2], i[3], i[4]])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
print(types(op.variadic))

# Test Variadic-Variadic-Fixed
op = test.SameVariadicResultSizeOpVVF(
[i[0], i[1], i[2]], [i[3], i[4], i[5]], i[6]
)
# CHECK: [IntegerType(i0), IntegerType(i1), IntegerType(i2)]
print(types(op.variadic1))
# CHECK: [IntegerType(i3), IntegerType(i4), IntegerType(i5)]
print(types(op.variadic2))
# CHECK: i6
print(op.non_variadic.type)

# Test Fixed-Variadic-Fixed-Variadic-Fixed
op = test.SameVariadicResultSizeOpFVFVF(
i[0], [i[1], i[2]], i[3], [i[4], i[5]], i[6]
)
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: [IntegerType(i1), IntegerType(i2)]
print(types(op.variadic1))
# CHECK: i3
print(op.non_variadic2.type)
# CHECK: [IntegerType(i4), IntegerType(i5)]
print(types(op.variadic2))
# CHECK: i6
print(op.non_variadic3.type)

# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 0
op = test.SameVariadicResultSizeOpFVFVF(i[0], [], i[1], [], i[2])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: []
print(types(op.variadic1))
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: []
print(types(op.variadic2))
# CHECK: i2
print(op.non_variadic3.type)

# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 1
op = test.SameVariadicResultSizeOpFVFVF(i[0], [i[1]], i[2], [i[3]], i[4])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: [IntegerType(i1)]
print(types(op.variadic1))
# CHECK: i2
print(op.non_variadic2.type)
# CHECK: [IntegerType(i3)]
print(types(op.variadic2))
# CHECK: i4
print(op.non_variadic3.type)
38 changes: 38 additions & 0 deletions mlir/test/python/python_test_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,42 @@ def OptionalOperandOp : TestOp<"optional_operand_op"> {
let results = (outs I32:$result);
}

def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
[SameVariadicOperandSize]> {
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
}

// Check different arrangements of variadic groups
def SameVariadicResultSizeOpVFV : TestOp<"same_variadic_result_vfv",
[SameVariadicResultSize]> {
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
Variadic<AnyType>:$variadic2);
}

def SameVariadicResultSizeOpVVV : TestOp<"same_variadic_result_vvv",
[SameVariadicResultSize]> {
let results = (outs Variadic<AnyType>:$variadic1, Variadic<AnyType>:$variadic2,
Variadic<AnyType>:$variadic3);
}

def SameVariadicResultSizeOpFFV : TestOp<"same_variadic_result_ffv",
[SameVariadicResultSize]> {
let results = (outs AnyType:$non_variadic1, AnyType:$non_variadic2,
Variadic<AnyType>:$variadic);
}

def SameVariadicResultSizeOpVVF : TestOp<"same_variadic_result_vvf",
[SameVariadicResultSize]> {
let results = (outs Variadic<AnyType>:$variadic1, Variadic<AnyType>:$variadic2,
AnyType:$non_variadic);
}

def SameVariadicResultSizeOpFVFVF : TestOp<"same_variadic_result_fvfvf",
[SameVariadicResultSize]> {
let results = (outs AnyType:$non_variadic1, Variadic<AnyType>:$variadic1,
AnyType:$non_variadic2, Variadic<AnyType>:$variadic2,
AnyType:$non_variadic3);
}

#endif // PYTHON_TEST_OPS
Loading
Loading