Skip to content

Commit f7c095e

Browse files
authored
Merge pull request #12 from hballard/tests
Added tests for subscription_manager module
2 parents 4326eef + 39990e7 commit f7c095e

7 files changed

Lines changed: 732 additions & 177 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This is a implementation of apollographql [subscriptions-transport-ws](https://
66

77
Meant to be used in conjunction with [graphql-python](https://github.com/graphql-python) / [graphene](http://graphene-python.org/) server and [apollo-client](http://dev.apollodata.com/) for graphql. The api is below, but if you want more information, consult the apollo graphql libraries referenced above.
88

9-
Initial implementation. Currently only works with Python 2. No tests yet.
9+
Initial implementation. Currently only works with Python 2.
1010

1111
## Installation
1212
```
@@ -39,7 +39,7 @@ $ pip install graphql-subscriptions
3939
args = kwargs.get('args')
4040
return {
4141
'new_user_channel': {
42-
'filter': lambda user, context: user.active == args.active
42+
'filter': lambda root, context: root.active == args.active
4343
}
4444
}
4545

graphql_subscriptions/subscription_manager.py

Lines changed: 58 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
import redis
2-
import gevent
31
import cPickle
42
from types import FunctionType
5-
from promise import Promise
3+
64
from graphql import parse, validate, specified_rules, value_from_ast, execute
75
from graphql.language.ast import OperationDefinition
6+
from promise import Promise
7+
import gevent
8+
import redis
89

10+
from .utils import to_snake_case
11+
from .validation import SubscriptionHasSingleRootField
912

10-
class RedisPubsub(object):
1113

14+
class RedisPubsub(object):
1215
def __init__(self, host='localhost', port=6379, *args, **kwargs):
1316
redis.connection.socket = gevent.socket
1417
self.redis = redis.StrictRedis(host, port, *args, **kwargs)
@@ -29,13 +32,10 @@ def subscribe(self, trigger_name, on_message_handler, options):
2932
except IndexError:
3033
self.pubsub.subscribe(trigger_name)
3134
self.subscriptions[self.sub_id_counter] = [
32-
trigger_name,
33-
on_message_handler
35+
trigger_name, on_message_handler
3436
]
3537
if not self.greenlet:
36-
self.greenlet = gevent.spawn(
37-
self.wait_and_get_message
38-
)
38+
self.greenlet = gevent.spawn(self.wait_and_get_message)
3939
return Promise.resolve(self.sub_id_counter)
4040

4141
def unsubscribe(self, sub_id):
@@ -63,14 +63,12 @@ def handle_message(self, message):
6363

6464

6565
class ValidationError(Exception):
66-
6766
def __init__(self, errors):
6867
self.errors = errors
6968
self.message = 'Subscription query has validation errors'
7069

7170

7271
class SubscriptionManager(object):
73-
7472
def __init__(self, schema, pubsub, setup_funcs={}):
7573
self.schema = schema
7674
self.pubsub = pubsub
@@ -84,16 +82,11 @@ def publish(self, trigger_name, payload):
8482
def subscribe(self, query, operation_name, callback, variables, context,
8583
format_error, format_response):
8684
parsed_query = parse(query)
87-
errors = validate(
88-
self.schema,
89-
parsed_query,
90-
# TODO: Need to create/add subscriptionHasSingleRootField
91-
# rule from apollo subscription manager package
92-
rules=specified_rules
93-
)
85+
rules = specified_rules + [SubscriptionHasSingleRootField]
86+
errors = validate(self.schema, parsed_query, rules=rules)
9487

9588
if errors:
96-
return Promise.reject(ValidationError(errors))
89+
return Promise.rejected(ValidationError(errors))
9790

9891
args = {}
9992

@@ -110,29 +103,25 @@ def subscribe(self, query, operation_name, callback, variables, context,
110103
for arg in root_field.arguments:
111104

112105
arg_definition = [
113-
arg_def for _, arg_def in
114-
fields.get(subscription_name).args.iteritems() if
115-
arg_def.out_name == arg.name.value
106+
arg_def
107+
for _, arg_def in fields.get(subscription_name)
108+
.args.iteritems() if arg_def.out_name == arg.name.value
116109
][0]
117110

118111
args[arg_definition.out_name] = value_from_ast(
119-
arg.value,
120-
arg_definition.type,
121-
variables=variables
122-
)
123-
124-
if self.setup_funcs.get(subscription_name):
125-
trigger_map = self.setup_funcs[subscription_name](
126-
query,
127-
operation_name,
128-
callback,
129-
variables,
130-
context,
131-
format_error,
132-
format_response,
133-
args,
134-
subscription_name
135-
)
112+
arg.value, arg_definition.type, variables=variables)
113+
114+
if self.setup_funcs.get(to_snake_case(subscription_name)):
115+
trigger_map = self.setup_funcs[to_snake_case(subscription_name)](
116+
query=query,
117+
operation_name=operation_name,
118+
callback=callback,
119+
variables=variables,
120+
context=context,
121+
format_error=format_error,
122+
format_response=format_response,
123+
args=args,
124+
subscription_name=subscription_name)
136125
else:
137126
trigger_map = {}
138127
trigger_map[subscription_name] = {}
@@ -143,71 +132,55 @@ def subscribe(self, query, operation_name, callback, variables, context,
143132
subscription_promises = []
144133

145134
for trigger_name in trigger_map.viewkeys():
146-
channel_options = trigger_map[trigger_name].get(
147-
'channel_options',
148-
{}
149-
)
150-
filter = trigger_map[trigger_name].get(
151-
'filter',
152-
lambda arg1, arg2: True
153-
)
135+
try:
136+
channel_options = trigger_map[trigger_name].get(
137+
'channel_options', {})
138+
filter = trigger_map[trigger_name].get('filter',
139+
lambda arg1, arg2: True)
140+
# TODO: Think about this some more...the Apollo library
141+
# let's all messages through by default, even if
142+
# the users incorrectly uses the setup_funcs (does not
143+
# use 'filter' or 'channel_options' keys); I think it
144+
# would be better to raise an exception here
145+
except AttributeError:
146+
channel_options = {}
147+
148+
def filter(arg1, arg2):
149+
return True
154150

155151
def on_message(root_value):
156-
157152
def context_promise_handler(result):
158153
if isinstance(context, FunctionType):
159154
return context()
160155
else:
161156
return context
162157

163158
def filter_func_promise_handler(context):
164-
return Promise.all([
165-
context,
166-
filter(root_value, context)
167-
])
159+
return Promise.all([context, filter(root_value, context)])
168160

169161
def context_do_execute_handler(result):
170162
context, do_execute = result
171163
if not do_execute:
172164
return
173165
else:
174-
return execute(
175-
self.schema,
176-
parsed_query,
177-
root_value,
178-
context,
179-
variables,
180-
operation_name
181-
)
182-
183-
return Promise.resolve(
184-
True
185-
).then(
186-
context_promise_handler
187-
).then(
188-
filter_func_promise_handler
189-
).then(
190-
context_do_execute_handler
191-
).then(
192-
lambda result: callback(None, result)
193-
).catch(
194-
lambda error: callback(error, None)
195-
)
166+
return execute(self.schema, parsed_query, root_value,
167+
context, variables, operation_name)
168+
169+
return Promise.resolve(True).then(
170+
context_promise_handler).then(
171+
filter_func_promise_handler).then(
172+
context_do_execute_handler).then(
173+
lambda result: callback(None, result)).catch(
174+
lambda error: callback(error, None))
196175

197176
subscription_promises.append(
198-
self.pubsub.subscribe(
199-
trigger_name,
200-
on_message,
201-
channel_options
202-
).then(
203-
lambda id: self.subscriptions[
204-
external_subscription_id].append(id)
205-
)
206-
)
177+
self.pubsub.
178+
subscribe(trigger_name, on_message, channel_options).then(
179+
lambda id: self.subscriptions[external_subscription_id].append(id)
180+
))
207181

208182
return Promise.all(subscription_promises).then(
209-
lambda result: external_subscription_id
210-
)
183+
lambda result: external_subscription_id)
211184

212185
def unsubscribe(self, sub_id):
213186
for internal_id in self.subscriptions.get(sub_id):

0 commit comments

Comments
 (0)