Skip to content

Commit 9c0d0fe

Browse files
committed
fix lint
1 parent d25b57b commit 9c0d0fe

File tree

14 files changed

+51
-59
lines changed

14 files changed

+51
-59
lines changed

codegen/tools/test/test_gen_oplist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,9 @@ def test_dump_operator_from_ops_schema_yaml_with_mix_syntax(self) -> None:
231231
self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"])
232232

233233
def test_get_kernel_metadata_from_ops_yaml(self) -> None:
234-
metadata: Dict[str, List[str]] = (
235-
gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
236-
)
234+
metadata: Dict[
235+
str, List[str]
236+
] = gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
237237

238238
self.assertEqual(len(metadata), 2)
239239

examples/models/llama2/export_llama_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901
387387
# to_backend
388388
partitioners = {}
389389
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
390-
partitioners[XnnpackDynamicallyQuantizedPartitioner.__name__] = (
391-
XnnpackDynamicallyQuantizedPartitioner()
392-
)
390+
partitioners[
391+
XnnpackDynamicallyQuantizedPartitioner.__name__
392+
] = XnnpackDynamicallyQuantizedPartitioner()
393393
modelname = f"xnnpack_dq_{modelname}"
394394

395395
if args.xnnpack:

exir/backend/backend_details.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def enforcedmethod(func):
2121
@dataclass
2222
class PreprocessResult:
2323
processed_bytes: bytes = bytes()
24-
debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = (
25-
None
26-
)
24+
debug_handle_map: Optional[
25+
Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]
26+
] = None
2727

2828

2929
"""

exir/backend/test/test_backends.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -720,11 +720,7 @@ def forward(self, x_raw, h, c):
720720
composite_m = CompositeModel(3)
721721
orig_res = composite_m(*inputs)
722722

723-
traced = exir.capture(
724-
composite_m,
725-
inputs,
726-
exir.CaptureConfig(),
727-
).to_edge(
723+
traced = exir.capture(composite_m, inputs, exir.CaptureConfig(),).to_edge(
728724
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
729725
exir.EdgeCompileConfig(_check_ir_validity=False)
730726
)

exir/backend/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ def __init__(self, generated_identifiers: bool = False):
277277

278278
# Note that the internal struct has a Set value, while the getter
279279
# function returns the values as a tuple
280-
self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = (
281-
defaultdict(set)
282-
)
280+
self._debug_handle_map: Union[
281+
Dict[int, Set[int]], Dict[str, Set[int]]
282+
] = defaultdict(set)
283283
self._next_index: int = 0
284284

285285
def get_delegate_mapping(

exir/capture/_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ class CaptureConfig:
2121
pt2_mode: bool = True
2222
enable_functionalization: bool = True
2323
enable_dynamic_shape: bool = False # This flag does nothing if enable_aot is True
24-
enable_aot: bool = (
25-
False # When it's true it implies automatic dynamic shapes via default dynamo config
26-
)
24+
enable_aot: bool = False # When it's true it implies automatic dynamic shapes via default dynamo config
2725
_dynamo_config: "ExirDynamoConfig" = field(default_factory=ExirDynamoConfig)
2826
_unlift: bool = False # This flag does nothing if enable_aot is False.
2927
_use_old_decomp_table: bool = False

exir/capture/_unlift.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,17 @@ def unlift_exported_program_lifted_states(
135135
if node.name in ep.graph_signature.inputs_to_buffers:
136136
buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
137137
if buffer_name in param_buffer_name_to_corrected_name:
138-
inp_pos_to_param_buffer_name[count] = (
139-
param_buffer_name_to_corrected_name[buffer_name]
140-
)
138+
inp_pos_to_param_buffer_name[
139+
count
140+
] = param_buffer_name_to_corrected_name[buffer_name]
141141
else:
142142
inp_pos_to_param_buffer_name[count] = buffer_name
143143
if node.name in ep.graph_signature.inputs_to_parameters:
144144
param_name = ep.graph_signature.inputs_to_parameters[node.name]
145145
if param_name in param_buffer_name_to_corrected_name:
146-
inp_pos_to_param_buffer_name[count] = (
147-
param_buffer_name_to_corrected_name[param_name]
148-
)
146+
inp_pos_to_param_buffer_name[
147+
count
148+
] = param_buffer_name_to_corrected_name[param_name]
149149
else:
150150
inp_pos_to_param_buffer_name[count] = param_name
151151
count += 1

exir/emit/_emit_program.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ def emit_program(
186186
plans.append(emitter.plan())
187187

188188
debug_handle_map[name] = emitter.debug_handle_map
189-
method_to_delegate_debug_id_map[name] = (
190-
emitter.instr_id_to_delegate_debug_id_map
191-
)
189+
method_to_delegate_debug_id_map[
190+
name
191+
] = emitter.instr_id_to_delegate_debug_id_map
192192

193193
# emit any primitive getters
194194
if prim_getters is not None:

exir/operator/convert.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
2424
"""
2525

26+
2627
import dataclasses
2728
import logging
2829
from typing import Dict, Optional, Tuple
@@ -206,9 +207,9 @@ def set_mapping_for_op(op: OpOverload) -> None:
206207
mismatched_out_schema: Optional[FunctionSchema] = next(
207208
(s for s in all_schemas if s.kind() == SchemaKind.out), None
208209
)
209-
_schema_mismatch_map[schema_to_opoverload(func_op_schema)] = (
210-
mismatched_out_schema
211-
)
210+
_schema_mismatch_map[
211+
schema_to_opoverload(func_op_schema)
212+
] = mismatched_out_schema
212213

213214
# update hte map even if scratch_schema is None to cache the negative
214215
# case

exir/passes/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,14 +483,14 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
483483
]
484484
).passes
485485

486-
base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = (
487-
PassManager(
488-
passes=[
489-
dead_code_elimination_pass,
490-
DebugHandleGeneratorPass(),
491-
]
492-
).passes
493-
)
486+
base_post_op_replace_passes: List[
487+
Callable[[torch.nn.Module], PassResult]
488+
] = PassManager(
489+
passes=[
490+
dead_code_elimination_pass,
491+
DebugHandleGeneratorPass(),
492+
]
493+
).passes
494494

495495

496496
def propagate_dynamic_shape(

exir/passes/constant_prop_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
112112
)
113113
prop_constant_data.append(prop_constant_node_input_spec)
114114
buffers.append(prop_constant_tensor_fqn)
115-
exported_program.state_dict[prop_constant_tensor_fqn] = (
116-
prop_constant_tensor
117-
)
115+
exported_program.state_dict[
116+
prop_constant_tensor_fqn
117+
] = prop_constant_tensor
118118
exported_program.graph_signature.input_specs.append(
119119
prop_constant_node_input_spec
120120
)

exir/tests/test_quant_lowering_custom_backend_pass.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -975,10 +975,7 @@ def test_qat_linear(self) -> None:
975975
backend_config=get_executorch_backend_config(),
976976
)
977977
print("converted:", converted_mod)
978-
captured_mod = exir.capture(
979-
converted_mod,
980-
example_inputs,
981-
).to_edge(
978+
captured_mod = exir.capture(converted_mod, example_inputs,).to_edge(
982979
exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True)
983980
)
984981

exir/verification/arg_validator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ class EdgeOpArgValidator(torch.fx.Interpreter):
3737

3838
def __init__(self, graph_module: torch.fx.GraphModule) -> None:
3939
super().__init__(graph_module)
40-
self.violating_ops: Dict[EdgeOpOverload, Dict[str, Optional[torch.dtype]]] = (
41-
defaultdict(dict)
42-
)
40+
self.violating_ops: Dict[
41+
EdgeOpOverload, Dict[str, Optional[torch.dtype]]
42+
] = defaultdict(dict)
4343

4444
def run_node(self, n: torch.fx.Node) -> None:
4545
self.node = n

sdk/inspector/_inspector.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def gen_from_events(run_events: List[flatcc.Event]) -> List["InstructionEvent"]:
8080
and DebugEvents by instruction id and return a list of InstructionEvents
8181
constructed from collated events (ignoring run_output events)
8282
"""
83-
instruction_events: Dict[InstructionEventSignature, InstructionEvent] = (
84-
OrderedDict()
85-
)
83+
instruction_events: Dict[
84+
InstructionEventSignature, InstructionEvent
85+
] = OrderedDict()
8686
for event in run_events:
8787
# Find the event that was logged
8888
populated_event: Union[DebugEvent, ProfileEvent] = find_populated_event(
@@ -668,9 +668,9 @@ class GroupedRunInstances:
668668
continue
669669

670670
# Collate the run_events into InstructionEvents
671-
instruction_events: List[InstructionEvent] = (
672-
InstructionEvent.gen_from_events(run_events)
673-
)
671+
instruction_events: List[
672+
InstructionEvent
673+
] = InstructionEvent.gen_from_events(run_events)
674674

675675
# Map EventSignatures to the InstructionEvents
676676
event_signatures: Dict[EventSignature, InstructionEvent] = OrderedDict()
@@ -720,9 +720,9 @@ class GroupedRunInstances:
720720
TIME_SCALE_DICT[source_time_scale] / TIME_SCALE_DICT[target_time_scale]
721721
)
722722
for run_signature, grouped_run_instance in run_groups.items():
723-
run_group: OrderedDict[EventSignature, List[InstructionEvent]] = (
724-
grouped_run_instance.events
725-
)
723+
run_group: OrderedDict[
724+
EventSignature, List[InstructionEvent]
725+
] = grouped_run_instance.events
726726
run_outputs: ProgramOutput = grouped_run_instance.run_output
727727

728728
# Construct the Events

0 commit comments

Comments
 (0)