|
20 | 20 | from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node
|
21 | 21 | from torchrec.pt2.checks import (
|
22 | 22 | is_non_strict_exporting,
|
| 23 | + is_pt2_compiling, |
23 | 24 | is_torchdynamo_compiling,
|
24 | 25 | pt2_check_size_nonzero,
|
25 | 26 | pt2_checks_all_is_size,
|
@@ -1201,62 +1202,63 @@ def _maybe_compute_kjt_to_jt_dict(
|
1201 | 1202 | if not length_per_key:
|
1202 | 1203 | return {}
|
1203 | 1204 |
|
1204 |
| - if jt_dict is None: |
1205 |
| - _jt_dict: Dict[str, JaggedTensor] = {} |
| 1205 | + if jt_dict is not None: |
| 1206 | + return jt_dict |
| 1207 | + |
| 1208 | + _jt_dict: Dict[str, JaggedTensor] = {} |
| 1209 | + if not torch.jit.is_scripting() and is_pt2_compiling(): |
| 1210 | + cat_size = 0 |
| 1211 | + total_size = values.size(0) |
| 1212 | + for i in length_per_key: |
| 1213 | + cat_size += i |
| 1214 | + torch._check(cat_size <= total_size) |
| 1215 | + torch._check(cat_size == total_size) |
| 1216 | + torch._check_is_size(stride) |
| 1217 | + values_list = torch.split(values, length_per_key) |
| 1218 | + if variable_stride_per_key: |
| 1219 | + split_lengths = torch.split(lengths, stride_per_key) |
| 1220 | + split_offsets = [ |
| 1221 | + torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) |
| 1222 | + for lengths in split_lengths |
| 1223 | + ] |
| 1224 | + elif pt2_guard_size_oblivious(lengths.numel() > 0): |
| 1225 | + strided_lengths = lengths.view(len(keys), stride) |
1206 | 1226 | if not torch.jit.is_scripting() and is_torchdynamo_compiling():
|
1207 |
| - cat_size = 0 |
1208 |
| - total_size = values.size(0) |
1209 |
| - for i in length_per_key: |
1210 |
| - cat_size += i |
1211 |
| - torch._check(cat_size <= total_size) |
1212 |
| - torch._check(cat_size == total_size) |
1213 |
| - values_list = torch.split(values, length_per_key) |
1214 |
| - if variable_stride_per_key: |
1215 |
| - split_lengths = torch.split(lengths, stride_per_key) |
1216 |
| - split_offsets = [ |
1217 |
| - torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) |
1218 |
| - for lengths in split_lengths |
1219 |
| - ] |
1220 |
| - else: |
1221 |
| - split_lengths = torch.unbind( |
1222 |
| - ( |
1223 |
| - lengths.view(-1, stride) |
1224 |
| - if pt2_guard_size_oblivious(lengths.numel() != 0) |
1225 |
| - else lengths |
1226 |
| - ), |
1227 |
| - dim=0, |
| 1227 | + torch._check(strided_lengths.size(0) > 0) |
| 1228 | + torch._check(strided_lengths.size(1) > 0) |
| 1229 | + split_lengths = torch.unbind( |
| 1230 | + strided_lengths, |
| 1231 | + dim=0, |
| 1232 | + ) |
| 1233 | + split_offsets = torch.unbind( |
| 1234 | + _batched_lengths_to_offsets(strided_lengths), |
| 1235 | + dim=0, |
| 1236 | + ) |
| 1237 | + else: |
| 1238 | + split_lengths = torch.unbind(lengths, dim=0) |
| 1239 | + split_offsets = torch.unbind(lengths, dim=0) |
| 1240 | + |
| 1241 | + if weights is not None: |
| 1242 | + weights_list = torch.split(weights, length_per_key) |
| 1243 | + for idx, key in enumerate(keys): |
| 1244 | + length = split_lengths[idx] |
| 1245 | + offset = split_offsets[idx] |
| 1246 | + _jt_dict[key] = JaggedTensor( |
| 1247 | + lengths=length, |
| 1248 | + offsets=offset, |
| 1249 | + values=values_list[idx], |
| 1250 | + weights=weights_list[idx], |
1228 | 1251 | )
|
1229 |
| - split_offsets = torch.unbind( |
1230 |
| - ( |
1231 |
| - _batched_lengths_to_offsets(lengths.view(-1, stride)) |
1232 |
| - if pt2_guard_size_oblivious(lengths.numel() != 0) |
1233 |
| - else lengths |
1234 |
| - ), |
1235 |
| - dim=0, |
| 1252 | + else: |
| 1253 | + for idx, key in enumerate(keys): |
| 1254 | + length = split_lengths[idx] |
| 1255 | + offset = split_offsets[idx] |
| 1256 | + _jt_dict[key] = JaggedTensor( |
| 1257 | + lengths=length, |
| 1258 | + offsets=offset, |
| 1259 | + values=values_list[idx], |
1236 | 1260 | )
|
1237 |
| - |
1238 |
| - if weights is not None: |
1239 |
| - weights_list = torch.split(weights, length_per_key) |
1240 |
| - for idx, key in enumerate(keys): |
1241 |
| - length = split_lengths[idx] |
1242 |
| - offset = split_offsets[idx] |
1243 |
| - _jt_dict[key] = JaggedTensor( |
1244 |
| - lengths=length, |
1245 |
| - offsets=offset, |
1246 |
| - values=values_list[idx], |
1247 |
| - weights=weights_list[idx], |
1248 |
| - ) |
1249 |
| - else: |
1250 |
| - for idx, key in enumerate(keys): |
1251 |
| - length = split_lengths[idx] |
1252 |
| - offset = split_offsets[idx] |
1253 |
| - _jt_dict[key] = JaggedTensor( |
1254 |
| - lengths=length, |
1255 |
| - offsets=offset, |
1256 |
| - values=values_list[idx], |
1257 |
| - ) |
1258 |
| - jt_dict = _jt_dict |
1259 |
| - return jt_dict |
| 1261 | + return _jt_dict |
1260 | 1262 |
|
1261 | 1263 |
|
1262 | 1264 | @torch.fx.wrap
|
|
0 commit comments