|
25 | 25 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
26 | 26 |
|
27 | 27 | #include <stdint.h>
|
28 |
| -<<<<<<< HEAD |
29 |
| - |
30 |
| -======= |
31 |
| ->>>>>>> implicit state management |
32 |
| -#include <cstdint> |
33 | 28 | #include <exception>
|
34 | 29 |
|
35 | 30 | #include "libtorch_utils.h"
|
|
41 | 36 | #include "triton/backend/backend_output_responder.h"
|
42 | 37 | #include "triton/common/nvtx.h"
|
43 | 38 | #include "triton/core/tritonbackend.h"
|
44 |
| -#include "triton/core/tritonserver.h" |
45 | 39 |
|
46 | 40 | #ifdef TRITON_PYTORCH_ENABLE_TORCHVISION
|
47 | 41 | // Suppress warnings in torch headers
|
@@ -153,7 +147,9 @@ class ModelState : public BackendModel {
|
153 | 147 | torch_models_;
|
154 | 148 |
|
155 | 149 | // model_outputs is a map that contains unique outputs that the model must
|
156 |
| - // provide. In the model configuration, the output in the state configuration |
| 150 | + // provide. The first pair is the model output index and the second is |
| 151 | + // the index in the model state, -1 is used if one is not required. |
| 152 | + // In the model configuration, the output in the state configuration |
157 | 153 | // can have intersection with the outputs section of the model. If an output
|
158 | 154 | // is specified both in the output section and state section, it indicates
|
159 | 155 | // that the backend must return the output state to the client too.
|
@@ -539,10 +535,6 @@ class ModelInstanceState : public BackendModelInstance {
|
539 | 535 | TRITONSERVER_Error* ValidateTypedSequenceControl(
|
540 | 536 | triton::common::TritonJson::Value& sequence_batching,
|
541 | 537 | const std::string& control_kind, bool required, bool* have_control);
|
542 |
| - void AddInputToMap( |
543 |
| - NamingConvention naming_convention, |
544 |
| - const std::vector<std::string> allowed_inputs, const std::string& io_name, |
545 |
| - const uint32_t index); |
546 | 538 | TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
|
547 | 539 | void AddInputToMap(
|
548 | 540 | NamingConvention naming_convention,
|
@@ -814,42 +806,6 @@ ModelInstanceState::ValidateTypedSequenceControl(
|
814 | 806 |
|
815 | 807 | return nullptr; // success
|
816 | 808 | }
|
817 |
| -void |
818 |
| -ModelInstanceState::AddInputToMap( |
819 |
| - NamingConvention naming_convention, |
820 |
| - const std::vector<std::string> allowed_inputs, const std::string& io_name, |
821 |
| - const uint32_t index) |
822 |
| -{ |
823 |
| - std::string deliminator = "__"; |
824 |
| - |
825 |
| - if (is_dict_input_) { |
826 |
| - // If dictionary, index is irrelevant but we use the map to store the |
827 |
| - // input names since they are the keys for the dictionary |
828 |
| - input_index_map_[io_name] = index; |
829 |
| - } else { |
830 |
| - switch (naming_convention) { |
831 |
| - case NamingConvention::FORWARD_ARGUMENT: { |
832 |
| - auto itr = |
833 |
| - std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name); |
834 |
| - if (itr != allowed_inputs.end()) { |
835 |
| - input_index_map_[io_name] = |
836 |
| - std::distance(allowed_inputs.begin(), itr); |
837 |
| - } |
838 |
| - return; |
839 |
| - } |
840 |
| - case NamingConvention::NAMED_INDEX: { |
841 |
| - int start_pos = io_name.find(deliminator); |
842 |
| - int ip_index = std::atoi(io_name.substr(start_pos + 2).c_str()); |
843 |
| - input_index_map_[io_name] = ip_index; |
844 |
| - return; |
845 |
| - } |
846 |
| - case NamingConvention::STRICT_CONFIG_ORDERING: { |
847 |
| - input_index_map_[io_name] = index; |
848 |
| - return; |
849 |
| - } |
850 |
| - } |
851 |
| - } |
852 |
| - } |
853 | 809 |
|
854 | 810 | void
|
855 | 811 | ModelInstanceState::AddInputToMap(
|
@@ -972,10 +928,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
|
972 | 928 | std::string io_name;
|
973 | 929 | RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
|
974 | 930 | AddInputToMap(naming_convention, allowed_inputs, io_name, i);
|
975 |
| -<<<<<<< HEAD |
976 |
| -======= |
977 |
| - |
978 |
| ->>>>>>> implicit state management |
979 | 931 | // Validate data type
|
980 | 932 | std::string io_dtype;
|
981 | 933 | RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype));
|
@@ -1035,7 +987,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
|
1035 | 987 | .c_str());
|
1036 | 988 | }
|
1037 | 989 |
|
1038 |
| - |
1039 | 990 | // Validate shape for String inputs. Only allow 1 dimension.
|
1040 | 991 | if (state_dtype == "TYPE_STRING") {
|
1041 | 992 | std::vector<int64_t> dims;
|
|
0 commit comments