Skip to content

Fix type-safety of torch.nn.Module instances #6922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
addmm_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
Expand Down Expand Up @@ -107,6 +108,7 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
bmm_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
Expand All @@ -127,6 +129,7 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv1d_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
Expand Down Expand Up @@ -165,6 +168,7 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv2d_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
Expand Down Expand Up @@ -203,6 +207,7 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
layer_norm_node = fused_partition[0].nodes[-1]

others = [(layer_norm_node, 1)]
Expand Down Expand Up @@ -237,6 +242,7 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
linear_node = fused_partition[0].nodes[-1]

bias_qspec = DerivedQuantizationSpec(
Expand Down Expand Up @@ -275,6 +281,7 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
matmul_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
Expand All @@ -297,6 +304,7 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
relu_node = fused_partition[0].nodes[-1]

return PartitionAnchors(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
)
out_to_dim_node_user.args = tuple(out_to_dim_node_user_new_args)

# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
graph_module.erase_node(out_to_dim_node)
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
graph_module.erase_node(node)
# TODO: Handle other merging rules, including 1->N, N->1, N->N
return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/vulkan/_passes/int4_weight_only_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _vk_replace_linear_int4(
scales_precision=scales_precision,
)
if copy_weights and child.weight.device != torch.device("meta"):
# pyre-fixme[16]: `Module` has no attribute `weight`.
new_linear.weight = child.weight
setattr(module, name, new_linear)
else:
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/test/ops/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def _test(
conv_count=1,
dtype: torch.dtype = torch.float,
):
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
tester = Tester(m.eval(), m.get_inputs())

if quant_config is not None:
Expand Down
8 changes: 8 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,14 @@ def _load_llama_model(
use_kv_cache,
use_sdpa_with_kv_cache,
enable_dynamic_shape,
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
# `Union[Tensor, Module]`.
model.max_seq_len,
# pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor,
# Module]`.
model.n_layers,
# pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
# Module]`.
model.vocab_size,
metadata_str,
),
Expand Down Expand Up @@ -1045,6 +1051,7 @@ def _get_source_transforms( # noqa
transforms.append(replace_attention_to_attention_sha)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
transforms.append(convert_linear_to_conv2d)
else:
transforms.append(replace_kv_cache_with_simple_kv_cache)
Expand All @@ -1056,6 +1063,7 @@ def _get_source_transforms( # noqa
transforms.append(
get_model_with_r1_r2(args.optimized_rotation_path)
)
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
transforms.append(convert_linear_to_conv2d)

elif args.mps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,17 @@ def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str)
optimized_rotation = torch.load(optimized_rotation_path, weights_only=True)
R1 = optimized_rotation["R1"].to(torch.float32)
config = model.params
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
# `n_heads`.
num_heads = config.n_heads
head_dim = config.dim // num_heads

rotate_embeddings(model, R1)
rotate_head(model, R1)
cleanup_memory()

# pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
# `Union[Tensor, Module]`.
for idx, layer in enumerate(model.layers):
key = f"model.layers.{idx}.self_attn.R2"
R2 = optimized_rotation[key].to(torch.float32)
Expand All @@ -130,6 +134,8 @@ def fuse_ln_linear(

# Calculating new weight and bias
W_ = linear.weight.data.to(dtype=torch.float32)
# pyre-fixme[58]: `*` is not supported for operand types `Tensor` and
# `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`.
linear.weight.data = (W_ * layernorm.weight.to(dtype=torch.float32)).to(
linear_dtype
)
Expand All @@ -140,6 +146,8 @@ def fuse_ln_linear(
torch.zeros(linear.out_features, dtype=torch.float32)
)
linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul(
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
# `Union[Tensor, Module]`.
W_, layernorm.bias.to(dtype=torch.float32)
)
linear.bias.data = linear.bias.data.to(linear_dtype)
Expand All @@ -148,10 +156,18 @@ def fuse_ln_linear(
def fuse_layer_norms(model: torch.nn.Module):
# Embedding fusion
for W in [model.tok_embeddings]:
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
# `weight`.
W_ = W.weight.data.to(dtype=torch.float32)
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
# `weight`.
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)

# Fuse the linear operations in Layernorm into the adjacent linear blocks.
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch._tensor.Tensor.__iter__)[[Named(self,
# torch._tensor.Tensor)], typing.Any], torch._tensor.Tensor],
# torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function.
for layer in model.layers:
# fuse the input layernorms into the linear layers
fuse_ln_linear(layer.ffn_norm, [layer.feed_forward.w3, layer.feed_forward.w1])
Expand All @@ -170,9 +186,15 @@ def fuse_layer_norms(model: torch.nn.Module):
layer.attention_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)

fuse_ln_linear(
# pyre-fixme[6]: For 1st argument expected `Module` but got `Union[Tensor,
# Module]`.
model.norm,
# pyre-fixme[6]: For 2nd argument expected `Iterable[Linear]` but got
# `Iterable[Union[Tensor, Module]]`.
[model.output],
)
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
# `weight`.
W_norm = model.norm.weight.data
model.norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)

Expand Down
6 changes: 6 additions & 0 deletions examples/models/llama/source_transformation/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,19 @@ def __init__(self, attention_mha: nn.Module):

for i in range(self.n_heads):
self.wq[i].weight.data.copy_(
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
# attribute `weight`.
attention_mha.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
for i in range(self.n_kv_heads):
self.wk[i].weight.data.copy_(
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
# attribute `weight`.
attention_mha.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
self.wv[i].weight.data.copy_(
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no
# attribute `weight`.
attention_mha.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
)
self.wo = attention_mha.wo
Expand Down
10 changes: 10 additions & 0 deletions examples/models/llama/source_transformation/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,23 @@ def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt4WeightLinearLoRA(
# pyre-fixme[6]: For 1st argument expected `int` but got `Union[Module,
# Tensor]`.
child.in_features,
# pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
# Tensor]`.
child.out_features,
lora_rank=lora_rank,
bias=False,
device=child.weight.device,
# pyre-fixme[6]: For 6th argument expected `int` but got `Union[Module,
# Tensor]`.
groupsize=child.groupsize,
# pyre-fixme[6]: For 7th argument expected `dtype` but got
# `Union[Module, Tensor]`.
precision=child.precision,
# pyre-fixme[6]: For 8th argument expected `dtype` but got
# `Union[Module, dtype, Tensor]`.
scales_precision=child.scales.dtype,
)
return new_linear
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt4WeightLinear(
# pyre-fixme[6]: For 1st argument expected `int` but got `Union[Module,
# Tensor]`.
child.in_features,
# pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
# Tensor]`.
child.out_features,
bias=False,
device=child.weight.device,
Expand Down Expand Up @@ -99,7 +103,11 @@ def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt8WeightLinear(
device=child.weight.device,
# pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
# Tensor]`.
in_features=child.in_features,
# pyre-fixme[6]: For 3rd argument expected `int` but got `Union[Module,
# Tensor]`.
out_features=child.out_features,
precision=dtype,
bias=False,
Expand Down
6 changes: 6 additions & 0 deletions exir/capture/_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def _capture_legacy_do_not_use(f, args) -> ExirExportedProgram:
signature=ModuleCallSignature(
inputs=[],
outputs=[],
# pyre-fixme[6]: For 3rd argument expected `TreeSpec` but got
# `Union[Tensor, Module]`.
in_spec=in_spec,
# pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
# `Union[Tensor, Module]`.
out_spec=out_spec,
),
)
Expand Down Expand Up @@ -350,6 +354,8 @@ def convert_to_fake(x):
inputs=[],
outputs=[],
in_spec=in_spec,
# pyre-fixme[6]: For 4th argument expected `TreeSpec` but got
# `Union[None, TreeSpec, Tensor, Module]`.
out_spec=out_spec,
),
)
Expand Down
1 change: 1 addition & 0 deletions exir/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _make_submodule(
)
output.args = tuple(output.args[0])
gm.recompile()
# pyre-fixme[16]: `GraphModule` has no attribute `__tracing_inputs__`.
gm.__tracing_inputs__ = args
return gm

Expand Down
1 change: 1 addition & 0 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,7 @@ def plan(self) -> ExecutionPlan:
# missing in scenarios like unit test that does not enable memory planning, assume an
# empty list.
non_const_buffer_sizes=typing.cast(
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorB...
List[int], self.module.meta["non_const_buffer_sizes"]
),
container_meta_type=self.container_meta_type,
Expand Down
1 change: 1 addition & 0 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule"
processed_bytes=self._processed_bytes,
compile_specs=copy.deepcopy(self._compile_specs, memo),
)
# pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`.
res.meta = copy.copy(getattr(self, "meta", {}))
return res

Expand Down
3 changes: 3 additions & 0 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,10 @@ def greedy(
for mem_id in shared_objects:
input_total_size = 0
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
# pyre-fixme[6]: For 1st argument expected
# `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`.
if len(bufsizes) > mem_id:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten...
input_total_size = bufsizes[mem_id]
total_sizes[mem_id] = materialize_buffer(
shared_objects[mem_id], input_total_size
Expand Down
2 changes: 2 additions & 0 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
else:
out_var_target, out_args_names = to_out_variant(target)
except RuntimeError as e:
# pyre-fixme[16]: `GraphModule` has no attribute
# `encounter_to_out_var_failure`.
graph_module.encounter_to_out_var_failure = True
logging.info(
f"Failed converting '{target}' to its out variant with error: '{e}'"
Expand Down
2 changes: 2 additions & 0 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module):
# This is the pre-dispatch export that we will be switching to primarily
# in the near future. The input to to_edge_transform_and_lower needs to
# be a graph generated by this pre dispatch export.
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
ep = _export(model, model._get_random_inputs(), pre_dispatch=True)
non_decomp_partitioner = NonDecompTestPartitioner()

Expand Down
2 changes: 2 additions & 0 deletions exir/tests/test_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class TestCapture(unittest.TestCase):
# pyre-ignore
@parameterized.expand(models.MODELS)
def test_module_call(self, model_name: str, model: torch.nn.Module) -> None:
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
inputs = model.get_random_inputs()
expected = model(*inputs)
# TODO(ycao): Replace it with capture_multiple
Expand Down
2 changes: 2 additions & 0 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def wrapper(self: "TestMemoryPlanning") -> None:
f"algo {getattr(algo, '__name__', repr(algo))}, expect_reuse {expect_reuse}"
)
eager_module = module_cls().eval()
# pyre-fixme[29]: `Union[nn.modules.module.Module,
# torch._tensor.Tensor]` is not a function.
inputs = eager_module.get_random_inputs()
graph_module = (
to_edge(
Expand Down
2 changes: 2 additions & 0 deletions exir/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,9 @@ def dispatch_trace(

_, out_spec = pytree.tree_flatten(tree_out)

# pyre-fixme[16]: `GraphModule` has no attribute `in_spec`.
gm.in_spec = in_spec
# pyre-fixme[16]: `GraphModule` has no attribute `out_spec`.
gm.out_spec = out_spec

# TODO (tmanlaibaatar) This is bit clowny, but our
Expand Down
4 changes: 4 additions & 0 deletions extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,12 @@ def create_program(
"""Returns an executorch program based on ModuleAdd, along with inputs."""

# Trace the test module and create a serialized ExecuTorch program.
# pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
# is not a function.
inputs = eager_module.get_inputs()
input_map = {}
# pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
# is not a function.
for method in eager_module.get_methods_to_export():
input_map[method] = inputs

Expand Down
Loading
Loading