Skip to content

Commit b51e6dd

Browse files
EikanWangZelboK
authored andcommitted
[2/N] Non-Tensor: Scalar Support: Add scalar to the cache for eager-through-torch.compile (pytorch#124070)
Add scalar information to the kernel configuration. #### Additional Context Currently, the input parameters are orchestrated by input order in the kernel configuration and loaded/mapped to the kernel at runtime. For example, the cache order of the input parameters of `torch.add(a, b, alpha=2.0)` is `a' first, followed by `b` and then `alpha`. The same order is for cache loading. However, the orchestration mechanism does not support kwargs because the order of kwargs is useless. For example, the `out` of `aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)` may be before `approximate`. We will support it with subsequent PRs. Pull Request resolved: pytorch#124070 Approved by: https://github.com/jansel, https://github.com/jgong5
1 parent 8a7f719 commit b51e6dd

File tree

6 files changed

+278
-23
lines changed

6 files changed

+278
-23
lines changed

test/inductor/test_torchinductor.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,86 @@ def fn(a):
841841

842842
self.assertTrue(kernel_lib_path in kernel_libs_abs_path)
843843

844+
@skipCUDAIf(not SM80OrLater, "Requires sm80")
845+
def test_eager_aoti_with_scalar(self):
846+
namespace_name = "aten"
847+
op_name = "add"
848+
op_overload_name = "Tensor"
849+
op_name_with_overload = f"{op_name}.{op_overload_name}"
850+
851+
dispatch_key = "CPU"
852+
device = torch.device("cpu")
853+
if self.device.lower() == "cuda":
854+
dispatch_key = "CUDA"
855+
device = torch.device("cuda")
856+
857+
# Test the difference between scalar tensor and scalar
858+
a = torch.scalar_tensor(1.0, device=device)
859+
b = torch.scalar_tensor(2.0, device=device)
860+
861+
kernel_lib_path = aoti_compile_with_persistent_cache(
862+
namespace_name,
863+
op_name_with_overload,
864+
a.device.type,
865+
False,
866+
torch.ops.aten.add,
867+
args=(a, b),
868+
kwargs={"alpha": 3.0},
869+
)
870+
self.assertTrue(Path(kernel_lib_path).exists())
871+
device_kernel_cache = aoti_eager_cache_dir(namespace_name, device.type)
872+
kernel_conf = device_kernel_cache / f"{op_name_with_overload}.json"
873+
self.assertTrue(kernel_conf.exists())
874+
json_data = load_aoti_eager_cache(
875+
namespace_name, op_name_with_overload, a.device.type
876+
)
877+
op_info = json_data[0]
878+
self.assertTrue(isinstance(op_info, dict))
879+
self.assertTrue("meta_info" in op_info)
880+
self.assertTrue(len(op_info["meta_info"]) == 3)
881+
self.assertTrue(op_info["meta_info"][0]["sizes"] == [])
882+
self.assertTrue(op_info["meta_info"][0]["strides"] == [])
883+
# Scalar Tensor
884+
self.assertTrue("scalar_value" not in op_info["meta_info"][0])
885+
self.assertTrue(op_info["meta_info"][1]["sizes"] == [])
886+
self.assertTrue(op_info["meta_info"][1]["strides"] == [])
887+
# Scalar Tensor
888+
self.assertTrue("scalar_value" not in op_info["meta_info"][1])
889+
self.assertTrue(op_info["meta_info"][2]["sizes"] == [])
890+
self.assertTrue(op_info["meta_info"][2]["strides"] == [])
891+
# Scalar
892+
self.assertTrue("scalar_value" in op_info["meta_info"][2])
893+
894+
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
895+
a = torch.randn(128, device=device)
896+
b = torch.randn(128, device=device)
897+
898+
scalar_values = [1.0, 2.0, 3.0]
899+
ref_values = []
900+
for scalar_value in scalar_values:
901+
ref_values.append(torch.add(a, b, alpha=scalar_value))
902+
903+
qualified_op_name = f"{namespace_name}::{op_name}"
904+
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
905+
for overload_name in overload_names:
906+
try:
907+
reg_op_name = qualified_op_name
908+
schema = torch._C._get_schema(reg_op_name, overload_name)
909+
if schema.overload_name:
910+
reg_op_name = f"{reg_op_name}.{schema.overload_name}"
911+
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
912+
reg_op_name, dispatch_key
913+
)
914+
except Exception as e:
915+
continue
916+
917+
res_values = []
918+
for scalar_value in scalar_values:
919+
res_values.append(torch.add(a, b, alpha=scalar_value))
920+
921+
self.assertEqual(len(ref_values), len(res_values))
922+
self.assertEqual(ref_values, res_values)
923+
844924
@skipCUDAIf(not SM80OrLater, "Requires sm80")
845925
def test_torch_compile_override_registration(self):
846926
dynamic = False

torch/_inductor/utils.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,16 +1578,23 @@ def aoti_compile_with_persistent_cache(
15781578
"""
15791579
Compile the given function with persistent cache for AOTI eager mode.
15801580
"""
1581-
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
1582-
assert all(
1583-
isinstance(input, torch.Tensor) for input in flattened_inputs
1584-
), "Only support tensor for now"
15851581
assert not dynamic, "Only support static shape for now"
1582+
type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool}
1583+
supported_scalar_types = tuple(type_to_torch_dtype.keys())
1584+
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
1585+
if not all(
1586+
isinstance(input, (supported_scalar_types, torch.Tensor))
1587+
for input in flattened_inputs
1588+
):
1589+
raise NotImplementedError("Only support tensor, int, float, bool for now")
15861590

15871591
persistent_cache = aoti_eager_cache_dir(ns, device_type)
1588-
persistent_cache.mkdir(parents=True, exist_ok=True)
1592+
if not persistent_cache.exists():
1593+
persistent_cache.mkdir(parents=True)
1594+
15891595
persistent_cache_lib = persistent_cache / "lib"
1590-
persistent_cache_lib.mkdir(parents=True, exist_ok=True)
1596+
if not persistent_cache_lib.exists():
1597+
persistent_cache_lib.mkdir()
15911598

15921599
with mock.patch.dict(
15931600
os.environ,
@@ -1609,18 +1616,30 @@ def aoti_compile_with_persistent_cache(
16091616
)
16101617

16111618
kernel_metadata_items = []
1612-
for input_tensor in flattened_inputs:
1619+
for input in flattened_inputs:
16131620
# TODO(Eikan): To add dynamic support
16141621
metadata: Dict[str, Any] = {}
16151622
metadata["is_dynamic"] = dynamic
1616-
metadata["device_type"] = f"{input_tensor.device.type}"
1617-
if is_cpu_device([input_tensor]):
1618-
metadata["device_index"] = -1
1623+
1624+
if isinstance(input, torch.Tensor):
1625+
metadata["device_type"] = f"{input.device.type}"
1626+
if is_cpu_device([input]):
1627+
metadata["device_index"] = -1
1628+
else:
1629+
metadata["device_index"] = input.device.index
1630+
metadata["dtype"] = f"{input.dtype}"
1631+
metadata["sizes"] = list(input.size())
1632+
metadata["strides"] = list(input.stride())
16191633
else:
1620-
metadata["device_index"] = input_tensor.device.index
1621-
metadata["dtype"] = f"{input_tensor.dtype}"
1622-
metadata["sizes"] = list(input_tensor.size())
1623-
metadata["strides"] = list(input_tensor.stride())
1634+
assert isinstance(input, supported_scalar_types)
1635+
# Scalar tensor
1636+
metadata["device_type"] = device_type
1637+
metadata["device_index"] = -1 if device_type == "cpu" else 0
1638+
metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}"
1639+
metadata["sizes"] = []
1640+
metadata["strides"] = []
1641+
metadata["scalar_value"] = input
1642+
16241643
kernel_metadata_items.append(metadata)
16251644

16261645
kernel_meta_info: Dict[str, Any] = {}

torch/csrc/inductor/aoti_eager/kernel_holder.cpp

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,13 @@ bool unpack_tensors(
9696
const std::vector<c10::Argument>& arguments,
9797
const torch::jit::Stack& stack,
9898
const c10::Device& device,
99-
std::vector<at::Tensor>& inputs) {
99+
std::vector<at::Tensor>& inputs,
100+
bool with_scalar = false) {
100101
for (size_t idx = 0; idx < stack.size(); idx++) {
102+
if (!with_scalar && stack[idx].isScalar()) {
103+
continue;
104+
}
105+
101106
if (!unpack_ivalue(arguments[idx], stack[idx], device, inputs)) {
102107
return false;
103108
}
@@ -106,6 +111,40 @@ bool unpack_tensors(
106111
return true;
107112
}
108113

114+
std::vector<size_t> get_tensor_parameter_index(
115+
const std::vector<c10::Argument>& arguments,
116+
const torch::jit::Stack& stack) {
117+
std::vector<size_t> tensor_parameter_index;
118+
for (size_t idx = 0; idx < stack.size(); idx++) {
119+
if (stack[idx].isScalar() || stack[idx].isTensor()) {
120+
// scalar and tensor
121+
tensor_parameter_index.push_back(idx);
122+
} else if (stack[idx].isTensorList()) {
123+
// tensor list
124+
std::fill_n(
125+
std::back_inserter(tensor_parameter_index),
126+
stack[idx].toListRef().size(),
127+
idx);
128+
} else if (stack[idx].isOptionalTensorList()) {
129+
// optional tensor list: std::vector<std::optional<at::Tensor>>
130+
for (const auto& item : stack[idx].toListRef()) {
131+
if (item.toOptional<at::Tensor>().has_value()) {
132+
tensor_parameter_index.push_back(idx);
133+
}
134+
}
135+
} else if (
136+
*arguments[idx].real_type() ==
137+
*c10::getTypePtr<c10::optional<at::Tensor>>()) {
138+
// optional tensor
139+
if (stack[idx].toOptional<at::Tensor>().has_value()) {
140+
tensor_parameter_index.push_back(idx);
141+
}
142+
}
143+
}
144+
145+
return tensor_parameter_index;
146+
}
147+
109148
} // namespace
110149

111150
AOTIPythonKernelHolder::AOTIPythonKernelHolder(
@@ -149,14 +188,19 @@ bool AOTIPythonKernelHolder::cache_lookup(
149188
"Not implemented for operations that return a non-Tensor value.");
150189

151190
std::vector<at::Tensor> inputs;
152-
auto res = unpack_tensors(op.schema().arguments(), *stack, device_, inputs);
191+
auto res =
192+
unpack_tensors(op.schema().arguments(), *stack, device_, inputs, true);
153193
TORCH_CHECK_NOT_IMPLEMENTED(
154194
res && inputs.size() > 0,
155195
"Not implemented for operations that contain a parameter which is ",
156196
"not one of the following types: at::Tensor, at::TensorList, ",
157197
"std::optional<at::Tensor>, std::vector<std::optional<at::Tensor>>.");
158198

159-
auto inputs_metadata = get_inputs_metadata(inputs);
199+
auto tensor_parameter_index =
200+
get_tensor_parameter_index(op.schema().arguments(), *stack);
201+
TORCH_INTERNAL_ASSERT(tensor_parameter_index.size() == inputs.size());
202+
auto inputs_metadata = get_inputs_metadata(
203+
inputs, op.schema().arguments(), tensor_parameter_index);
160204
auto aoti_kernel_state = aoti_kernel_cache_.find(inputs_metadata);
161205
if (aoti_kernel_state == aoti_kernel_cache_.end()) {
162206
return false;
@@ -197,18 +241,49 @@ void AOTIPythonKernelHolder::cache_hit(
197241
}
198242

199243
AOTIKernelMetadata AOTIPythonKernelHolder::get_inputs_metadata(
200-
const std::vector<at::Tensor>& inputs) {
244+
const std::vector<at::Tensor>& inputs,
245+
const std::vector<c10::Argument>& inputs_argument,
246+
const std::vector<size_t>& inputs_argument_index) {
201247
AOTIKernelMetadata inputs_metadata;
202-
for (const auto& input : inputs) {
248+
for (size_t idx = 0; idx < inputs.size(); ++idx) {
249+
auto input = inputs[idx];
250+
auto input_info = inputs_argument[inputs_argument_index[idx]];
251+
203252
auto device = input.device();
204253
if (device.is_cpu()) {
205254
// If the device is CPU, set the device index to -1.
206255
device = c10::Device(device.type(), -1);
207256
}
208257

258+
c10::Scalar scalar_value((double)1.0);
259+
auto tensor_type = input.scalar_type();
260+
261+
bool is_scalar = input_info.type()->isSubtypeOf(*c10::NumberType::get());
262+
if (is_scalar) {
263+
if (c10::isFloatingType(input.scalar_type())) {
264+
auto scalar_numeric_value = input.item().toDouble();
265+
tensor_type = c10::ScalarType::Double;
266+
scalar_value = c10::Scalar(scalar_numeric_value);
267+
} else if (c10::isIntegralType(input.scalar_type(), false)) {
268+
auto scalar_numeric_value = input.item().toUInt64();
269+
tensor_type = c10::ScalarType::UInt64;
270+
scalar_value = c10::Scalar(scalar_numeric_value);
271+
} else if (input.scalar_type() == c10::ScalarType::Bool) {
272+
auto scalar_numeric_value = input.item().toBool();
273+
tensor_type = c10::ScalarType::Bool;
274+
scalar_value = c10::Scalar(scalar_numeric_value);
275+
} else {
276+
TORCH_CHECK(
277+
false,
278+
"Unsupported scalar tensor type: ",
279+
c10::toString(input.scalar_type()));
280+
}
281+
}
282+
209283
inputs_metadata.emplace_back(
210-
false, // is symbloic
211-
input.scalar_type(),
284+
false,
285+
tensor_type,
286+
c10::IValue(scalar_value),
212287
device,
213288
input.sizes().vec(),
214289
input.strides().vec());
@@ -269,6 +344,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
269344
reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;
270345
auto sizes = metadata["sizes"].cast<std::vector<int64_t>>();
271346
auto strides = metadata["strides"].cast<std::vector<int64_t>>();
347+
bool is_scalar = metadata.contains("scalar_value");
272348

273349
std::vector<std::optional<c10::SymInt>> sym_optional_sizes;
274350
std::vector<std::optional<c10::SymInt>> sym_optional_strides;
@@ -279,10 +355,34 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() {
279355
sym_optional_strides.push_back(std::optional<c10::SymInt>(stride));
280356
}
281357

282-
// Now you can use these variables in your code
358+
// If an input parameter is a scalar, its detailed value is cached.
359+
// This is done to ensure correctness during subsequent checks.
360+
c10::Scalar scalar_value((double)1.0);
361+
if (is_scalar) {
362+
if (c10::isFloatingType(data_type)) {
363+
auto scalar_numeric_value = metadata["scalar_value"].cast<double>();
364+
data_type = c10::ScalarType::Double;
365+
scalar_value = c10::Scalar(scalar_numeric_value);
366+
} else if (c10::isIntegralType(data_type, false)) {
367+
auto scalar_numeric_value = metadata["scalar_value"].cast<int64_t>();
368+
data_type = c10::ScalarType::UInt64;
369+
scalar_value = c10::Scalar(scalar_numeric_value);
370+
} else if (data_type == c10::ScalarType::Bool) {
371+
auto scalar_numeric_value = metadata["scalar_value"].cast<bool>();
372+
data_type = c10::ScalarType::Bool;
373+
scalar_value = c10::Scalar(scalar_numeric_value);
374+
} else {
375+
TORCH_CHECK(
376+
false,
377+
"Unsupported scalar tensor type: ",
378+
c10::toString(data_type));
379+
}
380+
}
381+
283382
tensor_metadata_list.emplace_back(
284383
is_dynamic,
285384
data_type,
385+
c10::IValue(scalar_value),
286386
c10::Device(c10::Device(device_type).type(), device_index),
287387
sizes,
288388
strides);

torch/csrc/inductor/aoti_eager/kernel_holder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <ATen/ATen.h>
55
#include <ATen/core/boxing/KernelFunction.h>
6+
#include <ATen/core/function_schema.h>
67

78
#include <torch/csrc/dynamo/guards.h>
89
#include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h>
@@ -82,7 +83,10 @@ class AOTIPythonKernelHolder : public c10::OperatorKernel {
8283
void init_aoti_kernel_cache();
8384
// Abstract the meta information of each tensor for the given operation. The
8485
// meta infomation will be used for cache lookup as the key.
85-
AOTIKernelMetadata get_inputs_metadata(const std::vector<at::Tensor>&);
86+
AOTIKernelMetadata get_inputs_metadata(
87+
const std::vector<at::Tensor>& inputs,
88+
const std::vector<c10::Argument>& inputs_argument,
89+
const std::vector<size_t>& inputs_argument_index);
8690
// Load the AOTIModelContainerRunner object from the given file path.
8791
std::shared_ptr<AOTIModelContainerRunner> load_aoti_model_runner(
8892
const std::string&);

0 commit comments

Comments
 (0)