@@ -1741,11 +1741,32 @@ ModelInstanceState::SetInputTensors(
1741
1741
1742
1742
input_names->emplace_back (input_name);
1743
1743
1744
- // The shape for the entire input patch, [total_batch_size, ...]
1745
- std::vector<int64_t > batchn_shape (
1746
- input_shape, input_shape + input_dims_count);
1747
- if (supports_batching_) {
1748
- batchn_shape[0 ] = total_batch_size;
1744
+ // The shape for the entire input patch,
1745
+ // [total_batch_size, ...] for non-ragged input and
1746
+ // [total_element_count] for ragged input (non-nested tensor)
1747
+ std::vector<int64_t > batchn_shape;
1748
+ if (StateForModel ()->IsInputRagged (input_name)) {
1749
+ batchn_shape = std::vector<int64_t >{0 };
1750
+ for (size_t idx = 0 ; idx < request_count; idx++) {
1751
+ TRITONBACKEND_Input* input;
1752
+ RESPOND_AND_SET_NULL_IF_ERROR (
1753
+ &((*responses)[idx]),
1754
+ TRITONBACKEND_RequestInput (requests[idx], input_name, &input));
1755
+ const int64_t * input_shape;
1756
+ uint32_t input_dims_count;
1757
+ RESPOND_AND_SET_NULL_IF_ERROR (
1758
+ &((*responses)[idx]), TRITONBACKEND_InputProperties (
1759
+ input, nullptr , nullptr , &input_shape,
1760
+ &input_dims_count, nullptr , nullptr ));
1761
+
1762
+ batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
1763
+ }
1764
+ }
1765
+ else {
1766
+ batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1767
+ if (supports_batching_) {
1768
+ batchn_shape[0 ] = total_batch_size;
1769
+ }
1749
1770
}
1750
1771
1751
1772
// The input must be in contiguous CPU/GPU memory.
@@ -1866,28 +1887,36 @@ ModelInstanceState::ReadOutputTensors(
1866
1887
1867
1888
// Output tensors may not reside on the same device as model
1868
1889
torch::Device tensor_device = output_flat.device ();
1890
+ const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1891
+ : TRITONSERVER_MEMORY_GPU;
1892
+ const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1893
+
1894
+ // Batch output doesn't support string data type yet, as it is not trivial
1895
+ // to parse string output
1896
+ const BatchOutput* batch_output = StateForModel ()->FindBatchOutput (name);
1897
+ if (batch_output == nullptr ) {
1898
+ // Get output shape
1899
+ std::vector<int64_t > batchn_shape;
1900
+ auto shape = output_tensors[op_index].toTensor ().sizes ();
1901
+ for (auto itr = shape.begin (); itr != shape.end (); itr++) {
1902
+ batchn_shape.push_back (*itr);
1903
+ }
1869
1904
1870
- // Get output shape
1871
- std::vector<int64_t > batchn_shape;
1872
- auto shape = output_tensors[op_index].toTensor ().sizes ();
1873
- for (auto itr = shape.begin (); itr != shape.end (); itr++) {
1874
- batchn_shape.push_back (*itr);
1875
- }
1905
+ if (batchn_shape.size () == 0 ) {
1906
+ return TRITONSERVER_ErrorNew (
1907
+ TRITONSERVER_ERROR_INVALID_ARG,
1908
+ (std::string (" output '" ) + name +
1909
+ " ' is a scalar which is not supported." )
1910
+ .c_str ());
1911
+ }
1876
1912
1877
- if (batchn_shape. size () == 0 ) {
1878
- return TRITONSERVER_ErrorNew (
1879
- TRITONSERVER_ERROR_INVALID_ARG,
1880
- ( std::string ( " output ' " ) + name +
1881
- " ' is a scalar which is not supported. " )
1882
- . c_str () );
1913
+ responder. ProcessTensor (
1914
+ name, output_dtype, batchn_shape, output_buffer,
1915
+ memory_type, memory_id);
1916
+ } else {
1917
+ responder. ProcessBatchOutput (
1918
+ name, *batch_output, output_buffer, memory_type, memory_id );
1883
1919
}
1884
-
1885
- responder.ProcessTensor (
1886
- name, output_dtype, batchn_shape, output_buffer,
1887
- (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1888
- : TRITONSERVER_MEMORY_GPU,
1889
- (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ());
1890
-
1891
1920
} else if (output_tensors[op_index].isList ()) {
1892
1921
// Custom handling for string/bytes tensor...
1893
1922
torch::List<torch::jit::IValue> output_list =
0 commit comments