Skip to content

Commit 3766ba4

Browse files
authored
[mlir][python] Fix how the mlir variadic Python accessor _ods_equally_sized_accessor is used (llvm#101132) (llvm#106003)
As reported in llvm#101132, this fixes two bugs: 1. When accessing variadic operands inside an operation, it must be accessed as `self.operation.operands` instead of `operation.operands` 2. The implementation of the `equally_sized_accessor` function is doing wrong arithmetics when calculating the resulting index and group sizes. I have added a test for the `equally_sized_accessor` function, which did not have a test previously.
1 parent c6008ce commit 3766ba4

File tree

6 files changed

+269
-41
lines changed

6 files changed

+269
-41
lines changed

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,22 @@ def segmented_accessor(elements, raw_segments, idx):
5151

5252

5353
def equally_sized_accessor(
54-
elements, n_variadic, n_preceding_simple, n_preceding_variadic
54+
elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic
5555
):
5656
"""
5757
Returns a starting position and a number of elements per variadic group
5858
assuming equally-sized groups and the given numbers of preceding groups.
5959
6060
elements: a sequential container.
61+
n_simple: the number of non-variadic groups in the container.
6162
n_variadic: the number of variadic groups in the container.
6263
n_preceding_simple: the number of non-variadic groups preceding the current
6364
group.
6465
n_preceding_variadic: the number of variadic groups preceding the current
6566
group.
6667
"""
6768

68-
total_variadic_length = len(elements) - n_variadic + 1
69+
total_variadic_length = len(elements) - n_simple
6970
# This should be enforced by the C++-side trait verifier.
7071
assert total_variadic_length % n_variadic == 0
7172

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -480,18 +480,18 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
480480
[SameVariadicOperandSize]> {
481481
// CHECK: @builtins.property
482482
// CHECK: def variadic1(self):
483-
// CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 0)
484-
// CHECK: return self.operation.operands[start:start + pg]
483+
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 0)
484+
// CHECK: return self.operation.operands[start:start + elements_per_group]
485485
//
486486
// CHECK: @builtins.property
487487
// CHECK: def non_variadic(self):
488-
// CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 0, 1)
488+
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 0, 1)
489489
// CHECK: return self.operation.operands[start]
490490
//
491491
// CHECK: @builtins.property
492492
// CHECK: def variadic2(self):
493-
// CHECK: start, pg = _ods_equally_sized_accessor(operation.operands, 2, 1, 1)
494-
// CHECK: return self.operation.operands[start:start + pg]
493+
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.operands, 1, 2, 1, 1)
494+
// CHECK: return self.operation.operands[start:start + elements_per_group]
495495
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
496496
Variadic<AnyType>:$variadic2);
497497
}
@@ -506,18 +506,18 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
506506
[SameVariadicResultSize]> {
507507
// CHECK: @builtins.property
508508
// CHECK: def variadic1(self):
509-
// CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 0)
510-
// CHECK: return self.operation.results[start:start + pg]
509+
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 0)
510+
// CHECK: return self.operation.results[start:start + elements_per_group]
511511
//
512512
// CHECK: @builtins.property
513513
// CHECK: def non_variadic(self):
514-
// CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 0, 1)
514+
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 0, 1)
515515
// CHECK: return self.operation.results[start]
516516
//
517517
// CHECK: @builtins.property
518518
// CHECK: def variadic2(self):
519-
// CHECK: start, pg = _ods_equally_sized_accessor(operation.results, 2, 1, 1)
520-
// CHECK: return self.operation.results[start:start + pg]
519+
// CHECK: start, elements_per_group = _ods_equally_sized_accessor(self.operation.results, 1, 2, 1, 1)
520+
// CHECK: return self.operation.results[start:start + elements_per_group]
521521
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
522522
Variadic<AnyType>:$variadic2);
523523
}

mlir/test/python/dialects/ods_helpers.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gc
44

55
from mlir.ir import *
6+
from mlir.dialects._ods_common import equally_sized_accessor
67

78

89
def run(f):
@@ -208,3 +209,70 @@ class TestOp(OpView):
208209

209210

210211
run(testOdsBuildDefaultCastError)
212+
213+
214+
def testOdsEquallySizedAccessor():
215+
class TestOpMultiResultSegments(OpView):
216+
OPERATION_NAME = "custom.test_op"
217+
_ODS_REGIONS = (1, True)
218+
219+
with Context() as ctx, Location.unknown():
220+
ctx.allow_unregistered_dialects = True
221+
m = Module.create()
222+
with InsertionPoint(m.body):
223+
v = add_dummy_value()
224+
ts = [IntegerType.get_signless(i * 8) for i in range(4)]
225+
226+
op = TestOpMultiResultSegments.build_generic(
227+
results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
228+
)
229+
start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0)
230+
# CHECK: start: 1, elements_per_group: 1
231+
print(f"start: {start}, elements_per_group: {elements_per_group}")
232+
# CHECK: i8
233+
print(op.results[start].type)
234+
235+
start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1)
236+
# CHECK: start: 2, elements_per_group: 1
237+
print(f"start: {start}, elements_per_group: {elements_per_group}")
238+
# CHECK: i16
239+
print(op.results[start].type)
240+
241+
242+
run(testOdsEquallySizedAccessor)
243+
244+
245+
def testOdsEquallySizedAccessorMultipleSegments():
246+
class TestOpMultiResultSegments(OpView):
247+
OPERATION_NAME = "custom.test_op"
248+
_ODS_REGIONS = (1, True)
249+
_ODS_RESULT_SEGMENTS = [0, -1, -1]
250+
251+
def types(lst):
252+
return [e.type for e in lst]
253+
254+
with Context() as ctx, Location.unknown():
255+
ctx.allow_unregistered_dialects = True
256+
m = Module.create()
257+
with InsertionPoint(m.body):
258+
v = add_dummy_value()
259+
ts = [IntegerType.get_signless(i * 8) for i in range(7)]
260+
261+
op = TestOpMultiResultSegments.build_generic(
262+
results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
263+
operands=[v],
264+
)
265+
start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0)
266+
# CHECK: start: 1, elements_per_group: 3
267+
print(f"start: {start}, elements_per_group: {elements_per_group}")
268+
# CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
269+
print(types(op.results[start : start + elements_per_group]))
270+
271+
start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1)
272+
# CHECK: start: 4, elements_per_group: 3
273+
print(f"start: {start}, elements_per_group: {elements_per_group}")
274+
# CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
275+
print(types(op.results[start : start + elements_per_group]))
276+
277+
278+
run(testOdsEquallySizedAccessorMultipleSegments)

mlir/test/python/dialects/python_test.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,123 @@ def testInferTypeOpInterface():
555555
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
556556
# CHECK: f32
557557
print(two_operands.result.type)
558+
559+
560+
# CHECK-LABEL: TEST: testVariadicOperandAccess
561+
@run
562+
def testVariadicOperandAccess():
563+
def values(lst):
564+
return [str(e) for e in lst]
565+
566+
with Context() as ctx, Location.unknown(ctx):
567+
module = Module.create()
568+
with InsertionPoint(module.body):
569+
i32 = IntegerType.get_signless(32)
570+
zero = arith.ConstantOp(i32, 0)
571+
one = arith.ConstantOp(i32, 1)
572+
two = arith.ConstantOp(i32, 2)
573+
three = arith.ConstantOp(i32, 3)
574+
four = arith.ConstantOp(i32, 4)
575+
576+
variadic_operands = test.SameVariadicOperandSizeOp(
577+
[zero, one], two, [three, four]
578+
)
579+
# CHECK: Value(%{{.*}} = arith.constant 2 : i32)
580+
print(variadic_operands.non_variadic)
581+
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
582+
print(values(variadic_operands.variadic1))
583+
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
584+
print(values(variadic_operands.variadic2))
585+
586+
587+
# CHECK-LABEL: TEST: testVariadicResultAccess
588+
@run
589+
def testVariadicResultAccess():
590+
def types(lst):
591+
return [e.type for e in lst]
592+
593+
with Context() as ctx, Location.unknown(ctx):
594+
module = Module.create()
595+
with InsertionPoint(module.body):
596+
i = [IntegerType.get_signless(k) for k in range(7)]
597+
598+
# Test Variadic-Fixed-Variadic
599+
op = test.SameVariadicResultSizeOpVFV([i[0], i[1]], i[2], [i[3], i[4]])
600+
# CHECK: i2
601+
print(op.non_variadic.type)
602+
# CHECK: [IntegerType(i0), IntegerType(i1)]
603+
print(types(op.variadic1))
604+
# CHECK: [IntegerType(i3), IntegerType(i4)]
605+
print(types(op.variadic2))
606+
607+
# Test Variadic-Variadic-Variadic
608+
op = test.SameVariadicResultSizeOpVVV(
609+
[i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
610+
)
611+
# CHECK: [IntegerType(i0), IntegerType(i1)]
612+
print(types(op.variadic1))
613+
# CHECK: [IntegerType(i2), IntegerType(i3)]
614+
print(types(op.variadic2))
615+
# CHECK: [IntegerType(i4), IntegerType(i5)]
616+
print(types(op.variadic3))
617+
618+
# Test Fixed-Fixed-Variadic
619+
op = test.SameVariadicResultSizeOpFFV(i[0], i[1], [i[2], i[3], i[4]])
620+
# CHECK: i0
621+
print(op.non_variadic1.type)
622+
# CHECK: i1
623+
print(op.non_variadic2.type)
624+
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
625+
print(types(op.variadic))
626+
627+
# Test Variadic-Variadic-Fixed
628+
op = test.SameVariadicResultSizeOpVVF(
629+
[i[0], i[1], i[2]], [i[3], i[4], i[5]], i[6]
630+
)
631+
# CHECK: [IntegerType(i0), IntegerType(i1), IntegerType(i2)]
632+
print(types(op.variadic1))
633+
# CHECK: [IntegerType(i3), IntegerType(i4), IntegerType(i5)]
634+
print(types(op.variadic2))
635+
# CHECK: i6
636+
print(op.non_variadic.type)
637+
638+
# Test Fixed-Variadic-Fixed-Variadic-Fixed
639+
op = test.SameVariadicResultSizeOpFVFVF(
640+
i[0], [i[1], i[2]], i[3], [i[4], i[5]], i[6]
641+
)
642+
# CHECK: i0
643+
print(op.non_variadic1.type)
644+
# CHECK: [IntegerType(i1), IntegerType(i2)]
645+
print(types(op.variadic1))
646+
# CHECK: i3
647+
print(op.non_variadic2.type)
648+
# CHECK: [IntegerType(i4), IntegerType(i5)]
649+
print(types(op.variadic2))
650+
# CHECK: i6
651+
print(op.non_variadic3.type)
652+
653+
# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 0
654+
op = test.SameVariadicResultSizeOpFVFVF(i[0], [], i[1], [], i[2])
655+
# CHECK: i0
656+
print(op.non_variadic1.type)
657+
# CHECK: []
658+
print(types(op.variadic1))
659+
# CHECK: i1
660+
print(op.non_variadic2.type)
661+
# CHECK: []
662+
print(types(op.variadic2))
663+
# CHECK: i2
664+
print(op.non_variadic3.type)
665+
666+
# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 1
667+
op = test.SameVariadicResultSizeOpFVFVF(i[0], [i[1]], i[2], [i[3]], i[4])
668+
# CHECK: i0
669+
print(op.non_variadic1.type)
670+
# CHECK: [IntegerType(i1)]
671+
print(types(op.variadic1))
672+
# CHECK: i2
673+
print(op.non_variadic2.type)
674+
# CHECK: [IntegerType(i3)]
675+
print(types(op.variadic2))
676+
# CHECK: i4
677+
print(op.non_variadic3.type)

mlir/test/python/python_test_ops.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,4 +227,42 @@ def OptionalOperandOp : TestOp<"optional_operand_op"> {
227227
let results = (outs I32:$result);
228228
}
229229

230+
def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
231+
[SameVariadicOperandSize]> {
232+
let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
233+
Variadic<AnyType>:$variadic2);
234+
}
235+
236+
// Check different arrangements of variadic groups
237+
def SameVariadicResultSizeOpVFV : TestOp<"same_variadic_result_vfv",
238+
[SameVariadicResultSize]> {
239+
let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
240+
Variadic<AnyType>:$variadic2);
241+
}
242+
243+
def SameVariadicResultSizeOpVVV : TestOp<"same_variadic_result_vvv",
244+
[SameVariadicResultSize]> {
245+
let results = (outs Variadic<AnyType>:$variadic1, Variadic<AnyType>:$variadic2,
246+
Variadic<AnyType>:$variadic3);
247+
}
248+
249+
def SameVariadicResultSizeOpFFV : TestOp<"same_variadic_result_ffv",
250+
[SameVariadicResultSize]> {
251+
let results = (outs AnyType:$non_variadic1, AnyType:$non_variadic2,
252+
Variadic<AnyType>:$variadic);
253+
}
254+
255+
def SameVariadicResultSizeOpVVF : TestOp<"same_variadic_result_vvf",
256+
[SameVariadicResultSize]> {
257+
let results = (outs Variadic<AnyType>:$variadic1, Variadic<AnyType>:$variadic2,
258+
AnyType:$non_variadic);
259+
}
260+
261+
def SameVariadicResultSizeOpFVFVF : TestOp<"same_variadic_result_fvfvf",
262+
[SameVariadicResultSize]> {
263+
let results = (outs AnyType:$non_variadic1, Variadic<AnyType>:$variadic1,
264+
AnyType:$non_variadic2, Variadic<AnyType>:$variadic2,
265+
AnyType:$non_variadic3);
266+
}
267+
230268
#endif // PYTHON_TEST_OPS

0 commit comments

Comments
 (0)