-
Notifications
You must be signed in to change notification settings - Fork 60
RSDK-4977 - add ml training apis #455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6befd9
f0f2a64
081d5ea
b330218
9b0d8eb
b3c9269
c45e888
86f7374
0fecedc
65ac46e
70428f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from typing import Mapping, List, Optional | ||
|
||
from grpclib.client import Channel | ||
|
||
from viam import logging | ||
from viam.proto.app.mltraining import ( | ||
CancelTrainingJobRequest, | ||
GetTrainingJobRequest, | ||
GetTrainingJobResponse, | ||
ListTrainingJobsRequest, | ||
ListTrainingJobsResponse, | ||
MLTrainingServiceStub, | ||
ModelType, | ||
TrainingStatus, | ||
TrainingJobMetadata, | ||
) | ||
from viam.proto.app.data import Filter | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class MLTrainingClient: | ||
"""gRPC client for working with ML training jobs. | ||
|
||
Constructor is used by `ViamClient` to instantiate relevant service stubs. | ||
Calls to `MLTrainingClient` methods should be made through `ViamClient`. | ||
""" | ||
|
||
def __init__(self, channel: Channel, metadata: Mapping[str, str]): | ||
"""Create a `MLTrainingClient` that maintains a connection to app. | ||
|
||
Args: | ||
channel (grpclib.client.Channel): Connection to app. | ||
metadata (Mapping[str, str]): Required authorization token to send requests to app. | ||
""" | ||
self._metadata = metadata | ||
self._ml_training_client = MLTrainingServiceStub(channel) | ||
self._channel = channel | ||
|
||
async def submit_training_job( | ||
self, | ||
org_id: str, | ||
model_name: str, | ||
model_version: str, | ||
model_type: ModelType, | ||
tags: List[str], | ||
filter: Optional[Filter] = None, | ||
) -> str: | ||
raise NotImplementedError() | ||
|
||
async def get_training_job(self, id: str) -> TrainingJobMetadata: | ||
"""Gets training job data. | ||
|
||
Args: | ||
id (str): id of the requested training job. | ||
|
||
Returns: | ||
viam.proto.app.mltraining.TrainingJobMetadata: training job data. | ||
""" | ||
|
||
request = GetTrainingJobRequest(id=id) | ||
response: GetTrainingJobResponse = await self._ml_training_client.GetTrainingJob(request, metadata=self._metadata) | ||
|
||
return response.metadata | ||
|
||
async def list_training_jobs( | ||
self, | ||
org_id: str, | ||
training_status: Optional[TrainingStatus.ValueType] = None, | ||
) -> List[TrainingJobMetadata]: | ||
"""Returns training job data for all jobs within an org. | ||
|
||
Args: | ||
org_id (str): the id of the org to request training job data from. | ||
training_status (Optional[TrainingStatus]): status of training jobs to filter the list by. | ||
If unspecified, all training jobs will be returned. | ||
|
||
Returns: | ||
List[viam.proto.app.mltraining.TrainingJobMetadata]: a list of training job data. | ||
""" | ||
|
||
training_status = training_status if training_status else TrainingStatus.TRAINING_STATUS_UNSPECIFIED | ||
request = ListTrainingJobsRequest(organization_id=org_id, status=training_status) | ||
response: ListTrainingJobsResponse = await self._ml_training_client.ListTrainingJobs(request, metadata=self._metadata) | ||
|
||
return list(response.jobs) | ||
|
||
async def cancel_training_job(self, id: str) -> None: | ||
"""Cancels the specified training job. | ||
|
||
Args: | ||
id (str): the id of the job to be canceled. | ||
|
||
Raises: | ||
GRPCError: if no training job exists with the given id. | ||
""" | ||
|
||
request = CancelTrainingJobRequest(id=id) | ||
await self._ml_training_client.CancelTrainingJob(request, metadata=self._metadata) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from viam import logging | ||
from viam.app.app_client import AppClient | ||
from viam.app.data_client import DataClient | ||
from viam.app.ml_training_client import MLTrainingClient | ||
from viam.rpc.dial import DialOptions, _dial_app, _get_access_token | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
@@ -68,10 +69,15 @@ def app_client(self) -> AppClient: | |
"""Insantiate and return an `AppClient` used to make `app` method calls.""" | ||
return AppClient(self._channel, self._metadata, self._location_id) | ||
|
||
@property | ||
def ml_training_client(self) -> MLTrainingClient: | ||
"""Instantiate and return a `MLTrainingClient` used to make `ml_training` method calls.""" | ||
return MLTrainingClient(self._channel, self._metadata) | ||
|
||
def close(self): | ||
"""Close opened channels used for the various service stubs initialized.""" | ||
if self._closed: | ||
LOGGER.debug("AppClient is already closed.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (flyby) fix type name. |
||
LOGGER.debug("ViamClient is already closed.") | ||
return | ||
LOGGER.debug("Closing gRPC channel to app.") | ||
self._channel.close() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,11 @@ | |
from google.protobuf.timestamp_pb2 import Timestamp | ||
|
||
from viam.proto.common import Geometry, GeoPoint, GetGeometriesRequest, GetGeometriesResponse, Orientation, ResourceName, Vector3 | ||
from viam.proto.app.data import ( | ||
CaptureInterval, | ||
Filter, | ||
TagsFilter, | ||
) | ||
from viam.resource.base import ResourceBase | ||
from viam.resource.registry import Registry | ||
from viam.resource.types import Subtype, SupportsGetGeometries | ||
|
@@ -263,3 +268,65 @@ def from_dm_from_extra(extra: Optional[Dict[str, Any]]) -> bool: | |
return False | ||
|
||
return bool(extra.get("fromDataManagement", False)) | ||
|
||
|
||
def create_filter( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't put this in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine here. If you really wanted, you could create a single-use mixin, but beware the pitfalls |
||
component_name: Optional[str] = None, | ||
component_type: Optional[str] = None, | ||
method: Optional[str] = None, | ||
robot_name: Optional[str] = None, | ||
robot_id: Optional[str] = None, | ||
part_name: Optional[str] = None, | ||
part_id: Optional[str] = None, | ||
location_ids: Optional[List[str]] = None, | ||
organization_ids: Optional[List[str]] = None, | ||
mime_type: Optional[List[str]] = None, | ||
start_time: Optional[datetime] = None, | ||
end_time: Optional[datetime] = None, | ||
tags: Optional[List[str]] = None, | ||
bbox_labels: Optional[List[str]] = None, | ||
) -> Filter: | ||
"""Create a `Filter`. | ||
|
||
Args: | ||
component_name (Optional[str]): Optional name of the component that captured the data being filtered (e.g., "left_motor"). | ||
component_type (Optional[str]): Optional type of the componenet that captured the data being filtered (e.g., "motor"). | ||
method (Optional[str]): Optional name of the method used to capture the data being filtered (e.g., "IsPowered"). | ||
robot_name (Optional[str]): Optional name of the robot associated with the data being filtered (e.g., "viam_rover_1"). | ||
robot_id (Optional[str]): Optional ID of the robot associated with the data being filtered. | ||
part_name (Optional[str]): Optional name of the system part associated with the data being filtered (e.g., "viam_rover_1-main"). | ||
part_id (Optional[str]): Optional ID of the system part associated with the data being filtered. | ||
location_ids (Optional[List[str]]): Optional list of location IDs associated with the data being filtered. | ||
organization_ids (Optional[List[str]]): Optional list of organization IDs associated with the data being filtered. | ||
mime_type (Optional[List[str]]): Optional mime type of data being filtered (e.g., "image/png"). | ||
start_time (Optional[datetime.datetime]): Optional start time of an interval to filter data by. | ||
end_time (Optional[datetime.datetime]): Optional end time of an interval to filter data by. | ||
tags (Optional[List[str]]): Optional list of tags attached to the data being filtered (e.g., ["test"]). | ||
bbox_labels (Optional[List[str]]): Optional list of bounding box labels attached to the data being filtered (e.g., ["square", | ||
"circle"]). | ||
|
||
Returns: | ||
viam.proto.app.data.Filter: The `Filter` object. | ||
""" | ||
return Filter( | ||
component_name=component_name if component_name else "", | ||
component_type=component_type if component_type else "", | ||
method=method if method else "", | ||
robot_name=robot_name if robot_name else "", | ||
robot_id=robot_id if robot_id else "", | ||
part_name=part_name if part_name else "", | ||
part_id=part_id if part_id else "", | ||
location_ids=location_ids, | ||
organization_ids=organization_ids, | ||
mime_type=mime_type, | ||
interval=( | ||
CaptureInterval( | ||
start=datetime_to_timestamp(start_time), | ||
end=datetime_to_timestamp(end_time), | ||
) | ||
) | ||
if start_time or end_time | ||
else None, | ||
tags_filter=TagsFilter(tags=tags), | ||
bbox_labels=bbox_labels, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(note to reviewers) I think there's an argument to be made for creating our own shadow type for
TrainingJobMetadata
. The type ultimately breaks down into a struct of primitives, a request (which is just a a string and afilter
, which we already have conversion functions for), and an enum (which is relatively straightforward). So, the type isn't terribly complicated, but it is a bit nested (i.e., it contains a request which contains a filter) and I could imagine users not loving interacting with it.I opted for the simpler thing for now because 1) why do extra work if it's not necessary, and 2) my understanding is @njooma has a preference for using the
proto
types when we can reasonably get away with it. But, I'm curious to get y'all's thoughts. You can see what a manually constructedTrainingJobMetadata
looks like in the test file below.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see how hoisting some of the fields to top-level would be useful. If we were to make a shadow class, would filter also be automatically converted to the python-native type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we actually have a python-native
Filter
type, we just have a helper function for creating aFilter
. So users would still be stuck with the proto type unless we want to now create a python-nativeFilter
. This would be a bit of a reversal from our previous approach, but maybe worth considering given the now-expanded presence ofFilter
s?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I think we can leave this as is -- no need to create a shadow type