Skip to content

Commit 09b1b11

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Cache min / max seq len for torch.nested.as_nested_tensor(t) (pytorch#130766)
For the `torch.nested.as_nested_tensor(t)` constructor, computing min / max seq len is trivial since the sequence lengths are all the same. Might as well cache them during construction. Pull Request resolved: pytorch#130766 Approved by: https://github.com/YuqingJ, https://github.com/soulitzer
1 parent 408c921 commit 09b1b11

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

test/test_nestedtensor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4512,9 +4512,8 @@ def test_as_nested_tensor_from_tensor(
45124512
dim=dim,
45134513
batch_size=expected_batch_size,
45144514
contiguous=True,
4515-
# TODO: compute min / max during construction for this case since it's easy
4516-
cached_min_seqlen=None,
4517-
cached_max_seqlen=None,
4515+
cached_min_seqlen=expected_seqlen,
4516+
cached_max_seqlen=expected_seqlen,
45184517
)
45194518

45204519
if torch.device(device) == t.device and dtype == t.dtype and contiguous:
@@ -4549,9 +4548,8 @@ def test_as_nested_tensor_from_tensor(
45494548
dim=dim,
45504549
batch_size=expected_batch_size,
45514550
contiguous=True,
4552-
# TODO: compute min / max during construction for this case since it's easy
4553-
cached_min_seqlen=None,
4554-
cached_max_seqlen=None,
4551+
cached_min_seqlen=expected_seqlen,
4552+
cached_max_seqlen=expected_seqlen,
45554553
)
45564554

45574555
# we don't support conversion between layouts this way atm

torch/nested/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def as_nested_tensor(
119119

120120
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
121121

122-
return nested_view_from_values_offsets(values, offsets)
122+
return nested_view_from_values_offsets(
123+
values, offsets, min_seqlen=seq_len, max_seqlen=seq_len
124+
)
123125
else:
124126
from torch.nested._internal.nested_tensor import jagged_from_list
125127

0 commit comments

Comments
 (0)