-
Notifications
You must be signed in to change notification settings - Fork 363
feat(serde)!: Refactor CudaDevice struct, implement ABI versioning, serde cleanup #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#include "cuda_runtime.h" | ||
|
||
#include "core/runtime/runtime.h" | ||
#include "core/util/prelude.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace runtime { | ||
|
||
const std::string DEVICE_INFO_DELIM = "%"; | ||
|
||
typedef enum { ID_IDX = 0, SM_MAJOR_IDX, SM_MINOR_IDX, DEVICE_TYPE_IDX, DEVICE_NAME_IDX } SerializedDeviceInfoIndex; | ||
|
||
CudaDevice::CudaDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {} | ||
|
||
CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) { | ||
CudaDevice cuda_device; | ||
cudaDeviceProp device_prop; | ||
|
||
// Device ID | ||
this->id = gpu_id; | ||
|
||
// Get Device Properties | ||
cudaGetDeviceProperties(&device_prop, gpu_id); | ||
|
||
// Compute capability major version | ||
this->major = device_prop.major; | ||
|
||
// Compute capability minor version | ||
this->minor = device_prop.minor; | ||
|
||
std::string device_name(device_prop.name); | ||
|
||
// Set Device name | ||
this->device_name = device_name; | ||
|
||
// Set Device Type | ||
this->device_type = device_type; | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switched to a delimted string vs. a byte indexed string, should be easier to parse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It still requires parsing based on the positional encoding of the parameters during deserialization. |
||
// NOTE: Serialization Format for Device Info: | ||
// id%major%minor%(enum)device_type%device_name | ||
|
||
CudaDevice::CudaDevice(std::string device_info) { | ||
LOG_DEBUG("Deserializing Device Info: " << device_info); | ||
|
||
std::vector<std::string> tokens; | ||
int64_t start = 0; | ||
int64_t end = device_info.find(DEVICE_INFO_DELIM); | ||
|
||
while (end != -1) { | ||
tokens.push_back(device_info.substr(start, end - start)); | ||
start = end + DEVICE_INFO_DELIM.size(); | ||
end = device_info.find(DEVICE_INFO_DELIM, start); | ||
} | ||
tokens.push_back(device_info.substr(start, end - start)); | ||
|
||
TRTORCH_CHECK(tokens.size() == DEVICE_NAME_IDX + 1, "Unable to deserializable program target device infomation"); | ||
|
||
id = std::stoi(tokens[ID_IDX]); | ||
major = std::stoi(tokens[SM_MAJOR_IDX]); | ||
minor = std::stoi(tokens[SM_MINOR_IDX]); | ||
device_type = (nvinfer1::DeviceType)(std::stoi(tokens[DEVICE_TYPE_IDX])); | ||
device_name = tokens[DEVICE_NAME_IDX]; | ||
|
||
LOG_DEBUG("Deserialized Device Info: " << *this); | ||
} | ||
|
||
std::string CudaDevice::serialize() { | ||
narendasan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::vector<std::string> content; | ||
content.resize(DEVICE_NAME_IDX + 1); | ||
|
||
content[ID_IDX] = std::to_string(id); | ||
content[SM_MAJOR_IDX] = std::to_string(major); | ||
content[SM_MINOR_IDX] = std::to_string(minor); | ||
content[DEVICE_TYPE_IDX] = std::to_string((int64_t)device_type); | ||
content[DEVICE_NAME_IDX] = device_name; | ||
|
||
std::stringstream ss; | ||
for (size_t i = 0; i < content.size() - 1; i++) { | ||
ss << content[i] << DEVICE_INFO_DELIM; | ||
} | ||
ss << content[DEVICE_NAME_IDX]; | ||
|
||
std::string serialized_device_info = ss.str(); | ||
|
||
LOG_DEBUG("Serialized Device Info: " << serialized_device_info); | ||
|
||
return serialized_device_info; | ||
} | ||
|
||
std::string CudaDevice::getSMCapability() const { | ||
std::stringstream ss; | ||
ss << major << "." << minor; | ||
return ss.str(); | ||
} | ||
|
||
std::ostream& operator<<(std::ostream& os, const CudaDevice& device) { | ||
os << "Device(ID: " << device.id << ", Name: " << device.device_name << ", SM Capability: " << device.major << '.' | ||
<< device.minor << ", Type: " << device.device_type << ')'; | ||
return os; | ||
} | ||
|
||
} // namespace runtime | ||
} // namespace core | ||
} // namespace trtorch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#include "cuda_runtime.h" | ||
|
||
#include "core/runtime/runtime.h" | ||
#include "core/util/prelude.h" | ||
|
||
namespace trtorch { | ||
namespace core { | ||
namespace runtime { | ||
|
||
DeviceList::DeviceList() { | ||
int num_devices = 0; | ||
auto status = cudaGetDeviceCount(&num_devices); | ||
TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status); | ||
for (int i = 0; i < num_devices; i++) { | ||
device_list[i] = CudaDevice(i, nvinfer1::DeviceType::kGPU); | ||
} | ||
|
||
// REVIEW: DO WE CARE ABOUT DLA? | ||
|
||
LOG_DEBUG("Runtime:\n Available CUDA Devices: \n" << this->dump_list()); | ||
} | ||
|
||
void DeviceList::insert(int device_id, CudaDevice cuda_device) { | ||
device_list[device_id] = cuda_device; | ||
} | ||
|
||
CudaDevice DeviceList::find(int device_id) { | ||
return device_list[device_id]; | ||
} | ||
|
||
DeviceList::DeviceMap DeviceList::get_devices() { | ||
return device_list; | ||
} | ||
|
||
std::string DeviceList::dump_list() { | ||
std::stringstream ss; | ||
for (auto it = device_list.begin(); it != device_list.end(); ++it) { | ||
ss << " " << it->second << std::endl; | ||
} | ||
return ss.str(); | ||
} | ||
|
||
} // namespace runtime | ||
} // namespace core | ||
} // namespace trtorch |
Uh oh!
There was an error while loading. Please reload this page.