1- import redis
2- import gevent
31import cPickle
42from types import FunctionType
5- from promise import Promise
3+
64from graphql import parse , validate , specified_rules , value_from_ast , execute
75from graphql .language .ast import OperationDefinition
6+ from promise import Promise
7+ import gevent
8+ import redis
89
10+ from .utils import to_snake_case
11+ from .validation import SubscriptionHasSingleRootField
912
10- class RedisPubsub (object ):
1113
14+ class RedisPubsub (object ):
1215 def __init__ (self , host = 'localhost' , port = 6379 , * args , ** kwargs ):
1316 redis .connection .socket = gevent .socket
1417 self .redis = redis .StrictRedis (host , port , * args , ** kwargs )
@@ -29,13 +32,10 @@ def subscribe(self, trigger_name, on_message_handler, options):
2932 except IndexError :
3033 self .pubsub .subscribe (trigger_name )
3134 self .subscriptions [self .sub_id_counter ] = [
32- trigger_name ,
33- on_message_handler
35+ trigger_name , on_message_handler
3436 ]
3537 if not self .greenlet :
36- self .greenlet = gevent .spawn (
37- self .wait_and_get_message
38- )
38+ self .greenlet = gevent .spawn (self .wait_and_get_message )
3939 return Promise .resolve (self .sub_id_counter )
4040
4141 def unsubscribe (self , sub_id ):
@@ -63,14 +63,12 @@ def handle_message(self, message):
6363
6464
6565class ValidationError (Exception ):
66-
6766 def __init__ (self , errors ):
6867 self .errors = errors
6968 self .message = 'Subscription query has validation errors'
7069
7170
7271class SubscriptionManager (object ):
73-
7472 def __init__ (self , schema , pubsub , setup_funcs = {}):
7573 self .schema = schema
7674 self .pubsub = pubsub
@@ -84,16 +82,11 @@ def publish(self, trigger_name, payload):
8482 def subscribe (self , query , operation_name , callback , variables , context ,
8583 format_error , format_response ):
8684 parsed_query = parse (query )
87- errors = validate (
88- self .schema ,
89- parsed_query ,
90- # TODO: Need to create/add subscriptionHasSingleRootField
91- # rule from apollo subscription manager package
92- rules = specified_rules
93- )
85+ rules = specified_rules + [SubscriptionHasSingleRootField ]
86+ errors = validate (self .schema , parsed_query , rules = rules )
9487
9588 if errors :
96- return Promise .reject (ValidationError (errors ))
89+ return Promise .rejected (ValidationError (errors ))
9790
9891 args = {}
9992
@@ -110,29 +103,25 @@ def subscribe(self, query, operation_name, callback, variables, context,
110103 for arg in root_field .arguments :
111104
112105 arg_definition = [
113- arg_def for _ , arg_def in
114- fields .get (subscription_name ). args . iteritems () if
115- arg_def .out_name == arg .name .value
106+ arg_def
107+ for _ , arg_def in fields .get (subscription_name )
108+ . args . iteritems () if arg_def .out_name == arg .name .value
116109 ][0 ]
117110
118111 args [arg_definition .out_name ] = value_from_ast (
119- arg .value ,
120- arg_definition .type ,
121- variables = variables
122- )
123-
124- if self .setup_funcs .get (subscription_name ):
125- trigger_map = self .setup_funcs [subscription_name ](
126- query ,
127- operation_name ,
128- callback ,
129- variables ,
130- context ,
131- format_error ,
132- format_response ,
133- args ,
134- subscription_name
135- )
112+ arg .value , arg_definition .type , variables = variables )
113+
114+ if self .setup_funcs .get (to_snake_case (subscription_name )):
115+ trigger_map = self .setup_funcs [to_snake_case (subscription_name )](
116+ query = query ,
117+ operation_name = operation_name ,
118+ callback = callback ,
119+ variables = variables ,
120+ context = context ,
121+ format_error = format_error ,
122+ format_response = format_response ,
123+ args = args ,
124+ subscription_name = subscription_name )
136125 else :
137126 trigger_map = {}
138127 trigger_map [subscription_name ] = {}
@@ -143,71 +132,55 @@ def subscribe(self, query, operation_name, callback, variables, context,
143132 subscription_promises = []
144133
145134 for trigger_name in trigger_map .viewkeys ():
146- channel_options = trigger_map [trigger_name ].get (
147- 'channel_options' ,
148- {}
149- )
150- filter = trigger_map [trigger_name ].get (
151- 'filter' ,
152- lambda arg1 , arg2 : True
153- )
135+ try :
136+ channel_options = trigger_map [trigger_name ].get (
137+ 'channel_options' , {})
138+ filter = trigger_map [trigger_name ].get ('filter' ,
139+ lambda arg1 , arg2 : True )
140+ # TODO: Think about this some more...the Apollo library
141+ # let's all messages through by default, even if
142+ # the users incorrectly uses the setup_funcs (does not
143+ # use 'filter' or 'channel_options' keys); I think it
144+ # would be better to raise an exception here
145+ except AttributeError :
146+ channel_options = {}
147+
148+ def filter (arg1 , arg2 ):
149+ return True
154150
155151 def on_message (root_value ):
156-
157152 def context_promise_handler (result ):
158153 if isinstance (context , FunctionType ):
159154 return context ()
160155 else :
161156 return context
162157
163158 def filter_func_promise_handler (context ):
164- return Promise .all ([
165- context ,
166- filter (root_value , context )
167- ])
159+ return Promise .all ([context , filter (root_value , context )])
168160
169161 def context_do_execute_handler (result ):
170162 context , do_execute = result
171163 if not do_execute :
172164 return
173165 else :
174- return execute (
175- self .schema ,
176- parsed_query ,
177- root_value ,
178- context ,
179- variables ,
180- operation_name
181- )
182-
183- return Promise .resolve (
184- True
185- ).then (
186- context_promise_handler
187- ).then (
188- filter_func_promise_handler
189- ).then (
190- context_do_execute_handler
191- ).then (
192- lambda result : callback (None , result )
193- ).catch (
194- lambda error : callback (error , None )
195- )
166+ return execute (self .schema , parsed_query , root_value ,
167+ context , variables , operation_name )
168+
169+ return Promise .resolve (True ).then (
170+ context_promise_handler ).then (
171+ filter_func_promise_handler ).then (
172+ context_do_execute_handler ).then (
173+ lambda result : callback (None , result )).catch (
174+ lambda error : callback (error , None ))
196175
197176 subscription_promises .append (
198- self .pubsub .subscribe (
199- trigger_name ,
200- on_message ,
201- channel_options
202- ).then (
203- lambda id : self .subscriptions [
204- external_subscription_id ].append (id )
205- )
206- )
177+ self .pubsub .
178+ subscribe (trigger_name , on_message , channel_options ).then (
179+ lambda id : self .subscriptions [external_subscription_id ].append (id )
180+ ))
207181
208182 return Promise .all (subscription_promises ).then (
209- lambda result : external_subscription_id
210- )
183+ lambda result : external_subscription_id )
211184
212185 def unsubscribe (self , sub_id ):
213186 for internal_id in self .subscriptions .get (sub_id ):
0 commit comments