@@ -43,8 +43,8 @@ def visit_affine_exprs(expr):
43
43
if isinstance (expr , TensorUse ):
44
44
for ind in expr .indices :
45
45
ind .visit_affine_exprs (visit_dim_def )
46
- if isinstance (expr , ReduceApply ):
47
- for ind in expr .reduce .reduce_dims :
46
+ if isinstance (expr , TensorReduceFn ):
47
+ for ind in expr .reduce_fn .reduce_dims :
48
48
ind .visit_affine_exprs (visit_dim_def )
49
49
50
50
self .visit_tensor_exprs (visit_affine_exprs )
@@ -114,8 +114,8 @@ def tensor_name(self) -> str:
114
114
assert name is not None , "TensorDef not attached"
115
115
return name
116
116
117
- def __iadd__ (self , rhs : TensorExpression ) -> TensorExpression :
118
- return ReduceFn .add ( * self ._compute_reduce_dims (rhs ))(rhs )
117
+ def __iadd__ (self , rhs : TensorExpression ) -> "TensorReduceFn" :
118
+ return ReduceFnUse ( ArithFn .add , * self ._compute_reduce_dims (rhs ))(rhs )
119
119
120
120
def _compute_reduce_dims (self , rhs : TensorExpression ) -> Set [DimDef ]:
121
121
"""For implicit reductions, computes default reduction dims.
@@ -285,7 +285,7 @@ def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
285
285
286
286
# Find the lhs to reduction rhs.
287
287
for assign , value in bindings :
288
- if isinstance (value , ReduceApply ):
288
+ if isinstance (value , TensorReduceFn ):
289
289
if value .lhs :
290
290
raise ValueError (f"Reduction expression already assigns: { value } " )
291
291
value .lhs = assign
@@ -297,8 +297,8 @@ def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
297
297
"""Gets the reduction dims for the comprehension or None."""
298
298
result = set ()
299
299
for use in self .values :
300
- if isinstance (use , ReduceApply ):
301
- result .add (use .reduce .reduce_dims )
300
+ if isinstance (use , TensorReduceFn ):
301
+ result .add (use .reduce_use .reduce_dims )
302
302
else :
303
303
result .add (tuple ())
304
304
return result
@@ -360,10 +360,6 @@ def __init__(self, fn_name: str):
360
360
def __call__ (self , * args ) -> "TensorArithFn" :
361
361
return TensorArithFn (self , args )
362
362
363
- def reduce (self , * reduce_dims : DimDef ):
364
- """Shortcut to create a Reduce operation from this function."""
365
- return ReduceFnType (self , * reduce_dims )
366
-
367
363
def __repr__ (self ):
368
364
return f"{ self .fn_name } "
369
365
@@ -389,31 +385,49 @@ class ArithFn:
389
385
min_unsigned = ArithFnType ("min_unsigned" )
390
386
391
387
392
- class ReduceFnType :
393
- """A reduction operator that reduces into its LHS from its RHS."""
388
+ class ReduceFnUse :
389
+ """Reduction function use.
390
+
391
+ A reduction use specifies the reduction function and dimensions.
392
+ """
394
393
395
- def __init__ (self , operator : ArithFnType , * reduce_dims : DimDef ):
396
- """Initializes the ReduceFn with an airthmetic function and dims."""
397
- if not isinstance (operator , ArithFnType ):
398
- raise ValueError (f"Reduce expected a ArithFnType but got { operator } " )
399
- self .operator = operator
400
- self .reduce_dims = tuple (reduce_dims )
394
+ def __init__ (self , arith_fn : ArithFnType , * reduce_dims : DimDef ):
395
+ self .arith_fn = arith_fn
396
+ self .reduce_dims = reduce_dims
401
397
402
398
def __call__ (self , * args : TensorExpression ):
403
- return ReduceApply (self , args )
399
+ return TensorReduceFn (self , args )
404
400
405
401
def __repr__ (self ):
406
- return (f"reduce_{ self .operator .fn_name } "
402
+ return (f"reduce_{ self .arith_fn .fn_name } "
407
403
f"({ ', ' .join (repr (d ) for d in self .reduce_dims )} )" )
408
404
409
405
406
+ class ReduceFnType :
407
+ """Reduction function.
408
+
409
+ An arithmetic function that reduces its RHS into its LHS.
410
+ """
411
+
412
+ def __init__ (self , arith_fn : ArithFnType ):
413
+ if not isinstance (arith_fn , ArithFnType ):
414
+ raise ValueError (f"Reduce expected a ArithFnType but got { arith_fn } " )
415
+ self .arith_fn = arith_fn
416
+
417
+ def __getitem__ (self , reduce_dims : Tuple [DimDef ]) -> ReduceFnUse :
418
+ return ReduceFnUse (self .arith_fn , * reduce_dims )
419
+
420
+ def __repr__ (self ):
421
+ return (f"reduce_{ self .arith_fn .fn_name } " )
422
+
423
+
410
424
class ReduceFn :
411
- add = ArithFn .add . reduce
412
- mul = ArithFn .mul . reduce
413
- max = ArithFn .max . reduce
414
- min = ArithFn .min . reduce
415
- max_unsigned = ArithFn .max_unsigned . reduce
416
- min_unsigned = ArithFn .min_unsigned . reduce
425
+ add = ReduceFnType ( ArithFn .add )
426
+ mul = ReduceFnType ( ArithFn .mul )
427
+ max = ReduceFnType ( ArithFn .max )
428
+ min = ReduceFnType ( ArithFn .min )
429
+ max_unsigned = ReduceFnType ( ArithFn .max_unsigned )
430
+ min_unsigned = ReduceFnType ( ArithFn .min_unsigned )
417
431
418
432
419
433
class TensorArithFn (TensorExpression ):
@@ -499,31 +513,31 @@ def __repr__(self):
499
513
return f"index({ repr (self .dim )} )"
500
514
501
515
502
- class ReduceApply (TensorExpression ):
503
- """Application of a reduction.
516
+ class TensorReduceFn (TensorExpression ):
517
+ """Application of a reduction function .
504
518
505
- This captures the lhs separately (initial value) separately from the rhs.
519
+ This captures the lhs (initial value) separately from the rhs.
506
520
"""
507
521
508
- def __init__ (self , reduce : ReduceFnType , args : Sequence [TensorExpression ]):
509
- self .reduce = reduce
522
+ def __init__ (self , reduce_use : ReduceFnUse , args : Sequence [TensorExpression ]):
523
+ self .reduce_use = reduce_use
510
524
self .lhs = None # type: Optional[TensorUse]
511
525
self .args = tuple (args )
512
526
513
527
def to_scalar_expression (self ) -> ScalarExpression :
514
528
if self .lhs is None :
515
- raise ValueError (f"Cannot scalarize a ReduceApply that has not been "
529
+ raise ValueError (f"Cannot scalarize a TensorReduceFn that has not been "
516
530
f"bound to its lhs: { self } " )
517
531
full_args = [self .lhs .to_scalar_expression ()
518
532
] + [arg .to_scalar_expression () for arg in self .args ]
519
- return ScalarArithFn (self .reduce . operator .fn_name , * full_args ).expr ()
533
+ return ScalarArithFn (self .reduce_use . arith_fn .fn_name , * full_args ).expr ()
520
534
521
535
def visit_tensor_exprs (self , callback ):
522
536
for arg in self .args :
523
537
arg .visit_tensor_exprs (callback )
524
538
525
539
def __repr__ (self ):
526
- return f"{ repr (self .reduce )} ({ ', ' .join (repr (a ) for a in self .args )} )"
540
+ return f"{ repr (self .reduce_use )} ({ ', ' .join (repr (a ) for a in self .args )} )"
527
541
528
542
529
543
class OpInterfaceDef :
0 commit comments