Skip to content

Commit afc6c64

Browse files
author
Rena Chen
committed
Update async wrapper
1 parent 8f2e0d1 commit afc6c64

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

sdk/batch/azure-batch/azure/batch/_operations/_patch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
88
"""
99
import datetime
10-
from typing import Any, List, Optional
10+
from typing import Any, List, Optional, Iterator
1111
import collections
1212
import logging
1313
import threading
@@ -134,7 +134,7 @@ def get_node_file(
134134
if_unmodified_since: Optional[datetime.datetime] = None,
135135
ocp_range: Optional[str] = None,
136136
**kwargs: Any
137-
) -> bytes:
137+
) -> Iterator[bytes]:
138138
"""Returns the content of the specified Compute Node file.
139139
140140
:param pool_id: The ID of the Pool that contains the Compute Node. Required.
@@ -321,7 +321,7 @@ def get_task_file(
321321
if_unmodified_since: Optional[datetime.datetime] = None,
322322
ocp_range: Optional[str] = None,
323323
**kwargs: Any
324-
) -> bytes:
324+
) -> Iterator[bytes]:
325325
"""Returns the content of the specified Task file.
326326
327327
:param job_id: The ID of the Job that contains the Task. Required.

sdk/batch/azure-batch/azure/batch/aio/_operations/_patch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import datetime
1111
import collections
1212
import logging
13-
from typing import Any, List, Optional
13+
from typing import Any, AsyncIterator, List, Optional
1414

1515
from azure.batch import models as _models
1616
from azure.core import MatchConditions
@@ -125,7 +125,7 @@ async def get_node_file(
125125
if_unmodified_since: Optional[datetime.datetime] = None,
126126
ocp_range: Optional[str] = None,
127127
**kwargs: Any
128-
) -> bytes:
128+
) -> AsyncIterator[bytes]:
129129
"""Returns the content of the specified Compute Node file.
130130
131131
:param pool_id: The ID of the Pool that contains the Compute Node. Required.
@@ -312,7 +312,7 @@ async def get_task_file(
312312
if_unmodified_since: Optional[datetime.datetime] = None,
313313
ocp_range: Optional[str] = None,
314314
**kwargs: Any
315-
) -> bytes:
315+
) -> AsyncIterator[bytes]:
316316
"""Returns the content of the specified Task file.
317317
318318
:param job_id: The ID of the Job that contains the Task. Required.

sdk/batch/azure-batch/tests/async_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import AsyncIterable, AsyncIterator, Any, cast, Coroutine, Iterable, Iterator, TypeVar, Union
2+
from typing import Any, AsyncIterator, AsyncIterable, cast, Coroutine, Iterable, Iterator, TypeVar, Union
33

44
T = TypeVar("T")
55

sdk/batch/azure-batch/tests/test_batch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Iterable, Union
2222

2323
from batch_preparers import AccountPreparer, PoolPreparer, JobPreparer
24-
from async_wrapper import wrap_list_result, wrap_result
24+
from async_wrapper import wrap_list_result, wrap_result, wrap_file_result
2525
from decorators import recorded_by_proxy_async, client_setup
2626

2727
from devtools_testutils import (
@@ -751,6 +751,7 @@ async def test_batch_files(self, client: BatchClient, **kwargs):
751751
nodes = list(await wrap_list_result(client.list_nodes(batch_pool.name)))
752752
assert len(nodes) == 1
753753
node = nodes[0].id
754+
assert node is not None
754755
task_id = "test_task"
755756
task_param = models.BatchTaskCreateContent(id=task_id, command_line='cmd /c "echo hello world"')
756757
response = await wrap_result(client.create_task(batch_job.id, task_param))
@@ -773,7 +774,7 @@ async def test_batch_files(self, client: BatchClient, **kwargs):
773774
# Test Get File from Batch Node
774775
file_length = 0
775776
with io.BytesIO() as file_handle:
776-
response = await wrap_result(client.get_node_file(batch_pool.name, node, only_files[0].name))
777+
response = await wrap_file_result(client.get_node_file(batch_pool.name, node, only_files[0].name))
777778
for data in response:
778779
file_length += 1
779780
assert file_length == int(props.headers["Content-Length"])
@@ -797,7 +798,7 @@ async def test_batch_files(self, client: BatchClient, **kwargs):
797798
# Test Get File from Task
798799
file_length = 0
799800
with io.BytesIO() as file_handle:
800-
response = await wrap_result(client.get_task_file(batch_job.id, task_id, only_files[0].name))
801+
response = await wrap_file_result(client.get_task_file(batch_job.id, task_id, only_files[0].name))
801802
for data in response:
802803
file_length += len(data)
803804
assert file_length == int(props.headers["Content-Length"])

0 commit comments

Comments
 (0)