Skip to content

typing for decorators - fx/_compatibility #4810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def convert_pt2(
# Export with dynamo
model_gm = capture_pre_autograd_graph(model, inputs)

if model_gm_has_SDPA(model_gm):
if model_gm_has_SDPA(model_gm): # pyre-fixme[6]
# Decompose SDPA
DecomposeScaledDotProductAttention(False)(model_gm)
DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6]

# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
# for details).
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6]
assert result is not None
model_gm = result.graph_module

Expand Down
4 changes: 2 additions & 2 deletions backends/transforms/addmm_mm_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
"call_function", ops.aten.linear.default, args
)
node.replace_all_uses_with(linear_node)
output_val = linear_node.target(
output_val = linear_node.target( # pyre-fixme[29]
args[0].meta["val"], args[1].meta["val"], args[2].meta["val"]
)
else:
Expand All @@ -147,7 +147,7 @@ def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
"call_function", ops.aten.linear.default, args
)
node.replace_all_uses_with(linear_node)
output_val = linear_node.target(
output_val = linear_node.target( # pyre-fixme[29]
args[0].meta["val"], args[1].meta["val"]
)
linear_node.meta = node.meta
Expand Down
2 changes: 1 addition & 1 deletion backends/transforms/decompose_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def call(
# refer to pytorch/test/test_decomp.py
decomposed_module = make_fx(
node.target,
decomposition_table=get_decompositions(
decomposition_table=get_decompositions( # pyre-fixme[6]
[
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def create_call_function_node(
"call_function",
target=target,
args=args,
kwargs=(
kwargs=( # pyre-fixme[6]
{"memory_format": memory_format} if memory_format is not None else {}
),
)
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/passes/convert_to_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def create_sdpa(
kwargs={"scale": scale},
)

sdpa_node.meta["val"] = sdpa_node.target(
sdpa_node.meta["val"] = sdpa_node.target( # pyre-fixme[29]
*[n.meta["val"] for n in match.placeholder_nodes],
scale=scale,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/models/llama2/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def quantize(

try:
# torchao 0.3+
from torchao._eval import InputRecorder
from torchao._eval import InputRecorder # pyre-fixme[21]
except ImportError:
from torchao.quantization.GPTQ import InputRecorder # pyre-ignore

Expand All @@ -110,7 +110,7 @@ def quantize(
)

inputs = (
InputRecorder(
InputRecorder( # pyre-fixme[16]
tokenizer,
calibration_seq_length,
None, # input_prep_func
Expand Down
2 changes: 1 addition & 1 deletion examples/models/phi-3-mini/export_phi-3-mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def export(args) -> None:
model = capture_pre_autograd_graph(
model, example_inputs, dynamic_shapes=dynamic_shapes
)
model = prepare_pt2e(model, xnnpack_quantizer)
model = prepare_pt2e(model, xnnpack_quantizer) # pyre-fixme[6]
model(*example_inputs)
model = convert_pt2e(model, fold_quantize=False)
DuplicateDynamicQuantChainPass()(model)
Expand Down
14 changes: 7 additions & 7 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ def _emit_prim_getters(self, prim_getters: Dict[str, Any]) -> List[ExecutionPlan

def fetch_attr(self, target: _Target) -> _AbstractValue:
"""Fetch weights and other module parameters. If the attribute is a tensor, emit it."""
attr = super().fetch_attr(target)
attr = super().fetch_attr(target) # pyre-fixme[6]

if isinstance(attr, torch.Tensor):
return self._emit_evalue(
Expand All @@ -1286,23 +1286,23 @@ def fetch_attr(self, target: _Target) -> _AbstractValue:
else:
return attr

def call_module(
def call_module( # pyre-fixme[14]
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> None:
"""Unsupported in execution IR, so unhandled by the emitter."""
raise InternalError(
self._emit_node_specific_error(self.node, "call_module is not supported")
)

def call_method(
def call_method( # pyre-fixme[14]
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> _EmitterValue:
"""Unsupported in execution IR, so unhandled by the emitter."""
raise InternalError(
self._emit_node_specific_error(self.node, "call_method is not supported")
)

def placeholder(
def placeholder( # pyre-fixme[14]
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> _AbstractValue:
"""Performs actions for the placeholder node of a graph module.
Expand All @@ -1324,7 +1324,7 @@ def placeholder(
self.placeholder_count += 1
return value

def output(
def output( # pyre-fixme[14]
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> None:
"""Performs actions for the output node of a graph module.
Expand Down Expand Up @@ -1354,7 +1354,7 @@ def output(
)
self.chain.instructions.append(instruction)

def call_function(
def call_function( # pyre-fixme[14]
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> _EmitterValue:
"""Performs actions for the call_function node of a graph module.
Expand Down Expand Up @@ -1412,7 +1412,7 @@ def call_function(
)
)

def run(
def run( # pyre-fixme[14]
self,
*args: _Argument,
initial_env: Optional[Dict[torch.fx.Node, _Argument]] = None,
Expand Down
4 changes: 2 additions & 2 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def buffer(
segment_alignment: int = 4096,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
memory_planning: MemoryPlanningPass = None,
memory_planning: MemoryPlanningPass = None, # pyre-fixme[9]
) -> bytes:
"""
Returns a buffer containing the serialized ExecuTorch binary.
Expand All @@ -161,7 +161,7 @@ def buffer(
def program(
self,
emit_stacktrace: bool = False,
memory_planning: MemoryPlanningPass = None,
memory_planning: MemoryPlanningPass = None, # pyre-fixme[9]
) -> Program:
# Fix autodpes introuces cyclic dependencies:
# program -> verifier -> lowered_backend_module -> program
Expand Down
8 changes: 4 additions & 4 deletions exir/pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(self, callback: "_ExportPassBase", codegen: CodeGen) -> None:
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self.submodules: Dict[torch.nn.Module, str] = {}

def trace(self) -> None:
def trace(self) -> None: # pyre-fixme[14,15]
raise ExportPassBaseError("ExportTracer doesn't support trace().")

def create_arg(self, a: Argument) -> torch.fx.Node:
Expand Down Expand Up @@ -290,7 +290,7 @@ def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None:
self.callback = callback
self.node: torch.fx.Node = next(iter(gm.graph.nodes))

def placeholder(
def placeholder( # pyre-fixme[14]
self,
target: str,
args: Tuple[Argument, ...],
Expand Down Expand Up @@ -351,7 +351,7 @@ def call_function(
else:
raise ExportPassBaseError(f"Unsupported target type: {target}")

def get_attr(
def get_attr( # pyre-fixme[14]
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
) -> Argument:
return super().get_attr(target, args, kwargs)
Expand All @@ -364,7 +364,7 @@ def call_module(
) -> None:
raise ExportPassBaseError("call_module is not supported.")

def call_method(
def call_method( # pyre-fixme[14]
self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
) -> None:
raise ExportPassBaseError("call_method is not supported.")
Expand Down
1 change: 1 addition & 0 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def make_alloc_node(
"Memory allocator node needs FakeTensor val or TensorMetadata to proceed"
)

# pyre-fixme[6]
alloc = graph_module.graph.call_function(memory.alloc, (alloc_spec,))
alloc.meta["val"] = val
alloc.meta["tensor_meta"] = tensor_meta
Expand Down
2 changes: 1 addition & 1 deletion exir/passes/remove_noop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def eliminate_dq_q(
qparams_q = list(user.args)[1:]
if qparams_dq != qparams_q:
continue
user.replace_all_uses_with(node.args[0])
user.replace_all_uses_with(node.args[0]) # pyre-fixme[6]


class RemoveNoopPass(ExportPass):
Expand Down
2 changes: 1 addition & 1 deletion exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ def quantize_model(
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config()
quantizer.set_global(quantization_config)
m = prepare_pt2e(m, quantizer)
m = prepare_pt2e(m, quantizer) # pyre-fixme[6]
m = convert_pt2e(m, fold_quantize=True)
ep = torch.export.export(m, example_inputs)
dq_nodes_pre = count_dq_nodes(ep.graph_module)
Expand Down
2 changes: 1 addition & 1 deletion exir/tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_resnet(self) -> None:
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m = prepare_pt2e(m, quantizer)
m = prepare_pt2e(m, quantizer) # pyre-fixme[6]
self.assertEqual(
id(m.activation_post_process_3), id(m.activation_post_process_2)
)
Expand Down
6 changes: 3 additions & 3 deletions exir/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def __torch_function__(
kwargs = {}
if torch.is_inference_mode_enabled():
if func is torch.nn.functional.layer_norm:
args, kwargs = normalize_function(func, args, kwargs)
args, kwargs = normalize_function(func, args, kwargs) # pyre-fixme[23]
input, normalized_shape = args
normalized_shape = list(normalized_shape)
return cls.__torch_dispatch__(
Expand Down Expand Up @@ -470,13 +470,13 @@ def create_arg(self, a: Value) -> torch.fx.Node: # noqa: C901
self.submodules[a] = name_submodule
return self.create_node("get_attr", self.submodules[a], (), {})

return super().create_arg(a)
return super().create_arg(a) # pyre-fixme[7]

@staticmethod
def get() -> "DispatchTracer":
return TRACER

def trace(
def trace( # pyre-fixme[14,15]
self,
root: Callable[..., Value],
concrete_args: Tuple[Value, ...] = (),
Expand Down
6 changes: 3 additions & 3 deletions exir/verification/arg_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _get_kernel_arg(self, schema_arg, schema_arg_idx, args, kwargs):

return kernel_arg

def call_function( # noqa: C901
def call_function( # noqa: C901 # pyre-fixme[14]
self, target: _Target, args: Tuple[_Argument, ...], kwargs: Dict[str, _Argument]
) -> Any:
"""
Expand All @@ -73,7 +73,7 @@ def call_function( # noqa: C901
):
if isinstance(target, HigherOrderOperator):
raise RunHigherOrderOperatorError("Can't run delegate")
return super().call_function(target, args, kwargs)
return super().call_function(target, args, kwargs) # pyre-fixme[6]

# TODO(gasoonjia): Update Optional[torch.dtype] to a concrete class to support mixed dtypes in tensorlist.
tensor_arg_types: Dict[str, Optional[torch.dtype]] = {}
Expand Down Expand Up @@ -126,4 +126,4 @@ def call_function( # noqa: C901
valid = target._schema.dtype_constraint.validate(tensor_arg_types)
if not valid:
self.violating_ops[target] = tensor_arg_types
return super().call_function(target, args, kwargs)
return super().call_function(target, args, kwargs) # pyre-fixme[6]
4 changes: 3 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
# pyre-fixme[8]
self.pre_autograd_graph_module = capture_pre_autograd_graph(
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
)
Expand Down Expand Up @@ -209,11 +210,12 @@ def export_to_edge(self) -> "LLMEdgeManager":
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if self.pre_autograd_graph_module is None:
# pyre-fixme[8]
self.pre_autograd_graph_module = capture_pre_autograd_graph(
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
)
self.edge_manager = export_to_edge(
self.pre_autograd_graph_module,
self.pre_autograd_graph_module, # pyre-fixme[6]
self.example_inputs,
dynamic_shapes=dynamic_shape,
edge_constant_methods=self.metadata,
Expand Down
15 changes: 8 additions & 7 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_mps_partitioner(use_kv_cache: bool = False):
)

compile_specs = [CompileSpec("use_fp16", bytes([True]))]
return MPSPartitioner(compile_specs)
return MPSPartitioner(compile_specs) # pyre-fixme[16]


def get_coreml_partitioner(
Expand Down Expand Up @@ -92,14 +92,14 @@ def get_coreml_partitioner(
# if use_kv_cache:
# minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18)

compile_specs = CoreMLBackend.generate_compile_specs(
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=minimum_deployment_target,
compute_precision=ct.precision(ct.precision.FLOAT16.value),
# using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()],
model_type=CoreMLBackend.MODEL_TYPE.MODEL,
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
)
return CoreMLPartitioner(
return CoreMLPartitioner( # pyre-fixme[16]
compile_specs=compile_specs,
)

Expand Down Expand Up @@ -136,9 +136,10 @@ def get_qnn_partitioner(
if pt2e_quantize is not None:
use_fp16 = False

return QnnPartitioner(
generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8650, # default to SM8650
return QnnPartitioner( # pyre-fixme[16]
generate_qnn_executorch_compiler_spec( # pyre-fixme[16]
soc_model=QcomChipset.SM8650, # default to SM8650 # pyre-fixme[16]
# pyre-fixme[16]
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
debug=False,
saver=False,
Expand Down
8 changes: 4 additions & 4 deletions extension/llm/export/quantizer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_qnn_quantizer(
quantization_mode: Optional[str] = None,
):
try:
from executorch.backends.qualcomm.quantizer.custom_annotation import (
from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21]
custom_annotate_llama_matmul_16a8w,
)

Expand All @@ -168,15 +168,15 @@ def get_qnn_quantizer(
assert (
backend == "qnn"
), f"The quantization config is for backend {backend} instead of qnn."
qnn_quantizer = QnnQuantizer()
qnn_quantizer = QnnQuantizer() # pyre-fixme[16]
qnn_quantizer.set_per_channel_conv_quant(enable=True)
qnn_quantizer.set_per_channel_linear_quant(enable=True)
# more custom quantization are supported including 16a4w etc. default to 8bit quantized
custom_annotations = ()
if quant_config == "8a8w":
quant_dtype = QuantDtype.use_8a8w
quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16]
elif quant_config == "16a16w":
quant_dtype = QuantDtype.use_16a16w
quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16]
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
qnn_quantizer.set_bit16_op_quant_config(
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
Expand Down
Loading