|
16 | 16 | import numpy as np
|
17 | 17 | from numpy.core.multiarray import normalize_axis_index
|
18 | 18 | from numpy.core.numeric import normalize_axis_tuple
|
19 |
| -from numpy.lib.arraypad import _get_edges, _slice_at_axis |
20 | 19 |
|
21 | 20 | import pytensor
|
22 | 21 | import pytensor.scalar.sharedvar
|
|
51 | 50 | scalar_elemwise,
|
52 | 51 | )
|
53 | 52 | from pytensor.tensor.exceptions import NotScalarConstantError
|
54 |
| -from pytensor.tensor.extra_ops import broadcast_to, linspace |
55 | 53 | from pytensor.tensor.shape import (
|
56 | 54 | Shape,
|
57 | 55 | Shape_i,
|
|
63 | 61 | shape_tuple,
|
64 | 62 | specify_broadcastable,
|
65 | 63 | )
|
66 |
| -from pytensor.tensor.subtensor import set_subtensor |
67 | 64 | from pytensor.tensor.type import (
|
68 | 65 | TensorType,
|
69 | 66 | discrete_dtypes,
|
@@ -4345,125 +4342,6 @@ def ix_(*args):
|
4345 | 4342 | return tuple(out)
|
4346 | 4343 |
|
4347 | 4344 |
|
4348 |
| -def _symbolic_pad( |
4349 |
| - x: TensorVariable, pad_width: TensorVariable |
4350 |
| -) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]: |
4351 |
| - pad_width = broadcast_to(pad_width, (x.ndim, 2)) |
4352 |
| - new_shape = as_tensor( |
4353 |
| - [pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)] |
4354 |
| - ) |
4355 |
| - original_area_slice = tuple( |
4356 |
| - slice(pad_width[i][0], pad_width[i][0] + size) for i, size in enumerate(x.shape) |
4357 |
| - ) |
4358 |
| - padded: TensorVariable = set_subtensor(zeros(new_shape)[original_area_slice], x) |
4359 |
| - return padded, original_area_slice, pad_width |
4360 |
| - |
4361 |
| - |
4362 |
| -def _get_padding_slices( |
4363 |
| - dim_shape: TensorVariable, width_pair: tuple[TensorVariable], axis: int |
4364 |
| -): |
4365 |
| - left_slice = _slice_at_axis(slice(None, width_pair[0]), axis) |
4366 |
| - right_slice = _slice_at_axis(slice(dim_shape - width_pair[1], None), axis) |
4367 |
| - |
4368 |
| - return left_slice, right_slice |
4369 |
| - |
4370 |
| - |
4371 |
| -def _constant_pad( |
4372 |
| - x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable |
4373 |
| -): |
4374 |
| - padded, area_slice, pad_width = _symbolic_pad(x, pad_width) |
4375 |
| - values = broadcast_to(constant_values, (padded.ndim, 2)) |
4376 |
| - |
4377 |
| - for axis in range(padded.ndim): |
4378 |
| - width_pair = pad_width[axis] |
4379 |
| - value_pair = values[axis] |
4380 |
| - dim_shape = padded.shape[axis] |
4381 |
| - |
4382 |
| - left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) |
4383 |
| - padded = set_subtensor(padded[left_slice], value_pair[0]) |
4384 |
| - padded = set_subtensor(padded[right_slice], value_pair[1]) |
4385 |
| - |
4386 |
| - return padded |
4387 |
| - |
4388 |
| - |
4389 |
| -def _edge_pad(x: TensorVariable, pad_width: TensorVariable): |
4390 |
| - padded, area_slice, pad_width = _symbolic_pad(x, pad_width) |
4391 |
| - for axis in range(padded.ndim): |
4392 |
| - width_pair = pad_width[axis] |
4393 |
| - dim_shape = padded.shape[axis] |
4394 |
| - |
4395 |
| - left_edge, right_edge = _get_edges(padded, axis, width_pair) |
4396 |
| - left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) |
4397 |
| - |
4398 |
| - padded = set_subtensor(padded[left_slice], left_edge) |
4399 |
| - padded = set_subtensor(padded[right_slice], right_edge) |
4400 |
| - |
4401 |
| - return padded |
4402 |
| - |
4403 |
| - |
4404 |
| -def _linear_ramp_pad( |
4405 |
| - x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable = 0 |
4406 |
| -): |
4407 |
| - padded, area_slice, pad_width = _symbolic_pad(x, pad_width) |
4408 |
| - end_values = broadcast_to(end_values, (padded.ndim, 2)) |
4409 |
| - for axis in range(padded.ndim): |
4410 |
| - width_pair = pad_width[axis] |
4411 |
| - end_value_pair = end_values[axis] |
4412 |
| - edge_pair = _get_edges(padded, axis, width_pair) |
4413 |
| - dim_shape = padded.shape[axis] |
4414 |
| - left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) |
4415 |
| - |
4416 |
| - # pt.linspace doesn't have the endpoint kwarg, so need to take one extra step then slice it away |
4417 |
| - left_ramp = linspace( |
4418 |
| - start=end_value_pair[0], |
4419 |
| - end=specify_broadcastable(edge_pair[0], axis).squeeze(axis), |
4420 |
| - steps=width_pair[0] + 1, |
4421 |
| - )[:-1] |
4422 |
| - right_ramp = linspace( |
4423 |
| - start=end_value_pair[1], |
4424 |
| - end=specify_broadcastable(edge_pair[1], axis).squeeze(axis), |
4425 |
| - steps=width_pair[1] + 1, |
4426 |
| - )[:-1] |
4427 |
| - right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)] |
4428 |
| - |
4429 |
| - # FIXME: This swapaxes is needed because the shapes of the linspaces don't "rotate" with |
4430 |
| - # the different dimensions. But this makes the non-active dimensions backwards in the padding. |
4431 |
| - padded = set_subtensor(padded[left_slice], swapaxes(left_ramp, 0, axis)) |
4432 |
| - padded = set_subtensor(padded[right_slice], swapaxes(right_ramp, 0, axis)) |
4433 |
| - |
4434 |
| - return padded |
4435 |
| - |
4436 |
| - |
4437 |
| -def pad(x, pad_width, mode="constant", **kwargs): |
4438 |
| - allowed_kwargs = { |
4439 |
| - "empty": [], |
4440 |
| - "edge": [], |
4441 |
| - "wrap": [], |
4442 |
| - "constant": ["constant_values"], |
4443 |
| - "linear_ramp": ["end_values"], |
4444 |
| - "maximum": ["stat_length"], |
4445 |
| - "mean": ["stat_length"], |
4446 |
| - "median": ["stat_length"], |
4447 |
| - "minimum": ["stat_length"], |
4448 |
| - "reflect": ["reflect_type"], |
4449 |
| - "symmetric": ["reflect_type"], |
4450 |
| - } |
4451 |
| - |
4452 |
| - if any(value not in allowed_kwargs[mode] for value in kwargs.values()): |
4453 |
| - raise ValueError( |
4454 |
| - f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}" |
4455 |
| - ) |
4456 |
| - |
4457 |
| - if mode == "constant": |
4458 |
| - constant_values = kwargs.pop("constant_values", 0) |
4459 |
| - return _constant_pad(x, pad_width, constant_values) |
4460 |
| - elif mode == "edge": |
4461 |
| - return _edge_pad(x, pad_width) |
4462 |
| - elif mode == "linear_ramp": |
4463 |
| - end_values = kwargs.pop("end_values", 0) |
4464 |
| - return _linear_ramp_pad(x, pad_width, end_values) |
4465 |
| - |
4466 |
| - |
4467 | 4345 | __all__ = [
|
4468 | 4346 | "take_along_axis",
|
4469 | 4347 | "expand_dims",
|
|
0 commit comments