Skip to content

Commit 91311d5

Browse files
Refactor linspace, logspace, and geomspace to match numpy implementation
1 parent 05d376f commit 91311d5

File tree

2 files changed

+319
-40
lines changed

2 files changed

+319
-40
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 297 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Collection, Iterable
23

34
import numpy as np
@@ -20,14 +21,25 @@
2021
from pytensor.raise_op import Assert
2122
from pytensor.scalar import int32 as int_t
2223
from pytensor.scalar import upcast
23-
from pytensor.tensor import as_tensor_variable
24+
from pytensor.tensor import TensorLike, as_tensor_variable
2425
from pytensor.tensor import basic as ptb
2526
from pytensor.tensor.basic import alloc, second
2627
from pytensor.tensor.exceptions import NotScalarConstantError
2728
from pytensor.tensor.math import abs as pt_abs
2829
from pytensor.tensor.math import all as pt_all
30+
from pytensor.tensor.math import (
31+
bitwise_and,
32+
ge,
33+
gt,
34+
log,
35+
lt,
36+
maximum,
37+
minimum,
38+
prod,
39+
sign,
40+
switch,
41+
)
2942
from pytensor.tensor.math import eq as pt_eq
30-
from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch
3143
from pytensor.tensor.math import max as pt_max
3244
from pytensor.tensor.math import sum as pt_sum
3345
from pytensor.tensor.shape import specify_broadcastable
@@ -1585,27 +1597,294 @@ def broadcast_shape_iter(
15851597
return tuple(result_dims)
15861598

15871599

1588-
def geomspace(start, end, steps, base=10.0):
1589-
from pytensor.tensor.math import log
1600+
def _check_deprecated_inputs(stop, end, num, steps):
1601+
if end is not None:
1602+
warnings.warn(
1603+
"The 'end' parameter is deprecated and will be removed in a future version. Use 'stop' instead.",
1604+
DeprecationWarning,
1605+
)
1606+
stop = end
1607+
if steps is not None:
1608+
warnings.warn(
1609+
"The 'steps' parameter is deprecated and will be removed in a future version. Use 'num' instead.",
1610+
DeprecationWarning,
1611+
)
1612+
num = steps
1613+
1614+
return stop, num
1615+
1616+
1617+
def _linspace_core(
1618+
start: TensorVariable,
1619+
stop: TensorVariable,
1620+
num: int,
1621+
dtype: str,
1622+
endpoint=True,
1623+
retstep=False,
1624+
axis=0,
1625+
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
1626+
div = (num - 1) if endpoint else num
1627+
delta = (stop - start).astype(dtype)
1628+
samples = ptb.arange(0, num, dtype=dtype).reshape((-1,) + (1,) * delta.ndim)
1629+
1630+
step = switch(gt(div, 0), delta / div, np.nan)
1631+
samples = switch(gt(div, 0), samples * delta / div + start, samples * delta + start)
1632+
samples = switch(
1633+
bitwise_and(gt(num, 1), np.asarray(endpoint)),
1634+
set_subtensor(samples[-1, ...], stop),
1635+
samples,
1636+
)
1637+
1638+
if axis != 0:
1639+
samples = ptb.moveaxis(samples, 0, axis)
1640+
1641+
if retstep:
1642+
return samples, step
1643+
1644+
return samples
1645+
1646+
1647+
def _broadcast_inputs_and_dtypes(*args, dtype=None):
1648+
args = map(ptb.as_tensor_variable, args)
1649+
args = broadcast_arrays(*args)
1650+
1651+
if dtype is None:
1652+
dtype = pytensor.config.floatX
1653+
1654+
return args, dtype
1655+
1656+
1657+
def _broadcast_base_with_inputs(start, stop, base, dtype, axis):
1658+
"""
1659+
Broadcast the base tensor with the start and stop tensors if base is not a scalar. This is important because it
1660+
may change how the axis argument is interpreted in the final output.
1661+
1662+
Parameters
1663+
----------
1664+
start
1665+
stop
1666+
base
1667+
dtype
1668+
axis
1669+
1670+
Returns
1671+
-------
1672+
1673+
"""
1674+
base = ptb.as_tensor_variable(base, dtype=dtype)
1675+
if base.ndim > 0:
1676+
ndmax = len(broadcast_shape(start, stop, base))
1677+
start, stop, base = (
1678+
ptb.shape_padleft(a, ndmax - a.ndim) for a in (start, stop, base)
1679+
)
1680+
base = ptb.expand_dims(base, axis=(axis,))
1681+
1682+
return start, stop, base
1683+
1684+
1685+
def linspace(
1686+
start: TensorLike,
1687+
stop: TensorLike,
1688+
num: TensorLike = 50,
1689+
endpoint: bool = True,
1690+
retstep: bool = False,
1691+
dtype: str | None = None,
1692+
axis: int = 0,
1693+
end: TensorLike | None = None,
1694+
steps: TensorLike | None = None,
1695+
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
1696+
"""
1697+
Return evenly spaced numbers over a specified interval.
1698+
1699+
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
1700+
1701+
The endpoint of the interval can optionally be excluded.
15901702
1591-
start = ptb.as_tensor_variable(start)
1592-
end = ptb.as_tensor_variable(end)
1593-
return base ** linspace(log(start) / log(base), log(end) / log(base), steps)
1703+
Parameters
1704+
----------
1705+
start: int, float, or TensorVariable
1706+
The starting value of the sequence.
15941707
1708+
stop: int, float or TensorVariable
1709+
The end value of the sequence, unless `endpoint` is set to False.
1710+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded.
15951711
1596-
def logspace(start, end, steps, base=10.0):
1597-
start = ptb.as_tensor_variable(start)
1598-
end = ptb.as_tensor_variable(end)
1599-
return base ** linspace(start, end, steps)
1712+
num: int
1713+
Number of samples to generate. Must be non-negative.
16001714
1715+
endpoint: bool
1716+
Whether to include the endpoint in the range.
1717+
1718+
retstep: bool
1719+
If true, returns both the samples and an array of steps between samples.
1720+
1721+
dtype: str, optional
1722+
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
1723+
and `end` arguments.
1724+
1725+
axis: int
1726+
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
1727+
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
1728+
1729+
end: int, float or TensorVariable
1730+
.. warning::
1731+
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
1732+
The end value of the sequence, unless `endpoint` is set to False.
1733+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
1734+
excluded.
1735+
1736+
steps: float, int, or TensorVariable
1737+
.. warning::
1738+
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
1739+
1740+
Number of samples to generate. Must be non-negative
1741+
1742+
Returns
1743+
-------
1744+
samples: TensorVariable
1745+
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
1746+
1747+
step: TensorVariable
1748+
Tensor containing the spacing between samples. Only returned if `retstep` is True.
1749+
"""
1750+
end, num = _check_deprecated_inputs(stop, end, num, steps)
1751+
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1752+
1753+
return _linspace_core(
1754+
start=start,
1755+
stop=stop,
1756+
num=num,
1757+
dtype=dtype,
1758+
endpoint=endpoint,
1759+
retstep=retstep,
1760+
axis=axis,
1761+
)
1762+
1763+
1764+
def geomspace(
1765+
start: TensorLike,
1766+
stop: TensorLike,
1767+
num: int = 50,
1768+
base: float = 10.0,
1769+
endpoint: bool = True,
1770+
dtype: str | None = None,
1771+
axis: int = 0,
1772+
end: TensorLike | None = None,
1773+
steps: TensorLike | None = None,
1774+
) -> TensorVariable:
1775+
"""
1776+
Return numbers spaced evenly on a log scale (a geometric progression).
1777+
1778+
Parameters
1779+
----------
1780+
Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`].
1781+
1782+
The endpoint of the interval can optionally be excluded.
1783+
1784+
Parameters
1785+
----------
1786+
start: int, float, or TensorVariable
1787+
The starting value of the sequence.
1788+
1789+
stop: int, float or TensorVariable
1790+
The end value of the sequence, unless `endpoint` is set to False.
1791+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `stop` is excluded.
1792+
1793+
num: int
1794+
Number of samples to generate. Must be non-negative.
1795+
1796+
base: float
1797+
The base of the log space. The step size between the elements in ln(samples) / ln(base)
1798+
(or log_base(samples)) is uniform.
1799+
1800+
endpoint: bool
1801+
Whether to include the endpoint in the range.
1802+
1803+
dtype: str, optional
1804+
dtype of the output tensor(s). If None, the dtype is inferred from that of the values provided to the `start`
1805+
and `end` arguments.
1806+
1807+
axis: int
1808+
Axis along which to generate samples. Ignored if both `start` and `end` have dimension 0. By default, axis=0
1809+
will insert the samples on a new left-most dimension. To insert samples on a right-most dimension, use axis=-1.
1810+
1811+
end: int, float or TensorVariable
1812+
.. warning::
1813+
The "end" parameter is deprecated and will be removed in a future version. Use "stop" instead.
1814+
The end value of the sequence, unless `endpoint` is set to False.
1815+
In that case, the sequence consists of all but the last of `num + 1` evenly spaced samples, such that `end` is
1816+
excluded.
1817+
1818+
steps: float, int, or TensorVariable
1819+
.. warning::
1820+
The "steps" parameter is deprecated and will be removed in a future version. Use "num" instead.
1821+
1822+
Number of samples to generate. Must be non-negative
1823+
1824+
Returns
1825+
-------
1826+
samples: TensorVariable
1827+
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
1828+
"""
1829+
stop, num = _check_deprecated_inputs(stop, end, num, steps)
1830+
(start, stop), dtype = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1831+
start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis)
1832+
1833+
out_sign = sign(start)
1834+
log_start, log_stop = (
1835+
log(start * out_sign) / log(base),
1836+
log(stop * out_sign) / log(base),
1837+
)
1838+
result = _linspace_core(
1839+
start=log_start,
1840+
stop=log_stop,
1841+
num=num,
1842+
endpoint=endpoint,
1843+
dtype=dtype,
1844+
axis=0,
1845+
retstep=False,
1846+
)
1847+
result = base**result
1848+
1849+
if num > 0:
1850+
set_subtensor(result[0, ...], start, inplace=True)
1851+
if num > 1 and endpoint:
1852+
set_subtensor(result[-1, ...], stop, inplace=True)
1853+
1854+
result = result * out_sign
1855+
1856+
if axis != 0:
1857+
result = ptb.moveaxis(result, 0, axis)
1858+
1859+
return result
1860+
1861+
1862+
def logspace(
1863+
start: TensorLike,
1864+
stop: TensorLike,
1865+
num: int = 50,
1866+
base: float = 10.0,
1867+
endpoint: bool = True,
1868+
dtype: str | None = None,
1869+
axis: int = 0,
1870+
end: TensorLike | None = None,
1871+
steps: TensorLike | None = None,
1872+
) -> TensorVariable:
1873+
stop, num = _check_deprecated_inputs(stop, end, num, steps)
1874+
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1875+
start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis)
1876+
1877+
ls = _linspace_core(
1878+
start=start,
1879+
stop=stop,
1880+
num=num,
1881+
endpoint=endpoint,
1882+
dtype=dtype,
1883+
axis=axis,
1884+
retstep=False,
1885+
)
16011886

1602-
def linspace(start, end, steps):
1603-
start = ptb.as_tensor_variable(start)
1604-
end = ptb.as_tensor_variable(end)
1605-
arr = ptb.arange(steps)
1606-
arr = ptb.shape_padright(arr, max(start.ndim, end.ndim))
1607-
multiplier = (end - start) / (steps - 1)
1608-
return start + arr * multiplier
1887+
return base**ls
16091888

16101889

16111890
def broadcast_to(

0 commit comments

Comments
 (0)