Skip to content

Commit 91afa6e

Browse files
authored
Fix type-safety of torch.nn.Module instances
Differential Revision: D52890934 Pull Request resolved: #6922
1 parent 96a9d35 commit 91afa6e

File tree

22 files changed

+105
-0
lines changed

22 files changed

+105
-0
lines changed

backends/cadence/aot/quantizer/patterns.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def partition_types(self) -> List[OpOverload]:
7575
def get_anchors(
7676
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
7777
) -> PartitionAnchors:
78+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
7879
addmm_node = fused_partition[0].nodes[-1]
7980

8081
bias_qspec = DerivedQuantizationSpec(
@@ -107,6 +108,7 @@ def partition_types(self) -> List[OpOverload]:
107108
def get_anchors(
108109
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
109110
) -> PartitionAnchors:
111+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
110112
bmm_node = fused_partition[0].nodes[-1]
111113

112114
return PartitionAnchors(
@@ -127,6 +129,7 @@ def partition_types(self) -> List[OpOverload]:
127129
def get_anchors(
128130
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
129131
) -> PartitionAnchors:
132+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
130133
conv1d_node = fused_partition[0].nodes[-1]
131134

132135
bias_qspec = DerivedQuantizationSpec(
@@ -165,6 +168,7 @@ def partition_types(self) -> List[OpOverload]:
165168
def get_anchors(
166169
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
167170
) -> PartitionAnchors:
171+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
168172
conv2d_node = fused_partition[0].nodes[-1]
169173

170174
bias_qspec = DerivedQuantizationSpec(
@@ -203,6 +207,7 @@ def partition_types(self) -> List[OpOverload]:
203207
def get_anchors(
204208
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
205209
) -> PartitionAnchors:
210+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
206211
layer_norm_node = fused_partition[0].nodes[-1]
207212

208213
others = [(layer_norm_node, 1)]
@@ -237,6 +242,7 @@ def partition_types(self) -> List[OpOverload]:
237242
def get_anchors(
238243
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
239244
) -> PartitionAnchors:
245+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
240246
linear_node = fused_partition[0].nodes[-1]
241247

242248
bias_qspec = DerivedQuantizationSpec(
@@ -275,6 +281,7 @@ def partition_types(self) -> List[OpOverload]:
275281
def get_anchors(
276282
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
277283
) -> PartitionAnchors:
284+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
278285
matmul_node = fused_partition[0].nodes[-1]
279286

280287
return PartitionAnchors(
@@ -297,6 +304,7 @@ def partition_types(self) -> List[OpOverload]:
297304
def get_anchors(
298305
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
299306
) -> PartitionAnchors:
307+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
300308
relu_node = fused_partition[0].nodes[-1]
301309

302310
return PartitionAnchors(

backends/example/example_backend_delegate_passes/merge_to_dim_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
7373
)
7474
out_to_dim_node_user.args = tuple(out_to_dim_node_user_new_args)
7575

76+
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
77+
# torch.nn.modules.module.Module]` is not a function.
7678
graph_module.erase_node(out_to_dim_node)
79+
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
80+
# torch.nn.modules.module.Module]` is not a function.
7781
graph_module.erase_node(node)
7882
# TODO: Handle other merging rules, including 1->N, N->1, N->N
7983
return PassResult(graph_module, True)

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _vk_replace_linear_int4(
136136
scales_precision=scales_precision,
137137
)
138138
if copy_weights and child.weight.device != torch.device("meta"):
139+
# pyre-fixme[16]: `Module` has no attribute `weight`.
139140
new_linear.weight = child.weight
140141
setattr(module, name, new_linear)
141142
else:

backends/xnnpack/test/ops/conv2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def _test(
155155
conv_count=1,
156156
dtype: torch.dtype = torch.float,
157157
):
158+
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
159+
# torch.nn.modules.module.Module]` is not a function.
158160
tester = Tester(m.eval(), m.get_inputs())
159161

160162
if quant_config is not None:

examples/models/llama/export_llama_lib.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,8 +960,14 @@ def _load_llama_model(
960960
use_kv_cache,
961961
use_sdpa_with_kv_cache,
962962
enable_dynamic_shape,
963+
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
964+
# `Union[Tensor, Module]`.
963965
model.max_seq_len,
966+
# pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor,
967+
# Module]`.
964968
model.n_layers,
969+
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
970+
# Module]`.
965971
model.vocab_size,
966972
metadata_str,
967973
),
@@ -1045,6 +1051,7 @@ def _get_source_transforms( # noqa
10451051
transforms.append(replace_attention_to_attention_sha)
10461052
transforms.append(replace_causal_mask)
10471053
transforms.append(replace_rms_norm_with_native_rms_norm)
1054+
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
10481055
transforms.append(convert_linear_to_conv2d)
10491056
else:
10501057
transforms.append(replace_kv_cache_with_simple_kv_cache)
@@ -1056,6 +1063,7 @@ def _get_source_transforms( # noqa
10561063
transforms.append(
10571064
get_model_with_r1_r2(args.optimized_rotation_path)
10581065
)
1066+
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
10591067
transforms.append(convert_linear_to_conv2d)
10601068

10611069
elif args.mps:

examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,17 @@ def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str)
101101
optimized_rotation = torch.load(optimized_rotation_path, weights_only=True)
102102
R1 = optimized_rotation["R1"].to(torch.float32)
103103
config = model.params
104+
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
105+
# `n_heads`.
104106
num_heads = config.n_heads
105107
head_dim = config.dim // num_heads
106108

107109
rotate_embeddings(model, R1)
108110
rotate_head(model, R1)
109111
cleanup_memory()
110112

113+
# pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
114+
# `Union[Tensor, Module]`.
111115
for idx, layer in enumerate(model.layers):
112116
key = f"model.layers.{idx}.self_attn.R2"
113117
R2 = optimized_rotation[key].to(torch.float32)
@@ -130,6 +134,8 @@ def fuse_ln_linear(
130134

131135
# Calculating new weight and bias
132136
W_ = linear.weight.data.to(dtype=torch.float32)
137+
# pyre-fixme[58]: `*` is not supported for operand types `Tensor` and
138+
# `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`.
133139
linear.weight.data = (W_ * layernorm.weight.to(dtype=torch.float32)).to(
134140
linear_dtype
135141
)
@@ -140,6 +146,8 @@ def fuse_ln_linear(
140146
torch.zeros(linear.out_features, dtype=torch.float32)
141147
)
142148
linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul(
149+
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
150+
# `Union[Tensor, Module]`.
143151
W_, layernorm.bias.to(dtype=torch.float32)
144152
)
145153
linear.bias.data = linear.bias.data.to(linear_dtype)
@@ -148,10 +156,18 @@ def fuse_ln_linear(
148156
def fuse_layer_norms(model: torch.nn.Module):
149157
# Embedding fusion
150158
for W in [model.tok_embeddings]:
159+
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
160+
# `weight`.
151161
W_ = W.weight.data.to(dtype=torch.float32)
162+
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
163+
# `weight`.
152164
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
153165

154166
# Fuse the linear operations in Layernorm into the adjacent linear blocks.
167+
# pyre-fixme[29]:
168+
# `Union[BoundMethod[typing.Callable(torch._tensor.Tensor.__iter__)[[Named(self,
169+
# torch._tensor.Tensor)], typing.Any], torch._tensor.Tensor],
170+
# torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function.
155171
for layer in model.layers:
156172
# fuse the input layernorms into the linear layers
157173
fuse_ln_linear(layer.ffn_norm, [layer.feed_forward.w3, layer.feed_forward.w1])
@@ -170,9 +186,15 @@ def fuse_layer_norms(model: torch.nn.Module):
170186
layer.attention_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
171187

172188
fuse_ln_linear(
189+
# pyre-fixme[6]: For 1st argument expected `Module` but got `Union[Tensor,
190+
# Module]`.
173191
model.norm,
192+
# pyre-fixme[6]: For 2nd argument expected `Iterable[Linear]` but got
193+
# `Iterable[Union[Tensor, Module]]`.
174194
[model.output],
175195
)
196+
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
197+
# `weight`.
176198
W_norm = model.norm.weight.data
177199
model.norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
178200

examples/models/llama/source_transformation/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,19 @@ def __init__(self, attention_mha: nn.Module):
165165

166166
for i in range(self.n_heads):
167167
self.wq[i].weight.data.copy_(
168+
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
169+
# attribute `weight`.
168170
attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
169171
)
170172
for i in range(self.n_kv_heads):
171173
self.wk[i].weight.data.copy_(
174+
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
175+
# attribute `weight`.
172176
attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
173177
)
174178
self.wv[i].weight.data.copy_(
179+
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
180+
# attribute `weight`.
175181
attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
176182
)
177183
self.wo = attention_mha.wo

examples/models/llama/source_transformation/lora.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,23 @@ def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
110110

111111
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
112112
new_linear = Int8DynActInt4WeightLinearLoRA(
113+
# pyre-fixme[6]: For 1st argument expected `int` but got `Union[Module,
114+
# Tensor]`.
113115
child.in_features,
116+
# pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
117+
# Tensor]`.
114118
child.out_features,
115119
lora_rank=lora_rank,
116120
bias=False,
117121
device=child.weight.device,
122+
# pyre-fixme[6]: For 6th argument expected `int` but got `Union[Module,
123+
# Tensor]`.
118124
groupsize=child.groupsize,
125+
# pyre-fixme[6]: For 7th argument expected `dtype` but got
126+
# `Union[Module, Tensor]`.
119127
precision=child.precision,
128+
# pyre-fixme[6]: For 8th argument expected `dtype` but got
129+
# `Union[Module, dtype, Tensor]`.
120130
scales_precision=child.scales.dtype,
121131
)
122132
return new_linear

examples/models/llama/source_transformation/pre_quantization.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
3838

3939
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
4040
new_linear = Int8DynActInt4WeightLinear(
41+
# pyre-fixme[6]: For 1st argument expected `int` but got `Union[Module,
42+
# Tensor]`.
4143
child.in_features,
44+
# pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
45+
# Tensor]`.
4246
child.out_features,
4347
bias=False,
4448
device=child.weight.device,
@@ -99,7 +103,11 @@ def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
99103
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
100104
new_linear = Int8DynActInt8WeightLinear(
101105
device=child.weight.device,
106+
# pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
107+
# Tensor]`.
102108
in_features=child.in_features,
109+
# pyre-fixme[6]: For 3rd argument expected `int` but got `Union[Module,
110+
# Tensor]`.
103111
out_features=child.out_features,
104112
precision=dtype,
105113
bias=False,

exir/capture/_capture.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
120120
signature=ModuleCallSignature(
121121
inputs=[],
122122
outputs=[],
123+
# pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got
124+
# `Union[Tensor, Module]`.
123125
in_spec=in_spec,
126+
# pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
127+
# `Union[Tensor, Module]`.
124128
out_spec=out_spec,
125129
),
126130
)
@@ -350,6 +354,8 @@ def convert_to_fake(x):
350354
inputs=[],
351355
outputs=[],
352356
in_spec=in_spec,
357+
# pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
358+
# `Union[None, TreeSpec, Tensor, Module]`.
353359
out_spec=out_spec,
354360
),
355361
)

exir/control_flow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _make_submodule(
121121
)
122122
output.args = tuple(output.args[0])
123123
gm.recompile()
124+
# pyre-fixme[16]: `GraphModule` has no attribute `__tracing_inputs__`.
124125
gm.__tracing_inputs__ = args
125126
return gm
126127

exir/emit/_emitter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,7 @@ def plan(self) -> ExecutionPlan:
16341634
# missing in scenarios like unit test that does not enable memory planning, assume an
16351635
# empty list.
16361636
non_const_buffer_sizes=typing.cast(
1637+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorB...
16371638
List[int], self.module.meta["non_const_buffer_sizes"]
16381639
),
16391640
container_meta_type=self.container_meta_type,

exir/lowered_backend_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule"
102102
processed_bytes=self._processed_bytes,
103103
compile_specs=copy.deepcopy(self._compile_specs, memo),
104104
)
105+
# pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`.
105106
res.meta = copy.copy(getattr(self, "meta", {}))
106107
return res
107108

exir/memory_planning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,10 @@ def greedy(
581581
for mem_id in shared_objects:
582582
input_total_size = 0
583583
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
584+
# pyre-fixme[6]: For 1st argument expected
585+
# `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`.
584586
if len(bufsizes) > mem_id:
587+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten...
585588
input_total_size = bufsizes[mem_id]
586589
total_sizes[mem_id] = materialize_buffer(
587590
shared_objects[mem_id], input_total_size

exir/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
372372
else:
373373
out_var_target, out_args_names = to_out_variant(target)
374374
except RuntimeError as e:
375+
# pyre-fixme[16]: `GraphModule` has no attribute
376+
# `encounter_to_out_var_failure`.
375377
graph_module.encounter_to_out_var_failure = True
376378
logging.info(
377379
f"Failed converting '{target}' to its out variant with error: '{e}'"

exir/program/test/test_program.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,8 @@ def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module):
591591
# This is the pre-dispatch export that we will be switching to primarily
592592
# in the near future. The input to to_edge_transform_and_lower needs to
593593
# be a graph generated by this pre dispatch export.
594+
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
595+
# torch.nn.modules.module.Module]` is not a function.
594596
ep = _export(model, model._get_random_inputs(), pre_dispatch=True)
595597
non_decomp_partitioner = NonDecompTestPartitioner()
596598

exir/tests/test_capture.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class TestCapture(unittest.TestCase):
1919
# pyre-ignore
2020
@parameterized.expand(models.MODELS)
2121
def test_module_call(self, model_name: str, model: torch.nn.Module) -> None:
22+
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
23+
# torch.nn.modules.module.Module]` is not a function.
2224
inputs = model.get_random_inputs()
2325
expected = model(*inputs)
2426
# TODO(ycao): Replace it with capture_multiple

exir/tests/test_memory_planning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def wrapper(self: "TestMemoryPlanning") -> None:
234234
f"algo {getattr(algo, '__name__', repr(algo))}, expect_reuse {expect_reuse}"
235235
)
236236
eager_module = module_cls().eval()
237+
# pyre-fixme[29]: `Union[nn.modules.module.Module,
238+
# torch._tensor.Tensor]` is not a function.
237239
inputs = eager_module.get_random_inputs()
238240
graph_module = (
239241
to_edge(

exir/tracer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,9 @@ def dispatch_trace(
725725

726726
_, out_spec = pytree.tree_flatten(tree_out)
727727

728+
# pyre-fixme[16]: `GraphModule` has no attribute `in_spec`.
728729
gm.in_spec = in_spec
730+
# pyre-fixme[16]: `GraphModule` has no attribute `out_spec`.
729731
gm.out_spec = out_spec
730732

731733
# TODO (tmanlaibaatar) This is bit clowny, but our

extension/pybindings/test/make_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,12 @@ def create_program(
9191
"""Returns an executorch program based on ModuleAdd, along with inputs."""
9292

9393
# Trace the test module and create a serialized ExecuTorch program.
94+
# pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
95+
# is not a function.
9496
inputs = eager_module.get_inputs()
9597
input_map = {}
98+
# pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
99+
# is not a function.
96100
for method in eager_module.get_methods_to_export():
97101
input_map[method] = inputs
98102

0 commit comments

Comments
 (0)