Skip to content

Commit 8716780

Browse files
lucylqfacebook-github-bot
authored andcommitted
Optional and ArrayRef for make_aten_functor_from_et_functor (#2495)
Summary: Pull Request resolved: #2495 Add optional and ArrayRef to WRAP_ATEN Add tests for - type_map - type_convert - WRAP Reviewed By: larryliu0820 Differential Revision: D54971798 fbshipit-source-id: 0b69730c54347a33e4daab283a3f5f58126d59b0
1 parent 6bef9e7 commit 8716780

File tree

3 files changed

+462
-18
lines changed

3 files changed

+462
-18
lines changed

extension/aten_util/make_aten_functor_from_et_functor.h

Lines changed: 138 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616
#include <type_traits>
17+
#include <vector>
1718
#if __cplusplus < 201703L
1819
#error "This header requires C++17"
1920
#endif
@@ -29,21 +30,62 @@ namespace executor {
2930
class KernelRuntimeContext; // Forward declaration
3031
using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove
3132

33+
// Map types from ETen to ATen.
34+
// This is used to convert ETen arguments into ATen.
3235
template <typename T>
3336
struct type_map final {
3437
using type = T;
3538
};
3639

37-
template <>
38-
struct type_map<torch::executor::Tensor&> final {
39-
using type = at::Tensor&;
40+
// Const.
41+
template <typename T>
42+
struct type_map<const T> final {
43+
using type = const typename type_map<T>::type;
44+
};
45+
46+
// Ref.
47+
template <typename T>
48+
struct type_map<T&> final {
49+
using type = typename type_map<T>::type&;
50+
};
51+
52+
// Const ref.
53+
template <typename T>
54+
struct type_map<const T&> final {
55+
using type = const typename type_map<T>::type&;
4056
};
4157

58+
// Tensor.
4259
template <>
43-
struct type_map<const torch::executor::Tensor&> final {
44-
using type = const at::Tensor&;
60+
struct type_map<torch::executor::Tensor> final {
61+
using type = at::Tensor;
62+
};
63+
64+
// Optional.
65+
template <class T>
66+
struct type_map<torch::executor::optional<T>> final {
67+
using type = c10::optional<typename type_map<T>::type>;
68+
};
69+
70+
template <class T>
71+
struct type_map<torch::executor::optional<T>&> final {
72+
using type = c10::optional<typename type_map<T>::type>&;
73+
};
74+
75+
// ArrayRef.
76+
template <class T>
77+
struct type_map<torch::executor::ArrayRef<T>> final {
78+
using type = at::ArrayRef<typename type_map<T>::type>;
79+
};
80+
81+
template <typename T>
82+
struct remove_const_ref final {
83+
using type = std::remove_const_t<std::remove_reference_t<T>>;
4584
};
4685

86+
// Convert ATen->ETen: input args.
87+
// Convert ETen->ATen: output args.
88+
// Default argument conversions between ATen and ETen (scalars).
4789
template <typename F, typename T, typename Enable = void>
4890
struct type_convert final {
4991
public:
@@ -54,11 +96,7 @@ struct type_convert final {
5496
}
5597
};
5698

57-
template <typename T>
58-
struct remove_const_ref final {
59-
using type = std::remove_const_t<std::remove_reference_t<T>>;
60-
};
61-
99+
// Tensors: ATen to ETen.
62100
template <class ATensor, class ETensor>
63101
struct type_convert<
64102
ATensor,
@@ -90,13 +128,22 @@ struct type_convert<
90128
}
91129
};
92130

93-
template <>
94-
struct type_convert<torch::executor::Tensor&, at::Tensor&> final {
131+
// Tensors: ETen to ATen.
132+
template <class ETensor, class ATensor>
133+
struct type_convert<
134+
ETensor,
135+
ATensor,
136+
std::enable_if_t<
137+
std::is_same_v<typename remove_const_ref<ATensor>::type, at::Tensor> &&
138+
std::is_same_v<
139+
typename remove_const_ref<ETensor>::type,
140+
torch::executor::Tensor>>>
141+
final {
95142
public:
96-
torch::executor::Tensor& val;
143+
ETensor val;
97144
at::Tensor converted;
98145
std::vector<int64_t> sizes;
99-
explicit type_convert(torch::executor::Tensor& value) : val(value) {
146+
explicit type_convert(ETensor value) : val(value) {
100147
for (auto size : val.sizes()) {
101148
sizes.push_back(size);
102149
}
@@ -105,11 +152,87 @@ struct type_convert<torch::executor::Tensor&, at::Tensor&> final {
105152
converted =
106153
at::from_blob(val.mutable_data_ptr(), val.numel(), sizes, scalar_type);
107154
}
108-
at::Tensor& call() {
155+
ATensor call() {
109156
return converted;
110157
}
111158
};
112159

160+
// Optionals: ATen to ETen.
161+
template <class F, class T>
162+
struct type_convert<c10::optional<F>, torch::executor::optional<T>> final {
163+
public:
164+
c10::optional<F> val;
165+
std::unique_ptr<struct type_convert<F, T>> convert_struct;
166+
explicit type_convert(c10::optional<F> value) : val(value) {}
167+
torch::executor::optional<T> call() {
168+
if (val.has_value()) {
169+
convert_struct = std::make_unique<struct type_convert<F, T>>(
170+
type_convert<F, T>(val.value()));
171+
return torch::executor::optional<T>(convert_struct->call());
172+
} else {
173+
return torch::executor::optional<T>();
174+
}
175+
}
176+
};
177+
178+
// Optionals: ETen to ATen.
179+
template <class F, class T>
180+
struct type_convert<torch::executor::optional<F>, c10::optional<T>> final {
181+
public:
182+
torch::executor::optional<F> val;
183+
std::unique_ptr<struct type_convert<F, T>> convert_struct;
184+
explicit type_convert(torch::executor::optional<F> value) : val(value) {}
185+
c10::optional<T> call() {
186+
if (val.has_value()) {
187+
convert_struct = std::make_unique<struct type_convert<F, T>>(
188+
type_convert<F, T>(val.value()));
189+
return c10::optional<T>(convert_struct->call());
190+
} else {
191+
return c10::optional<T>();
192+
}
193+
}
194+
};
195+
196+
// ArrayRefs: ATen to ETen.
197+
template <class F, class T>
198+
struct type_convert<c10::ArrayRef<F>, torch::executor::ArrayRef<T>> final {
199+
public:
200+
c10::ArrayRef<F> val;
201+
std::vector<T> converted;
202+
std::vector<type_convert<F, T>> converters;
203+
explicit type_convert(c10::ArrayRef<F> value) : val(value) {
204+
for (int i = 0; i < val.size(); i++) {
205+
converters.push_back(type_convert<F, T>(val[i]));
206+
}
207+
}
208+
torch::executor::ArrayRef<T> call() {
209+
for (int i = 0; i < val.size(); i++) {
210+
converted.push_back(converters[i].call());
211+
}
212+
return torch::executor::ArrayRef<T>(converted.data(), converted.size());
213+
}
214+
};
215+
216+
// ArrayRefs: ETen to ATen.
217+
template <class F, class T>
218+
struct type_convert<torch::executor::ArrayRef<F>, c10::ArrayRef<T>> final {
219+
public:
220+
torch::executor::ArrayRef<F> val;
221+
std::vector<T> converted;
222+
std::vector<type_convert<F, T>> converters;
223+
explicit type_convert(torch::executor::ArrayRef<F> value) : val(value) {
224+
for (int i = 0; i < val.size(); i++) {
225+
converters.push_back(type_convert<F, T>(val[i]));
226+
}
227+
}
228+
c10::ArrayRef<T> call() {
229+
for (int i = 0; i < val.size(); i++) {
230+
converted.push_back(converters[i].call());
231+
}
232+
return c10::ArrayRef<T>(converted);
233+
}
234+
};
235+
113236
template <class F, F f, typename N = int, N index = N(-1)>
114237
struct wrapper_impl;
115238

0 commit comments

Comments
 (0)