Skip to content

RSDK-4977 - add ml training apis #455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 18 additions & 46 deletions src/viam/app/data_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from datetime import datetime
from pathlib import Path
from typing import Any, List, Mapping, Optional, Tuple
Expand All @@ -17,7 +18,6 @@
BinaryMetadata,
BoundingBoxLabelsByFilterRequest,
BoundingBoxLabelsByFilterResponse,
CaptureInterval,
CaptureMetadata,
DataRequest,
DataServiceStub,
Expand All @@ -38,7 +38,6 @@
TabularDataByFilterResponse,
TagsByFilterRequest,
TagsByFilterResponse,
TagsFilter,
)
from viam.proto.app.datasync import (
DataCaptureUploadRequest,
Expand All @@ -52,7 +51,7 @@
SensorMetadata,
UploadMetadata,
)
from viam.utils import datetime_to_timestamp, struct_to_dict
from viam.utils import create_filter, datetime_to_timestamp, struct_to_dict

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -689,47 +688,20 @@ def create_filter(
tags: Optional[List[str]] = None,
bbox_labels: Optional[List[str]] = None,
) -> Filter:
"""Create a `Filter`.

Args:
component_name (Optional[str]): Optional name of the component that captured the data being filtered (e.g., "left_motor").
component_type (Optional[str]): Optional type of the componenet that captured the data being filtered (e.g., "motor").
method (Optional[str]): Optional name of the method used to capture the data being filtered (e.g., "IsPowered").
robot_name (Optional[str]): Optional name of the robot associated with the data being filtered (e.g., "viam_rover_1").
robot_id (Optional[str]): Optional ID of the robot associated with the data being filtered.
part_name (Optional[str]): Optional name of the system part associated with the data being filtered (e.g., "viam_rover_1-main").
part_id (Optional[str]): Optional ID of the system part associated with the data being filtered.
location_ids (Optional[List[str]]): Optional list of location IDs associated with the data being filtered.
organization_ids (Optional[List[str]]): Optional list of organization IDs associated with the data being filtered.
mime_type (Optional[List[str]]): Optional mime type of data being filtered (e.g., "image/png").
start_time (Optional[datetime.datetime]): Optional start time of an interval to filter data by.
end_time (Optional[datetime.datetime]): Optional end time of an interval to filter data by.
tags (Optional[List[str]]): Optional list of tags attached to the data being filtered (e.g., ["test"]).
bbox_labels (Optional[List[str]]): Optional list of bounding box labels attached to the data being filtered (e.g., ["square",
"circle"]).

Returns:
viam.proto.app.data.Filter: The `Filter` object.
"""
return Filter(
component_name=component_name if component_name else "",
component_type=component_type if component_type else "",
method=method if method else "",
robot_name=robot_name if robot_name else "",
robot_id=robot_id if robot_id else "",
part_name=part_name if part_name else "",
part_id=part_id if part_id else "",
location_ids=location_ids,
organization_ids=organization_ids,
mime_type=mime_type,
interval=(
CaptureInterval(
start=datetime_to_timestamp(start_time),
end=datetime_to_timestamp(end_time),
)
)
if start_time or end_time
else None,
tags_filter=TagsFilter(tags=tags),
bbox_labels=bbox_labels,
warnings.warn("DataClient.create_filter is deprecated. Use AppClient.create_filter instead.", DeprecationWarning, stacklevel=2)
return create_filter(
component_name,
component_type,
method,
robot_name,
robot_id,
part_name,
part_id,
location_ids,
organization_ids,
mime_type,
start_time,
end_time,
tags,
bbox_labels,
)
99 changes: 99 additions & 0 deletions src/viam/app/ml_training_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Mapping, List, Optional

from grpclib.client import Channel

from viam import logging
from viam.proto.app.mltraining import (
CancelTrainingJobRequest,
GetTrainingJobRequest,
GetTrainingJobResponse,
ListTrainingJobsRequest,
ListTrainingJobsResponse,
MLTrainingServiceStub,
ModelType,
TrainingStatus,
TrainingJobMetadata,
)
from viam.proto.app.data import Filter

LOGGER = logging.getLogger(__name__)


class MLTrainingClient:
"""gRPC client for working with ML training jobs.

Constructor is used by `ViamClient` to instantiate relevant service stubs.
Calls to `MLTrainingClient` methods should be made through `ViamClient`.
"""

def __init__(self, channel: Channel, metadata: Mapping[str, str]):
"""Create a `MLTrainingClient` that maintains a connection to app.

Args:
channel (grpclib.client.Channel): Connection to app.
metadata (Mapping[str, str]): Required authorization token to send requests to app.
"""
self._metadata = metadata
self._ml_training_client = MLTrainingServiceStub(channel)
self._channel = channel

async def submit_training_job(
self,
org_id: str,
model_name: str,
model_version: str,
model_type: ModelType,
tags: List[str],
filter: Optional[Filter] = None,
) -> str:
raise NotImplementedError()

async def get_training_job(self, id: str) -> TrainingJobMetadata:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(note to reviewers) I think there's an argument to be made for creating our own shadow type for TrainingJobMetadata. The type ultimately breaks down into a struct of primitives, a request (which is just a a string and a filter, which we already have conversion functions for), and an enum (which is relatively straightforward). So, the type isn't terribly complicated, but it is a bit nested (i.e., it contains a request which contains a filter) and I could imagine users not loving interacting with it.

I opted for the simpler thing for now because 1) why do extra work if it's not necessary, and 2) my understanding is @njooma has a preference for using the proto types when we can reasonably get away with it. But, I'm curious to get y'all's thoughts. You can see what a manually constructed TrainingJobMetadata looks like in the test file below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see how hoisting some of the fields to top-level would be useful. If we were to make a shadow class, would filter also be automatically converted to the python-native type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we actually have a python-native Filter type, we just have a helper function for creating a Filter. So users would still be stuck with the proto type unless we want to now create a python-native Filter. This would be a bit of a reversal from our previous approach, but maybe worth considering given the now-expanded presence of Filters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I think we can leave this as is -- no need to create a shadow type

"""Gets training job data.

Args:
id (str): id of the requested training job.

Returns:
viam.proto.app.mltraining.TrainingJobMetadata: training job data.
"""

request = GetTrainingJobRequest(id=id)
response: GetTrainingJobResponse = await self._ml_training_client.GetTrainingJob(request, metadata=self._metadata)

return response.metadata

async def list_training_jobs(
self,
org_id: str,
training_status: Optional[TrainingStatus.ValueType] = None,
) -> List[TrainingJobMetadata]:
"""Returns training job data for all jobs within an org.

Args:
org_id (str): the id of the org to request training job data from.
training_status (Optional[TrainingStatus]): status of training jobs to filter the list by.
If unspecified, all training jobs will be returned.

Returns:
List[viam.proto.app.mltraining.TrainingJobMetadata]: a list of training job data.
"""

training_status = training_status if training_status else TrainingStatus.TRAINING_STATUS_UNSPECIFIED
request = ListTrainingJobsRequest(organization_id=org_id, status=training_status)
response: ListTrainingJobsResponse = await self._ml_training_client.ListTrainingJobs(request, metadata=self._metadata)

return list(response.jobs)

async def cancel_training_job(self, id: str) -> None:
"""Cancels the specified training job.

Args:
id (str): the id of the job to be canceled.

Raises:
GRPCError: if no training job exists with the given id.
"""

request = CancelTrainingJobRequest(id=id)
await self._ml_training_client.CancelTrainingJob(request, metadata=self._metadata)
8 changes: 7 additions & 1 deletion src/viam/app/viam_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from viam import logging
from viam.app.app_client import AppClient
from viam.app.data_client import DataClient
from viam.app.ml_training_client import MLTrainingClient
from viam.rpc.dial import DialOptions, _dial_app, _get_access_token

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,10 +69,15 @@ def app_client(self) -> AppClient:
"""Insantiate and return an `AppClient` used to make `app` method calls."""
return AppClient(self._channel, self._metadata, self._location_id)

@property
def ml_training_client(self) -> MLTrainingClient:
"""Instantiate and return a `MLTrainingClient` used to make `ml_training` method calls."""
return MLTrainingClient(self._channel, self._metadata)

def close(self):
"""Close opened channels used for the various service stubs initialized."""
if self._closed:
LOGGER.debug("AppClient is already closed.")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(flyby) fix type name.

LOGGER.debug("ViamClient is already closed.")
return
LOGGER.debug("Closing gRPC channel to app.")
self._channel.close()
Expand Down
67 changes: 67 additions & 0 deletions src/viam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from google.protobuf.timestamp_pb2 import Timestamp

from viam.proto.common import Geometry, GeoPoint, GetGeometriesRequest, GetGeometriesResponse, Orientation, ResourceName, Vector3
from viam.proto.app.data import (
CaptureInterval,
Filter,
TagsFilter,
)
from viam.resource.base import ResourceBase
from viam.resource.registry import Registry
from viam.resource.types import Subtype, SupportsGetGeometries
Expand Down Expand Up @@ -263,3 +268,65 @@ def from_dm_from_extra(extra: Optional[Dict[str, Any]]) -> bool:
return False

return bool(extra.get("fromDataManagement", False))


def create_filter(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't put this in ViamClient because it would create a circular dependency, didn't want to leave it in DataClient because MLTrainingClient uses it too, and creating a third class (UsesFilter or something) that the relevant clients could inherit from felt clunky to me. Happy to revisit if others have opinions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine here. If you really wanted, you could create a single-use mixin, but beware the pitfalls

component_name: Optional[str] = None,
component_type: Optional[str] = None,
method: Optional[str] = None,
robot_name: Optional[str] = None,
robot_id: Optional[str] = None,
part_name: Optional[str] = None,
part_id: Optional[str] = None,
location_ids: Optional[List[str]] = None,
organization_ids: Optional[List[str]] = None,
mime_type: Optional[List[str]] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
tags: Optional[List[str]] = None,
bbox_labels: Optional[List[str]] = None,
) -> Filter:
"""Create a `Filter`.

Args:
component_name (Optional[str]): Optional name of the component that captured the data being filtered (e.g., "left_motor").
component_type (Optional[str]): Optional type of the componenet that captured the data being filtered (e.g., "motor").
method (Optional[str]): Optional name of the method used to capture the data being filtered (e.g., "IsPowered").
robot_name (Optional[str]): Optional name of the robot associated with the data being filtered (e.g., "viam_rover_1").
robot_id (Optional[str]): Optional ID of the robot associated with the data being filtered.
part_name (Optional[str]): Optional name of the system part associated with the data being filtered (e.g., "viam_rover_1-main").
part_id (Optional[str]): Optional ID of the system part associated with the data being filtered.
location_ids (Optional[List[str]]): Optional list of location IDs associated with the data being filtered.
organization_ids (Optional[List[str]]): Optional list of organization IDs associated with the data being filtered.
mime_type (Optional[List[str]]): Optional mime type of data being filtered (e.g., "image/png").
start_time (Optional[datetime.datetime]): Optional start time of an interval to filter data by.
end_time (Optional[datetime.datetime]): Optional end time of an interval to filter data by.
tags (Optional[List[str]]): Optional list of tags attached to the data being filtered (e.g., ["test"]).
bbox_labels (Optional[List[str]]): Optional list of bounding box labels attached to the data being filtered (e.g., ["square",
"circle"]).

Returns:
viam.proto.app.data.Filter: The `Filter` object.
"""
return Filter(
component_name=component_name if component_name else "",
component_type=component_type if component_type else "",
method=method if method else "",
robot_name=robot_name if robot_name else "",
robot_id=robot_id if robot_id else "",
part_name=part_name if part_name else "",
part_id=part_id if part_id else "",
location_ids=location_ids,
organization_ids=organization_ids,
mime_type=mime_type,
interval=(
CaptureInterval(
start=datetime_to_timestamp(start_time),
end=datetime_to_timestamp(end_time),
)
)
if start_time or end_time
else None,
tags_filter=TagsFilter(tags=tags),
bbox_labels=bbox_labels,
)
48 changes: 48 additions & 0 deletions tests/mocks/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@
StreamingDataCaptureUploadRequest,
StreamingDataCaptureUploadResponse,
)
from viam.proto.app.mltraining import (
CancelTrainingJobRequest,
CancelTrainingJobResponse,
GetTrainingJobRequest,
GetTrainingJobResponse,
ListTrainingJobsRequest,
ListTrainingJobsResponse,
MLTrainingServiceBase,
SubmitTrainingJobRequest,
SubmitTrainingJobResponse,
TrainingJobMetadata,
)
from viam.proto.common import DoCommandRequest, DoCommandResponse, GeoObstacle, GeoPoint, PointCloudObject, Pose, PoseInFrame, ResourceName
from viam.proto.service.mlmodel import (
FlatTensor,
Expand Down Expand Up @@ -765,6 +777,42 @@ async def StreamingDataCaptureUpload(
raise NotImplementedError()


class MockMLTraining(MLTrainingServiceBase):
def __init__(self, job_id: str, training_metadata: TrainingJobMetadata):
self.job_id = job_id
self.training_metadata = training_metadata

async def SubmitTrainingJob(self, stream: Stream[SubmitTrainingJobRequest, SubmitTrainingJobResponse]) -> None:
request = await stream.recv_message()
assert request is not None
self.filter = request.filter
self.org_id = request.organization_id
self.model_name = request.model_name
self.model_version = request.model_version
self.model_type = request.model_type
self.tags = request.tags
await stream.send_message(SubmitTrainingJobResponse(id=self.job_id))

async def GetTrainingJob(self, stream: Stream[GetTrainingJobRequest, GetTrainingJobResponse]) -> None:
request = await stream.recv_message()
assert request is not None
self.training_job_id = request.id
await stream.send_message(GetTrainingJobResponse(metadata=self.training_metadata))

async def ListTrainingJobs(self, stream: Stream[ListTrainingJobsRequest, ListTrainingJobsResponse]) -> None:
request = await stream.recv_message()
assert request is not None
self.training_status = request.status
self.org_id = request.organization_id
await stream.send_message(ListTrainingJobsResponse(jobs=[self.training_metadata]))

async def CancelTrainingJob(self, stream: Stream[CancelTrainingJobRequest, CancelTrainingJobResponse]) -> None:
request = await stream.recv_message()
assert request is not None
self.cancel_job_id = request.id
await stream.send_message(CancelTrainingJobResponse())


class MockApp(AppServiceBase):
def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_data_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from .mocks.services import MockData
from viam.utils import create_filter

INCLUDE_BINARY = True
COMPONENT_NAME = "component_name"
Expand All @@ -41,7 +42,7 @@
END_DATETIME = END_TS.ToDatetime()
TAGS = ["tag"]
BBOX_LABELS = ["bbox_label"]
FILTER = DataClient.create_filter(
FILTER = create_filter(
component_name=COMPONENT_NAME,
component_type=COMPONENT_TYPE,
method=METHOD,
Expand Down
Loading