Skip to content

Move legacy range constraint calculator to executorch to unblock pytorch CI #2925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
unsafe_remove_auto_functionalized_pass,
)
from torch.export.exported_program import (
_get_updated_range_constraints,
ConstantArgument,
ExportGraphSignature,
InputKind,
Expand All @@ -64,6 +63,39 @@
Val = Any


def _get_updated_range_constraints(gm):
def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode # type: ignore[21]

fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env

shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: v
for k, v in shape_env.var_to_range.items()
if k not in shape_env.replacements
}
# Only when we have an unbacked symint, and it's used as constructor inputs,
# runtime_var_to_range will make a difference compated to var_to_range.
# e.g. [2, oo) -> [0, oo)
for k, v in shape_env.var_to_range.items():
if k not in shape_env.replacements:
range_constraints[k] = v
return range_constraints


def _get_updated_graph_signature(
old_signature: ExportGraphSignature,
new_gm: torch.fx.GraphModule,
Expand Down