Skip to content

Commit 597d133

Browse files
committed
Add index.Tensor and logical_not
1 parent 4389442 commit 597d133

File tree

15 files changed

+430
-11
lines changed

15 files changed

+430
-11
lines changed

backends/apple/mps/mps_preprocess.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
1919
MPSGraph,
2020
MPSTensor,
21+
OpType,
2122
)
2223

2324
from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
@@ -65,6 +66,7 @@ def preprocess(
6566
input_ids=[],
6667
output_ids=[],
6768
constant_ids=[],
69+
graph_type=OpType.mps_graph
6870
)
6971

7072
convert_model_to_fp16 = True
@@ -111,6 +113,11 @@ def handle_call_function(
111113
mps_graph: MPSGraph,
112114
) -> None:
113115
logging.info(f"Visiting: {node}, {node.target.__name__}")
116+
117+
if "delegation_tag" in node.meta and "metal_kernel" in node.meta["delegation_tag"]:
118+
logging.info(f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!")
119+
mps_graph.graph_type = OpType.metal_kernel
120+
114121
if node.target.__name__ in node_visitors:
115122
node_visitors[node.target.__name__].define_node(node, mps_graph)
116123
else:

backends/apple/mps/operators/indexing_ops.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Provided subject to the LICENSE file in the top level directory.
44
#
55

6-
from typing import cast
6+
from typing import cast, List
77

88
import torch
99
from executorch.backends.apple.mps.operators.node_visitor import (
@@ -13,11 +13,13 @@
1313
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
1414
MPSEmbedding,
1515
MPSGraph,
16+
MPSIndexTensor,
17+
MPSIndexPut,
1618
MPSIndexSelect,
1719
)
1820
from executorch.backends.apple.mps.utils.mps_utils import get_input_node
1921
from executorch.exir.sym_util import eval_expr
20-
22+
from executorch.backends.transforms import get_shape
2123

2224
@register_node_visitor
2325
class IndexSelectVisitor(NodeVisitor):
@@ -39,6 +41,76 @@ def define_node(
3941

4042
mps_graph.mps_nodes.append(mps_node)
4143

44+
@register_node_visitor
45+
class IndexTensorVisitor(NodeVisitor):
46+
target = "aten.index.Tensor"
47+
48+
def __init__(self, *args) -> None:
49+
super().__init__(*args)
50+
51+
def define_node(
52+
self,
53+
node: torch.fx.Node,
54+
mps_graph: MPSGraph,
55+
) -> None:
56+
mps_node = self.create_unary_node(node, mps_graph, MPSIndexTensor)
57+
tensors = cast(List[torch.fx.Node], node.args[1])
58+
for tensor in tensors:
59+
mps_node.mpsnode_union.indices_id.append(self.define_tensor(tensor, mps_graph))
60+
61+
mps_graph.mps_nodes.append(mps_node)
62+
63+
64+
65+
# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens
66+
# are wrong when using Index put. Disabling it for now.
67+
@register_node_visitor
68+
class IndexPutVisitor(NodeVisitor):
69+
# target = "aten.index_put.default"
70+
target = "disabled"
71+
72+
def __init__(self, *args) -> None:
73+
super().__init__(*args)
74+
75+
def infer_sizes(self, a: List[int], b: List[int]):
76+
dimsA = len(a)
77+
dimsB = len(b)
78+
print(dimsA)
79+
print(dimsB)
80+
ndim = dimsA if dimsA > dimsB else dimsB
81+
expandedSizes = [0] * ndim
82+
for i in range(ndim - 1, -1, -1):
83+
offset = ndim - 1 - i
84+
dimA = dimsA - 1 - offset
85+
dimB = dimsB - 1 - offset
86+
sizeA = a[dimA] if dimA >= 0 else -1
87+
sizeB = b[dimB] if dimB >= 0 else -1
88+
expandedSizes[i] = sizeA if sizeB == -1 else sizeB
89+
90+
return expandedSizes
91+
92+
93+
def define_node(
94+
self,
95+
node: torch.fx.Node,
96+
mps_graph: MPSGraph,
97+
) -> None:
98+
mps_node = self.create_unary_node(node, mps_graph, MPSIndexPut)
99+
updates_shape = get_shape(node.args[2])
100+
input_shape = get_shape(node.args[0])
101+
new_shape = []
102+
if len(updates_shape) != 1 and len(updates_shape) != len(input_shape):
103+
new_shape = self.infer_sizes(input_shape, updates_shape)
104+
mps_node.mpsnode_union.values_shape = new_shape
105+
106+
tensors = cast(List[torch.fx.Node], node.args[1])
107+
for tensor in tensors:
108+
mps_node.mpsnode_union.indices_id.append(self.define_tensor(tensor, mps_graph))
109+
110+
mps_node.mpsnode_union.values_id = self.define_tensor(
111+
get_input_node(node, 2), mps_graph
112+
)
113+
mps_graph.mps_nodes.append(mps_node)
42114

43115
@register_node_visitor
44116
class EmbeddingVisitor(NodeVisitor):

backends/apple/mps/operators/unary_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
MPSSqrt,
4242
MPSTan,
4343
MPSTanh,
44+
MPSLogicalNot,
4445
)
4546
from executorch.exir.dialects._ops import ops as exir_ops
4647

@@ -79,6 +80,7 @@ class UnaryOpVisitor(NodeVisitor):
7980
"aten.isnan.default",
8081
"aten.isinf.default",
8182
"aten.round.default",
83+
"aten.logical_not.default",
8284
]
8385

8486
def __init__(self, *args) -> None:
@@ -115,6 +117,7 @@ def __init__(self, *args) -> None:
115117
exir_ops.edge.aten.isnan.default: MPSIsnan,
116118
exir_ops.edge.aten.isinf.default: MPSIsinf,
117119
exir_ops.edge.aten.round.default: MPSRound,
120+
exir_ops.edge.aten.logical_not.default: MPSLogicalNot,
118121
}
119122

120123
def define_node(

backends/apple/mps/partition/mps_partitioner.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55

66
import logging
7-
from typing import Any, Dict, List, Union
7+
from typing import cast, Any, Dict, List, Union
88

99
import torch
1010
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
@@ -23,11 +23,19 @@
2323
from torch.export.exported_program import ExportedProgram
2424
from torch.fx.passes.infra.partitioner import Partition
2525
from torch.fx.passes.operator_support import OperatorSupportBase
26+
from executorch.exir.dialects._ops import ops as exir_ops
27+
from executorch.backends.transforms import get_shape
2628

2729
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
2830
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
2931

3032

33+
# ops implemented as Metal kernels.
34+
METAL_KERNELS = [
35+
exir_ops.edge.aten.index.Tensor,
36+
exir_ops.edge.aten.index_put.default,
37+
]
38+
3139
class MPSOperatorSupport(OperatorSupportBase):
3240
def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
3341
self.node_visitors = get_node_visitors(edge_program)
@@ -65,10 +73,42 @@ def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]:
6573
op_support=self.supported_ops,
6674
)
6775

76+
def mps_graph_advanced_indexing_support(self, node: torch.fx.Node):
77+
num_indices = 0
78+
tensors = cast(List[torch.fx.Node], node.args[1])
79+
input = cast(torch.fx.Node, node.args[0])
80+
for t in tensors:
81+
if t is not None:
82+
num_indices += 1
83+
# Can dispatch to MPSGraph if the length of the slices is equal
84+
# to the number of dimensions of the sliced tensors, or only one
85+
# slice is present. All other cases will fallback to a Metal kernel.
86+
if num_indices == len(get_shape(input)) or num_indices == 1:
87+
return True
88+
89+
return False
90+
91+
def use_metal_kernel(self, node: torch.fx.Node):
92+
if node.target in METAL_KERNELS:
93+
if node.target == exir_ops.edge.aten.index.Tensor or node.target == exir_ops.edge.aten.index_put.default:
94+
if not self.mps_graph_advanced_indexing_support(node):
95+
return True
96+
return False
97+
6898
def tag_nodes(self, partitions: List[Partition]) -> None:
6999
for partition in partitions:
70-
for node in partition.nodes:
100+
crt_partition_counter = 0
101+
for node in sorted(partition.nodes):
71102
delegation_tag = f"mps_{partition.id}"
103+
if self.use_metal_kernel(node):
104+
logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!")
105+
# Partition the Metal kernel into a separate partition
106+
crt_partition_counter += 1
107+
delegation_tag = f"{delegation_tag}_metal_kernel_{crt_partition_counter}"
108+
crt_partition_counter += 1
109+
else:
110+
delegation_tag = f"{delegation_tag}_{crt_partition_counter}"
111+
72112
node.meta["delegation_tag"] = delegation_tag
73113
self.partition_tags[delegation_tag] = self.delegation_spec
74114

backends/apple/mps/runtime/MPSDevice.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,19 @@
55

66
#pragma once
77

8+
// Obj-C headers
89
#include <Foundation/Foundation.h>
910
#include <Metal/Metal.h>
11+
12+
// Runtime headers
13+
#include <executorch/runtime/backend/interface.h>
14+
15+
// MPS headers
1016
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
1117

18+
#include <unordered_map>
19+
#include <vector>
20+
1221
#define MB(x) (x * 1048576UL)
1322

1423
namespace torch {
@@ -25,6 +34,11 @@ enum class MacOSVersion : uint32_t {
2534
MACOS_VER_14_0_PLUS,
2635
};
2736

37+
enum class LibraryType : uint32_t {
38+
INDEXING_KERNELS = 0,
39+
MAX = INDEXING_KERNELS,
40+
};
41+
2842
class MPSDevice {
2943
public:
3044
/**
@@ -53,9 +67,18 @@ class MPSDevice {
5367

5468
~MPSDevice();
5569

70+
/**
71+
* Compile a PSO for a given library type.
72+
* Once compiled, the library and PSOs are cached.
73+
*/
74+
Error compilePSO(LibraryType libraryType, const char* kernelName);
75+
Error compileLibrary(LibraryType);
76+
5677
private:
5778
static MPSDevice* _device;
5879
id<MTLDevice> _mtl_device;
80+
std::unordered_map<LibraryType, id<MTLLibrary>> _m_library_cache;
81+
std::unordered_map<std::string, id<MTLComputePipelineState>> _m_pso_cache;
5982
MPSDevice();
6083
};
6184

backends/apple/mps/runtime/MPSDevice.mm

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616
static std::unique_ptr<MPSDevice> mps_device;
1717
static std::once_flag mpsdev_init;
1818

19+
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
20+
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
21+
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+)
22+
MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
23+
#if defined(__MAC_13_0)
24+
if (macOS13Plus) {
25+
languageVersion = MTLLanguageVersion3_0;
26+
}
27+
#endif
28+
29+
ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
30+
return languageVersion;
31+
}
32+
1933
MPSDevice::~MPSDevice() {
2034
[_mtl_device release];
2135
_mtl_device = nil;
@@ -79,6 +93,57 @@
7993
}
8094
}
8195

96+
const char* getLibraryCString(LibraryType libraryType) {
97+
switch (libraryType) {
98+
case LibraryType::INDEXING_KERNELS:
99+
return "Hello";
100+
default:
101+
ET_CHECK_MSG(false, "Unhandled library type!");
102+
}
103+
}
104+
105+
Error
106+
MPSDevice::compileLibrary(LibraryType libraryType) {
107+
Error err = Error::Ok;
108+
NSError* error = nil;
109+
MTLCompileOptions* options = [MTLCompileOptions new];
110+
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
111+
[options setFastMathEnabled:YES];
112+
id<MTLLibrary> lib =
113+
[_mtl_device newLibraryWithSource:[NSString stringWithCString:getLibraryCString(libraryType)
114+
encoding:NSASCIIStringEncoding]
115+
options:options
116+
error:&error];
117+
118+
ET_CHECK_OR_RETURN_ERROR(
119+
lib != nil,
120+
Internal,
121+
"Failed to create indexing library, error: %s", [[error description] UTF8String]
122+
);
123+
124+
_m_library_cache[libraryType] = lib;
125+
return err;
126+
}
127+
128+
Error
129+
MPSDevice::compilePSO(LibraryType libraryType, const char* kernelName) {
130+
Error err = Error::Ok;
131+
if (_m_library_cache.find(libraryType) == _m_library_cache.end()) {
132+
ET_LOG(Debug, "Compiling library type: %d", libraryType);
133+
err = compileLibrary(libraryType);
134+
ET_CHECK_OR_RETURN_ERROR(
135+
err == Error::Ok,
136+
Internal,
137+
"An error occured occured while compiling library %d", libraryType
138+
);
139+
}
140+
if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) {
141+
ET_LOG(Debug, "Compiling kernel: %s", kernelName);
142+
// err = compilePSO(libraryType, kernelName);
143+
}
144+
return err;
145+
}
146+
82147
bool isMacOS13OrNewer(MacOSVersion version) {
83148
return MPSDevice::getInstance()->isMacOS13Plus(version);
84149
}

backends/apple/mps/runtime/MPSGraphBuilder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class MPSGraphBuilder {
109109
_DEFINE_MPS_OP(Isnan);
110110
_DEFINE_MPS_OP(Isinf);
111111
_DEFINE_MPS_OP(Round);
112+
_DEFINE_MPS_OP(LogicalNot);
112113
_DEFINE_MPS_OP(NormCdf);
113114
// Clamp ops
114115
_DEFINE_MPS_OP(Clamp);
@@ -120,6 +121,8 @@ class MPSGraphBuilder {
120121
// Indexing ops
121122
_DEFINE_MPS_OP(IndexSelect);
122123
_DEFINE_MPS_OP(Embedding);
124+
_DEFINE_MPS_OP(IndexTensor);
125+
_DEFINE_MPS_OP(IndexPut);
123126
// Linear algebra ops
124127
_DEFINE_MPS_OP(MatMul);
125128
_DEFINE_MPS_OP(Addmm);
@@ -153,6 +156,7 @@ class MPSGraphBuilder {
153156

154157
// Helper functions
155158
Error addNodeToMPSGraph(NodePtr nodePtr);
159+
Error compileMetalKernel(NodePtr nodePtr);
156160
MPSShape *getMPSShape(int32_t id);
157161
MPSShape *getMPSShape(const flatbuffers::Vector<int32_t> *shape);
158162
int64_t numel(const flatbuffers::Vector<int32_t> *shape);
@@ -161,6 +165,8 @@ class MPSGraphBuilder {
161165
MPSGraphTensor *getMPSGraphTensor(int32_t id);
162166
NSData *getConstantData(int32_t id);
163167
std::pair<float, float> getMinMaxValues(NodePtr nodePtr);
168+
Error compileMPSGraph();
169+
Error compileMetalKernel();
164170

165171
// Each MPSGraph op result in at least MPSGraphTensor being
166172
// produced, which will be stored in this structure. Other ops
@@ -172,6 +178,7 @@ class MPSGraphBuilder {
172178
// FlatBuffer raw bytes of the serialized MPS model.
173179
const void *_buffer_pointer;
174180

181+
bool _metal_kernel;
175182
MPSGraph *_mpsGraph;
176183
MPSGraphExecutable *_mpsGraphExecutable;
177184
NSMutableDictionary<MPSGraphTensor *, MPSGraphShapedType *> *_feeds;

0 commit comments

Comments
 (0)