Skip to content

Commit 37c97af

Browse files
authored
Merge pull request #2588 from Dramex/fix/2572-retain-update-processing-tasks
Retain strong refs to update-processing tasks in AsyncTeleBot
2 parents 87924d4 + ea13186 commit 37c97af

2 files changed

Lines changed: 84 additions & 2 deletions

File tree

telebot/async_telebot.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[
193193

194194
self._user = None # set during polling
195195
self._polling = None
196+
# Strong references to background tasks created via asyncio.create_task().
197+
# asyncio only keeps weak references, so unreferenced tasks can be GC'd
198+
# mid-execution; see https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
199+
self._pending_tasks: set[asyncio.Task[Any]] = set()
196200
self.webhook_listener = None
197201

198202
if validate_token:
@@ -456,8 +460,10 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout:
456460
updates = await self.get_updates(offset=self.offset, allowed_updates=allowed_updates, timeout=timeout, request_timeout=request_timeout)
457461
if updates:
458462
self.offset = updates[-1].update_id + 1
459-
# noinspection PyAsyncCall
460-
asyncio.create_task(self.process_new_updates(updates)) # Seperate task for processing updates
463+
# Retain a strong reference so the task isn't GC'd mid-execution.
464+
task = asyncio.create_task(self.process_new_updates(updates))
465+
self._pending_tasks.add(task)
466+
task.add_done_callback(self._pending_tasks.discard)
461467
if interval: await asyncio.sleep(interval)
462468
error_interval = 0.25 # drop error_interval if no errors
463469

tests/test_async_telebot.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
"""Unit tests for `telebot.async_telebot.AsyncTeleBot`.
3+
4+
These tests are self-contained (no TOKEN/CHAT_ID required) and stub out all
5+
network I/O.
6+
"""
7+
import asyncio
8+
9+
from telebot import types
10+
from telebot.async_telebot import AsyncTeleBot
11+
12+
13+
def _make_fake_me() -> types.User:
14+
return types.User.de_json({
15+
"id": 1,
16+
"is_bot": True,
17+
"first_name": "Test",
18+
"username": "test_bot",
19+
})
20+
21+
22+
def test_process_polling_retains_update_processing_tasks():
23+
"""Regression test for issue #2572.
24+
25+
Tasks fired by `_process_polling` for `process_new_updates` must be held
26+
in `self._pending_tasks` while running and discarded on completion, so
27+
they cannot be garbage-collected mid-execution.
28+
"""
29+
bot = AsyncTeleBot("1:fake", validate_token=False)
30+
31+
task_was_tracked_during_run: list[bool] = []
32+
process_completed = asyncio.Event()
33+
34+
async def fake_process_new_updates(updates):
35+
current = asyncio.current_task()
36+
task_was_tracked_during_run.append(current in bot._pending_tasks)
37+
process_completed.set()
38+
39+
async def fake_get_me():
40+
return _make_fake_me()
41+
42+
# Deliver a single update batch, then stop polling on the next tick.
43+
fake_update = types.Update.de_json({"update_id": 1})
44+
call_count = {"n": 0}
45+
46+
async def fake_get_updates(*args, **kwargs):
47+
call_count["n"] += 1
48+
if call_count["n"] == 1:
49+
return [fake_update]
50+
bot._polling = False
51+
return []
52+
53+
async def noop():
54+
return None
55+
56+
bot.get_me = fake_get_me
57+
bot.get_updates = fake_get_updates
58+
bot.process_new_updates = fake_process_new_updates
59+
bot.close_session = noop # stub: no real aiohttp session in tests
60+
61+
async def driver():
62+
await bot._process_polling(non_stop=True, interval=0, timeout=0)
63+
# Allow the fire-and-forget task to finish plus one yield for the
64+
# add_done_callback discard to run. A timeout guards against the
65+
# stub ever being rewired such that the processing task never runs.
66+
await asyncio.wait_for(process_completed.wait(), timeout=1)
67+
await asyncio.sleep(0)
68+
69+
asyncio.run(driver())
70+
71+
assert task_was_tracked_during_run == [True], (
72+
"In-flight processing task must be held by _pending_tasks"
73+
)
74+
assert bot._pending_tasks == set(), (
75+
"Completed processing tasks must be discarded from _pending_tasks"
76+
)

0 commit comments

Comments
 (0)