Skip to content

add on_message callback with user data #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,11 @@ def __init__(
# List of subscribed topics, used for tracking
self._subscribed_topics: List[str] = []
self._on_message_filtered = MQTTMatcher()
self._on_message_filtered_user_data = MQTTMatcher()

# Default topic callback methods
self._on_message = None
self.on_message_user_data = None
self.on_connect = None
self.on_disconnect = None
self.on_publish = None
Expand Down Expand Up @@ -418,6 +420,18 @@ def add_topic_callback(self, mqtt_topic: str, callback_method) -> None:
raise ValueError("MQTT topic and callback method must both be defined.")
self._on_message_filtered[mqtt_topic] = callback_method

def add_topic_callback_user_data(self, mqtt_topic: str, callback_method) -> None:
"""Registers a callback_method for a specific MQTT topic.

:param str mqtt_topic: MQTT topic identifier.
:param function callback_method: The callback method with user data.
"""
if mqtt_topic is None or callback_method is None or self._user_data is None:
raise ValueError(
"MQTT topic, callback method and user data must both be defined."
)
self._on_message_filtered_user_data[mqtt_topic] = callback_method

def remove_topic_callback(self, mqtt_topic: str) -> None:
"""Removes a registered callback method.

Expand All @@ -427,10 +441,24 @@ def remove_topic_callback(self, mqtt_topic: str) -> None:
raise ValueError("MQTT Topic must be defined.")
try:
del self._on_message_filtered[mqtt_topic]
except KeyError:
raise KeyError(
except KeyError as exc:
raise MMQTTException(
"MQTT topic callback not added with add_topic_callback."
) from None
) from exc

def remove_topic_callback_user_data(self, mqtt_topic: str) -> None:
"""Removes a registered callback method with user data.

:param str mqtt_topic: MQTT topic identifier string.
"""
if mqtt_topic is None:
raise ValueError("MQTT Topic must be defined.")
try:
del self._on_message_filtered_user_data[mqtt_topic]
except KeyError as exc:
raise MMQTTException(
"MQTT topic callback not added with add_topic_callback_user_data."
) from exc

@property
def on_message(self):
Expand All @@ -451,8 +479,15 @@ def _handle_on_message(self, topic: str, message: str):
callback(self, topic, message) # on_msg with callback
matched = True

if not matched and self.on_message: # regular on_message
self.on_message(self, topic, message)
for callback in self._on_message_filtered_user_data.iter_match(topic):
callback(self, self._user_data, topic, message) # on_msg with callback
matched = True

if not matched: # regular on_message
if self.on_message:
self.on_message(self, topic, message)
if self.on_message_user_data:
self.on_message_user_data(self, self._user_data, topic, message)

def username_pw_set(self, username: str, password: Optional[str] = None) -> None:
"""Set client's username and an optional password.
Expand Down
77 changes: 77 additions & 0 deletions tests/test_handle_on_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
#
# SPDX-License-Identifier: Unlicense

"""_handle_on_message() tests"""

import socket
import ssl
from unittest import TestCase, main
from unittest.mock import MagicMock

import adafruit_minimqtt.adafruit_minimqtt as MQTT


class OnMessage(TestCase):
"""unit tests for _handle_on_message()"""

# pylint: disable=no-self-use
def test_handle_on_message(self) -> None:
"""
test that _handle_on_message() calls both regular message handlers if set.
"""

host = "172.40.0.3"
port = 1883

user_data = "regular"
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
socket_pool=socket,
ssl_context=ssl.create_default_context(),
user_data=user_data,
)

mqtt_client.on_message_user_data = MagicMock()
mqtt_client.on_message = MagicMock()

topic = "devices/foo/bar"
message = '{"foo": "bar"}'
# pylint: disable=protected-access
mqtt_client._handle_on_message(topic, message)
mqtt_client.on_message.assert_called_with(mqtt_client, topic, message)
mqtt_client.on_message_user_data.assert_called_with(
mqtt_client, user_data, topic, message
)

# pylint: disable=no-self-use
def test_handle_on_message_filtered(self) -> None:
"""
test that _handle_on_message() calls the callback for filtered topic if set.
"""

host = "172.40.0.3"
port = 1883

user_data = "filtered"
mqtt_client = MQTT.MQTT(
broker=host,
port=port,
socket_pool=socket,
ssl_context=ssl.create_default_context(),
user_data=user_data,
)

topic = "devices/foo/bar"
mock_callback = MagicMock()
mqtt_client.add_topic_callback_user_data(topic, mock_callback)

message = '{"foo": "bar"}'
# pylint: disable=protected-access
mqtt_client._handle_on_message(topic, message)
mock_callback.assert_called_with(mqtt_client, user_data, topic, message)


if __name__ == "__main__":
main()