Skip to content

Commit 1a918c7

Browse files
authored
Move ModelDataKit to ExecuTorch directory
Differential Revision: D70825994 Pull Request resolved: #9160
1 parent 1c2a69e commit 1a918c7

File tree

8 files changed

+120
-0
lines changed

8 files changed

+120
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import Foundation
4+
5+
public enum ModelRuntimeError: Error {
6+
case unsupportedInputType
7+
}
8+
9+
public protocol ModelRuntime {
10+
func infer(input: [ModelRuntimeValue]) throws -> [ModelRuntimeValue]
11+
12+
func getModelValueFactory() -> ModelRuntimeValueFactory
13+
func getModelTensorFactory() -> ModelRuntimeTensorValueFactory
14+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import Foundation
4+
5+
public enum ModelRuntimeValueError: Error, CustomStringConvertible {
6+
case unsupportedType(String)
7+
case invalidType(String, String)
8+
9+
public var description: String {
10+
switch self {
11+
case .unsupportedType(let type):
12+
return "Unsupported type: \(type)"
13+
case .invalidType(let expectedType, let type):
14+
return "Invalid type: \(type), expected \(expectedType)"
15+
}
16+
}
17+
}
18+
19+
@objc public class ModelRuntimeValueErrorFactory: NSObject {
20+
@objc public class func unsupportedType(_ type: String) -> Error {
21+
return ModelRuntimeValueError.unsupportedType(type)
22+
}
23+
24+
@objc public class func invalidType(_ actualType: String, expectedType: String) -> Error {
25+
return ModelRuntimeValueError.invalidType(expectedType, actualType)
26+
}
27+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import Foundation
4+
5+
public class ModelRuntimeTensorValue {
6+
public let innerValue: ModelRuntimeTensorValueBridging
7+
public init(innerValue: ModelRuntimeTensorValueBridging) {
8+
self.innerValue = innerValue
9+
}
10+
11+
public func floatRepresentation() throws -> (floatArray: [Float], shape: [Int]) {
12+
let value = try innerValue.floatRepresentation()
13+
let data = value.floatArray
14+
let shape = value.shape
15+
return (data.compactMap { $0.floatValue }, shape.compactMap { $0.intValue })
16+
}
17+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import Foundation
4+
5+
public class ModelRuntimeTensorValueBridgingTuple: NSObject {
6+
@objc public let floatArray: [NSNumber]
7+
@objc public let shape: [NSNumber]
8+
@objc public init(floatArray: [NSNumber], shape: [NSNumber]) {
9+
self.floatArray = floatArray
10+
self.shape = shape
11+
}
12+
}
13+
14+
@objc public protocol ModelRuntimeTensorValueBridging {
15+
func floatRepresentation() throws -> ModelRuntimeTensorValueBridgingTuple
16+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import Foundation
4+
5+
public protocol ModelRuntimeTensorValueFactory {
6+
func createFloatTensor(value: [Float], shape: [Int]) -> ModelRuntimeTensorValue
7+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import Foundation
4+
5+
public class ModelRuntimeValue {
6+
public let value: ModelRuntimeValueBridging
7+
public init(innerValue: ModelRuntimeValueBridging) {
8+
self.value = innerValue
9+
}
10+
11+
public func stringValue() throws -> String {
12+
return try value.stringValue()
13+
}
14+
15+
public func tensorValue() throws -> ModelRuntimeTensorValue {
16+
return try ModelRuntimeTensorValue(innerValue: value.tensorValue())
17+
}
18+
19+
public func arrayValue() throws -> [ModelRuntimeValue] {
20+
return try value.arrayValue().map { ModelRuntimeValue(innerValue: $0) }
21+
}
22+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import Foundation
4+
5+
@objc public protocol ModelRuntimeValueBridging {
6+
func stringValue() throws -> String
7+
func tensorValue() throws -> ModelRuntimeTensorValueBridging
8+
func arrayValue() throws -> [ModelRuntimeValueBridging]
9+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
import Foundation
4+
5+
public protocol ModelRuntimeValueFactory {
6+
func createString(value: String) throws -> ModelRuntimeValue
7+
func createTensor(value: ModelRuntimeTensorValue) throws -> ModelRuntimeValue
8+
}

0 commit comments

Comments
 (0)