Skip to content

Add missing context kwarg to _sentry_task_factory #2267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions sentry_sdk/integrations/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

if TYPE_CHECKING:
from typing import Any
from collections.abc import Coroutine

from sentry_sdk._types import ExcInfo

Expand All @@ -37,8 +38,8 @@ def patch_asyncio():
loop = asyncio.get_running_loop()
orig_task_factory = loop.get_task_factory()

def _sentry_task_factory(loop, coro):
# type: (Any, Any) -> Any
def _sentry_task_factory(loop, coro, **kwargs):
# type: (asyncio.AbstractEventLoop, Coroutine[Any, Any, Any], Any) -> asyncio.Future[Any]

async def _coro_creating_hub_and_span():
# type: () -> Any
Expand All @@ -56,7 +57,7 @@ async def _coro_creating_hub_and_span():

# Trying to use user set task factory (if there is one)
if orig_task_factory:
return orig_task_factory(loop, _coro_creating_hub_and_span())
return orig_task_factory(loop, _coro_creating_hub_and_span(), **kwargs)

# The default task factory in `asyncio` does not have its own function
# but is just a couple of lines in `asyncio.base_events.create_task()`
Expand All @@ -65,13 +66,13 @@ async def _coro_creating_hub_and_span():
# WARNING:
# If the default behavior of the task creation in asyncio changes,
# this will break!
task = Task(_coro_creating_hub_and_span(), loop=loop)
task = Task(_coro_creating_hub_and_span(), loop=loop, **kwargs)
if task._source_traceback: # type: ignore
del task._source_traceback[-1] # type: ignore

return task

loop.set_task_factory(_sentry_task_factory)
loop.set_task_factory(_sentry_task_factory) # type: ignore
except RuntimeError:
# When there is no running loop, we have nothing to patch.
pass
Expand Down
200 changes: 199 additions & 1 deletion tests/integrations/asyncio/test_asyncio_py3.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
import asyncio
import inspect
import sys

import pytest

import sentry_sdk
from sentry_sdk.consts import OP
from sentry_sdk.integrations.asyncio import AsyncioIntegration
from sentry_sdk.integrations.asyncio import AsyncioIntegration, patch_asyncio

try:
from unittest.mock import MagicMock, patch
except ImportError:
from mock import MagicMock, patch

try:
from contextvars import Context, ContextVar
except ImportError:
pass # All tests will be skipped with incompatible versions


minimum_python_37 = pytest.mark.skipif(
sys.version_info < (3, 7), reason="Asyncio tests need Python >= 3.7"
)


minimum_python_311 = pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Asyncio task context parameter was introduced in Python 3.11",
)


async def foo():
await asyncio.sleep(0.01)

Expand All @@ -33,6 +50,17 @@ def event_loop(request):
loop.close()


def get_sentry_task_factory(mock_get_running_loop):
"""
Patches (mocked) asyncio and gets the sentry_task_factory.
"""
mock_loop = mock_get_running_loop.return_value
patch_asyncio()
patched_factory = mock_loop.set_task_factory.call_args[0][0]

return patched_factory


@minimum_python_37
@pytest.mark.asyncio
async def test_create_task(
Expand Down Expand Up @@ -170,3 +198,173 @@ async def add(a, b):

result = await asyncio.create_task(add(1, 2))
assert result == 3, result


@minimum_python_311
@pytest.mark.asyncio
async def test_task_with_context(sentry_init):
"""
Integration test to ensure working context parameter in Python 3.11+
"""
sentry_init(
integrations=[
AsyncioIntegration(),
],
)

var = ContextVar("var")
var.set("original value")

async def change_value():
var.set("changed value")

async def retrieve_value():
return var.get()

# Create a context and run both tasks within the context
ctx = Context()
async with asyncio.TaskGroup() as tg:
tg.create_task(change_value(), context=ctx)
retrieve_task = tg.create_task(retrieve_value(), context=ctx)

assert retrieve_task.result() == "changed value"


@minimum_python_37
@patch("asyncio.get_running_loop")
def test_patch_asyncio(mock_get_running_loop):
"""
Test that the patch_asyncio function will patch the task factory.
"""
mock_loop = mock_get_running_loop.return_value

patch_asyncio()

assert mock_loop.set_task_factory.called

set_task_factory_args, _ = mock_loop.set_task_factory.call_args
assert len(set_task_factory_args) == 1

sentry_task_factory, *_ = set_task_factory_args
assert callable(sentry_task_factory)


@minimum_python_37
@pytest.mark.forked
@patch("asyncio.get_running_loop")
@patch("sentry_sdk.integrations.asyncio.Task")
def test_sentry_task_factory_no_factory(MockTask, mock_get_running_loop): # noqa: N803
mock_loop = mock_get_running_loop.return_value
mock_coro = MagicMock()

# Set the original task factory to None
mock_loop.get_task_factory.return_value = None

# Retieve sentry task factory (since it is an inner function within patch_asyncio)
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)

# The call we are testing
ret_val = sentry_task_factory(mock_loop, mock_coro)

assert MockTask.called
assert ret_val == MockTask.return_value

task_args, task_kwargs = MockTask.call_args
assert len(task_args) == 1

coro_param, *_ = task_args
assert inspect.iscoroutine(coro_param)

assert "loop" in task_kwargs
assert task_kwargs["loop"] == mock_loop


@minimum_python_37
@pytest.mark.forked
@patch("asyncio.get_running_loop")
def test_sentry_task_factory_with_factory(mock_get_running_loop):
mock_loop = mock_get_running_loop.return_value
mock_coro = MagicMock()

# The original task factory will be mocked out here, let's retrieve the value for later
orig_task_factory = mock_loop.get_task_factory.return_value

# Retieve sentry task factory (since it is an inner function within patch_asyncio)
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)

# The call we are testing
ret_val = sentry_task_factory(mock_loop, mock_coro)

assert orig_task_factory.called
assert ret_val == orig_task_factory.return_value

task_factory_args, _ = orig_task_factory.call_args
assert len(task_factory_args) == 2

loop_arg, coro_arg = task_factory_args
assert loop_arg == mock_loop
assert inspect.iscoroutine(coro_arg)


@minimum_python_311
@patch("asyncio.get_running_loop")
@patch("sentry_sdk.integrations.asyncio.Task")
def test_sentry_task_factory_context_no_factory(
MockTask, mock_get_running_loop # noqa: N803
):
mock_loop = mock_get_running_loop.return_value
mock_coro = MagicMock()
mock_context = MagicMock()

# Set the original task factory to None
mock_loop.get_task_factory.return_value = None

# Retieve sentry task factory (since it is an inner function within patch_asyncio)
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)

# The call we are testing
ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context)

assert MockTask.called
assert ret_val == MockTask.return_value

task_args, task_kwargs = MockTask.call_args
assert len(task_args) == 1

coro_param, *_ = task_args
assert inspect.iscoroutine(coro_param)

assert "loop" in task_kwargs
assert task_kwargs["loop"] == mock_loop
assert "context" in task_kwargs
assert task_kwargs["context"] == mock_context


@minimum_python_311
@patch("asyncio.get_running_loop")
def test_sentry_task_factory_context_with_factory(mock_get_running_loop):
mock_loop = mock_get_running_loop.return_value
mock_coro = MagicMock()
mock_context = MagicMock()

# The original task factory will be mocked out here, let's retrieve the value for later
orig_task_factory = mock_loop.get_task_factory.return_value

# Retieve sentry task factory (since it is an inner function within patch_asyncio)
sentry_task_factory = get_sentry_task_factory(mock_get_running_loop)

# The call we are testing
ret_val = sentry_task_factory(mock_loop, mock_coro, context=mock_context)

assert orig_task_factory.called
assert ret_val == orig_task_factory.return_value

task_factory_args, task_factory_kwargs = orig_task_factory.call_args
assert len(task_factory_args) == 2

loop_arg, coro_arg = task_factory_args
assert loop_arg == mock_loop
assert inspect.iscoroutine(coro_arg)

assert "context" in task_factory_kwargs
assert task_factory_kwargs["context"] == mock_context