Skip to content

fix lint #2235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions codegen/tools/test/test_gen_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def test_dump_operator_from_ops_schema_yaml_with_mix_syntax(self) -> None:
self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"])

def test_get_kernel_metadata_from_ops_yaml(self) -> None:
metadata: Dict[str, List[str]] = (
gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
)
metadata: Dict[
str, List[str]
] = gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)

self.assertEqual(len(metadata), 2)

Expand Down
6 changes: 3 additions & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901
# to_backend
partitioners = {}
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
partitioners[XnnpackDynamicallyQuantizedPartitioner.__name__] = (
XnnpackDynamicallyQuantizedPartitioner()
)
partitioners[
XnnpackDynamicallyQuantizedPartitioner.__name__
] = XnnpackDynamicallyQuantizedPartitioner()
modelname = f"xnnpack_dq_{modelname}"

if args.xnnpack:
Expand Down
6 changes: 3 additions & 3 deletions exir/backend/backend_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def enforcedmethod(func):
@dataclass
class PreprocessResult:
processed_bytes: bytes = bytes()
debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = (
None
)
debug_handle_map: Optional[
Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]
] = None


"""
Expand Down
6 changes: 1 addition & 5 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,11 +720,7 @@ def forward(self, x_raw, h, c):
composite_m = CompositeModel(3)
orig_res = composite_m(*inputs)

traced = exir.capture(
composite_m,
inputs,
exir.CaptureConfig(),
).to_edge(
traced = exir.capture(composite_m, inputs, exir.CaptureConfig(),).to_edge(
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
exir.EdgeCompileConfig(_check_ir_validity=False)
)
Expand Down
6 changes: 3 additions & 3 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ def __init__(self, generated_identifiers: bool = False):

# Note that the internal struct has a Set value, while the getter
# function returns the values as a tuple
self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = (
defaultdict(set)
)
self._debug_handle_map: Union[
Dict[int, Set[int]], Dict[str, Set[int]]
] = defaultdict(set)
self._next_index: int = 0

def get_delegate_mapping(
Expand Down
4 changes: 1 addition & 3 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ class CaptureConfig:
pt2_mode: bool = True
enable_functionalization: bool = True
enable_dynamic_shape: bool = False # This flag does nothing if enable_aot is True
enable_aot: bool = (
False # When it's true it implies automatic dynamic shapes via default dynamo config
)
enable_aot: bool = False # When it's true it implies automatic dynamic shapes via default dynamo config
_dynamo_config: "ExirDynamoConfig" = field(default_factory=ExirDynamoConfig)
_unlift: bool = False # This flag does nothing if enable_aot is False.
_use_old_decomp_table: bool = False
Expand Down
12 changes: 6 additions & 6 deletions exir/capture/_unlift.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,17 @@ def unlift_exported_program_lifted_states(
if node.name in ep.graph_signature.inputs_to_buffers:
buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
if buffer_name in param_buffer_name_to_corrected_name:
inp_pos_to_param_buffer_name[count] = (
param_buffer_name_to_corrected_name[buffer_name]
)
inp_pos_to_param_buffer_name[
count
] = param_buffer_name_to_corrected_name[buffer_name]
else:
inp_pos_to_param_buffer_name[count] = buffer_name
if node.name in ep.graph_signature.inputs_to_parameters:
param_name = ep.graph_signature.inputs_to_parameters[node.name]
if param_name in param_buffer_name_to_corrected_name:
inp_pos_to_param_buffer_name[count] = (
param_buffer_name_to_corrected_name[param_name]
)
inp_pos_to_param_buffer_name[
count
] = param_buffer_name_to_corrected_name[param_name]
else:
inp_pos_to_param_buffer_name[count] = param_name
count += 1
Expand Down
6 changes: 3 additions & 3 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ def emit_program(
plans.append(emitter.plan())

debug_handle_map[name] = emitter.debug_handle_map
method_to_delegate_debug_id_map[name] = (
emitter.instr_id_to_delegate_debug_id_map
)
method_to_delegate_debug_id_map[
name
] = emitter.instr_id_to_delegate_debug_id_map

# emit any primitive getters
if prim_getters is not None:
Expand Down
7 changes: 4 additions & 3 deletions exir/operator/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

"""


import dataclasses
import logging
from typing import Dict, Optional, Tuple
Expand Down Expand Up @@ -206,9 +207,9 @@ def set_mapping_for_op(op: OpOverload) -> None:
mismatched_out_schema: Optional[FunctionSchema] = next(
(s for s in all_schemas if s.kind() == SchemaKind.out), None
)
_schema_mismatch_map[schema_to_opoverload(func_op_schema)] = (
mismatched_out_schema
)
_schema_mismatch_map[
schema_to_opoverload(func_op_schema)
] = mismatched_out_schema

# update hte map even if scratch_schema is None to cache the negative
# case
Expand Down
16 changes: 8 additions & 8 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,14 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
]
).passes

base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = (
PassManager(
passes=[
dead_code_elimination_pass,
DebugHandleGeneratorPass(),
]
).passes
)
base_post_op_replace_passes: List[
Callable[[torch.nn.Module], PassResult]
] = PassManager(
passes=[
dead_code_elimination_pass,
DebugHandleGeneratorPass(),
]
).passes


def propagate_dynamic_shape(
Expand Down
6 changes: 3 additions & 3 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
)
prop_constant_data.append(prop_constant_node_input_spec)
buffers.append(prop_constant_tensor_fqn)
exported_program.state_dict[prop_constant_tensor_fqn] = (
prop_constant_tensor
)
exported_program.state_dict[
prop_constant_tensor_fqn
] = prop_constant_tensor
exported_program.graph_signature.input_specs.append(
prop_constant_node_input_spec
)
Expand Down
5 changes: 1 addition & 4 deletions exir/tests/test_quant_lowering_custom_backend_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,10 +975,7 @@ def test_qat_linear(self) -> None:
backend_config=get_executorch_backend_config(),
)
print("converted:", converted_mod)
captured_mod = exir.capture(
converted_mod,
example_inputs,
).to_edge(
captured_mod = exir.capture(converted_mod, example_inputs,).to_edge(
exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True)
)

Expand Down
6 changes: 3 additions & 3 deletions exir/verification/arg_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class EdgeOpArgValidator(torch.fx.Interpreter):

def __init__(self, graph_module: torch.fx.GraphModule) -> None:
super().__init__(graph_module)
self.violating_ops: Dict[EdgeOpOverload, Dict[str, Optional[torch.dtype]]] = (
defaultdict(dict)
)
self.violating_ops: Dict[
EdgeOpOverload, Dict[str, Optional[torch.dtype]]
] = defaultdict(dict)

def run_node(self, n: torch.fx.Node) -> None:
self.node = n
Expand Down
18 changes: 9 additions & 9 deletions sdk/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def gen_from_events(run_events: List[flatcc.Event]) -> List["InstructionEvent"]:
and DebugEvents by instruction id and return a list of InstructionEvents
constructed from collated events (ignoring run_output events)
"""
instruction_events: Dict[InstructionEventSignature, InstructionEvent] = (
OrderedDict()
)
instruction_events: Dict[
InstructionEventSignature, InstructionEvent
] = OrderedDict()
for event in run_events:
# Find the event that was logged
populated_event: Union[DebugEvent, ProfileEvent] = find_populated_event(
Expand Down Expand Up @@ -668,9 +668,9 @@ class GroupedRunInstances:
continue

# Collate the run_events into InstructionEvents
instruction_events: List[InstructionEvent] = (
InstructionEvent.gen_from_events(run_events)
)
instruction_events: List[
InstructionEvent
] = InstructionEvent.gen_from_events(run_events)

# Map EventSignatures to the InstructionEvents
event_signatures: Dict[EventSignature, InstructionEvent] = OrderedDict()
Expand Down Expand Up @@ -720,9 +720,9 @@ class GroupedRunInstances:
TIME_SCALE_DICT[source_time_scale] / TIME_SCALE_DICT[target_time_scale]
)
for run_signature, grouped_run_instance in run_groups.items():
run_group: OrderedDict[EventSignature, List[InstructionEvent]] = (
grouped_run_instance.events
)
run_group: OrderedDict[
EventSignature, List[InstructionEvent]
] = grouped_run_instance.events
run_outputs: ProgramOutput = grouped_run_instance.run_output

# Construct the Events
Expand Down