Skip to content

Commit 75352ad

Browse files
amyreesefacebook-github-bot
authored andcommitted
apply Black 2024 style in fbcode (9/16)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: aleivag Differential Revision: D54447729 fbshipit-source-id: fc781322b254f7027c24888cdadd5f1e90325ba8
1 parent 2966e38 commit 75352ad

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+152
-106
lines changed

backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,11 @@ def call( # noqa: suprress function is too complex (13)
127127
# if user is in to_dim_op_set, it means the user's arg is already set to_dim op
128128
if user not in to_dim_op_set:
129129
user_new_arg = [
130-
input_node_map[user_arg]
131-
if user_arg in input_node_map
132-
else user_arg
130+
(
131+
input_node_map[user_arg]
132+
if user_arg in input_node_map
133+
else user_arg
134+
)
133135
for user_arg in user.args
134136
]
135137
# Update input node's users arg

backends/transforms/duplicate_dynamic_quant_chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _replicate_chose_qparam_nodes_for_q_dq(
7777
)
7878
q_dq_pair.append((user, dq_node))
7979

80-
for (q_node, dq_node) in q_dq_pair:
80+
for q_node, dq_node in q_dq_pair:
8181
with gm.graph.inserting_after(get_item_node_1):
8282
new_get_item_node_1 = gm.graph.node_copy(get_item_node_1)
8383
new_get_item_node_2 = gm.graph.node_copy(get_item_node_2)

backends/xnnpack/passes/channels_last_tagged_reshape_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.exir.dialects._ops import ops as exir_ops
1313
from executorch.exir.pass_base import PassResult
1414

15+
1516
# TODO(T151254305) use subgraph_rewriter
1617
class ChannelsLastTaggedReshapePass(XNNPACKPass):
1718
"""

backends/xnnpack/utils/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from executorch.exir.pass_manager import PassType
1515

16+
1617
### XNNPACK Configs ###
1718
def get_xnnpack_edge_compile_config() -> exir.EdgeCompileConfig:
1819
return exir.EdgeCompileConfig(

backends/xnnpack/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_param,
2727
)
2828

29+
2930
### XNNPACK Capture ###
3031
def capture_graph_for_xnnpack(
3132
module: torch.nn.Module,

codegen/tools/test/test_gen_oplist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,9 @@ def test_dump_operator_from_ops_schema_yaml_with_mix_syntax(self) -> None:
231231
self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"])
232232

233233
def test_get_kernel_metadata_from_ops_yaml(self) -> None:
234-
metadata: Dict[
235-
str, List[str]
236-
] = gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
234+
metadata: Dict[str, List[str]] = (
235+
gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
236+
)
237237

238238
self.assertEqual(len(metadata), 2)
239239

examples/models/llama2/export_llama_lib.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901
387387
# to_backend
388388
partitioners = {}
389389
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
390-
partitioners[
391-
XnnpackDynamicallyQuantizedPartitioner.__name__
392-
] = XnnpackDynamicallyQuantizedPartitioner()
390+
partitioners[XnnpackDynamicallyQuantizedPartitioner.__name__] = (
391+
XnnpackDynamicallyQuantizedPartitioner()
392+
)
393393
modelname = f"xnnpack_dq_{modelname}"
394394

395395
if args.xnnpack:

examples/xtensa/aot/quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def quantize_tensor_multiplier(
4646
result = RoundingRightShift(FixedPointMultiplication(int32_value,
4747
out_multiplier[i]), right_shift[i])
4848
"""
49+
4950
# This is identical to C++11 std::round(). The general python round rounds
5051
# down, and C++ rounds away from zero.
5152
def round_away_zero(f) -> int:

exir/backend/backend_details.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def enforcedmethod(func):
2121
@dataclass
2222
class PreprocessResult:
2323
processed_bytes: bytes = bytes()
24-
debug_handle_map: Optional[
25-
Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]
26-
] = None
24+
debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = (
25+
None
26+
)
2727

2828

2929
"""

exir/backend/test/backend_with_delegate_mapping_demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch import nn
1515
from torch.export.exported_program import ExportedProgram
1616

17+
1718
# A simple way to represent an op along with its delegate debug identifier.
1819
class DummyOp:
1920
def __init__(

exir/backend/test/test_backends.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,11 @@ def forward(self, x_raw, h, c):
720720
composite_m = CompositeModel(3)
721721
orig_res = composite_m(*inputs)
722722

723-
traced = exir.capture(composite_m, inputs, exir.CaptureConfig(),).to_edge(
723+
traced = exir.capture(
724+
composite_m,
725+
inputs,
726+
exir.CaptureConfig(),
727+
).to_edge(
724728
# torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
725729
exir.EdgeCompileConfig(_check_ir_validity=False)
726730
)

exir/backend/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ def __init__(self, generated_identifiers: bool = False):
277277

278278
# Note that the internal struct has a Set value, while the getter
279279
# function returns the values as a tuple
280-
self._debug_handle_map: Union[
281-
Dict[int, Set[int]], Dict[str, Set[int]]
282-
] = defaultdict(set)
280+
self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = (
281+
defaultdict(set)
282+
)
283283
self._next_index: int = 0
284284

285285
def get_delegate_mapping(

exir/capture/_capture.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,11 @@ def capture_multiple(
432432
"forward",
433433
m.forward,
434434
args,
435-
dynamic_shapes["forward"]
436-
if dynamic_shapes and "forward" in dynamic_shapes
437-
else None,
435+
(
436+
dynamic_shapes["forward"]
437+
if dynamic_shapes and "forward" in dynamic_shapes
438+
else None
439+
),
438440
)
439441
)
440442
else:
@@ -447,9 +449,11 @@ def capture_multiple(
447449
method_name,
448450
getattr(m, method_name),
449451
method_args,
450-
dynamic_shapes[method_name]
451-
if dynamic_shapes and method_name in dynamic_shapes
452-
else None,
452+
(
453+
dynamic_shapes[method_name]
454+
if dynamic_shapes and method_name in dynamic_shapes
455+
else None
456+
),
453457
)
454458
)
455459
if prim_getters is not None:

exir/capture/_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ class CaptureConfig:
2121
pt2_mode: bool = True
2222
enable_functionalization: bool = True
2323
enable_dynamic_shape: bool = False # This flag does nothing if enable_aot is True
24-
enable_aot: bool = False # When it's true it implies automatic dynamic shapes via default dynamo config
24+
enable_aot: bool = (
25+
False # When it's true it implies automatic dynamic shapes via default dynamo config
26+
)
2527
_dynamo_config: "ExirDynamoConfig" = field(default_factory=ExirDynamoConfig)
2628
_unlift: bool = False # This flag does nothing if enable_aot is False.
2729
_use_old_decomp_table: bool = False

exir/capture/_unlift.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,17 @@ def unlift_exported_program_lifted_states(
135135
if node.name in ep.graph_signature.inputs_to_buffers:
136136
buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
137137
if buffer_name in param_buffer_name_to_corrected_name:
138-
inp_pos_to_param_buffer_name[
139-
count
140-
] = param_buffer_name_to_corrected_name[buffer_name]
138+
inp_pos_to_param_buffer_name[count] = (
139+
param_buffer_name_to_corrected_name[buffer_name]
140+
)
141141
else:
142142
inp_pos_to_param_buffer_name[count] = buffer_name
143143
if node.name in ep.graph_signature.inputs_to_parameters:
144144
param_name = ep.graph_signature.inputs_to_parameters[node.name]
145145
if param_name in param_buffer_name_to_corrected_name:
146-
inp_pos_to_param_buffer_name[
147-
count
148-
] = param_buffer_name_to_corrected_name[param_name]
146+
inp_pos_to_param_buffer_name[count] = (
147+
param_buffer_name_to_corrected_name[param_name]
148+
)
149149
else:
150150
inp_pos_to_param_buffer_name[count] = param_name
151151
count += 1

exir/delegate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
3737

38+
3839
# pyre-ignore
3940
def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
4041
# pyre-ignore

exir/dialects/_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def wrapper(f: Callable):
7272

7373

7474
class _OpNamespace(types.ModuleType):
75-
7675
"""
7776
EXIR Dialect op namespace object. Contains ops and overloads registered into PyTorch dispatcher.
7877
"""

exir/dialects/edge/spec/gen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ def gen_op_yaml(op_name: str) -> Optional[EdgeOpYamlInfo]:
432432
Arguments:
433433
op_name: The name of operator. Needs to conform the convention of "<name>.<overload_name>".
434434
If no overload name for the operator, needs to use "default" as overload name.
435-
Return the yaml info for given operator if generation succeed. Otherwise return None."""
435+
Return the yaml info for given operator if generation succeed. Otherwise return None.
436+
"""
436437

437438
try:
438439
func_schema: torch._C.FunctionSchema = get_callable(op_name)._schema

exir/dialects/edge/spec/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def gen_index_pairs_to_types_mapping(
171171
type_alias: Dict[Tuple[str], int], type_constraint: List[List[int]]
172172
) -> Dict[Tuple[int], List[str]]:
173173
"""Generate mapping from index pairs to types. For example, given type_constraint [0, 0], [1, 1]
174-
type_alias ('Double',): 0, ('Int',): 1, output will be {(0, 1): ['Double', 'Int', 'Double', 'Int']}."""
174+
type_alias ('Double',): 0, ('Int',): 1, output will be {(0, 1): ['Double', 'Int', 'Double', 'Int']}.
175+
"""
175176

176177
def gen(x: List[int]):
177178
"""Generate all possible pairs of elements in the list."""

exir/emit/_emit_program.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ def emit_program(
186186
plans.append(emitter.plan())
187187

188188
debug_handle_map[name] = emitter.debug_handle_map
189-
method_to_delegate_debug_id_map[
190-
name
191-
] = emitter.instr_id_to_delegate_debug_id_map
189+
method_to_delegate_debug_id_map[name] = (
190+
emitter.instr_id_to_delegate_debug_id_map
191+
)
192192

193193
# emit any primitive getters
194194
if prim_getters is not None:

exir/emit/_emitter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class _AbstractValue:
187187
Dict[int, Tuple[int]], Dict[str, Tuple[int]]
188188
]
189189

190+
190191
# pyre-ignore[13]: Attribute `node` is never initialized.
191192
class _Emitter(torch.fx.Interpreter):
192193
"""An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the

exir/experimental/export_pt2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def trace(root: Callable[..., Value], concrete_args: Tuple[Value, ...]) -> Trace
141141
concrete_args,
142142
CaptureConfig(enable_functionalization=False),
143143
).graph_module
144+
144145
# TODO convert torchdynamo guards to our own guards
145146
def _convert_dynamo_guard_to_exir_guard(
146147
dynamo_guard: DynamoGuard,

exir/memory_planning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ class SharedObject:
435435
last_used_index attribute. The shared object will be available for nodes
436436
with index greater than last_used_index.
437437
"""
438+
438439
# index of the shared object in the list of shared objects, used as a unique id
439440
idx: int
440441
# offset in the memory buffer

exir/operator/convert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,9 @@ def set_mapping_for_op(op: OpOverload) -> None:
206206
mismatched_out_schema: Optional[FunctionSchema] = next(
207207
(s for s in all_schemas if s.kind() == SchemaKind.out), None
208208
)
209-
_schema_mismatch_map[
210-
schema_to_opoverload(func_op_schema)
211-
] = mismatched_out_schema
209+
_schema_mismatch_map[schema_to_opoverload(func_op_schema)] = (
210+
mismatched_out_schema
211+
)
212212

213213
# update hte map even if scratch_schema is None to cache the negative
214214
# case

exir/passes/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,14 +483,14 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
483483
]
484484
).passes
485485

486-
base_post_op_replace_passes: List[
487-
Callable[[torch.nn.Module], PassResult]
488-
] = PassManager(
489-
passes=[
490-
dead_code_elimination_pass,
491-
DebugHandleGeneratorPass(),
492-
]
493-
).passes
486+
base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = (
487+
PassManager(
488+
passes=[
489+
dead_code_elimination_pass,
490+
DebugHandleGeneratorPass(),
491+
]
492+
).passes
493+
)
494494

495495

496496
def propagate_dynamic_shape(

exir/passes/constant_prop_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram:
112112
)
113113
prop_constant_data.append(prop_constant_node_input_spec)
114114
buffers.append(prop_constant_tensor_fqn)
115-
exported_program.state_dict[
116-
prop_constant_tensor_fqn
117-
] = prop_constant_tensor
115+
exported_program.state_dict[prop_constant_tensor_fqn] = (
116+
prop_constant_tensor
117+
)
118118
exported_program.graph_signature.input_specs.append(
119119
prop_constant_node_input_spec
120120
)

exir/passes/insert_write_back_for_buffers_pass.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,17 @@ def _insert_copy(
7272
def insert_write_back_for_buffers_pass(ep: ExportedProgram):
7373
gm: torch.fx.GraphModule = ep.graph_module
7474
lifted_inputs: List[Optional[str]] = [
75-
in_spec.target
76-
if in_spec.kind
77-
in (
78-
InputKind.BUFFER,
79-
InputKind.CONSTANT_TENSOR,
80-
InputKind.PARAMETER,
81-
InputKind.CUSTOM_OBJ,
75+
(
76+
in_spec.target
77+
if in_spec.kind
78+
in (
79+
InputKind.BUFFER,
80+
InputKind.CONSTANT_TENSOR,
81+
InputKind.PARAMETER,
82+
InputKind.CUSTOM_OBJ,
83+
)
84+
else None
8285
)
83-
else None
8486
for in_spec in ep.graph_signature.input_specs
8587
]
8688

exir/serde/export_serialize.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -835,9 +835,11 @@ def serialize_module_call_graph(
835835
return [
836836
ModuleCallEntry(
837837
fqn=entry.fqn,
838-
signature=self.serialize_module_call_signature(entry.signature)
839-
if entry.signature
840-
else None,
838+
signature=(
839+
self.serialize_module_call_signature(entry.signature)
840+
if entry.signature
841+
else None
842+
),
841843
)
842844
for entry in module_call_graph
843845
]
@@ -1668,9 +1670,11 @@ def deserialize_module_call_graph(
16681670
return [
16691671
ep.ModuleCallEntry(
16701672
fqn=entry.fqn,
1671-
signature=self.deserialize_module_call_signature(entry.signature)
1672-
if entry.signature
1673-
else None,
1673+
signature=(
1674+
self.deserialize_module_call_signature(entry.signature)
1675+
if entry.signature
1676+
else None
1677+
),
16741678
)
16751679
for entry in module_call_graph
16761680
]

exir/tests/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_get_schema_for_operators(self) -> None:
2828

2929
schemas = get_schema_for_operators(op_list)
3030
pat = re.compile(r"[^\(]+\([^\)]+\) -> ")
31-
for (_op_name, schema) in schemas.items():
31+
for _op_name, schema in schemas.items():
3232
self.assertIsNotNone(re.match(pat, schema))
3333

3434
def test_get_out_args(self) -> None:

exir/tests/test_pass_infra.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_pass_registry_func(self) -> None:
3636
"""
3737
Test if we register a callable correctly
3838
"""
39+
3940
# Registering w/o specifying pass_name
4041
@PassRegistry.register()
4142
def test_pass1(graph_module: torch.fx.GraphModule) -> None:

0 commit comments

Comments
 (0)