Skip to content

Commit e3b442b

Browse files
author
gysit
committed
[mlir][OpDSL] Separate ReduceFn and ReduceFnUse.
The revision distinguishes `ReduceFn` and `ReduceFnUse`. The latter has the reduction dimensions attached while the former specifies the arithmetic function only. This separation allows us to adapt the reduction syntax a little bit and specify the reduction dimensions using square brackets (in contrast to the round brackets used for the values to reduce). It als is a preparation to add reduction function attributes to OpDSL. A reduction function attribute shall only specify the arithmetic function and not the reduction dimensions. Example: ``` ReduceFn.max_unsigned(D.kh, D.kw)(...) ``` changes to: ``` ReduceFn.max_unsigned[D.kh, D.kw](...) ``` Depends On D115240 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D115241
1 parent cf05668 commit e3b442b

File tree

4 files changed

+68
-47
lines changed

4 files changed

+68
-47
lines changed

mlir/docs/Dialects/Linalg/OpDSL.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,18 @@ A number of arithmetic functions are supported:
192192
As the integer types are signless, signedness is implement by different
193193
functions that treat integers as signed or unsigned values.
194194

195-
Reduction functions can appear as the outer-most function on the RHS:
195+
A subset of the arithmetic functions are supported in reductions. These
196+
reduction functions can appear as the outermost function on the RHS:
196197

197198
* `ReduceFn.add` (also overloading the inplace `+=` on a LHS)
198199
* `ReduceFn.mul`
199200
* `ReduceFn.max`
201+
* `ReduceFn.min`
202+
* `ReduceFn.max_unsigned`
203+
* `ReduceFn.min_unsigned`
204+
205+
As the integer types are signless, signedness is implement by different
206+
functions that treat integers as signed or unsigned values.
200207

201208
Additionally, type conversion functions cast an operand to a target type:
202209

mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def visit_affine_exprs(expr):
4343
if isinstance(expr, TensorUse):
4444
for ind in expr.indices:
4545
ind.visit_affine_exprs(visit_dim_def)
46-
if isinstance(expr, ReduceApply):
47-
for ind in expr.reduce.reduce_dims:
46+
if isinstance(expr, TensorReduceFn):
47+
for ind in expr.reduce_fn.reduce_dims:
4848
ind.visit_affine_exprs(visit_dim_def)
4949

5050
self.visit_tensor_exprs(visit_affine_exprs)
@@ -114,8 +114,8 @@ def tensor_name(self) -> str:
114114
assert name is not None, "TensorDef not attached"
115115
return name
116116

117-
def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
118-
return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
117+
def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
118+
return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
119119

120120
def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
121121
"""For implicit reductions, computes default reduction dims.
@@ -285,7 +285,7 @@ def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
285285

286286
# Find the lhs to reduction rhs.
287287
for assign, value in bindings:
288-
if isinstance(value, ReduceApply):
288+
if isinstance(value, TensorReduceFn):
289289
if value.lhs:
290290
raise ValueError(f"Reduction expression already assigns: {value}")
291291
value.lhs = assign
@@ -297,8 +297,8 @@ def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
297297
"""Gets the reduction dims for the comprehension or None."""
298298
result = set()
299299
for use in self.values:
300-
if isinstance(use, ReduceApply):
301-
result.add(use.reduce.reduce_dims)
300+
if isinstance(use, TensorReduceFn):
301+
result.add(use.reduce_use.reduce_dims)
302302
else:
303303
result.add(tuple())
304304
return result
@@ -360,10 +360,6 @@ def __init__(self, fn_name: str):
360360
def __call__(self, *args) -> "TensorArithFn":
361361
return TensorArithFn(self, args)
362362

363-
def reduce(self, *reduce_dims: DimDef):
364-
"""Shortcut to create a Reduce operation from this function."""
365-
return ReduceFnType(self, *reduce_dims)
366-
367363
def __repr__(self):
368364
return f"{self.fn_name}"
369365

@@ -389,31 +385,49 @@ class ArithFn:
389385
min_unsigned = ArithFnType("min_unsigned")
390386

391387

392-
class ReduceFnType:
393-
"""A reduction operator that reduces into its LHS from its RHS."""
388+
class ReduceFnUse:
389+
"""Reduction function use.
390+
391+
A reduction use specifies the reduction function and dimensions.
392+
"""
394393

395-
def __init__(self, operator: ArithFnType, *reduce_dims: DimDef):
396-
"""Initializes the ReduceFn with an airthmetic function and dims."""
397-
if not isinstance(operator, ArithFnType):
398-
raise ValueError(f"Reduce expected a ArithFnType but got {operator}")
399-
self.operator = operator
400-
self.reduce_dims = tuple(reduce_dims)
394+
def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
395+
self.arith_fn = arith_fn
396+
self.reduce_dims = reduce_dims
401397

402398
def __call__(self, *args: TensorExpression):
403-
return ReduceApply(self, args)
399+
return TensorReduceFn(self, args)
404400

405401
def __repr__(self):
406-
return (f"reduce_{self.operator.fn_name}"
402+
return (f"reduce_{self.arith_fn.fn_name}"
407403
f"({', '.join(repr(d) for d in self.reduce_dims)})")
408404

409405

406+
class ReduceFnType:
407+
"""Reduction function.
408+
409+
An arithmetic function that reduces its RHS into its LHS.
410+
"""
411+
412+
def __init__(self, arith_fn: ArithFnType):
413+
if not isinstance(arith_fn, ArithFnType):
414+
raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
415+
self.arith_fn = arith_fn
416+
417+
def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
418+
return ReduceFnUse(self.arith_fn, *reduce_dims)
419+
420+
def __repr__(self):
421+
return (f"reduce_{self.arith_fn.fn_name}")
422+
423+
410424
class ReduceFn:
411-
add = ArithFn.add.reduce
412-
mul = ArithFn.mul.reduce
413-
max = ArithFn.max.reduce
414-
min = ArithFn.min.reduce
415-
max_unsigned = ArithFn.max_unsigned.reduce
416-
min_unsigned = ArithFn.min_unsigned.reduce
425+
add = ReduceFnType(ArithFn.add)
426+
mul = ReduceFnType(ArithFn.mul)
427+
max = ReduceFnType(ArithFn.max)
428+
min = ReduceFnType(ArithFn.min)
429+
max_unsigned = ReduceFnType(ArithFn.max_unsigned)
430+
min_unsigned = ReduceFnType(ArithFn.min_unsigned)
417431

418432

419433
class TensorArithFn(TensorExpression):
@@ -499,31 +513,31 @@ def __repr__(self):
499513
return f"index({repr(self.dim)})"
500514

501515

502-
class ReduceApply(TensorExpression):
503-
"""Application of a reduction.
516+
class TensorReduceFn(TensorExpression):
517+
"""Application of a reduction function.
504518
505-
This captures the lhs separately (initial value) separately from the rhs.
519+
This captures the lhs (initial value) separately from the rhs.
506520
"""
507521

508-
def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]):
509-
self.reduce = reduce
522+
def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]):
523+
self.reduce_use = reduce_use
510524
self.lhs = None # type: Optional[TensorUse]
511525
self.args = tuple(args)
512526

513527
def to_scalar_expression(self) -> ScalarExpression:
514528
if self.lhs is None:
515-
raise ValueError(f"Cannot scalarize a ReduceApply that has not been "
529+
raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
516530
f"bound to its lhs: {self}")
517531
full_args = [self.lhs.to_scalar_expression()
518532
] + [arg.to_scalar_expression() for arg in self.args]
519-
return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr()
533+
return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
520534

521535
def visit_tensor_exprs(self, callback):
522536
for arg in self.args:
523537
arg.visit_tensor_exprs(callback)
524538

525539
def __repr__(self):
526-
return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})"
540+
return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
527541

528542

529543
class OpInterfaceDef:

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def pooling_nhwc_max(
479479
"""
480480
implements(ConvolutionOpInterface)
481481
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
482-
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
482+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
483483
TypeFn.cast(
484484
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
485485

@@ -499,7 +499,7 @@ def pooling_nhwc_max_unsigned(
499499
"""
500500
implements(ConvolutionOpInterface)
501501
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
502-
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
502+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
503503
TypeFn.cast_unsigned(
504504
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
505505

@@ -519,7 +519,7 @@ def pooling_nchw_max(
519519
"""
520520
implements(ConvolutionOpInterface)
521521
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
522-
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
522+
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw](
523523
TypeFn.cast(
524524
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
525525
D.ow * S.SW + D.kw * S.DW,]))
@@ -540,7 +540,7 @@ def pooling_nhwc_min(
540540
"""
541541
implements(ConvolutionOpInterface)
542542
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
543-
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
543+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
544544
TypeFn.cast(
545545
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
546546

@@ -560,7 +560,7 @@ def pooling_nhwc_min_unsigned(
560560
"""
561561
implements(ConvolutionOpInterface)
562562
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
563-
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
563+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
564564
TypeFn.cast_unsigned(
565565
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
566566

@@ -600,7 +600,7 @@ def pooling_ndhwc_max(
600600
"""
601601
implements(ConvolutionOpInterface)
602602
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
603-
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)(
603+
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw](
604604
TypeFn.cast(
605605
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
606606
D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -621,7 +621,7 @@ def pooling_ndhwc_min(
621621
"""
622622
implements(ConvolutionOpInterface)
623623
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
624-
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)(
624+
O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw](
625625
TypeFn.cast(
626626
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
627627
D.ow * S.SW + D.kw * S.DW, D.c]))

mlir/test/python/dialects/linalg/opdsl/emit_pooling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def pooling_max_poly(
1919
strides=IndexAttrDef(S.SH, S.SW),
2020
dilations=IndexAttrDef(S.DH, S.DW)):
2121
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
22-
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
22+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
2323
TypeFn.cast(
2424
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
2525

@@ -32,7 +32,7 @@ def pooling_max_unsigned_poly(
3232
strides=IndexAttrDef(S.SH, S.SW),
3333
dilations=IndexAttrDef(S.DH, S.DW)):
3434
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
35-
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
35+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
3636
TypeFn.cast_unsigned(
3737
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
3838

@@ -45,7 +45,7 @@ def pooling_min_poly(
4545
strides=IndexAttrDef(S.SH, S.SW),
4646
dilations=IndexAttrDef(S.DH, S.DW)):
4747
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
48-
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
48+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
4949
TypeFn.cast(
5050
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
5151

@@ -58,7 +58,7 @@ def pooling_min_unsigned_poly(
5858
strides=IndexAttrDef(S.SH, S.SW),
5959
dilations=IndexAttrDef(S.DH, S.DW)):
6060
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
61-
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
61+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
6262
TypeFn.cast_unsigned(
6363
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
6464

0 commit comments

Comments
 (0)