14
14
15
15
#pragma once
16
16
#include < type_traits>
17
+ #include < vector>
17
18
#if __cplusplus < 201703L
18
19
#error "This header requires C++17"
19
20
#endif
@@ -29,21 +30,62 @@ namespace executor {
29
30
class KernelRuntimeContext ; // Forward declaration
30
31
using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove
31
32
33
+ // Map types from ETen to ATen.
34
+ // This is used to convert ETen arguments into ATen.
32
35
template <typename T>
33
36
struct type_map final {
34
37
using type = T;
35
38
};
36
39
37
- template <>
38
- struct type_map <torch::executor::Tensor&> final {
39
- using type = at::Tensor&;
40
+ // Const.
41
+ template <typename T>
42
+ struct type_map <const T> final {
43
+ using type = const typename type_map<T>::type;
44
+ };
45
+
46
+ // Ref.
47
+ template <typename T>
48
+ struct type_map <T&> final {
49
+ using type = typename type_map<T>::type&;
50
+ };
51
+
52
+ // Const ref.
53
+ template <typename T>
54
+ struct type_map <const T&> final {
55
+ using type = const typename type_map<T>::type&;
40
56
};
41
57
58
+ // Tensor.
42
59
template <>
43
- struct type_map <const torch::executor::Tensor&> final {
44
- using type = const at::Tensor&;
60
+ struct type_map <torch::executor::Tensor> final {
61
+ using type = at::Tensor;
62
+ };
63
+
64
+ // Optional.
65
+ template <class T >
66
+ struct type_map <torch::executor::optional<T>> final {
67
+ using type = c10::optional<typename type_map<T>::type>;
68
+ };
69
+
70
+ template <class T >
71
+ struct type_map <torch::executor::optional<T>&> final {
72
+ using type = c10::optional<typename type_map<T>::type>&;
73
+ };
74
+
75
+ // ArrayRef.
76
+ template <class T >
77
+ struct type_map <torch::executor::ArrayRef<T>> final {
78
+ using type = at::ArrayRef<typename type_map<T>::type>;
79
+ };
80
+
81
+ template <typename T>
82
+ struct remove_const_ref final {
83
+ using type = std::remove_const_t <std::remove_reference_t <T>>;
45
84
};
46
85
86
+ // Convert ATen->ETen: input args.
87
+ // Convert ETen->ATen: output args.
88
+ // Default argument conversions between ATen and ETen (scalars).
47
89
template <typename F, typename T, typename Enable = void >
48
90
struct type_convert final {
49
91
public:
@@ -54,11 +96,7 @@ struct type_convert final {
54
96
}
55
97
};
56
98
57
- template <typename T>
58
- struct remove_const_ref final {
59
- using type = std::remove_const_t <std::remove_reference_t <T>>;
60
- };
61
-
99
+ // Tensors: ATen to ETen.
62
100
template <class ATensor , class ETensor >
63
101
struct type_convert <
64
102
ATensor,
@@ -90,13 +128,22 @@ struct type_convert<
90
128
}
91
129
};
92
130
93
- template <>
94
- struct type_convert <torch::executor::Tensor&, at::Tensor&> final {
131
+ // Tensors: ETen to ATen.
132
+ template <class ETensor , class ATensor >
133
+ struct type_convert <
134
+ ETensor,
135
+ ATensor,
136
+ std::enable_if_t <
137
+ std::is_same_v<typename remove_const_ref<ATensor>::type, at::Tensor> &&
138
+ std::is_same_v<
139
+ typename remove_const_ref<ETensor>::type,
140
+ torch::executor::Tensor>>>
141
+ final {
95
142
public:
96
- torch::executor::Tensor& val;
143
+ ETensor val;
97
144
at::Tensor converted;
98
145
std::vector<int64_t > sizes;
99
- explicit type_convert (torch::executor::Tensor& value) : val(value) {
146
+ explicit type_convert (ETensor value) : val(value) {
100
147
for (auto size : val.sizes ()) {
101
148
sizes.push_back (size);
102
149
}
@@ -105,11 +152,87 @@ struct type_convert<torch::executor::Tensor&, at::Tensor&> final {
105
152
converted =
106
153
at::from_blob (val.mutable_data_ptr (), val.numel (), sizes, scalar_type);
107
154
}
108
- at::Tensor& call () {
155
+ ATensor call () {
109
156
return converted;
110
157
}
111
158
};
112
159
160
+ // Optionals: ATen to ETen.
161
+ template <class F , class T >
162
+ struct type_convert <c10::optional<F>, torch::executor::optional<T>> final {
163
+ public:
164
+ c10::optional<F> val;
165
+ std::unique_ptr<struct type_convert <F, T>> convert_struct;
166
+ explicit type_convert (c10::optional<F> value) : val(value) {}
167
+ torch::executor::optional<T> call () {
168
+ if (val.has_value ()) {
169
+ convert_struct = std::make_unique<struct type_convert <F, T>>(
170
+ type_convert<F, T>(val.value ()));
171
+ return torch::executor::optional<T>(convert_struct->call ());
172
+ } else {
173
+ return torch::executor::optional<T>();
174
+ }
175
+ }
176
+ };
177
+
178
+ // Optionals: ETen to ATen.
179
+ template <class F , class T >
180
+ struct type_convert <torch::executor::optional<F>, c10::optional<T>> final {
181
+ public:
182
+ torch::executor::optional<F> val;
183
+ std::unique_ptr<struct type_convert <F, T>> convert_struct;
184
+ explicit type_convert (torch::executor::optional<F> value) : val(value) {}
185
+ c10::optional<T> call () {
186
+ if (val.has_value ()) {
187
+ convert_struct = std::make_unique<struct type_convert <F, T>>(
188
+ type_convert<F, T>(val.value ()));
189
+ return c10::optional<T>(convert_struct->call ());
190
+ } else {
191
+ return c10::optional<T>();
192
+ }
193
+ }
194
+ };
195
+
196
+ // ArrayRefs: ATen to ETen.
197
+ template <class F , class T >
198
+ struct type_convert <c10::ArrayRef<F>, torch::executor::ArrayRef<T>> final {
199
+ public:
200
+ c10::ArrayRef<F> val;
201
+ std::vector<T> converted;
202
+ std::vector<type_convert<F, T>> converters;
203
+ explicit type_convert (c10::ArrayRef<F> value) : val(value) {
204
+ for (int i = 0 ; i < val.size (); i++) {
205
+ converters.push_back (type_convert<F, T>(val[i]));
206
+ }
207
+ }
208
+ torch::executor::ArrayRef<T> call () {
209
+ for (int i = 0 ; i < val.size (); i++) {
210
+ converted.push_back (converters[i].call ());
211
+ }
212
+ return torch::executor::ArrayRef<T>(converted.data (), converted.size ());
213
+ }
214
+ };
215
+
216
+ // ArrayRefs: ETen to ATen.
217
+ template <class F , class T >
218
+ struct type_convert <torch::executor::ArrayRef<F>, c10::ArrayRef<T>> final {
219
+ public:
220
+ torch::executor::ArrayRef<F> val;
221
+ std::vector<T> converted;
222
+ std::vector<type_convert<F, T>> converters;
223
+ explicit type_convert (torch::executor::ArrayRef<F> value) : val(value) {
224
+ for (int i = 0 ; i < val.size (); i++) {
225
+ converters.push_back (type_convert<F, T>(val[i]));
226
+ }
227
+ }
228
+ c10::ArrayRef<T> call () {
229
+ for (int i = 0 ; i < val.size (); i++) {
230
+ converted.push_back (converters[i].call ());
231
+ }
232
+ return c10::ArrayRef<T>(converted);
233
+ }
234
+ };
235
+
113
236
template <class F , F f, typename N = int , N index = N(-1 )>
114
237
struct wrapper_impl ;
115
238
0 commit comments