Skip to content

Adding expand_dims for xtensor #1449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: labeled_tensors
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from pytensor.graph import node_rewriter
from pytensor.raise_op import Assert
from pytensor.tensor import (
broadcast_to,
get_scalar_constant_value,
gt,
join,
moveaxis,
specify_shape,
squeeze,
)
from pytensor.tensor import (
shape as tensor_shape,
)
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import (
Concat,
ExpandDims,
Squeeze,
Stack,
Transpose,
Expand Down Expand Up @@ -132,3 +139,32 @@ def local_squeeze_reshape(fgraph, node):

new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims)
return [new_out]


@register_lower_xtensor
@node_rewriter([ExpandDims])
def local_expand_dims_reshape(fgraph, node):
"""Rewrite ExpandDims to tensor.expand_dims and optionally broadcast_to or specify shape."""
x, size = node.inputs
out = node.outputs[0]
# Lower to tensor.expand_dims(x, axis=0)
from pytensor.tensor import expand_dims as tensor_expand_dims

expanded = tensor_expand_dims(tensor_from_xtensor(x), 0)
# Optionally broadcast to the correct shape if size is not 1
from pytensor.tensor import broadcast_to

# Ensure size is positive
expanded = Assert(msg="size must be positive")(expanded, gt(size, 0))
# If size is not 1, broadcast
try:
static_size = get_scalar_constant_value(size)
except Exception:
static_size = None
if static_size is not None and static_size == 1:
result = expanded
else:
# Broadcast to (size, ...)
new_shape = (size,) + tuple(tensor_shape(expanded))[1:]
result = broadcast_to(expanded, new_shape)
return [xtensor_from_tensor(result, out.type.dims)]
74 changes: 74 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from types import EllipsisType
from typing import Literal

import numpy as np

from pytensor.graph import Apply
from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
Expand Down Expand Up @@ -380,3 +382,75 @@ def squeeze(x, dim=None):
return x # no-op if nothing to squeeze

return Squeeze(dims=dims)(x)


class ExpandDims(XOp):
"""Add a new dimension to an XTensorVariable."""

__props__ = ("dims",)

def __init__(self, dim):
self.dims = dim

def make_node(self, x, size):
x = as_xtensor(x)
# Insert new dim at front
new_dims = (self.dims, *x.type.dims)

# Determine shape
try:
static_size = get_scalar_constant_value(size)
except NotScalarConstantError:
static_size = None
if static_size is not None:
new_shape = (int(static_size), *x.type.shape)
else:
new_shape = (None, *x.type.shape) # symbolic size

out = xtensor(
dtype=x.type.dtype,
shape=new_shape,
dims=new_dims,
)
return Apply(self, [x, size], [out])


def expand_dims(x, dim: str | None, size=1):
"""Add a new dimension to an XTensorVariable.

Parameters
----------
x : XTensorVariable
Input tensor
dim : str or None
Name of new dimension. If None, returns x unchanged.
size : int or symbolic, optional
Size of the new dimension (default 1)

Returns
-------
XTensorVariable
Tensor with the new dimension inserted
"""
x = as_xtensor(x)

if dim is None:
return x # No-op

if not isinstance(dim, str):
raise TypeError(f"`dim` must be a string or None, got: {type(dim)}")

if dim in x.type.dims:
raise ValueError(f"Dimension {dim} already exists in {x.type.dims}")

if isinstance(size, int | np.integer):
if size <= 0:
raise ValueError(f"size must be positive, got: {size}")
elif not (
hasattr(size, "ndim") and getattr(size, "ndim", None) == 0 # symbolic scalar
):
raise TypeError(f"size must be an int or scalar variable, got: {type(size)}")

# Always convert size to a PyTensor scalar variable
size_var = as_tensor(size, ndim=0)
return ExpandDims(dim)(x, size_var)
205 changes: 204 additions & 1 deletion tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from itertools import chain, combinations

import numpy as np
import pytest
import xarray as xr
from xarray import DataArray
from xarray import concat as xr_concat

from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
concat,
expand_dims,
squeeze,
stack,
transpose,
Expand Down Expand Up @@ -301,6 +303,15 @@ def test_squeeze_explicit_dims():
fn3d = xr_function([x3], y3d)
xr_assert_allclose(fn3d(x3_test), x3_test)

# Reversibility with expand_dims
x6 = xtensor("x6", dims=("a", "b", "c"), shape=(2, 1, 3))
y6 = squeeze(x6, "b")
# First expand_dims adds at front, then transpose puts it in the right place
z6 = transpose(expand_dims(y6, "b"), "a", "b", "c")
fn6 = xr_function([x6], z6)
x6_test = xr_arange_like(x6)
xr_assert_allclose(fn6(x6_test), x6_test)


def test_squeeze_implicit_dims():
"""Test squeeze with implicit dim=None (all size-1 dimensions)."""
Expand Down Expand Up @@ -369,3 +380,195 @@ def test_squeeze_errors():
fn2 = xr_function([x2], y2)
with pytest.raises(Exception):
fn2(x2_test)


def test_expand_dims_explicit():
"""Test expand_dims with explicitly named dimensions and sizes."""

# 1D case
x = xtensor("x", dims=("city",), shape=(3,))
y = expand_dims(x, "country")
fn = xr_function([x], y)
x_xr = xr_arange_like(x)
xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country"))

# 2D case
x = xtensor("x", dims=("city", "year"), shape=(2, 2))
y = expand_dims(x, "country")
fn = xr_function([x], y)
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))

# 3D case
x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2))
y = expand_dims(x, "country")
fn = xr_function([x], y)
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country"))

# Prepending various dims
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
for new_dim in ("x", "y", "z"):
y = expand_dims(x, new_dim)
assert y.type.dims == (new_dim, "a", "b")
assert y.type.shape == (1, 2, 3)

# Explicit size=1 behaves like default
y1 = expand_dims(x, "batch", size=1)
y2 = expand_dims(x, "batch")
fn1 = xr_function([x], y1)
fn2 = xr_function([x], y2)
x_test = xr_arange_like(x)
xr_assert_allclose(fn1(x_test), fn2(x_test))

# Scalar expansion
x = xtensor("x", dims=(), shape=())
y = expand_dims(x, "batch")
assert y.type.dims == ("batch",)
assert y.type.shape == (1,)
fn = xr_function([x], y)
xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch"))

# Static size > 1: broadcast
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=4)
fn = xr_function([x], y)
expected = xr.DataArray(
np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)),
dims=("batch", "a", "b"),
coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]},
)
xr_assert_allclose(fn(xr_arange_like(x)), expected)

# Insert new dim between existing dims
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "new")
# Insert new dim between a and b: ("a", "new", "b")
y = transpose(y, "a", "new", "b")
fn = xr_function([x], y)
x_test = xr_arange_like(x)
expected = x_test.expand_dims("new").transpose("a", "new", "b")
xr_assert_allclose(fn(x_test), expected)

# Expand with multiple dims
x = xtensor("x", dims=(), shape=())
y = expand_dims(expand_dims(x, "a"), "b")
fn = xr_function([x], y)
expected = xr_arange_like(x).expand_dims("a").expand_dims("b")
xr_assert_allclose(fn(xr_arange_like(x)), expected)


def test_expand_dims_implicit():
"""Test expand_dims with default or symbolic sizes and dim=None."""

# Symbolic size=1: same as default
size_sym_1 = scalar("size_sym_1", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym_1)
fn = xr_function([x, size_sym_1], y, on_unused_input="ignore")
x_test = xr_arange_like(x)
xr_assert_allclose(fn(x_test, 1), x_test.expand_dims("batch"))

# Test using symbolic size from an existing dimension of the same tensor
# This verifies that expand_dims can use the size of one dimension to create another
x = xtensor(dims=("a", "b", "c"))
y = expand_dims(x, "d", size=x.sizes["b"])
fn = xr_function([x], y)
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5)))
res = fn(x_test)
expected = x_test.expand_dims({"d": 3}) # 3 is the size of dimension "b"
xr_assert_allclose(res, expected)

# Test broadcasting with symbolic size from a different tensor
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
other = xtensor("other", dims=("c",), shape=(4,))
y = expand_dims(x, "batch", size=other.sizes["c"])
fn = xr_function([x, other], y)
x_test = xr_arange_like(x)
other_test = xr_arange_like(other)
res = fn(x_test, other_test)
expected = x_test.expand_dims(
{"batch": 4}
) # 4 is the size of dimension "c" in other
xr_assert_allclose(res, expected)

# Test behavior with symbolic size > 1
# NOTE: This test documents our current behavior where expand_dims broadcasts to the requested size.
# This differs from xarray's behavior where expand_dims always adds a size-1 dimension.
size_sym_4 = scalar("size_sym_4", dtype="int64")
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = expand_dims(x, "batch", size=size_sym_4)
fn = xr_function([x, size_sym_4], y, on_unused_input="ignore")
x_test = xr_arange_like(x)
res = fn(x_test, 4)
# Our current behavior: broadcasts to size 4
expected = x_test.expand_dims({"batch": 4})
xr_assert_allclose(res, expected)
# xarray's behavior would be:
# expected = x_test.expand_dims("batch") # always size 1
# xr_assert_allclose(res, expected)

# Test using symbolic size from a reduction operation
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
reduced = x.sum("a") # shape: (b: 3)
y = expand_dims(x, "batch", size=reduced.sizes["b"])
fn = xr_function([x], y)
x_test = xr_arange_like(x)
res = fn(x_test)
expected = x_test.expand_dims({"batch": 3}) # 3 is the size of dimension "b"
xr_assert_allclose(res, expected)

# Test chaining expand_dims with symbolic sizes
x = xtensor("x", dims=("a",), shape=(2,))
y = expand_dims(x, "b", size=x.sizes["a"]) # shape: (a: 2, b: 2)
z = expand_dims(y, "c", size=y.sizes["b"]) # shape: (a: 2, b: 2, c: 2)
fn = xr_function([x], z)
x_test = xr_arange_like(x)
res = fn(x_test)
expected = x_test.expand_dims({"b": 2}).expand_dims({"c": 2})
xr_assert_allclose(res, expected)

# Test bidirectional broadcasting with symbolic sizes
x = xtensor("x", dims=("a",), shape=(2,))
y = xtensor("y", dims=("b",), shape=(3,))
# Expand x with size from y, then add y
expanded = expand_dims(x, "b", size=y.sizes["b"])
z = expanded + y # Should broadcast x to match y's size
fn = xr_function([x, y], z)
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
res = fn(x_test, y_test)
expected = x_test.expand_dims({"b": 3}) + y_test
xr_assert_allclose(res, expected)


def test_expand_dims_errors():
"""Test error handling in expand_dims."""

# Expanding existing dim
x = xtensor("x", dims=("city",), shape=(3,))
y = expand_dims(x, "country")
with pytest.raises(ValueError, match="already exists"):
expand_dims(y, "city")

# Size = 0 is invalid
with pytest.raises(ValueError, match="size must be.*positive"):
expand_dims(x, "batch", size=0)

# Invalid dim type
with pytest.raises(TypeError):
expand_dims(x, 123)

# Invalid size type
with pytest.raises(TypeError):
expand_dims(x, "new", size=[1])

# Duplicate dimension creation
y = expand_dims(x, "new")
with pytest.raises(ValueError):
expand_dims(y, "new")

# Symbolic size with invalid runtime value
size_sym = scalar("size_sym", dtype="int64")
y = expand_dims(x, "batch", size=size_sym)
fn = xr_function([x, size_sym], y, on_unused_input="ignore")
with pytest.raises(Exception):
fn(xr_arange_like(x), 0)