@@ -35,62 +35,62 @@ @implementation GenericTests
35
35
}
36
36
37
37
+ (NSDictionary <NSString *, BOOL (^)(NSString *)> *)predicates {
38
- return @{@" model" : ^BOOL (NSString *filename){
38
+ return @{
39
+ @" model" : ^BOOL (NSString *filename){
39
40
return [filename hasSuffix: @" .pte" ];
40
- }
41
- }
42
- ;
41
+ },
42
+ };
43
43
}
44
44
45
45
+ (NSDictionary <NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources :
46
46
(NSDictionary <NSString *, NSString *> *)resources {
47
47
NSString *modelPath = resources[@" model" ];
48
- return @{@" load" : ^(XCTestCase *testCase){
48
+ return @{
49
+ @" load" : ^(XCTestCase *testCase){
49
50
[testCase
50
51
measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
51
52
block: ^{
52
53
XCTAssertEqual (
53
54
Module (modelPath.UTF8String ).load_forward (),
54
55
Error::Ok);
55
56
}];
56
- }
57
- , @" forward" : ^(XCTestCase *testCase) {
58
- auto __block module = std::make_unique<Module>(modelPath.UTF8String );
59
-
60
- const auto method_meta = module ->method_meta (" forward" );
61
- ASSERT_OK_OR_RETURN (method_meta);
62
-
63
- const auto num_inputs = method_meta->num_inputs ();
64
- XCTAssertGreaterThan (num_inputs, 0 );
65
-
66
- std::vector<TensorPtr> tensors;
67
- tensors.reserve (num_inputs);
68
-
69
- for (auto index = 0 ; index < num_inputs; ++index) {
70
- const auto input_tag = method_meta->input_tag (index);
71
- ASSERT_OK_OR_RETURN (input_tag);
72
-
73
- switch (*input_tag) {
74
- case Tag::Tensor: {
75
- const auto tensor_meta = method_meta->input_tensor_meta (index);
76
- ASSERT_OK_OR_RETURN (tensor_meta);
77
-
78
- const auto sizes = tensor_meta->sizes ();
79
- tensors.emplace_back (
80
- ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
81
- XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
82
- } break ;
83
- default :
84
- XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
85
- }
86
- }
87
- [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
88
- block: ^{
89
- XCTAssertEqual (module ->forward ().error (), Error::Ok);
90
- }];
91
- }
92
- }
93
- ;
57
+ },
58
+ @" forward" : ^(XCTestCase *testCase) {
59
+ auto __block module = std::make_unique<Module>(modelPath.UTF8String );
60
+
61
+ const auto method_meta = module ->method_meta (" forward" );
62
+ ASSERT_OK_OR_RETURN (method_meta);
63
+
64
+ const auto num_inputs = method_meta->num_inputs ();
65
+ XCTAssertGreaterThan (num_inputs, 0 );
66
+
67
+ std::vector<TensorPtr> tensors;
68
+ tensors.reserve (num_inputs);
69
+
70
+ for (auto index = 0 ; index < num_inputs; ++index) {
71
+ const auto input_tag = method_meta->input_tag (index);
72
+ ASSERT_OK_OR_RETURN (input_tag);
73
+
74
+ switch (*input_tag) {
75
+ case Tag::Tensor: {
76
+ const auto tensor_meta = method_meta->input_tensor_meta (index);
77
+ ASSERT_OK_OR_RETURN (tensor_meta);
78
+
79
+ const auto sizes = tensor_meta->sizes ();
80
+ tensors.emplace_back (
81
+ ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
82
+ XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
83
+ } break ;
84
+ default :
85
+ XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
86
+ }
87
+ }
88
+ [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
89
+ block: ^{
90
+ XCTAssertEqual (module ->forward ().error (), Error::Ok);
91
+ }];
92
+ },
93
+ };
94
94
}
95
95
96
96
@end
0 commit comments