Skip to content

Commit c7b500c

Browse files
aorenstexuhancn
authored andcommitted
typing fake_tensor.py (pytorch#128041)
Pull Request resolved: pytorch#128041 Approved by: https://github.com/eellison ghstack dependencies: pytorch#129182
1 parent 3f6ca86 commit c7b500c

File tree

14 files changed

+397
-223
lines changed

14 files changed

+397
-223
lines changed

tools/pyi/gen_pyi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ def replace_special_case(hint: str) -> str:
12041204
],
12051205
"set_": [
12061206
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
1207-
"offset: _int, size: _size, stride: _size) -> Tensor: ...",
1207+
"offset: _int, size: _symsize, stride: _symsize) -> Tensor: ...",
12081208
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
12091209
],
12101210
"split": [

torch/_C/__init__.pyi.in

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ from torch.types import (
5656
_qscheme,
5757
_size,
5858
_str,
59+
_symsize,
5960
)
6061
from torch.utils._python_dispatch import TorchDispatchMode
6162

@@ -1661,6 +1662,18 @@ class _SetExcludeDispatchKeyGuard:
16611662
def __enter__(self): ...
16621663
def __exit__(self, exc_type, exc_value, traceback): ...
16631664

1665+
# Defined in torch/csrc/utils/schema_info.h
1666+
1667+
class _SchemaInfo:
1668+
def __init__(self, schema: _int) -> None: ...
1669+
1670+
@overload
1671+
def is_mutable(self) -> _bool: ...
1672+
@overload
1673+
def is_mutable(self, name: str) -> _bool: ...
1674+
1675+
def has_argument(self, name: str) -> _bool: ...
1676+
16641677
# Defined in torch/csrc/utils/init.cpp
16651678
class BenchmarkConfig:
16661679
num_calling_threads: _int

torch/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
)
3737
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
3838

39+
if TYPE_CHECKING:
40+
from .types import IntLikeType
41+
3942

4043
# multipy/deploy is setting this import before importing torch, this is the most
4144
# reliable way we have to detect if we're running within deploy.
@@ -471,6 +474,9 @@ def __ge__(self, other) -> builtins.bool:
471474
def __add__(self, other) -> "SymInt":
472475
raise TypeError("type stub not overridden")
473476

477+
def __mod__(self, other: "IntLikeType") -> "SymInt":
478+
raise TypeError("type stub not overridden")
479+
474480
def __mul__(self, other) -> "SymInt":
475481
raise TypeError("type stub not overridden")
476482

@@ -504,6 +510,9 @@ def __sym_float__(self):
504510
def __neg__(self):
505511
raise TypeError("type stub not overridden")
506512

513+
def __sub__(self, other: "IntLikeType") -> "SymInt":
514+
raise TypeError("type stub not overridden")
515+
507516
def __repr__(self):
508517
return self.node._graph_repr()
509518

torch/_dynamo/decorators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
165165
assert is_traceable_wrapper_subclass(t)
166166

167167
attrs, ctx = t.__tensor_flatten__()
168+
assert isinstance(t, torch.Tensor)
168169
for attr in attrs:
169170
inner = getattr(t, attr)
170171
if inner.dim() == t.dim():

torch/_export/non_strict_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def fakify(
8383
constraint_sizes=[None] * n_dims,
8484
)
8585
t_id = id(t)
86+
assert mode.shape_env is not None
8687
if t_id in t_constraints:
8788
for i, constraint in t_constraints[t_id].items():
8889
symbolic_context.constraint_sizes[i] = constraint.constraint_range
@@ -256,6 +257,7 @@ def produce_guards_and_solve_constraints(
256257
_disable_forced_specializations: if True, avoids forced specializations
257258
"""
258259
shape_env = fake_mode.shape_env
260+
assert shape_env is not None
259261
assert shape_env.tracked_fakes is not None
260262

261263
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
@@ -322,6 +324,7 @@ def make_constraints(
322324
"""
323325

324326
shape_env = fake_mode.shape_env
327+
assert shape_env is not None
325328
inline_constraints = gm.meta.get("inline_constraints", [])
326329
range_constraints = {
327330
symbol: inline_constraints[symbol] for symbol in inline_constraints

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from contextlib import nullcontext
1313
from dataclasses import dataclass, field
1414
from functools import wraps
15-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
15+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.utils.dlpack
@@ -1450,7 +1450,7 @@ def coerce_runtime_tangent(x, metadata):
14501450
14511451
Runtime metadata: {str(runtime_tangent_metadata)}
14521452
1453-
shape: {str(x.shape)}
1453+
shape: {str(cast(torch.Tensor, x).shape)}
14541454
To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
14551455
"""
14561456
)
@@ -1830,14 +1830,16 @@ def get_types_for_tangents(tangents):
18301830
)
18311831
assert CompiledFunction.metadata.traced_tangent_metas is not None
18321832
all_args = [
1833-
AOTDispatchAutograd.coerce_runtime_tangent(
1834-
t,
1835-
CompiledFunction.metadata.traced_tangent_metas[
1836-
i - tangents_start_idx
1837-
],
1833+
(
1834+
AOTDispatchAutograd.coerce_runtime_tangent(
1835+
t,
1836+
CompiledFunction.metadata.traced_tangent_metas[
1837+
i - tangents_start_idx
1838+
],
1839+
)
1840+
if tangents_start_idx <= i < tangents_end_idx
1841+
else t
18381842
)
1839-
if tangents_start_idx <= i < tangents_end_idx
1840-
else t
18411843
for i, t in enumerate(all_args)
18421844
]
18431845
all_args = unwrap_tensor_subclasses(
@@ -1849,9 +1851,11 @@ def get_types_for_tangents(tangents):
18491851
# Make the tangents contiguous. Note that we must do this after subclass desugaring
18501852
# because inputs to inductor have to be contiguous
18511853
all_args = [
1852-
AOTDispatchAutograd._force_contiguous(t)
1853-
if (tangents_start_idx <= i < tangents_end_idx)
1854-
else t
1854+
(
1855+
AOTDispatchAutograd._force_contiguous(t)
1856+
if (tangents_start_idx <= i < tangents_end_idx)
1857+
else t
1858+
)
18551859
for i, t in enumerate(all_args)
18561860
]
18571861

torch/_functorch/_aot_autograd/subclass_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
and this includes tensor subclasses that implement __torch_dispatch__.
66
"""
77

8+
import typing
89
from typing import Any, List, Optional, Tuple, Union
910

1011
import torch.utils._pytree as pytree
@@ -115,7 +116,7 @@ def concat_inner_tensors_from_subclasses(xs):
115116
xs_inner = []
116117
for x in xs:
117118
if is_traceable_wrapper_subclass(x):
118-
xs_inner.extend(get_plain_tensors(x))
119+
xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x)))
119120
else:
120121
xs_inner.append(x)
121122
return xs_inner

torch/_functorch/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
functionalize_rng_ops = False
1717

1818
# can be useful for debugging if we are incorrectly creating meta fake tensors
19-
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
19+
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
2020

2121
# Enables optional asserts in hotpath code to check for errors. If
2222
# you are seeing weird accuracy problems, try turning this on.
2323
# This is currently off by default as it will harm tracing time,
2424
# but it is on by default for aot_eager.
2525
debug_assert = False
2626

27-
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False)
27+
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
2828

2929
# Today, if you are in a situation where there is "false aliasing"
3030
# (e.g. you have a bunch of model parameters that all alias the same underlying buffer),

0 commit comments

Comments
 (0)