Skip to content

Commit d13bb4e

Browse files
committed
Fix lint
1 parent 724d676 commit d13bb4e

File tree

5 files changed

+34
-15
lines changed

5 files changed

+34
-15
lines changed

backends/apple/mps/mps_preprocess.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def preprocess(
6666
input_ids=[],
6767
output_ids=[],
6868
constant_ids=[],
69-
graph_type=OpType.mps_graph
69+
graph_type=OpType.mps_graph,
7070
)
7171

7272
convert_model_to_fp16 = True
@@ -114,8 +114,13 @@ def handle_call_function(
114114
) -> None:
115115
logging.info(f"Visiting: {node}, {node.target.__name__}")
116116

117-
if "delegation_tag" in node.meta and "metal_kernel" in node.meta["delegation_tag"]:
118-
logging.info(f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!")
117+
if (
118+
"delegation_tag" in node.meta
119+
and "metal_kernel" in node.meta["delegation_tag"]
120+
):
121+
logging.info(
122+
f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!"
123+
)
119124
mps_graph.graph_type = OpType.metal_kernel
120125

121126
if node.target.__name__ in node_visitors:

backends/apple/mps/operators/indexing_ops.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
1414
MPSEmbedding,
1515
MPSGraph,
16-
MPSIndexTensor,
1716
MPSIndexPut,
1817
MPSIndexSelect,
18+
MPSIndexTensor,
1919
)
2020
from executorch.backends.apple.mps.utils.mps_utils import get_input_node
21-
from executorch.exir.sym_util import eval_expr
2221
from executorch.backends.transforms import get_shape
22+
from executorch.exir.sym_util import eval_expr
23+
2324

2425
@register_node_visitor
2526
class IndexSelectVisitor(NodeVisitor):
@@ -41,6 +42,7 @@ def define_node(
4142

4243
mps_graph.mps_nodes.append(mps_node)
4344

45+
4446
@register_node_visitor
4547
class IndexTensorVisitor(NodeVisitor):
4648
target = "aten.index.Tensor"
@@ -56,12 +58,13 @@ def define_node(
5658
mps_node = self.create_unary_node(node, mps_graph, MPSIndexTensor)
5759
tensors = cast(List[torch.fx.Node], node.args[1])
5860
for tensor in tensors:
59-
mps_node.mpsnode_union.indices_id.append(self.define_tensor(tensor, mps_graph))
61+
mps_node.mpsnode_union.indices_id.append(
62+
self.define_tensor(tensor, mps_graph)
63+
)
6064

6165
mps_graph.mps_nodes.append(mps_node)
6266

6367

64-
6568
# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens
6669
# are wrong when using Index put. Disabling it for now.
6770
@register_node_visitor
@@ -87,7 +90,6 @@ def infer_sizes(self, a: List[int], b: List[int]):
8790

8891
return expandedSizes
8992

90-
9193
def define_node(
9294
self,
9395
node: torch.fx.Node,
@@ -103,13 +105,16 @@ def define_node(
103105

104106
tensors = cast(List[torch.fx.Node], node.args[1])
105107
for tensor in tensors:
106-
mps_node.mpsnode_union.indices_id.append(self.define_tensor(tensor, mps_graph))
108+
mps_node.mpsnode_union.indices_id.append(
109+
self.define_tensor(tensor, mps_graph)
110+
)
107111

108112
mps_node.mpsnode_union.values_id = self.define_tensor(
109113
get_input_node(node, 2), mps_graph
110114
)
111115
mps_graph.mps_nodes.append(mps_node)
112116

117+
113118
@register_node_visitor
114119
class EmbeddingVisitor(NodeVisitor):
115120
target = "aten.embedding.default"

backends/apple/mps/operators/unary_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
MPSLog,
3131
MPSLog10,
3232
MPSLog2,
33+
MPSLogicalNot,
3334
MPSNeg,
3435
MPSReciprocal,
3536
MPSRound,
@@ -41,7 +42,6 @@
4142
MPSSqrt,
4243
MPSTan,
4344
MPSTanh,
44-
MPSLogicalNot,
4545
)
4646
from executorch.exir.dialects._ops import ops as exir_ops
4747

backends/apple/mps/partition/mps_partitioner.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
#
55

66
import logging
7-
from typing import cast, Any, Dict, List, Union
7+
from typing import Any, cast, Dict, List, Union
88

99
import torch
1010
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
1111
from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors
1212
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
13+
from executorch.backends.transforms import get_shape
1314
from executorch.exir.backend.backend_details import CompileSpec
1415
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
1516
generate_partitions_from_list_of_nodes,
@@ -20,11 +21,10 @@
2021
PartitionResult,
2122
)
2223
from executorch.exir.backend.utils import tag_constant_data
24+
from executorch.exir.dialects._ops import ops as exir_ops
2325
from torch.export.exported_program import ExportedProgram
2426
from torch.fx.passes.infra.partitioner import Partition
2527
from torch.fx.passes.operator_support import OperatorSupportBase
26-
from executorch.exir.dialects._ops import ops as exir_ops
27-
from executorch.backends.transforms import get_shape
2828

2929
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3030
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
@@ -36,6 +36,7 @@
3636
exir_ops.edge.aten.index_put.default,
3737
]
3838

39+
3940
class MPSOperatorSupport(OperatorSupportBase):
4041
def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
4142
self.node_visitors = get_node_visitors(edge_program)
@@ -90,7 +91,10 @@ def mps_graph_advanced_indexing_support(self, node: torch.fx.Node):
9091

9192
def use_metal_kernel(self, node: torch.fx.Node):
9293
if node.target in METAL_KERNELS:
93-
if node.target == exir_ops.edge.aten.index.Tensor or node.target == exir_ops.edge.aten.index_put.default:
94+
if (
95+
node.target == exir_ops.edge.aten.index.Tensor
96+
or node.target == exir_ops.edge.aten.index_put.default
97+
):
9498
if not self.mps_graph_advanced_indexing_support(node):
9599
return True
96100
return False
@@ -104,7 +108,9 @@ def tag_nodes(self, partitions: List[Partition]) -> None:
104108
logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!")
105109
# Partition the Metal kernel into a separate partition
106110
crt_partition_counter += 1
107-
delegation_tag = f"{delegation_tag}_metal_kernel_{crt_partition_counter}"
111+
delegation_tag = (
112+
f"{delegation_tag}_metal_kernel_{crt_partition_counter}"
113+
)
108114
crt_partition_counter += 1
109115
else:
110116
delegation_tag = f"{delegation_tag}_{crt_partition_counter}"

backends/apple/mps/serialization/mps_graph_schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ class MPSDataType(IntEnum):
2626
mps_data_type_complex_float16 = 10
2727
mps_data_type_complex_float32 = 11
2828

29+
2930
class OpType(IntEnum):
3031
mps_graph = 0
3132
metal_kernel = 1
3233

34+
3335
@dataclass
3436
class MPSNode1x1:
3537
input1_id: int
@@ -453,6 +455,7 @@ class MPSIndexPut(MPSNode1x1):
453455
values_shape: List[int] = field(default_factory=list)
454456
values_id: int = -1
455457

458+
456459
##
457460
## Shape ops
458461
##

0 commit comments

Comments
 (0)