Skip to content

Commit ebe9075

Browse files
authored
typing for decorators - fx/_compatibility
Differential Revision: D61493706 Pull Request resolved: #4810
1 parent 8471c22 commit ebe9075

File tree

19 files changed

+49
-45
lines changed

19 files changed

+49
-45
lines changed

backends/cadence/aot/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ def convert_pt2(
6060
# Export with dynamo
6161
model_gm = capture_pre_autograd_graph(model, inputs)
6262

63-
if model_gm_has_SDPA(model_gm):
63+
if model_gm_has_SDPA(model_gm): # pyre-fixme[6]
6464
# Decompose SDPA
65-
DecomposeScaledDotProductAttention(False)(model_gm)
65+
DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6]
6666

6767
# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
6868
# for details).
69-
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
69+
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6]
7070
assert result is not None
7171
model_gm = result.graph_module
7272

backends/transforms/addmm_mm_to_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
130130
"call_function", ops.aten.linear.default, args
131131
)
132132
node.replace_all_uses_with(linear_node)
133-
output_val = linear_node.target(
133+
output_val = linear_node.target( # pyre-fixme[29]
134134
args[0].meta["val"], args[1].meta["val"], args[2].meta["val"]
135135
)
136136
else:
@@ -147,7 +147,7 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
147147
"call_function", ops.aten.linear.default, args
148148
)
149149
node.replace_all_uses_with(linear_node)
150-
output_val = linear_node.target(
150+
output_val = linear_node.target( # pyre-fixme[29]
151151
args[0].meta["val"], args[1].meta["val"]
152152
)
153153
linear_node.meta = node.meta

backends/transforms/decompose_sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def call(
3434
# refer to pytorch/test/test_decomp.py
3535
decomposed_module = make_fx(
3636
node.target,
37-
decomposition_table=get_decompositions(
37+
decomposition_table=get_decompositions( # pyre-fixme[6]
3838
[
3939
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
4040
]

backends/xnnpack/passes/channels_last_tagged_reshape_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def create_call_function_node(
124124
"call_function",
125125
target=target,
126126
args=args,
127-
kwargs=(
127+
kwargs=( # pyre-fixme[6]
128128
{"memory_format": memory_format} if memory_format is not None else {}
129129
),
130130
)

backends/xnnpack/passes/convert_to_sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def create_sdpa(
8383
kwargs={"scale": scale},
8484
)
8585

86-
sdpa_node.meta["val"] = sdpa_node.target(
86+
sdpa_node.meta["val"] = sdpa_node.target( # pyre-fixme[29]
8787
*[n.meta["val"] for n in match.placeholder_nodes],
8888
scale=scale,
8989
)

examples/models/llama2/source_transformation/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def quantize(
9696

9797
try:
9898
# torchao 0.3+
99-
from torchao._eval import InputRecorder
99+
from torchao._eval import InputRecorder # pyre-fixme[21]
100100
except ImportError:
101101
from torchao.quantization.GPTQ import InputRecorder # pyre-ignore
102102

@@ -110,7 +110,7 @@ def quantize(
110110
)
111111

112112
inputs = (
113-
InputRecorder(
113+
InputRecorder( # pyre-fixme[16]
114114
tokenizer,
115115
calibration_seq_length,
116116
None, # input_prep_func

examples/models/phi-3-mini/export_phi-3-mini.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def export(args) -> None:
6767
model = capture_pre_autograd_graph(
6868
model, example_inputs, dynamic_shapes=dynamic_shapes
6969
)
70-
model = prepare_pt2e(model, xnnpack_quantizer)
70+
model = prepare_pt2e(model, xnnpack_quantizer) # pyre-fixme[6]
7171
model(*example_inputs)
7272
model = convert_pt2e(model, fold_quantize=False)
7373
DuplicateDynamicQuantChainPass()(model)

exir/emit/_emitter.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan
12701270

12711271
def fetch_attr(self, target: _Target) -> _AbstractValue:
12721272
"""Fetch weights and other module parameters. If the attribute is a tensor, emit it."""
1273-
attr = super().fetch_attr(target)
1273+
attr = super().fetch_attr(target) # pyre-fixme[6]
12741274

12751275
if isinstance(attr, torch.Tensor):
12761276
return self._emit_evalue(
@@ -1286,23 +1286,23 @@ def fetch_attr(self, target: _Target) -> _AbstractValue:
12861286
else:
12871287
return attr
12881288

1289-
def call_module(
1289+
def call_module( # pyre-fixme[14]
12901290
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
12911291
) -> None:
12921292
"""Unsupported in execution IR, so unhandled by the emitter."""
12931293
raise InternalError(
12941294
self._emit_node_specific_error(self.node, "call_module is not supported")
12951295
)
12961296

1297-
def call_method(
1297+
def call_method( # pyre-fixme[14]
12981298
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
12991299
) -> _EmitterValue:
13001300
"""Unsupported in execution IR, so unhandled by the emitter."""
13011301
raise InternalError(
13021302
self._emit_node_specific_error(self.node, "call_method is not supported")
13031303
)
13041304

1305-
def placeholder(
1305+
def placeholder( # pyre-fixme[14]
13061306
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
13071307
) -> _AbstractValue:
13081308
"""Performs actions for the placeholder node of a graph module.
@@ -1324,7 +1324,7 @@ def placeholder(
13241324
self.placeholder_count += 1
13251325
return value
13261326

1327-
def output(
1327+
def output( # pyre-fixme[14]
13281328
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
13291329
) -> None:
13301330
"""Performs actions for the output node of a graph module.
@@ -1354,7 +1354,7 @@ def output(
13541354
)
13551355
self.chain.instructions.append(instruction)
13561356

1357-
def call_function(
1357+
def call_function( # pyre-fixme[14]
13581358
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
13591359
) -> _EmitterValue:
13601360
"""Performs actions for the call_function node of a graph module.
@@ -1412,7 +1412,7 @@ def call_function(
14121412
)
14131413
)
14141414

1415-
def run(
1415+
def run( # pyre-fixme[14]
14161416
self,
14171417
*args: _Argument,
14181418
initial_env: Optional[Dict[torch.fx.Node, _Argument]] = None,

exir/lowered_backend_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def buffer(
139139
segment_alignment: int = 4096,
140140
constant_tensor_alignment: Optional[int] = None,
141141
delegate_alignment: Optional[int] = None,
142-
memory_planning: MemoryPlanningPass = None,
142+
memory_planning: MemoryPlanningPass = None, # pyre-fixme[9]
143143
) -> bytes:
144144
"""
145145
Returns a buffer containing the serialized ExecuTorch binary.
@@ -161,7 +161,7 @@ def buffer(
161161
def program(
162162
self,
163163
emit_stacktrace: bool = False,
164-
memory_planning: MemoryPlanningPass = None,
164+
memory_planning: MemoryPlanningPass = None, # pyre-fixme[9]
165165
) -> Program:
166166
# Fix autodpes introuces cyclic dependencies:
167167
# program -> verifier -> lowered_backend_module -> program

exir/pass_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def __init__(self, callback: "_ExportPassBase", codegen: CodeGen) -> None:
177177
self.fake_tensor_mode: Optional[FakeTensorMode] = None
178178
self.submodules: Dict[torch.nn.Module, str] = {}
179179

180-
def trace(self) -> None:
180+
def trace(self) -> None: # pyre-fixme[14,15]
181181
raise ExportPassBaseError("ExportTracer doesn't support trace().")
182182

183183
def create_arg(self, a: Argument) -> torch.fx.Node:
@@ -290,7 +290,7 @@ def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None:
290290
self.callback = callback
291291
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
292292

293-
def placeholder(
293+
def placeholder( # pyre-fixme[14]
294294
self,
295295
target: str,
296296
args: Tuple[Argument, ...],
@@ -351,7 +351,7 @@ def call_function(
351351
else:
352352
raise ExportPassBaseError(f"Unsupported target type: {target}")
353353

354-
def get_attr(
354+
def get_attr( # pyre-fixme[14]
355355
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
356356
) -> Argument:
357357
return super().get_attr(target, args, kwargs)
@@ -364,7 +364,7 @@ def call_module(
364364
) -> None:
365365
raise ExportPassBaseError("call_module is not supported.")
366366

367-
def call_method(
367+
def call_method( # pyre-fixme[14]
368368
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
369369
) -> None:
370370
raise ExportPassBaseError("call_method is not supported.")

exir/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ def make_alloc_node(
302302
"Memory allocator node needs FakeTensor val or TensorMetadata to proceed"
303303
)
304304

305+
# pyre-fixme[6]
305306
alloc = graph_module.graph.call_function(memory.alloc, (alloc_spec,))
306307
alloc.meta["val"] = val
307308
alloc.meta["tensor_meta"] = tensor_meta

exir/passes/remove_noop_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def eliminate_dq_q(
4040
qparams_q = list(user.args)[1:]
4141
if qparams_dq != qparams_q:
4242
continue
43-
user.replace_all_uses_with(node.args[0])
43+
user.replace_all_uses_with(node.args[0]) # pyre-fixme[6]
4444

4545

4646
class RemoveNoopPass(ExportPass):

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,7 @@ def quantize_model(
14211421
quantizer = XNNPACKQuantizer()
14221422
quantization_config = get_symmetric_quantization_config()
14231423
quantizer.set_global(quantization_config)
1424-
m = prepare_pt2e(m, quantizer)
1424+
m = prepare_pt2e(m, quantizer) # pyre-fixme[6]
14251425
m = convert_pt2e(m, fold_quantize=True)
14261426
ep = torch.export.export(m, example_inputs)
14271427
dq_nodes_pre = count_dq_nodes(ep.graph_module)

exir/tests/test_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_resnet(self) -> None:
5858
quantizer = XNNPACKQuantizer()
5959
operator_config = get_symmetric_quantization_config(is_per_channel=True)
6060
quantizer.set_global(operator_config)
61-
m = prepare_pt2e(m, quantizer)
61+
m = prepare_pt2e(m, quantizer) # pyre-fixme[6]
6262
self.assertEqual(
6363
id(m.activation_post_process_3), id(m.activation_post_process_2)
6464
)

exir/tracer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def __torch_function__(
272272
kwargs = {}
273273
if torch.is_inference_mode_enabled():
274274
if func is torch.nn.functional.layer_norm:
275-
args, kwargs = normalize_function(func, args, kwargs)
275+
args, kwargs = normalize_function(func, args, kwargs) # pyre-fixme[23]
276276
input, normalized_shape = args
277277
normalized_shape = list(normalized_shape)
278278
return cls.__torch_dispatch__(
@@ -470,13 +470,13 @@ def create_arg(self, a: Value) -> torch.fx.Node: # noqa: C901
470470
self.submodules[a] = name_submodule
471471
return self.create_node("get_attr", self.submodules[a], (), {})
472472

473-
return super().create_arg(a)
473+
return super().create_arg(a) # pyre-fixme[7]
474474

475475
@staticmethod
476476
def get() -> "DispatchTracer":
477477
return TRACER
478478

479-
def trace(
479+
def trace( # pyre-fixme[14,15]
480480
self,
481481
root: Callable[..., Value],
482482
concrete_args: Tuple[Value, ...] = (),

exir/verification/arg_validator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _get_kernel_arg(self, schema_arg, schema_arg_idx, args, kwargs):
6262

6363
return kernel_arg
6464

65-
def call_function( # noqa: C901
65+
def call_function( # noqa: C901 # pyre-fixme[14]
6666
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
6767
) -> Any:
6868
"""
@@ -73,7 +73,7 @@ def call_function( # noqa: C901
7373
):
7474
if isinstance(target, HigherOrderOperator):
7575
raise RunHigherOrderOperatorError("Can't run delegate")
76-
return super().call_function(target, args, kwargs)
76+
return super().call_function(target, args, kwargs) # pyre-fixme[6]
7777

7878
# TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist.
7979
tensor_arg_types: Dict[str, Optional[torch.dtype]] = {}
@@ -126,4 +126,4 @@ def call_function( # noqa: C901
126126
valid = target._schema.dtype_constraint.validate(tensor_arg_types)
127127
if not valid:
128128
self.violating_ops[target] = tensor_arg_types
129-
return super().call_function(target, args, kwargs)
129+
return super().call_function(target, args, kwargs) # pyre-fixme[6]

extension/llm/export/builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
161161
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
162162
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
163163
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
164+
# pyre-fixme[8]
164165
self.pre_autograd_graph_module = capture_pre_autograd_graph(
165166
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
166167
)
@@ -209,11 +210,12 @@ def export_to_edge(self) -> "LLMEdgeManager":
209210
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
210211
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
211212
if self.pre_autograd_graph_module is None:
213+
# pyre-fixme[8]
212214
self.pre_autograd_graph_module = capture_pre_autograd_graph(
213215
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
214216
)
215217
self.edge_manager = export_to_edge(
216-
self.pre_autograd_graph_module,
218+
self.pre_autograd_graph_module, # pyre-fixme[6]
217219
self.example_inputs,
218220
dynamic_shapes=dynamic_shape,
219221
edge_constant_methods=self.metadata,

extension/llm/export/partitioner_lib.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):
5252
)
5353

5454
compile_specs = [CompileSpec("use_fp16", bytes([True]))]
55-
return MPSPartitioner(compile_specs)
55+
return MPSPartitioner(compile_specs) # pyre-fixme[16]
5656

5757

5858
def get_coreml_partitioner(
@@ -92,14 +92,14 @@ def get_coreml_partitioner(
9292
# if use_kv_cache:
9393
# minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)
9494

95-
compile_specs = CoreMLBackend.generate_compile_specs(
95+
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
9696
minimum_deployment_target=minimum_deployment_target,
9797
compute_precision=ct.precision(ct.precision.FLOAT16.value),
9898
# using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
9999
compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()],
100-
model_type=CoreMLBackend.MODEL_TYPE.MODEL,
100+
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
101101
)
102-
return CoreMLPartitioner(
102+
return CoreMLPartitioner( # pyre-fixme[16]
103103
compile_specs=compile_specs,
104104
)
105105

@@ -136,9 +136,10 @@ def get_qnn_partitioner(
136136
if pt2e_quantize is not None:
137137
use_fp16 = False
138138

139-
return QnnPartitioner(
140-
generate_qnn_executorch_compiler_spec(
141-
soc_model=QcomChipset.SM8650, # default to SM8650
139+
return QnnPartitioner( # pyre-fixme[16]
140+
generate_qnn_executorch_compiler_spec( # pyre-fixme[16]
141+
soc_model=QcomChipset.SM8650, # default to SM8650 # pyre-fixme[16]
142+
# pyre-fixme[16]
142143
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
143144
debug=False,
144145
saver=False,

extension/llm/export/quantizer_lib.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def get_qnn_quantizer(
146146
quantization_mode: Optional[str] = None,
147147
):
148148
try:
149-
from executorch.backends.qualcomm.quantizer.custom_annotation import (
149+
from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21]
150150
custom_annotate_llama_matmul_16a8w,
151151
)
152152

@@ -168,15 +168,15 @@ def get_qnn_quantizer(
168168
assert (
169169
backend == "qnn"
170170
), f"The quantization config is for backend {backend} instead of qnn."
171-
qnn_quantizer = QnnQuantizer()
171+
qnn_quantizer = QnnQuantizer() # pyre-fixme[16]
172172
qnn_quantizer.set_per_channel_conv_quant(enable=True)
173173
qnn_quantizer.set_per_channel_linear_quant(enable=True)
174174
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
175175
custom_annotations = ()
176176
if quant_config == "8a8w":
177-
quant_dtype = QuantDtype.use_8a8w
177+
quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16]
178178
elif quant_config == "16a16w":
179-
quant_dtype = QuantDtype.use_16a16w
179+
quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16]
180180
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
181181
qnn_quantizer.set_bit16_op_quant_config(
182182
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.

0 commit comments

Comments
 (0)