Skip to content

Commit 9442b44

Browse files
[mlir][linalg][transform][python] Fix optional args of PadOp mix-in.
The mix-in did not allow to *not* set many of the arguments, even though they represent optional attributes. Instead, it set default values, which have different semantics in some cases. In other cases, setting the default values is already done by the C++ layer, in which case they are currently redundant and may be wrong in some potential future change in the TD or C++ files. With this patch, `None` is preserved until the generated binding, which handles them as desired. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D158844
1 parent e257c0a commit 9442b44

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

mlir/python/mlir/dialects/_structured_transform_ops_ext.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _get_value_list(
125125

126126
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
127127
if values is None:
128-
return ArrayAttr.get([])
128+
return None
129129

130130
# Turn into a Python list of Python ints.
131131
values = _get_value_list(values)
@@ -148,7 +148,7 @@ def _get_int_array_array_attr(
148148
If the input is None, an empty ArrayAttr is returned.
149149
"""
150150
if values is None:
151-
return ArrayAttr.get([])
151+
return None
152152

153153
# Make sure the outer level is a list.
154154
values = _get_value_list(values)
@@ -493,9 +493,7 @@ def __init__(
493493
self,
494494
target: Union[Operation, OpView, Value],
495495
*,
496-
padding_values: Optional[
497-
Union[ArrayAttr, Sequence[Union[bool, int, float, Attribute]]]
498-
] = None,
496+
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
499497
padding_dimensions: OptionalIntList = None,
500498
pad_to_multiple_of: OptionalIntList = None,
501499
pack_paddings: OptionalIntList = None,
@@ -506,17 +504,6 @@ def __init__(
506504
loc=None,
507505
ip=None,
508506
):
509-
if padding_values is None:
510-
padding_values = []
511-
if padding_dimensions is None:
512-
padding_dimensions = []
513-
if pad_to_multiple_of is None:
514-
pad_to_multiple_of = []
515-
if pack_paddings is None:
516-
pack_paddings = []
517-
if transpose_paddings is None:
518-
transpose_paddings = []
519-
520507
padding_dimensions = _get_int_array_attr(padding_dimensions)
521508
pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of)
522509
pack_paddings = _get_int_array_attr(pack_paddings)

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,22 +314,41 @@ def testMultitileSizes():
314314

315315

316316
@run
317-
def testPad():
317+
def testPadOpNoArgs():
318+
sequence = transform.SequenceOp(
319+
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
320+
)
321+
with InsertionPoint(sequence.body):
322+
structured.PadOp(sequence.bodyTarget)
323+
transform.YieldOp()
324+
# CHECK-LABEL: TEST: testPadOpNoArgs
325+
# CHECK: transform.sequence
326+
# CHECK: transform.structured.pad
327+
# CHECK-NOT: copy_back_op
328+
# CHECK-NOT: pack_paddings
329+
# CHECK-NOT: pad_to_multiple_of
330+
# CHECK-NOT: padding_dimensions
331+
# CHECK-NOT: padding_values
332+
# CHECK-NOT: transpose_paddings
333+
334+
335+
@run
336+
def testPadOpArgs():
318337
sequence = transform.SequenceOp(
319338
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
320339
)
321340
with InsertionPoint(sequence.body):
322341
structured.PadOp(
323342
sequence.bodyTarget,
324-
padding_values=[FloatAttr.get_f32(42.0)],
343+
padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
325344
padding_dimensions=Attribute.parse("[1]"),
326345
pad_to_multiple_of=[128],
327346
pack_paddings=[0],
328347
transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
329348
copy_back_op="linalg.copy",
330349
)
331350
transform.YieldOp()
332-
# CHECK-LABEL: TEST: testPad
351+
# CHECK-LABEL: TEST: testPadOpArgs
333352
# CHECK: transform.sequence
334353
# CHECK: transform.structured.pad
335354
# CHECK-DAG: copy_back_op = "linalg.copy"

0 commit comments

Comments
 (0)