Skip to content

Commit 98ad1ab

Browse files
committed
Add ZMQ Interface sample python client
1 parent 8cff816 commit 98ad1ab

3 files changed

Lines changed: 371 additions & 0 deletions

File tree

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import matplotlib.pyplot as plt
2+
import zmq
3+
import sys
4+
import numpy as np
5+
import json
6+
import uuid
7+
import time
8+
9+
10+
class OpenEphysEvent(object):
11+
event_types = {0: 'TIMESTAMP', 1: 'BUFFER_SIZE', 2: 'PARAMETER_CHANGE',
12+
3: 'TTL', 4: 'SPIKE', 5: 'MESSAGE', 6: 'BINARY_MSG'}
13+
14+
def __init__(self, _d, _data=None):
15+
self.type = None
16+
self.stream = ''
17+
self.sample_num = 0
18+
self.source_node = 0
19+
self.event_state = 0
20+
self.event_line = 0
21+
self.event_word = 0
22+
self.numBytes = 0
23+
self.data = b''
24+
self.__dict__.update(_d)
25+
self.timestamp = None
26+
# noinspection PyTypeChecker
27+
self.type = OpenEphysEvent.event_types[self.type]
28+
if _data:
29+
self.data = _data
30+
self.numBytes = len(_data)
31+
32+
dfb = np.frombuffer(self.data, dtype=np.uint8)
33+
self.event_line = dfb[0]
34+
35+
dfb = np.frombuffer(self.data, dtype=np.uint8, offset=1)
36+
self.event_state = dfb[0]
37+
38+
dfb = np.frombuffer(self.data, dtype=np.uint64, offset=2)
39+
self.event_word = dfb[0]
40+
if self.type == 'TIMESTAMP':
41+
t = np.frombuffer(self.data, dtype=np.int64)
42+
self.timestamp = t[0]
43+
44+
def set_data(self, _data):
45+
self.data = _data
46+
self.numBytes = len(_data)
47+
48+
def __str__(self):
49+
ds = self.__dict__.copy()
50+
del ds['data']
51+
return str(ds)
52+
53+
54+
class OpenEphysSpikeEvent(object):
55+
56+
def __init__(self, _d, _data=None):
57+
self.stream = ''
58+
self.source_node = 0
59+
self.electrode = 0
60+
self.sample_num = 0
61+
self.num_channels = 0
62+
self.num_samples = 0
63+
self.sorted_id = 0
64+
self.threshold = []
65+
66+
self.__dict__.update(_d)
67+
self.data = _data
68+
69+
def __str__(self):
70+
ds = self.__dict__.copy()
71+
del ds['data']
72+
return str(ds)
73+
74+
75+
class PlotProcess(object): # TODO more configuration stuff that may be obtained
76+
def __init__(self, ):
77+
# keep this slot for multiprocessing related initialization if needed
78+
self.context = zmq.Context()
79+
self.data_socket = None
80+
self.event_socket = None
81+
self.poller = zmq.Poller()
82+
self.message_num = -1
83+
self.socket_waits_reply = False
84+
self.event_no = 0
85+
self.app_name = 'Plot Process'
86+
self.uuid = str(uuid.uuid4())
87+
self.last_heartbeat_time = 0
88+
self.last_reply_time = time.time()
89+
self.isTesting = True
90+
91+
def startup(self):
92+
pass
93+
94+
@staticmethod
95+
def param_config():
96+
# TODO we'll have to pass the parameter requests via a second socket
97+
# this is meant to support a mechanism to set parameters of the application from the Open Ephys GUI.
98+
# not sure if it will be needed actually, it may disappear
99+
return ()
100+
101+
def update_plot(self, n_arr, sample_rate):
102+
pass
103+
104+
def update_plot_event(self, event):
105+
print(event)
106+
107+
# noinspection PyMethodMayBeStatic
108+
def update_plot_spike(self, spike):
109+
print(spike)
110+
111+
def send_heartbeat(self):
112+
d = {'application': self.app_name, 'uuid': self.uuid, 'type': 'heartbeat'}
113+
j_msg = json.dumps(d)
114+
print("sending heartbeat")
115+
self.event_socket.send(j_msg.encode('utf-8'))
116+
self.last_heartbeat_time = time.time()
117+
self.socket_waits_reply = True
118+
119+
def send_event(self, event_list=None, event_type=3, sample_num=0, event_id=2, event_channel=1):
120+
if not self.socket_waits_reply:
121+
self.event_no += 1
122+
if event_list:
123+
for e in event_list:
124+
self.send_event(event_type=e['event_type'], sample_num=e['sample_num'], event_id=e['event_id'],
125+
event_channel=e['event_channel'])
126+
else:
127+
de = {'type': event_type, 'sample_num': sample_num, 'event_id': event_id % 2 + 1,
128+
'event_channel': event_channel}
129+
d = {'application': self.app_name, 'uuid': self.uuid, 'type': 'event', 'event': de}
130+
j_msg = json.dumps(d)
131+
print(j_msg)
132+
if self.socket_waits_reply:
133+
print("Can't send event")
134+
else:
135+
self.event_socket.send(j_msg.encode('utf-8'), 0)
136+
self.socket_waits_reply = True
137+
self.last_reply_time = time.time()
138+
else:
139+
print("can't send event, still waiting for previous reply")
140+
141+
def callback(self):
142+
events = []
143+
144+
if not self.data_socket:
145+
print("init socket")
146+
self.data_socket = self.context.socket(zmq.SUB)
147+
self.data_socket.connect("tcp://localhost:5556")
148+
149+
self.event_socket = self.context.socket(zmq.REQ)
150+
self.event_socket.connect("tcp://localhost:5557")
151+
152+
# self.data_socket.connect("ipc://data.ipc")
153+
self.data_socket.setsockopt(zmq.SUBSCRIBE, b'')
154+
self.poller.register(self.data_socket, zmq.POLLIN)
155+
self.poller.register(self.event_socket, zmq.POLLIN)
156+
157+
# send every two seconds a "heartbeat" so that Open Ephys knows we're alive
158+
159+
if self.isTesting:
160+
if np.random.random() < 0.005:
161+
self.send_event(event_type=3, sample_num=0, event_id=self.event_no, event_channel=1)
162+
163+
while True:
164+
if (time.time() - self.last_heartbeat_time) > 2.:
165+
if self.socket_waits_reply:
166+
print("heartbeat haven't got reply, retrying...")
167+
self.last_heartbeat_time += 1.
168+
if (time.time() - self.last_reply_time) > 10.:
169+
# reconnecting the socket as per the "lazy pirate" pattern (see the ZeroMQ guide)
170+
print("looks like we lost the server, trying to reconnect")
171+
self.poller.unregister(self.event_socket)
172+
self.event_socket.close()
173+
self.event_socket = self.context.socket(zmq.REQ)
174+
self.event_socket.connect("tcp://localhost:5557")
175+
self.poller.register(self.event_socket)
176+
self.socket_waits_reply = False
177+
self.last_reply_time = time.time()
178+
else:
179+
self.send_heartbeat()
180+
181+
socks = dict(self.poller.poll(1))
182+
if not socks:
183+
# print("poll exits")
184+
break
185+
if self.data_socket in socks:
186+
try:
187+
message = self.data_socket.recv_multipart(zmq.NOBLOCK)
188+
except zmq.ZMQError as err:
189+
print("got error: {0}".format(err))
190+
break
191+
if message:
192+
if len(message) < 2:
193+
print("no frames for message: ", message[0])
194+
try:
195+
header = json.loads(message[1].decode('utf-8'))
196+
except ValueError as e:
197+
print("ValueError: ", e)
198+
print(message[1])
199+
if self.message_num != -1 and header['message_num'] != self.message_num + 1:
200+
print("missing a message at number", self.message_num)
201+
self.message_num = header['message_num']
202+
if header['type'] == 'data':
203+
c = header['content']
204+
num_samples = c['num_samples']
205+
channel_num = c['channel_num']
206+
sample_rate = c['sample_rate']
207+
208+
if channel_num == 1:
209+
try:
210+
n_arr = np.frombuffer(message[2], dtype=np.float32)
211+
n_arr = np.reshape(n_arr, num_samples)
212+
if num_samples > 0:
213+
self.update_plot(n_arr, sample_rate)
214+
except IndexError as e:
215+
print(e)
216+
print(header)
217+
print(message[1])
218+
if len(message) > 2:
219+
print(len(message[2]))
220+
else:
221+
print("only one frame???")
222+
223+
elif header['type'] == 'event':
224+
225+
if header['data_size'] > 0:
226+
event = OpenEphysEvent(header['content'], message[2])
227+
else:
228+
event = OpenEphysEvent(header['content'])
229+
self.update_plot_event(event)
230+
elif header['type'] == 'spike':
231+
spike = OpenEphysSpikeEvent(header['spike'], message[2])
232+
self.update_plot_spike(spike)
233+
234+
elif header['type'] == 'param':
235+
c = header['content']
236+
self.__dict__.update(c)
237+
print(c)
238+
else:
239+
raise ValueError("message type unknown")
240+
else:
241+
print("got not data")
242+
243+
break
244+
elif self.event_socket in socks and self.socket_waits_reply:
245+
message = self.event_socket.recv()
246+
print("event reply received")
247+
print(message)
248+
if self.socket_waits_reply:
249+
self.socket_waits_reply = False
250+
251+
else:
252+
print("???? getting a reply before a send?")
253+
# print "finishing callback"
254+
if events:
255+
pass # TODO implement the event passing
256+
257+
return True
258+
259+
@staticmethod
260+
def terminate():
261+
plt.close()
262+
sys.exit(0)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import sys
2+
import simple_plotter_zmq
3+
4+
if __name__ == '__main__':
5+
pl = simple_plotter_zmq.SimplePlotter(40000.)
6+
pl.startup()
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
2+
3+
import numpy as np
4+
# import matplotlib
5+
# matplotlib.use('QT4Agg')
6+
import matplotlib.pyplot as plt
7+
from matplotlib.widgets import Slider
8+
9+
from plot_process_zmq import PlotProcess
10+
11+
12+
class SimplePlotter(PlotProcess):
13+
def __init__(self, sampling_rate):
14+
"""
15+
:param sampling_rate: the sampling rate of the process
16+
:return: None
17+
Here all the configuration detail that is available at initialization time, however,
18+
no matplotlib object should be defined in here because they can't be pickled and sent it
19+
through the process borders. The constructor gets called in the
20+
"""
21+
22+
super(SimplePlotter, self).__init__()
23+
print("in init")
24+
self.y = np.empty(0, dtype=np.float32) # the buffer for the data that gets accumulated
25+
# self.chan_in = 10
26+
self.plotting_interval = 1000. # in ms
27+
self.frame_count = 0
28+
self.frame_max = 0
29+
self.sampling_rate = sampling_rate
30+
self.app_name = "Simple Plotter"
31+
# matplotlib members, initialized to None, they will be filled in startup
32+
self.ax = None
33+
self.hl = None
34+
self.figure = None
35+
self.num_samples = 0
36+
self.pipe = None
37+
self.code = 0
38+
39+
def startup(self):
40+
# build the plot
41+
ylim0 = 200
42+
print("starting plot")
43+
self.figure, self.ax = plt.subplots()
44+
plt.subplots_adjust(left=0.1, bottom=0.2)
45+
axcolor = 'lightgoldenrodyellow'
46+
axylim = plt.axes([0.1, 0.05, 0.65, 0.03], facecolor=axcolor)
47+
sylim = Slider(axylim, 'Ylim', 1, 600, valinit=ylim0)
48+
49+
# noinspection PyUnusedLocal
50+
def update(val):
51+
yl = sylim.val
52+
self.ax.set_ylim(-yl, yl)
53+
plt.draw()
54+
55+
sylim.on_changed(update)
56+
57+
self.hl, = self.ax.plot([], [])
58+
self.ax.set_autoscaley_on(True)
59+
self.ax.margins(y=0.1)
60+
self.ax.set_xlim(0., 1)
61+
self.ax.set_ylim(-ylim0, ylim0)
62+
# initialize timer
63+
timer = self.figure.canvas.new_timer(interval=50, )
64+
timer.add_callback(self.callback)
65+
timer.start()
66+
plt.show(block=True)
67+
68+
@staticmethod
69+
def param_config():
70+
chan_labels = list(range(32))
71+
return ("int_set", "chan_in", chan_labels),
72+
73+
def update_plot(self, n_arr, sample_rate):
74+
# setting up frame dependent parameters
75+
self.num_samples = int(n_arr.shape[0])
76+
self.sampling_rate = sample_rate
77+
events = []
78+
frame_time = 1000. * self.num_samples / self.sampling_rate
79+
self.frame_max = int(self.plotting_interval / frame_time)
80+
# increment the buffer
81+
self.y = np.append(self.y, n_arr)
82+
self.frame_count += 1
83+
84+
if self.frame_count == self.frame_max:
85+
# update the plot
86+
x = np.arange(len(self.y), dtype=np.float32) * 1000. / self.sampling_rate
87+
self.hl.set_ydata(self.y)
88+
self.hl.set_xdata(x)
89+
# print ("shape(x): ", x.shape, " shape(y): ", self.y.shape,
90+
# " min:", np.min(self.y), " max:", np.max(self.y) )
91+
self.ax.set_xlim(0., self.plotting_interval)
92+
self.ax.relim()
93+
self.ax.autoscale_view(True, True, False)
94+
self.figure.canvas.draw()
95+
self.figure.canvas.flush_events()
96+
97+
self.frame_count = 0
98+
self.y = np.empty(0, dtype=np.float32)
99+
100+
# if np.random.random() < 0.5:
101+
# events.append({'type': 3, 'sampleNum': 0, 'eventId': self.code})
102+
# self.code += 1
103+
return events

0 commit comments

Comments
 (0)