Source code for spikeGLX_remote.socket_utils

import socket
import ssl
import threading
import select
import logging
import time
import json

from enum import Enum


[docs] class MessageType(Enum): """ Enum for message types to be able to change the message easily across the code """ start_daq = 'start_rec' stop_daq = 'stop' start_daq_pulses = 'start_pulses' stop_daq_pulses = 'stop_pulses' start_daq_viewing = 'start_viewing' poll_status = 'status_poll' start_video_rec = 'start_rec' start_video_view = 'start_viewing' stop_video = 'stop' start_video_calibrec = 'start_calibrec' status = 'status' response = 'response' disconnected = 'disconnected' copy_files = 'copy_files' purge_files = 'purge_files'
[docs] class MessageStatus(Enum): """ Enum for message status to be able to change the message easily across the code """ ready = 'ready' error = 'error' viewing = 'viewing' recording = 'recording' viewing_ok = 'viewing_ok' recording_ok = 'recording_ok' recording_fail = 'recording_fail' stop_ok = 'stop_ok' pulsing_ok = 'pulsing_ok' calib_ok = 'calib_ok' copy_ok = 'copy_ok' copy_fail = 'copy_fail'
[docs] class SocketMessage: """ Class to hold all the messages that can be sent over the socket :param session_path: str: path to the session :param fps: float: frames per second :param session_id: str: session id :param daq_setting_file: str: path to the daq setting file :param basler_setting_file: str: path to the basler setting file :param pulse_lag: int: pulse lag after DAQ start :param start_daq: dict: message to start the daq :param stop_daq: dict: message to stop the daq :param start_daq_pulses: dict: message to start the daq pulses :param stop_daq_pulses: dict: message to stop the daq pulses :param start_daq_viewing: dict: message to start the daq viewing :param poll_status: dict: message to poll the status :param start_video_rec: dict: message to start the video recording :param start_video_view: dict: message to start the video viewing :param stop_video: dict: message to stop the video :param start_video_calibrec: dict: message to start the calibration recording :param copy_files: dict: message to copy the files :param purge_files: dict: message to purge the files :param view_spike_glx: dict: message to view the spike glx :param start_spike_glx: dict: message to start the spike glx :param stop_spike_glx: dict: message to stop the spike glx """ status_error = {'type': MessageType.status.value, 'status': MessageStatus.error.value} status_ready = {'type': MessageType.status.value, 'status': MessageStatus.ready.value} status_recording = {'type': MessageType.status.value, 'status': MessageStatus.recording.value} status_viewing = {'type': MessageType.status.value, 'status': MessageStatus.viewing.value} respond_recording = {'type': MessageType.response.value, 'status': MessageStatus.recording_ok.value} respond_recording_fail = {'type': MessageType.response.value, 'status': MessageStatus.recording_fail.value} respond_viewing = {'type': MessageType.response.value, 'status': MessageStatus.viewing_ok.value} respond_stop = {'type': MessageType.response.value, 'status': MessageStatus.stop_ok.value} respond_pulsing = {'type': MessageType.response.value, 'status': MessageStatus.pulsing_ok.value} respond_calib = {'type': MessageType.response.value, 'status': MessageStatus.calib_ok.value} respond_copy = {'type': MessageType.response.value, 'status': MessageStatus.copy_ok.value} respond_copy_fail = {'type': MessageType.response.value, 'status': MessageStatus.copy_fail.value} client_disconnected = {'type': MessageType.disconnected.value} def __init__(self): self._session_path = None self._fps = 30 self._session_id = "test" self._daq_setting_file = '' self._basler_setting_file = '' self._pulse_lag = 0 self.start_daq = {'type': MessageType.start_daq.value, 'session_id': self._session_id, 'setting_file': self._daq_setting_file} self.stop_daq = {'type': MessageType.stop_daq.value} self.start_daq_pulses = {'type': MessageType.start_daq_pulses.value, 'fps': self._fps, 'pulse_lag': self._pulse_lag} self.stop_daq_pulses = {'type': MessageType.stop_daq_pulses.value} self.start_daq_viewing = {'type': MessageType.start_daq_viewing.value, 'session_id': self._session_id, 'setting_file': self._daq_setting_file} self.poll_status = {'type': MessageType.poll_status.value} self.start_video_rec = {'type': MessageType.start_video_rec.value, 'session_id': self._session_id, 'setting_file': self._basler_setting_file, 'frame_rate': self._fps} self.start_video_view = {'type': MessageType.start_video_view.value, 'session_id': self._session_id, 'setting_file': self._basler_setting_file, 'frame_rate': self._fps} self.stop_video = {'type': MessageType.stop_video.value} self.start_video_calibrec = {'type': MessageType.start_video_calibrec.value, 'session_id': 'calibration', 'setting_file': self._basler_setting_file, 'frame_rate': 5} self.copy_files = {'type': MessageType.copy_files.value, 'session_id': self._session_id, 'session_path': self._session_path} self.purge_files = {'type': MessageType.purge_files.value, 'session_id': self._session_id} self.view_spike_glx = {'type': MessageType.start_video_view.value, 'session_id': self._session_id} # maybe further params self.start_spike_glx = {'type': MessageType.start_video_rec.value, 'session_id': self._session_id} self.stop_spike_glx = {'type': MessageType.stop_video.value} # if i add new ones they also need to addd to the update_messages function or automate this ? @property def pulse_lag(self): return self._pulse_lag @pulse_lag.setter def pulse_lag(self, value: int): self._pulse_lag = value self.update_messages() @property def session_id(self): return self._session_id @session_id.setter def session_id(self, value: str): self._session_id = value self.update_messages() @property def session_path(self): return self._session_path @session_path.setter def session_path(self, value: str): self._session_path = value self.update_messages() @property def fps(self): return self._fps @fps.setter def fps(self, value: float): self._fps = value self.update_messages() @property def daq_setting_file(self): return self._daq_setting_file @daq_setting_file.setter def daq_setting_file(self, value: str): self._daq_setting_file = value self.update_messages() @property def basler_setting_file(self): return self._basler_setting_file @basler_setting_file.setter def basler_setting_file(self, value: str): self._basler_setting_file = value self.update_messages()
[docs] def update_messages(self): """ Updates all the messages with current values. :return: """ self.start_daq.update(**{'session_id': self.session_id, 'setting_file': self.daq_setting_file}) self.start_daq_viewing.update(**{'session_id': self._session_id, 'setting_file': self.daq_setting_file}) self.start_daq_pulses.update(**{'fps': self.fps, 'pulse_lag': self.pulse_lag}) self.start_video_rec.update(**{'session_id': self.session_id, 'setting_file': self.basler_setting_file, 'frame_rate': self.fps}) self.start_video_view.update(**{'session_id': self._session_id, 'setting_file': self.basler_setting_file, 'frame_rate': self.fps}) self.start_video_calibrec.update(**{'session_id': 'calibration', 'setting_file': self.basler_setting_file}) self.copy_files.update(**{'session_id': self.session_id, 'session_path': self._session_path}) self.purge_files.update(**{'session_id': self._session_id}) self.view_spike_glx.update(**{'session_id': self._session_id}) # maybe further params self.start_spike_glx.update(**{'session_id': self._session_id}) self.stop_spike_glx.update(**{'session_id': self._session_id})
[docs] class SocketComm: """ Class to handle socket communication between processes or devices :param soctype: str: type of socket, either 'client' or 'server' :param host: str: host IP address :param port: int: port number :param use_ssl: bool: use ssl encryption :param acception_thread: threading.Thread: thread to accept connection :param ssl_sock: ssl.SSLSocket: ssl socket :param sock: socket.socket: socket :param _sock: socket.socket: socket :param _ssl_sock: ssl.SSLSocket: ssl socket :param context: ssl.SSLContext: ssl context :param use_ssl: bool: use ssl encryption :param connected: bool: connection status :param stop_event: threading.Event: event to stop waiting for connection :param log: logging.Logger: logger :param message_time: float: time of last message """ def __init__(self, soctype: str = "server", host: str = "localhost", port: int = 8800, use_ssl: bool = False): self.acception_thread = None self.ssl_sock = None self.sock = None self._sock = None self._ssl_sock = None self.type = soctype self.host = host self.port = port if self.type == "server": self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) else: self.context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) # self.context.set_ciphers('DEFAULT') self.use_ssl = use_ssl if use_ssl: raise NotImplementedError("SSL not implemented yet") # this doesnt work yet get some weird error from ssl module self.connected = False self.stop_event = threading.Event() self.log = logging.getLogger(f"SocketComm_{self.type}") self.log.setLevel(logging.DEBUG) self.message_time = time.monotonic()
[docs] def create_socket(self): """ Creates the socket for the server or client """ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if self.type == 'client': pass elif self.type == 'server': try: self._sock.bind((self.host, self.port)) except OSError: self.log.warning('Address already in use.. need to delete somehow ?') self._sock.listen() if self.use_ssl: self._ssl_sock = self.context.wrap_socket(self._sock, server_side=True, do_handshake_on_connect=False)
[docs] def accept_connection(self): """ Accepts connection-request from client :return: """ self.create_socket() while not self.stop_event.is_set(): if time.monotonic() - self.message_time > 5: self.message_time = time.monotonic() self.log.debug('waiting for connection...') if self.use_ssl: ready, _, _ = select.select([self._ssl_sock], [], [], 0.1) if ready: self.ssl_sock, self.addr = self._ssl_sock.accept() self.ssl_sock.settimeout(0.1) self.connected = True self.log.info(f"Connected to {self.addr}") break else: ready, _, _ = select.select([self._sock], [], [], 0.1) if ready: self.sock, self.addr = self._sock.accept() self.sock.settimeout(0.1) self.connected = True self.log.info(f"Connected to {self.addr}") break else: self.log.debug("Stop event set. Stopping thread...") return
[docs] def threaded_accept_connection(self): """ Accepts connection in a separate thread, to not block the main thread """ self.stop_event.clear() self.acception_thread = threading.Thread(target=self.accept_connection) self.acception_thread.start()
[docs] def stop_waiting_for_connection(self): """ sets the stop event, so the thread will stop waiting for a connection """ self.stop_event.set()
[docs] def connect(self) -> bool: """ Connects to the server """ if self.type == 'client': if self.use_ssl: self.ssl_sock = self.context.wrap_socket(self._sock, server_hostname=self.host, do_handshake_on_connect=False) else: self.sock = self._sock self.sock.settimeout(0.1) # otherwise we get issues if nothing is comming self._connect(self.host, self.port) self.connected = True return True else: return False
# raise RuntimeError("Error: Cannot connect on server socket")
[docs] def close_socket(self): """ Closes the socket :return: """ if self.use_ssl: if self.ssl_sock: self.ssl_sock.close() self._ssl_sock.close() if self.sock: self.sock.close() if self._sock: self._sock.close() self.connected = False
[docs] def read_json_message(self) -> [dict, None]: """ Reads a json message from the socket until a linebreak is reached then decodes it via json :return: dict, None: message or None if no message is received """ try: message = self._recv_until(b'\n') if message is not None: message = json.loads(message.decode()) else: return message except json.decoder.JSONDecodeError: message = None return message
[docs] def read_json_message_fast(self) -> [dict, None]: """ Reads a json message from the socket via a large bulk then decodes it via json :return: dict, None: message or None if no message is received """ try: message = self._recv(1024) if message == -1: return SocketMessage.client_disconnected if message is not None: message = json.loads(message.decode()) else: return message except json.decoder.JSONDecodeError: message = None print('message decoding failed') return message
[docs] def read_json_message_fast_linebreak(self) -> [dict, None]: """ Reads a json message from the socket until a linebreak is reached then decodes it via json :return: dict, None: message or None if no message is received """ try: message = self._recv_until(b'\n') if message == -1: return SocketMessage.client_disconnected if message is not None: message = json.loads(message.decode()) except json.decoder.JSONDecodeError: message = None self.log.error('message decoding failed') except OSError: message = None self.log.warning('socket disconnected and deleted') return message
[docs] def send_json_message(self, message: dict): """ Sends a json message over the socket :param message: dict: message to send of SocketMessage type :return: """ message = json.dumps(message).encode() message += b'\n' self._send(message)
def _connect(self, host, port): if self.use_ssl: self.ssl_sock.connect((host, port)) else: self.sock.connect((host, port)) def _send(self, data): try: if self.use_ssl: self.ssl_sock.sendall(data) else: self.sock.sendall(data) except ConnectionResetError: self.log.error("Connection reset by peer") def _recv(self, size) -> (bytes, int): try: if self.use_ssl: return self.ssl_sock.recv(size) else: return self.sock.recv(size) except socket.timeout: return None except ConnectionResetError: self.log.warning("Client disconnected") return -1 def _recv_until(self, delimiter: bytes) -> [bytes, None, int]: """ Receives data until a delimiter is reached :param delimiter: bytes: delimiter to stop receiving :return: bytes, None, int: received data or None if no data is received or -1 if client disconnected """ data = b'' try: if self.use_ssl: while not data.endswith(delimiter): received = self.ssl_sock.recv(1) if received == b'': self.connected = False break data += self.ssl_sock.recv(1) else: while not data.endswith(delimiter): received = self.sock.recv(1) if received == b'': self.connected = False break data += received except socket.timeout: data = None except (BrokenPipeError, ConnectionResetError): self.log.warning("Client disconnected") data = -1 return data def _recv_all(self): data = b'' if self.use_ssl: while True: try: data += self.ssl_sock.recv(1024) except socket.timeout: break else: while True: try: data += self.sock.recv(1024) except socket.timeout: break return data
if __name__ == "__main__": import time import argparse import json """ parser = argparse.ArgumentParser(description='Socket communication test') parser.add_argument('--type', type=str, default='server', help='Socket type: client or server') parser.add_argument('--host', type=str, default='localhost', help='Host IP address') parser.add_argument('--port', type=int, default=8800, help='Port number') parser.add_argument('--use_ssl', type=bool, default=False, help='Use SSL') args = parser.parse_args() sock = SocketComm('server') sock.create_socket() sock.threaded_accept_connection() while not sock.connected: print('no connection established,waiting...') time.sleep(1) try: data = sock.read_json_message() print(data) except Exception as e: print(e) pass sock.close_socket() """ import json from pathlib import PureWindowsPath sock = SocketComm('client', port=8882) sock.create_socket() sock.connect() time.sleep(0.5) response = sock.read_json_message_fast() socket_messages = SocketMessage() print(response) socket_messages.session_id = 'testMousy42_yeah' sock.send_json_message(socket_messages.start_spike_glx) time.sleep(0.5) response = sock.read_json_message_fast() print(response) time.sleep(5) sock.send_json_message(socket_messages.stop_spike_glx) time.sleep(0.5) response = sock.read_json_message_fast() print(response) time.sleep(1) sess_path = r"2023_BehFlex\HillYmaze_training\data\0_raw\20230620_r0083_wt_1711" win_path = PureWindowsPath("O:\\archive\\users\\as153\\Copytest") / sess_path socket_messages.session_path = str(win_path) sock.send_json_message(socket_messages.copy_files) time.sleep(2) response = sock.read_json_message_fast() print(response) sock.close_socket()