Skip to content

Commit 32f8083

Browse files
committed
Refactor _on_disconnect_cb to handle varying MQTT version signatures
1 parent 2fc6f0f commit 32f8083

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

paradox/interfaces/mqtt/core.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import asyncio
2+
from enum import Enum
23
import logging
34
import os
45
import socket
56
import ssl
67
import sys
78
import time
89
import typing
9-
from enum import Enum
1010

11-
from paho.mqtt.client import LOGGING_LEVEL, MQTT_ERR_SUCCESS, Client, connack_string, MQTTv311, MQTTv31, MQTTv5
11+
from paho.mqtt.client import (
12+
LOGGING_LEVEL,
13+
MQTT_ERR_SUCCESS,
14+
Client,
15+
MQTTv5,
16+
MQTTv31,
17+
MQTTv311,
18+
connack_string,
19+
)
1220

1321
from paradox.config import config as cfg
1422
from paradox.data.enums import RunState
@@ -34,13 +42,10 @@ class ConnectionState(Enum):
3442
RunState.STOP: "stopped",
3543
}
3644

37-
protocol_map = {
38-
"3.1": MQTTv31,
39-
"3.1.1": MQTTv311,
40-
"5": MQTTv5
41-
}
45+
protocol_map = {"3.1": MQTTv31, "3.1.1": MQTTv311, "5": MQTTv5}
46+
4247

43-
class MQTTConnection():
48+
class MQTTConnection:
4449
client: Client
4550
_instance = None
4651

@@ -53,7 +58,7 @@ def get_instance(cls) -> "MQTTConnection":
5358

5459
def __init__(self):
5560
self.client = Client(
56-
"pai"+os.urandom(8).hex(),
61+
"pai" + os.urandom(8).hex(),
5762
protocol=protocol_map.get(str(cfg.MQTT_PROTOCOL), MQTTv311),
5863
transport=cfg.MQTT_TRANSPORT,
5964
)
@@ -78,7 +83,9 @@ def __init__(self):
7883
self.registrars = []
7984

8085
if cfg.MQTT_USERNAME is not None and cfg.MQTT_PASSWORD is not None:
81-
self.client.username_pw_set(username=cfg.MQTT_USERNAME, password=cfg.MQTT_PASSWORD)
86+
self.client.username_pw_set(
87+
username=cfg.MQTT_USERNAME, password=cfg.MQTT_PASSWORD
88+
)
8289

8390
if cfg.MQTT_TLS_CERT_PATH is not None:
8491
self.client.tls_set(
@@ -149,7 +156,7 @@ def stop(self):
149156
logger.info("MQTT loop stopped")
150157

151158
def publish(self, topic, payload=None, *args, **kwargs):
152-
logger.debug("MQTT: {}={}".format(topic, payload))
159+
logger.debug(f"MQTT: {topic}={payload}")
153160

154161
self.client.publish(topic, payload, *args, **kwargs)
155162

@@ -160,7 +167,7 @@ def _call_registars(self, method, *args, **kwargs):
160167
getattr(r, method), typing.Callable
161168
):
162169
getattr(r, method)(*args, **kwargs)
163-
except:
170+
except Exception:
164171
logger.exception(
165172
'Failed to call "%s" on "%s"', method, r.__class__.__name__
166173
)
@@ -198,10 +205,15 @@ def _on_connect_cb(self, client, userdata, flags, result, properties=None):
198205
self._report_pai_status(self._last_pai_status)
199206
self._call_registars("on_connect", client, userdata, flags, result)
200207
else:
201-
logger.error(f"Failed to connect to MQTT: {connack_string(result)} ({result})")
208+
logger.error(
209+
f"Failed to connect to MQTT: {connack_string(result)} ({result})"
210+
)
202211

203-
def _on_disconnect_cb(self, client, userdata, rc, properties=None):
212+
def _on_disconnect_cb(self, client, userdata, *args, **kwargs):
204213
# called on Thread-6
214+
# Handle different MQTT version signatures by using the first argument as rc
215+
rc = args[0] if args else MQTT_ERR_SUCCESS
216+
205217
if rc == MQTT_ERR_SUCCESS:
206218
logger.info("MQTT Broker Disconnected")
207219
else:
@@ -237,7 +249,7 @@ def start(self):
237249
logger.debug("Registars: %d", len(self.mqtt.registrars))
238250

239251
def stop(self):
240-
""" Stops the MQTT Interface Thread"""
252+
"""Stops the MQTT Interface Thread"""
241253

242254
def stop_loop():
243255
self.republish_task.cancel()
@@ -261,7 +273,7 @@ async def republish_loop(self):
261273
self.publish(k, v["value"], v["qos"], v["retain"])
262274

263275
def _run(self):
264-
super(AbstractMQTTInterface, self)._run()
276+
super()._run()
265277

266278
self.loop = asyncio.new_event_loop()
267279
asyncio.set_event_loop(self.loop)
@@ -294,9 +306,9 @@ def subscribe_callback(self, sub, callback: typing.Callable):
294306
self.mqtt.subscribe(sub)
295307

296308
def on_disconnect(self, client, userdata, rc):
297-
""" Called from MQTT connection """
309+
"""Called from MQTT connection"""
298310
pass
299311

300312
def on_connect(self, client, userdata, flags, result):
301-
""" Called from MQTT connection """
313+
"""Called from MQTT connection"""
302314
pass

0 commit comments

Comments
 (0)