@@ -498,16 +498,10 @@ def median_filter(
498
498
kernel = kernel .to (in_tensor )
499
499
# map the local window to single vector
500
500
conv = [F .conv1d , F .conv2d , F .conv3d ][spatial_dims - 1 ] # type: ignore
501
- if "padding" not in kwargs :
502
- if pytorch_after (1 , 10 ):
503
- kwargs ["padding" ] = "same"
504
- else :
505
- # even-sized kernels are not supported
506
- kwargs ["padding" ] = [(k - 1 ) // 2 for k in kernel .shape [2 :]]
507
- elif kwargs ["padding" ] == "same" and not pytorch_after (1 , 10 ):
508
- # even-sized kernels are not supported
509
- kwargs ["padding" ] = [(k - 1 ) // 2 for k in kernel .shape [2 :]]
510
- features : torch .Tensor = conv (in_tensor .reshape (oprod , 1 , * sshape ), kernel , stride = 1 , ** kwargs ) # type: ignore
501
+ # even-sized kernels are not supported
502
+ padding = [(k - 1 ) // 2 for k in kernel .shape [2 :]]
503
+ padded_input : torch .Tensor = F .pad (in_tensor .reshape (oprod , 1 , * sshape ), pad = padding , mode = "replicate" )
504
+ features : torch .Tensor = conv (padded_input , kernel , stride = 1 , ** kwargs ) # type: ignore
511
505
features = features .view (oprod , - 1 , * sshape ) # type: ignore
512
506
513
507
# compute the median along the feature axis
0 commit comments