|
24 | 24 |
|
25 | 25 | #ifdef __AOTI_MODEL__
|
26 | 26 | #include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
|
27 |
| -torch::Device cpu_device(torch::kCPU); |
| 27 | +#ifdef USE_CUDA |
| 28 | +#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h> |
| 29 | +#endif |
| 30 | +torch::Device aoti_device(torch::kCPU); |
28 | 31 |
|
29 | 32 | #else // __ET_MODEL__
|
30 | 33 | #include <executorch/extension/module/module.h>
|
@@ -82,7 +85,7 @@ typedef struct {
|
82 | 85 | RunState state; // buffers for the "wave" of activations in the forward pass
|
83 | 86 |
|
84 | 87 | #ifdef __AOTI_MODEL__
|
85 |
| - torch::inductor::AOTIModelContainerRunnerCpu* runner; |
| 88 | + torch::inductor::AOTIModelContainerRunner* runner; |
86 | 89 | #else // __ET_MODEL__
|
87 | 90 | Module* runner;
|
88 | 91 | #endif
|
@@ -132,9 +135,16 @@ void build_transformer(
|
132 | 135 | malloc_run_state(&t->state, &t->config);
|
133 | 136 |
|
134 | 137 | #ifdef __AOTI_MODEL__
|
135 |
| - t->runner = new torch::inductor::AOTIModelContainerRunnerCpu( |
136 |
| - /* path to model DSO */ model_path, |
137 |
| - /* thread pool size */ 1); |
| 138 | +#ifdef USE_CUDA |
| 139 | + try { |
| 140 | + t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path); |
| 141 | + aoti_device = torch::Device(torch::kCUDA); |
| 142 | + } catch (std::runtime_error& e) { |
| 143 | +#else |
| 144 | + { |
| 145 | +#endif |
| 146 | + t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path); |
| 147 | + } |
138 | 148 | #else //__ET_MODEL__
|
139 | 149 | t->runner = new Module(
|
140 | 150 | /* path to PTE model */ model_path,
|
@@ -186,11 +196,11 @@ float* forward(Transformer* transformer, int token, int pos) {
|
186 | 196 | torch::Tensor token_tensor =
|
187 | 197 | torch::from_blob(token_buffer, {1, 1}, torch::kLong);
|
188 | 198 | torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
|
189 |
| - std::vector<torch::Tensor> inputs{token_tensor, pos_tensor}; |
| 199 | + std::vector<torch::Tensor> inputs{token_tensor.to(aoti_device), pos_tensor.to(aoti_device)}; |
190 | 200 |
|
191 | 201 | torch::Tensor result = transformer->runner->run(inputs)[0]
|
192 | 202 | .to(torch::dtype(torch::kFloat32))
|
193 |
| - .to(cpu_device); |
| 203 | + .to(torch::kCPU); |
194 | 204 | auto logits = result[0].data_ptr();
|
195 | 205 | #else // __ET_MODEL__
|
196 | 206 | ManagedTensor pos_managed(pos_buffer, sizeof(int64_t), {1}, ScalarType::Long);
|
|
0 commit comments