@@ -35,62 +35,62 @@ @implementation GenericTests
35
35
}
36
36
37
37
+ (NSDictionary <NSString *, BOOL (^)(NSString *)> *)predicates {
38
- return @{
39
- @" model" : ^BOOL (NSString *filename){
38
+ return @{@" model" : ^BOOL (NSString *filename){
40
39
return [filename hasSuffix: @" .pte" ];
41
- }
42
- };
40
+ }
41
+ }
42
+ ;
43
43
}
44
44
45
45
+ (NSDictionary <NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources :
46
46
(NSDictionary <NSString *, NSString *> *)resources {
47
47
NSString *modelPath = resources[@" model" ];
48
- return @{
49
- @" load" : ^(XCTestCase *testCase){
48
+ return @{@" load" : ^(XCTestCase *testCase){
50
49
[testCase
51
50
measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
52
51
block: ^{
53
52
XCTAssertEqual (
54
53
Module (modelPath.UTF8String ).load_forward (),
55
54
Error::Ok);
56
55
}];
57
- },
58
- @" forward" : ^(XCTestCase *testCase) {
59
- auto __block module = std::make_unique<Module>(modelPath.UTF8String );
60
-
61
- const auto method_meta = module ->method_meta (" forward" );
62
- ASSERT_OK_OR_RETURN (method_meta);
63
-
64
- const auto num_inputs = method_meta->num_inputs ();
65
- XCTAssertGreaterThan (num_inputs, 0 );
66
-
67
- std::vector<TensorPtr> tensors;
68
- tensors.reserve (num_inputs);
69
-
70
- for (auto index = 0 ; index < num_inputs; ++index) {
71
- const auto input_tag = method_meta->input_tag (index);
72
- ASSERT_OK_OR_RETURN (input_tag);
73
-
74
- switch (*input_tag) {
75
- case Tag::Tensor: {
76
- const auto tensor_meta = method_meta->input_tensor_meta (index);
77
- ASSERT_OK_OR_RETURN (tensor_meta);
78
-
79
- const auto sizes = tensor_meta->sizes ();
80
- tensors.emplace_back (
81
- ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
82
- XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
83
- } break ;
84
- default :
85
- XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
86
- }
87
- }
88
- [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
89
- block: ^{
90
- XCTAssertEqual (module ->forward ().error (), Error::Ok);
91
- }];
56
+ }
57
+ , @" forward" : ^(XCTestCase *testCase) {
58
+ auto __block module = std::make_unique<Module>(modelPath.UTF8String );
59
+
60
+ const auto method_meta = module ->method_meta (" forward" );
61
+ ASSERT_OK_OR_RETURN (method_meta);
62
+
63
+ const auto num_inputs = method_meta->num_inputs ();
64
+ XCTAssertGreaterThan (num_inputs, 0 );
65
+
66
+ std::vector<TensorPtr> tensors;
67
+ tensors.reserve (num_inputs);
68
+
69
+ for (auto index = 0 ; index < num_inputs; ++index) {
70
+ const auto input_tag = method_meta->input_tag (index);
71
+ ASSERT_OK_OR_RETURN (input_tag);
72
+
73
+ switch (*input_tag) {
74
+ case Tag::Tensor: {
75
+ const auto tensor_meta = method_meta->input_tensor_meta (index);
76
+ ASSERT_OK_OR_RETURN (tensor_meta);
77
+
78
+ const auto sizes = tensor_meta->sizes ();
79
+ tensors.emplace_back (
80
+ ones ({sizes.begin (), sizes.end ()}, tensor_meta->scalar_type ()));
81
+ XCTAssertEqual (module ->set_input (tensors.back (), index), Error::Ok);
82
+ } break ;
83
+ default :
84
+ XCTFail (" Unsupported tag %i at input %d" , *input_tag, index);
92
85
}
93
- };
86
+ }
87
+ [testCase measureWithMetrics: @[ [XCTClockMetric new ], [XCTMemoryMetric new ] ]
88
+ block: ^{
89
+ XCTAssertEqual (module ->forward ().error (), Error::Ok);
90
+ }];
91
+ }
92
+ }
93
+ ;
94
94
}
95
95
96
96
@end
0 commit comments