Skip to content

Commit 85f994d

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
refactor _maybe_compute_kjt_to_jt_dict (#2326)
Summary: Pull Request resolved: #2326 # context * want to resolve graph break: [failures_and_restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpKJM3FI/failures_and_restarts.html), P1537573230 ``` Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time. You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs. Could not guard on data-dependent expression Eq(((2*u48)//(u48 + u49)), 0) (unhinted: Eq(((2*u48)//(u48 + u49)), 0)). (Size-like symbols: u49, u48) Potential framework code culprit (scroll up for full backtrace): File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_refs/__init__.py", line 3950, in unbind if guard_size_oblivious(t.shape[dim] == 0): For more information, run with TORCH_LOGS="dynamic" For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u49,u48" If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing User Stack (most recent call last): (snipped, see stack below for prefix) ... File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 2241, in to_dict _jt_dict = _maybe_compute_kjt_to_jt_dict( File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 1226, in _maybe_compute_kjt_to_jt_dict split_lengths = torch.unbind( ``` * we added [shape check](https://fburl.com/code/p02u4mck): ``` if pt2_guard_size_oblivious(lengths.numel() > 0): strided_lengths = lengths.view(-1, stride) if not torch.jit.is_scripting() and is_torchdynamo_compiling(): torch._check(strided_lengths.shape[0] > 0) torch._check(strided_lengths.shape[1] > 0) split_lengths = torch.unbind( strided_lengths, dim=0, ) ``` * however the error is still there ``` File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_refs/__init__.py", line 3950, in unbind if guard_size_oblivious(t.shape[dim] == 0): File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/fx/experimental/symbolic_shapes.py", line 253, in guard_size_oblivious return expr.node.guard_size_oblivious("", 0) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/fx/experimental/sym_node.py", line 503, in guard_size_oblivious r = self.shape_env.evaluate_expr( ``` * [implementation](https://fburl.com/code/20iue1ib) ``` register_decomposition(aten.unbind) def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: from torch.fx.experimental.symbolic_shapes import guard_size_oblivious dim = utils.canonicalize_dim(t.ndim, dim) torch._check_index( len(t.shape) > 0, lambda: "Dimension specified as 0 but tensor has no dimensions", ) if guard_size_oblivious(t.shape[dim] == 0): # <------- here return () else: return tuple( torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) ) ``` * with D61677207 [no graph break at _maybe_compute_kjt_to_jt_dict](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpNcI14t/failures_and_restarts.html) Reviewed By: IvanKobzarev Differential Revision: D55277785 fbshipit-source-id: d3bbb2439427677dc8e9a58560679acbf6f872d4
1 parent 97585b8 commit 85f994d

File tree

1 file changed

+55
-53
lines changed

1 file changed

+55
-53
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node
2121
from torchrec.pt2.checks import (
2222
is_non_strict_exporting,
23+
is_pt2_compiling,
2324
is_torchdynamo_compiling,
2425
pt2_check_size_nonzero,
2526
pt2_checks_all_is_size,
@@ -1201,62 +1202,63 @@ def _maybe_compute_kjt_to_jt_dict(
12011202
if not length_per_key:
12021203
return {}
12031204

1204-
if jt_dict is None:
1205-
_jt_dict: Dict[str, JaggedTensor] = {}
1205+
if jt_dict is not None:
1206+
return jt_dict
1207+
1208+
_jt_dict: Dict[str, JaggedTensor] = {}
1209+
if not torch.jit.is_scripting() and is_pt2_compiling():
1210+
cat_size = 0
1211+
total_size = values.size(0)
1212+
for i in length_per_key:
1213+
cat_size += i
1214+
torch._check(cat_size <= total_size)
1215+
torch._check(cat_size == total_size)
1216+
torch._check_is_size(stride)
1217+
values_list = torch.split(values, length_per_key)
1218+
if variable_stride_per_key:
1219+
split_lengths = torch.split(lengths, stride_per_key)
1220+
split_offsets = [
1221+
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
1222+
for lengths in split_lengths
1223+
]
1224+
elif pt2_guard_size_oblivious(lengths.numel() > 0):
1225+
strided_lengths = lengths.view(len(keys), stride)
12061226
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1207-
cat_size = 0
1208-
total_size = values.size(0)
1209-
for i in length_per_key:
1210-
cat_size += i
1211-
torch._check(cat_size <= total_size)
1212-
torch._check(cat_size == total_size)
1213-
values_list = torch.split(values, length_per_key)
1214-
if variable_stride_per_key:
1215-
split_lengths = torch.split(lengths, stride_per_key)
1216-
split_offsets = [
1217-
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
1218-
for lengths in split_lengths
1219-
]
1220-
else:
1221-
split_lengths = torch.unbind(
1222-
(
1223-
lengths.view(-1, stride)
1224-
if pt2_guard_size_oblivious(lengths.numel() != 0)
1225-
else lengths
1226-
),
1227-
dim=0,
1227+
torch._check(strided_lengths.size(0) > 0)
1228+
torch._check(strided_lengths.size(1) > 0)
1229+
split_lengths = torch.unbind(
1230+
strided_lengths,
1231+
dim=0,
1232+
)
1233+
split_offsets = torch.unbind(
1234+
_batched_lengths_to_offsets(strided_lengths),
1235+
dim=0,
1236+
)
1237+
else:
1238+
split_lengths = torch.unbind(lengths, dim=0)
1239+
split_offsets = torch.unbind(lengths, dim=0)
1240+
1241+
if weights is not None:
1242+
weights_list = torch.split(weights, length_per_key)
1243+
for idx, key in enumerate(keys):
1244+
length = split_lengths[idx]
1245+
offset = split_offsets[idx]
1246+
_jt_dict[key] = JaggedTensor(
1247+
lengths=length,
1248+
offsets=offset,
1249+
values=values_list[idx],
1250+
weights=weights_list[idx],
12281251
)
1229-
split_offsets = torch.unbind(
1230-
(
1231-
_batched_lengths_to_offsets(lengths.view(-1, stride))
1232-
if pt2_guard_size_oblivious(lengths.numel() != 0)
1233-
else lengths
1234-
),
1235-
dim=0,
1252+
else:
1253+
for idx, key in enumerate(keys):
1254+
length = split_lengths[idx]
1255+
offset = split_offsets[idx]
1256+
_jt_dict[key] = JaggedTensor(
1257+
lengths=length,
1258+
offsets=offset,
1259+
values=values_list[idx],
12361260
)
1237-
1238-
if weights is not None:
1239-
weights_list = torch.split(weights, length_per_key)
1240-
for idx, key in enumerate(keys):
1241-
length = split_lengths[idx]
1242-
offset = split_offsets[idx]
1243-
_jt_dict[key] = JaggedTensor(
1244-
lengths=length,
1245-
offsets=offset,
1246-
values=values_list[idx],
1247-
weights=weights_list[idx],
1248-
)
1249-
else:
1250-
for idx, key in enumerate(keys):
1251-
length = split_lengths[idx]
1252-
offset = split_offsets[idx]
1253-
_jt_dict[key] = JaggedTensor(
1254-
lengths=length,
1255-
offsets=offset,
1256-
values=values_list[idx],
1257-
)
1258-
jt_dict = _jt_dict
1259-
return jt_dict
1261+
return _jt_dict
12601262

12611263

12621264
@torch.fx.wrap

0 commit comments

Comments
 (0)