|
48 | 48 | unsafe_remove_auto_functionalized_pass,
|
49 | 49 | )
|
50 | 50 | from torch.export.exported_program import (
|
51 |
| - _get_updated_range_constraints, |
52 | 51 | ConstantArgument,
|
53 | 52 | ExportGraphSignature,
|
54 | 53 | InputKind,
|
|
64 | 63 | Val = Any
|
65 | 64 |
|
66 | 65 |
|
| 66 | +def _get_updated_range_constraints(gm): |
| 67 | + def get_shape_env(gm): |
| 68 | + vals = [ |
| 69 | + node.meta["val"] |
| 70 | + for node in gm.graph.nodes |
| 71 | + if node.meta.get("val", None) is not None |
| 72 | + ] |
| 73 | + from torch._guards import detect_fake_mode # type: ignore[21] |
| 74 | + |
| 75 | + fake_mode = detect_fake_mode(vals) |
| 76 | + if fake_mode is not None: |
| 77 | + return fake_mode.shape_env |
| 78 | + for v in vals: |
| 79 | + if isinstance(v, torch.SymInt): |
| 80 | + return v.node.shape_env |
| 81 | + |
| 82 | + shape_env = get_shape_env(gm) |
| 83 | + if shape_env is None: |
| 84 | + return {} |
| 85 | + range_constraints = { |
| 86 | + k: v |
| 87 | + for k, v in shape_env.var_to_range.items() |
| 88 | + if k not in shape_env.replacements |
| 89 | + } |
| 90 | + # Only when we have an unbacked symint, and it's used as constructor inputs, |
| 91 | + # runtime_var_to_range will make a difference compated to var_to_range. |
| 92 | + # e.g. [2, oo) -> [0, oo) |
| 93 | + for k, v in shape_env.var_to_range.items(): |
| 94 | + if k not in shape_env.replacements: |
| 95 | + range_constraints[k] = v |
| 96 | + return range_constraints |
| 97 | + |
| 98 | + |
67 | 99 | def _get_updated_graph_signature(
|
68 | 100 | old_signature: ExportGraphSignature,
|
69 | 101 | new_gm: torch.fx.GraphModule,
|
|
0 commit comments