|
1 | 1 | import json |
| 2 | +import gevent |
2 | 3 | from geventwebsocket import WebSocketApplication |
3 | 4 | from promise import Promise |
4 | 5 |
|
|
13 | 14 | INIT_FAIL = 'init_fail' |
14 | 15 | GRAPHQL_SUBSCRIPTIONS = 'graphql-subscriptions' |
15 | 16 |
|
16 | | -# TODO: Implement 'keep_alive' message sent to client that is in |
17 | | -# apollo subscription-transport constructor |
18 | | - |
19 | 17 |
|
20 | 18 | class ApolloSubscriptionServer(WebSocketApplication): |
21 | 19 |
|
22 | | - def __init__(self, subscription_manager, websocket): |
| 20 | + def __init__(self, subscription_manager, websocket, keep_alive=None, |
| 21 | + on_subscribe=None, on_unsubscribe=None, on_connect=None, |
| 22 | + on_disconnect=None): |
| 23 | + |
23 | 24 | assert subscription_manager, "Must provide\ |
24 | 25 | 'subscription_manager' to websocket app constructor" |
| 26 | + |
25 | 27 | self.subscription_manager = subscription_manager |
| 28 | + self.on_subscribe = on_subscribe |
| 29 | + self.on_unsubscribe = on_unsubscribe |
| 30 | + self.on_connect = on_connect |
| 31 | + self.on_disconnect = on_disconnect |
| 32 | + self.keep_alive = keep_alive |
26 | 33 | self.connection_subscriptions = {} |
27 | 34 | self.connection_context = {} |
| 35 | + |
28 | 36 | super(ApolloSubscriptionServer, self).__init__(websocket) |
29 | 37 |
|
| 38 | + def timer(self, callback, period): |
| 39 | + while True: |
| 40 | + callback() |
| 41 | + gevent.sleep(period) |
| 42 | + |
30 | 43 | def unsubscribe(self, graphql_sub_id): |
31 | 44 | self.subscription_manager.unsubscribe(graphql_sub_id) |
32 | 45 |
|
| 46 | + if self.on_unsubscribe: |
| 47 | + self.on_unsubscribe(self.ws) |
| 48 | + |
33 | 49 | def on_open(self): |
34 | 50 | if self.ws.protocol is None or (GRAPHQL_SUBSCRIPTIONS not in self.ws.protocol): |
35 | 51 | self.ws.close(1002) |
36 | 52 |
|
| 53 | + def keep_alive_callback(): |
| 54 | + if not self.ws.closed: |
| 55 | + self.send_keep_alive() |
| 56 | + else: |
| 57 | + gevent.kill(keep_alive_timer) |
| 58 | + |
| 59 | + if self.keep_alive: |
| 60 | + keep_alive_timer = gevent.spawn(self.timer, keep_alive_callback, |
| 61 | + self.keep_alive) |
| 62 | + |
37 | 63 | def on_close(self, reason): |
38 | | - for sub_id in self.connection_subscriptions.keys(): |
| 64 | + for sub_id in self.connection_subscriptions.viewkeys(): |
39 | 65 | self.unsubscribe(self.connection_subscriptions[sub_id]) |
40 | 66 | del self.connection_subscriptions[sub_id] |
41 | 67 |
|
| 68 | + if self.on_disconnect: |
| 69 | + self.on_disconnect(self.ws) |
| 70 | + |
42 | 71 | def on_message(self, msg): |
43 | 72 | if msg is None: |
44 | 73 | return |
@@ -67,6 +96,12 @@ def on_message_return_handler(message): |
67 | 96 | if parsed_message.get('type') == INIT: |
68 | 97 |
|
69 | 98 | on_connect_promise = Promise.resolve(True) |
| 99 | + |
| 100 | + if self.on_connect: |
| 101 | + on_connect_promise = Promise.resolve(self.on_connect( |
| 102 | + parsed_message.get('payload'), self.ws |
| 103 | + )) |
| 104 | + |
70 | 105 | nonlocal.on_init_resolve(on_connect_promise) |
71 | 106 |
|
72 | 107 | def init_success_promise_handler(result): |
@@ -99,6 +134,13 @@ def subscription_start_promise_handler(init_result): |
99 | 134 | } |
100 | 135 | promised_params = Promise.resolve(base_params) |
101 | 136 |
|
| 137 | + if self.on_subscribe: |
| 138 | + promised_params = Promise.resolve(self.on_subscribe( |
| 139 | + parsed_message, |
| 140 | + base_params, |
| 141 | + self.ws |
| 142 | + )) |
| 143 | + |
102 | 144 | if self.connection_subscriptions.get(sub_id): |
103 | 145 | self.unsubscribe(self.connection_subscriptions[sub_id]) |
104 | 146 | del self.connection_subscriptions[sub_id] |
|
0 commit comments