@@ -96,8 +96,13 @@ bool unpack_tensors(
96
96
const std::vector<c10::Argument>& arguments,
97
97
const torch::jit::Stack& stack,
98
98
const c10::Device& device,
99
- std::vector<at::Tensor>& inputs) {
99
+ std::vector<at::Tensor>& inputs,
100
+ bool with_scalar = false ) {
100
101
for (size_t idx = 0 ; idx < stack.size (); idx++) {
102
+ if (!with_scalar && stack[idx].isScalar ()) {
103
+ continue ;
104
+ }
105
+
101
106
if (!unpack_ivalue (arguments[idx], stack[idx], device, inputs)) {
102
107
return false ;
103
108
}
@@ -106,6 +111,40 @@ bool unpack_tensors(
106
111
return true ;
107
112
}
108
113
114
+ std::vector<size_t > get_tensor_parameter_index (
115
+ const std::vector<c10::Argument>& arguments,
116
+ const torch::jit::Stack& stack) {
117
+ std::vector<size_t > tensor_parameter_index;
118
+ for (size_t idx = 0 ; idx < stack.size (); idx++) {
119
+ if (stack[idx].isScalar () || stack[idx].isTensor ()) {
120
+ // scalar and tensor
121
+ tensor_parameter_index.push_back (idx);
122
+ } else if (stack[idx].isTensorList ()) {
123
+ // tensor list
124
+ std::fill_n (
125
+ std::back_inserter (tensor_parameter_index),
126
+ stack[idx].toListRef ().size (),
127
+ idx);
128
+ } else if (stack[idx].isOptionalTensorList ()) {
129
+ // optional tensor list: std::vector<std::optional<at::Tensor>>
130
+ for (const auto & item : stack[idx].toListRef ()) {
131
+ if (item.toOptional <at::Tensor>().has_value ()) {
132
+ tensor_parameter_index.push_back (idx);
133
+ }
134
+ }
135
+ } else if (
136
+ *arguments[idx].real_type () ==
137
+ *c10::getTypePtr<c10::optional<at::Tensor>>()) {
138
+ // optional tensor
139
+ if (stack[idx].toOptional <at::Tensor>().has_value ()) {
140
+ tensor_parameter_index.push_back (idx);
141
+ }
142
+ }
143
+ }
144
+
145
+ return tensor_parameter_index;
146
+ }
147
+
109
148
} // namespace
110
149
111
150
AOTIPythonKernelHolder::AOTIPythonKernelHolder (
@@ -149,14 +188,19 @@ bool AOTIPythonKernelHolder::cache_lookup(
149
188
" Not implemented for operations that return a non-Tensor value." );
150
189
151
190
std::vector<at::Tensor> inputs;
152
- auto res = unpack_tensors (op.schema ().arguments (), *stack, device_, inputs);
191
+ auto res =
192
+ unpack_tensors (op.schema ().arguments (), *stack, device_, inputs, true );
153
193
TORCH_CHECK_NOT_IMPLEMENTED (
154
194
res && inputs.size () > 0 ,
155
195
" Not implemented for operations that contain a parameter which is " ,
156
196
" not one of the following types: at::Tensor, at::TensorList, " ,
157
197
" std::optional<at::Tensor>, std::vector<std::optional<at::Tensor>>." );
158
198
159
- auto inputs_metadata = get_inputs_metadata (inputs);
199
+ auto tensor_parameter_index =
200
+ get_tensor_parameter_index (op.schema ().arguments (), *stack);
201
+ TORCH_INTERNAL_ASSERT (tensor_parameter_index.size () == inputs.size ());
202
+ auto inputs_metadata = get_inputs_metadata (
203
+ inputs, op.schema ().arguments (), tensor_parameter_index);
160
204
auto aoti_kernel_state = aoti_kernel_cache_.find (inputs_metadata);
161
205
if (aoti_kernel_state == aoti_kernel_cache_.end ()) {
162
206
return false ;
@@ -197,18 +241,49 @@ void AOTIPythonKernelHolder::cache_hit(
197
241
}
198
242
199
243
AOTIKernelMetadata AOTIPythonKernelHolder::get_inputs_metadata (
200
- const std::vector<at::Tensor>& inputs) {
244
+ const std::vector<at::Tensor>& inputs,
245
+ const std::vector<c10::Argument>& inputs_argument,
246
+ const std::vector<size_t >& inputs_argument_index) {
201
247
AOTIKernelMetadata inputs_metadata;
202
- for (const auto & input : inputs) {
248
+ for (size_t idx = 0 ; idx < inputs.size (); ++idx) {
249
+ auto input = inputs[idx];
250
+ auto input_info = inputs_argument[inputs_argument_index[idx]];
251
+
203
252
auto device = input.device ();
204
253
if (device.is_cpu ()) {
205
254
// If the device is CPU, set the device index to -1.
206
255
device = c10::Device (device.type (), -1 );
207
256
}
208
257
258
+ c10::Scalar scalar_value ((double )1.0 );
259
+ auto tensor_type = input.scalar_type ();
260
+
261
+ bool is_scalar = input_info.type ()->isSubtypeOf (*c10::NumberType::get ());
262
+ if (is_scalar) {
263
+ if (c10::isFloatingType (input.scalar_type ())) {
264
+ auto scalar_numeric_value = input.item ().toDouble ();
265
+ tensor_type = c10::ScalarType::Double;
266
+ scalar_value = c10::Scalar (scalar_numeric_value);
267
+ } else if (c10::isIntegralType (input.scalar_type (), false )) {
268
+ auto scalar_numeric_value = input.item ().toUInt64 ();
269
+ tensor_type = c10::ScalarType::UInt64;
270
+ scalar_value = c10::Scalar (scalar_numeric_value);
271
+ } else if (input.scalar_type () == c10::ScalarType::Bool) {
272
+ auto scalar_numeric_value = input.item ().toBool ();
273
+ tensor_type = c10::ScalarType::Bool;
274
+ scalar_value = c10::Scalar (scalar_numeric_value);
275
+ } else {
276
+ TORCH_CHECK (
277
+ false ,
278
+ " Unsupported scalar tensor type: " ,
279
+ c10::toString (input.scalar_type ()));
280
+ }
281
+ }
282
+
209
283
inputs_metadata.emplace_back (
210
- false , // is symbloic
211
- input.scalar_type (),
284
+ false ,
285
+ tensor_type,
286
+ c10::IValue (scalar_value),
212
287
device,
213
288
input.sizes ().vec (),
214
289
input.strides ().vec ());
@@ -269,6 +344,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
269
344
reinterpret_cast <THPDtype*>(data_type_obj.ptr ())->scalar_type ;
270
345
auto sizes = metadata[" sizes" ].cast <std::vector<int64_t >>();
271
346
auto strides = metadata[" strides" ].cast <std::vector<int64_t >>();
347
+ bool is_scalar = metadata.contains (" scalar_value" );
272
348
273
349
std::vector<std::optional<c10::SymInt>> sym_optional_sizes;
274
350
std::vector<std::optional<c10::SymInt>> sym_optional_strides;
@@ -279,10 +355,34 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
279
355
sym_optional_strides.push_back (std::optional<c10::SymInt>(stride));
280
356
}
281
357
282
- // Now you can use these variables in your code
358
+ // If an input parameter is a scalar, its detailed value is cached.
359
+ // This is done to ensure correctness during subsequent checks.
360
+ c10::Scalar scalar_value ((double )1.0 );
361
+ if (is_scalar) {
362
+ if (c10::isFloatingType (data_type)) {
363
+ auto scalar_numeric_value = metadata[" scalar_value" ].cast <double >();
364
+ data_type = c10::ScalarType::Double;
365
+ scalar_value = c10::Scalar (scalar_numeric_value);
366
+ } else if (c10::isIntegralType (data_type, false )) {
367
+ auto scalar_numeric_value = metadata[" scalar_value" ].cast <int64_t >();
368
+ data_type = c10::ScalarType::UInt64;
369
+ scalar_value = c10::Scalar (scalar_numeric_value);
370
+ } else if (data_type == c10::ScalarType::Bool) {
371
+ auto scalar_numeric_value = metadata[" scalar_value" ].cast <bool >();
372
+ data_type = c10::ScalarType::Bool;
373
+ scalar_value = c10::Scalar (scalar_numeric_value);
374
+ } else {
375
+ TORCH_CHECK (
376
+ false ,
377
+ " Unsupported scalar tensor type: " ,
378
+ c10::toString (data_type));
379
+ }
380
+ }
381
+
283
382
tensor_metadata_list.emplace_back (
284
383
is_dynamic,
285
384
data_type,
385
+ c10::IValue (scalar_value),
286
386
c10::Device (c10::Device (device_type).type (), device_index),
287
387
sizes,
288
388
strides);
0 commit comments