Skip to content

Commit a0cd50d

Browse files
authored
refactor: Refactor string input checks (#263)
Refactor string input tensor checks
1 parent 373bd88 commit a0cd50d

File tree

1 file changed

+15
-40
lines changed

1 file changed

+15
-40
lines changed

src/onnxruntime.cc

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,58 +2477,33 @@ ModelInstanceState::SetStringInputBuffer(
24772477
std::vector<TRITONBACKEND_Response*>* responses, char* input_buffer,
24782478
std::vector<const char*>* string_ptrs)
24792479
{
2480+
std::vector<std::pair<const char*, const uint32_t>> str_list;
24802481
// offset for each response
24812482
size_t buffer_copy_offset = 0;
24822483
for (size_t idx = 0; idx < expected_byte_sizes.size(); idx++) {
24832484
const size_t expected_byte_size = expected_byte_sizes[idx];
24842485
const size_t expected_element_cnt = expected_element_cnts[idx];
24852486

2486-
size_t element_cnt = 0;
24872487
if ((*responses)[idx] != nullptr) {
2488-
size_t remaining_bytes = expected_byte_size;
24892488
char* data_content = input_buffer + buffer_copy_offset;
2490-
// Continue if the remaining bytes may still contain size info
2491-
while (remaining_bytes >= sizeof(uint32_t)) {
2492-
if (element_cnt >= expected_element_cnt) {
2493-
RESPOND_AND_SET_NULL_IF_ERROR(
2494-
&((*responses)[idx]),
2495-
TRITONSERVER_ErrorNew(
2496-
TRITONSERVER_ERROR_INVALID_ARG,
2497-
(std::string("unexpected number of string elements ") +
2498-
std::to_string(element_cnt + 1) + " for inference input '" +
2499-
input_name + "', expecting " +
2500-
std::to_string(expected_element_cnt))
2501-
.c_str()));
2502-
break;
2503-
}
2504-
2505-
const uint32_t len = *(reinterpret_cast<const uint32_t*>(data_content));
2506-
remaining_bytes -= sizeof(uint32_t);
2489+
TRITONSERVER_Error* err = ValidateStringBuffer(
2490+
data_content, expected_byte_size, expected_element_cnt,
2491+
input_name.c_str(), &str_list);
2492+
// Set string values.
2493+
for (const auto& [addr, len] : str_list) {
25072494
// Make first byte of size info 0, so that if there is string data
25082495
// in front of it, the data becomes valid C string.
2509-
*data_content = 0;
2510-
data_content = data_content + sizeof(uint32_t);
2511-
if (len > remaining_bytes) {
2512-
RESPOND_AND_SET_NULL_IF_ERROR(
2513-
&((*responses)[idx]),
2514-
TRITONSERVER_ErrorNew(
2515-
TRITONSERVER_ERROR_INVALID_ARG,
2516-
(std::string("incomplete string data for inference input '") +
2517-
input_name + "', expecting string of length " +
2518-
std::to_string(len) + " but only " +
2519-
std::to_string(remaining_bytes) + " bytes available")
2520-
.c_str()));
2521-
break;
2522-
} else {
2523-
string_ptrs->push_back(data_content);
2524-
element_cnt++;
2525-
data_content = data_content + len;
2526-
remaining_bytes -= len;
2527-
}
2496+
*const_cast<char*>(addr - sizeof(uint32_t)) = 0;
2497+
string_ptrs->push_back(addr);
25282498
}
2529-
}
25302499

2531-
FillStringData(string_ptrs, expected_element_cnt - element_cnt);
2500+
size_t element_cnt = str_list.size();
2501+
if (err != nullptr) {
2502+
RESPOND_AND_SET_NULL_IF_ERROR(&((*responses)[idx]), err);
2503+
FillStringData(string_ptrs, expected_element_cnt - element_cnt);
2504+
}
2505+
str_list.clear();
2506+
}
25322507
buffer_copy_offset += expected_byte_size;
25332508
}
25342509
}

0 commit comments

Comments
 (0)