1
1
from collections .abc import Callable
2
- from typing import Literal
2
+ from typing import Literal , cast
3
3
4
+ from pytensor .compile .builders import OpFromGraph
4
5
from pytensor .scan import scan
5
6
from pytensor .tensor import TensorLike
6
7
from pytensor .tensor .basic import (
34
35
]
35
36
stat_funcs = {"maximum" : pt_max , "minimum" : pt_min , "mean" : mean }
36
37
38
+ allowed_kwargs = {
39
+ "edge" : [],
40
+ "wrap" : [],
41
+ "constant" : ["constant_values" ],
42
+ "linear_ramp" : ["end_values" ],
43
+ "maximum" : ["stat_length" ],
44
+ "mean" : ["stat_length" ],
45
+ "median" : ["stat_length" ],
46
+ "minimum" : ["stat_length" ],
47
+ "reflect" : ["reflect_type" ],
48
+ "symmetric" : ["reflect_type" ],
49
+ }
50
+
37
51
38
52
def _slice_at_axis (sl : slice , axis : int ) -> tuple [slice , ...]:
39
53
"""
@@ -225,17 +239,20 @@ def _get_stats(
225
239
226
240
227
241
def _stat_pad (
228
- x : TensorVariable , pad_width : TensorVariable , stat_func , stat_length = None
242
+ x : TensorVariable ,
243
+ pad_width : TensorVariable ,
244
+ stat_func : Callable ,
245
+ stat_length : TensorVariable | None ,
229
246
):
230
247
padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
231
248
if stat_length is None :
232
- stat_length = [[None , None ]] * padded .ndim
249
+ stat_length = [[None , None ]] * padded .ndim # type: ignore
233
250
else :
234
251
stat_length = broadcast_to (stat_length , as_tensor ((padded .ndim , 2 )))
235
252
236
253
for axis in range (padded .ndim ):
237
254
width_pair = pad_width [axis ]
238
- length_pair = stat_length [axis ]
255
+ length_pair = stat_length [axis ] # type: ignore
239
256
dim_shape = padded .shape [axis ]
240
257
241
258
left_stat , right_stat = _get_stats (
@@ -311,6 +328,10 @@ def inner_func(i, x):
311
328
# Delay creation of this function to here because we want to use the axis global inside the scan
312
329
def inner_func (i , x ):
313
330
return switch (eq (i % 2 , 0 ), flip (x , axis = axis ), x )
331
+ else :
332
+ raise ValueError (
333
+ "You should not have gotten here. Open an issue on github!"
334
+ ) # pragma no cover
314
335
315
336
size = x .shape [axis ]
316
337
repeats , (left_remainder , right_remainder ) = pt_divmod (pad_width [axis ], size )
@@ -330,55 +351,81 @@ def inner_func(i, x):
330
351
return x
331
352
332
353
333
- def pad (x : TensorLike , pad_width : TensorLike , mode : PadMode = "constant" , ** kwargs ):
334
- allowed_kwargs = {
335
- "edge" : [],
336
- "wrap" : [],
337
- "constant" : ["constant_values" ],
338
- "linear_ramp" : ["end_values" ],
339
- "maximum" : ["stat_length" ],
340
- "mean" : ["stat_length" ],
341
- "median" : ["stat_length" ],
342
- "minimum" : ["stat_length" ],
343
- "reflect" : ["reflect_type" ],
344
- "symmetric" : ["reflect_type" ],
345
- }
354
+ class Pad (OpFromGraph ):
355
+ """
356
+ Wrapper Op for Pad graphs
357
+ """
346
358
359
+
360
+ def pad (x : TensorLike , pad_width : TensorLike , mode : PadMode = "constant" , ** kwargs ):
347
361
if any (value not in allowed_kwargs [mode ] for value in kwargs .keys ()):
348
362
raise ValueError (
349
363
f"Invalid keyword arguments for mode '{ mode } ': { kwargs .keys ()} "
350
364
)
351
- x = as_tensor (x )
352
- pad_width = as_tensor (pad_width )
365
+ x = as_tensor (x , name = "x" )
366
+ pad_width = as_tensor (pad_width , name = "pad_width" )
367
+ inputs = [x , pad_width ]
368
+ attrs = {}
353
369
354
370
if mode == "constant" :
355
- constant_values = as_tensor (kwargs .pop ("constant_values" , 0 ))
356
- return _constant_pad (x , pad_width , constant_values )
371
+ constant_values = as_tensor (
372
+ kwargs .pop ("constant_values" , 0 ), name = "constant_values"
373
+ )
374
+ inputs += [constant_values ]
375
+ outputs = _constant_pad (x , pad_width , constant_values )
376
+
357
377
elif mode == "edge" :
358
- return _edge_pad (x , pad_width )
378
+ outputs = _edge_pad (x , pad_width )
379
+
359
380
elif mode in ["maximum" , "minimum" , "mean" , "median" ]:
360
381
if mode == "median" :
361
382
# TODO: pt.quantile? pt.median?
362
383
raise NotImplementedError ("Median padding not implemented" )
363
- stat_func = stat_funcs [mode ]
364
- return _stat_pad (x , pad_width , stat_func , ** kwargs )
384
+ stat_func = cast (Callable , stat_funcs [mode ])
385
+ stat_length = kwargs .get ("stat_length" )
386
+ if stat_length is not None :
387
+ stat_length = as_tensor (stat_length , name = "stat_length" )
388
+ inputs += [stat_length ]
389
+
390
+ attrs .update (
391
+ {"stat_func" : stat_func , "stat_length_input" : stat_length is not None }
392
+ )
393
+ outputs = _stat_pad (x , pad_width , stat_func , stat_length )
394
+
365
395
elif mode == "linear_ramp" :
366
396
end_values = kwargs .pop ("end_values" , 0 )
367
- return _linear_ramp_pad (x , pad_width , end_values )
397
+ end_values = as_tensor (end_values )
398
+
399
+ inputs += [end_values ]
400
+ outputs = _linear_ramp_pad (x , pad_width , end_values )
401
+
368
402
elif mode == "wrap" :
369
- return _looping_pad (x , pad_width , kind = "wrap" )
403
+ attrs .update ({"kind" : "wrap" })
404
+ outputs = _looping_pad (x , pad_width , kind = "wrap" )
405
+
370
406
elif mode == "symmetric" :
371
407
reflect_type = kwargs .pop ("reflect_type" , "even" )
372
408
if reflect_type == "odd" :
373
409
raise NotImplementedError ("Odd reflection not implemented" )
374
- return _looping_pad (x , pad_width , kind = "symmetric" )
410
+
411
+ attrs .update ({"kind" : reflect_type })
412
+ outputs = _looping_pad (x , pad_width , kind = "symmetric" )
413
+
375
414
elif mode == "reflect" :
376
415
reflect_type = kwargs .pop ("reflect_type" , "even" )
377
416
if reflect_type == "odd" :
378
417
raise NotImplementedError ("Odd reflection not implemented" )
418
+ attrs .update ({"reflect_type" : reflect_type })
379
419
raise NotImplementedError ("Reflect padding not implemented" )
380
420
else :
381
421
raise ValueError (f"Invalid mode: { mode } " )
382
422
423
+ op = Pad (inputs = inputs , outputs = [outputs ])(* inputs ) # type: ignore
424
+
425
+ setattr (op , "pad_mode" , mode )
426
+ for pad_arg , value in attrs .items ():
427
+ setattr (op , pad_arg , value )
428
+ return op
429
+
383
430
384
431
__all__ = ["pad" ]
0 commit comments