From 35e26a82e075f51b95f0ed0eb18e46f977c220d0 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 2 Apr 2020 14:45:41 -0700 Subject: [PATCH 001/242] add multithread simulation test --- multithread_test/test_BMI3D_interface.py | 47 ++++++++++++++++++++++ multithread_test/test_client.py | 51 ++++++++++++++++++++++++ multithread_test/test_script.py | 24 +++++++++++ multithread_test/test_server.py | 35 ++++++++++++++++ 4 files changed, 157 insertions(+) create mode 100644 multithread_test/test_BMI3D_interface.py create mode 100644 multithread_test/test_client.py create mode 100644 multithread_test/test_script.py create mode 100644 multithread_test/test_server.py diff --git a/multithread_test/test_BMI3D_interface.py b/multithread_test/test_BMI3D_interface.py new file mode 100644 index 00000000..d32cfd93 --- /dev/null +++ b/multithread_test/test_BMI3D_interface.py @@ -0,0 +1,47 @@ +from test_client import TestClient +import numpy as np +from multiprocessing import Process,Lock + +mutex = Lock() + +class MotionData(object): + + + def __init__(self, num_length): + self.test_client = TestClient() + #self.data_array = np.zeros(num_length) + #self.data_array = np.zeros(num_length) + self.data_array = [None] * num_length + self.num_length = num_length + + + def receive_data(self, data): + #print( "Received data from client", data) + rec_num = float(data) + #self.data_array[2:] = self.data_array[1:] + + #make a running buffer + with mutex: + self.data_array.insert(0,rec_num) + self.data_array.pop() + + #print(self.data_array) + + #update data in the motion buffer + + #save data + + def start(self): + self.test_client.dataListener = self.receive_data + self.test_client.run() + print('Start the interface thread') + + def stop(self): + pass + + def get(self): + current_value = None + with mutex: + current_value = self.data_array[0] + #return the latest saved data + return current_value \ No newline at end of file diff --git a/multithread_test/test_client.py b/multithread_test/test_client.py new file mode 100644 index 00000000..8bec80eb --- /dev/null +++ b/multithread_test/test_client.py @@ -0,0 +1,51 @@ +import socket +import struct +from threading import Thread + +class TestClient: + def __init__(self): + # Change this value to the IP address of the NatNet server. + self.serverIPAddress = "127.0.0.1" + + # Change this value to the IP address of your local network interface + self.localIPAddress = "127.0.0.1" + + self.dataPort = 5005 + + # similar to the rigidBodyListener + self.dataListener = None + + # Create a data socket to attach to the NatNet stream + def __createDataSocket( self, port ): + result = socket.socket( socket.AF_INET, # Internet + socket.SOCK_DGRAM, + socket.IPPROTO_UDP) # UDP + + #result.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + #result.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(self.multicastAddress) + socket.inet_aton(self.localIPAddress)) + + result.bind( (self.localIPAddress, port) ) + + return result + + def __dataThreadFunction( self, socket ): + while True: + # Block for input + data, addr = socket.recvfrom( 32768 ) # 32k byte buffer size + if( len( data ) > 0 ): + #self.__processMessage( data ) + + # Send information to any listener. + if self.dataListener is not None: + self.dataListener(data) + + def run( self ): + # Create the data socket + self.dataSocket = self.__createDataSocket( self.dataPort ) + if( self.dataSocket is None ): + print( "Could not open data channel" ) + exit + + # Create a separate thread for receiving data packets + dataThread = Thread( target = self.__dataThreadFunction, args = (self.dataSocket, )) + dataThread.start() \ No newline at end of file diff --git a/multithread_test/test_script.py b/multithread_test/test_script.py new file mode 100644 index 00000000..cb49942c --- /dev/null +++ b/multithread_test/test_script.py @@ -0,0 +1,24 @@ +from test_server import TestServer +from test_client import TestClient +from test_BMI3D_interface import MotionData +import time + +#fire up the server +test_server = TestServer() +test_server.run() + +num_length = 10 +motion_data = MotionData(num_length) +motion_data.start() + +while True: + print(motion_data.get()) + time.sleep(0.1) + + + + + +#fire up the client +#t_client = TestClient() +#t_client.run() \ No newline at end of file diff --git a/multithread_test/test_server.py b/multithread_test/test_server.py new file mode 100644 index 00000000..73dc72e9 --- /dev/null +++ b/multithread_test/test_server.py @@ -0,0 +1,35 @@ +import socket,time +from threading import Thread +import numpy as np +import string + + + +''' +print "UDP target IP:", UDP_IP +print "UDP target port:", UDP_PORT +print "message:", MESSAGE +''' +class TestServer(object): + UDP_IP = "127.0.0.1" + UDP_PORT = 5005 + MESSAGE = "Hello, World!" + SLEEP_TIME = 0.01 # seond + + def __init__(self): + self.sock = socket.socket(socket.AF_INET, # Internet + socket.SOCK_DGRAM) # UDP + + def __dataThreadFunction( self, socket,sleep_time ): + while True: + # Block for input + rand_num = np.random.rand() + + self.sock.sendto(str(rand_num).encode(), (self.UDP_IP, self.UDP_PORT)) + time.sleep(sleep_time) + + def run( self ): + # Create a separate thread for receiving data packets + dataThread = Thread( target = self.__dataThreadFunction, args = (self.sock, self.SLEEP_TIME)) + print('Server starts to broadcast data') + dataThread.start() From b037105b437ce877d9e20bcd1bf3f8d6e85c0b00 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 2 Apr 2020 15:12:36 -0700 Subject: [PATCH 002/242] commit -m 'save to mocap simulation to tests' --- .../multithread_test/test_BMI3D_interface.py | 47 +++++++++++++++++ tests/multithread_test/test_client.py | 51 +++++++++++++++++++ tests/multithread_test/test_script.py | 24 +++++++++ tests/multithread_test/test_server.py | 35 +++++++++++++ 4 files changed, 157 insertions(+) create mode 100644 tests/multithread_test/test_BMI3D_interface.py create mode 100644 tests/multithread_test/test_client.py create mode 100644 tests/multithread_test/test_script.py create mode 100644 tests/multithread_test/test_server.py diff --git a/tests/multithread_test/test_BMI3D_interface.py b/tests/multithread_test/test_BMI3D_interface.py new file mode 100644 index 00000000..d32cfd93 --- /dev/null +++ b/tests/multithread_test/test_BMI3D_interface.py @@ -0,0 +1,47 @@ +from test_client import TestClient +import numpy as np +from multiprocessing import Process,Lock + +mutex = Lock() + +class MotionData(object): + + + def __init__(self, num_length): + self.test_client = TestClient() + #self.data_array = np.zeros(num_length) + #self.data_array = np.zeros(num_length) + self.data_array = [None] * num_length + self.num_length = num_length + + + def receive_data(self, data): + #print( "Received data from client", data) + rec_num = float(data) + #self.data_array[2:] = self.data_array[1:] + + #make a running buffer + with mutex: + self.data_array.insert(0,rec_num) + self.data_array.pop() + + #print(self.data_array) + + #update data in the motion buffer + + #save data + + def start(self): + self.test_client.dataListener = self.receive_data + self.test_client.run() + print('Start the interface thread') + + def stop(self): + pass + + def get(self): + current_value = None + with mutex: + current_value = self.data_array[0] + #return the latest saved data + return current_value \ No newline at end of file diff --git a/tests/multithread_test/test_client.py b/tests/multithread_test/test_client.py new file mode 100644 index 00000000..8bec80eb --- /dev/null +++ b/tests/multithread_test/test_client.py @@ -0,0 +1,51 @@ +import socket +import struct +from threading import Thread + +class TestClient: + def __init__(self): + # Change this value to the IP address of the NatNet server. + self.serverIPAddress = "127.0.0.1" + + # Change this value to the IP address of your local network interface + self.localIPAddress = "127.0.0.1" + + self.dataPort = 5005 + + # similar to the rigidBodyListener + self.dataListener = None + + # Create a data socket to attach to the NatNet stream + def __createDataSocket( self, port ): + result = socket.socket( socket.AF_INET, # Internet + socket.SOCK_DGRAM, + socket.IPPROTO_UDP) # UDP + + #result.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + #result.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(self.multicastAddress) + socket.inet_aton(self.localIPAddress)) + + result.bind( (self.localIPAddress, port) ) + + return result + + def __dataThreadFunction( self, socket ): + while True: + # Block for input + data, addr = socket.recvfrom( 32768 ) # 32k byte buffer size + if( len( data ) > 0 ): + #self.__processMessage( data ) + + # Send information to any listener. + if self.dataListener is not None: + self.dataListener(data) + + def run( self ): + # Create the data socket + self.dataSocket = self.__createDataSocket( self.dataPort ) + if( self.dataSocket is None ): + print( "Could not open data channel" ) + exit + + # Create a separate thread for receiving data packets + dataThread = Thread( target = self.__dataThreadFunction, args = (self.dataSocket, )) + dataThread.start() \ No newline at end of file diff --git a/tests/multithread_test/test_script.py b/tests/multithread_test/test_script.py new file mode 100644 index 00000000..cb49942c --- /dev/null +++ b/tests/multithread_test/test_script.py @@ -0,0 +1,24 @@ +from test_server import TestServer +from test_client import TestClient +from test_BMI3D_interface import MotionData +import time + +#fire up the server +test_server = TestServer() +test_server.run() + +num_length = 10 +motion_data = MotionData(num_length) +motion_data.start() + +while True: + print(motion_data.get()) + time.sleep(0.1) + + + + + +#fire up the client +#t_client = TestClient() +#t_client.run() \ No newline at end of file diff --git a/tests/multithread_test/test_server.py b/tests/multithread_test/test_server.py new file mode 100644 index 00000000..73dc72e9 --- /dev/null +++ b/tests/multithread_test/test_server.py @@ -0,0 +1,35 @@ +import socket,time +from threading import Thread +import numpy as np +import string + + + +''' +print "UDP target IP:", UDP_IP +print "UDP target port:", UDP_PORT +print "message:", MESSAGE +''' +class TestServer(object): + UDP_IP = "127.0.0.1" + UDP_PORT = 5005 + MESSAGE = "Hello, World!" + SLEEP_TIME = 0.01 # seond + + def __init__(self): + self.sock = socket.socket(socket.AF_INET, # Internet + socket.SOCK_DGRAM) # UDP + + def __dataThreadFunction( self, socket,sleep_time ): + while True: + # Block for input + rand_num = np.random.rand() + + self.sock.sendto(str(rand_num).encode(), (self.UDP_IP, self.UDP_PORT)) + time.sleep(sleep_time) + + def run( self ): + # Create a separate thread for receiving data packets + dataThread = Thread( target = self.__dataThreadFunction, args = (self.sock, self.SLEEP_TIME)) + print('Server starts to broadcast data') + dataThread.start() From 3fcf4ebedd0c3ae28f15d554004b4ba1655a891a Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 2 Apr 2020 15:13:28 -0700 Subject: [PATCH 003/242] 'move mocap simulation to tests folder' --- multithread_test/test_BMI3D_interface.py | 47 ---------------------- multithread_test/test_client.py | 51 ------------------------ multithread_test/test_script.py | 24 ----------- multithread_test/test_server.py | 35 ---------------- 4 files changed, 157 deletions(-) delete mode 100644 multithread_test/test_BMI3D_interface.py delete mode 100644 multithread_test/test_client.py delete mode 100644 multithread_test/test_script.py delete mode 100644 multithread_test/test_server.py diff --git a/multithread_test/test_BMI3D_interface.py b/multithread_test/test_BMI3D_interface.py deleted file mode 100644 index d32cfd93..00000000 --- a/multithread_test/test_BMI3D_interface.py +++ /dev/null @@ -1,47 +0,0 @@ -from test_client import TestClient -import numpy as np -from multiprocessing import Process,Lock - -mutex = Lock() - -class MotionData(object): - - - def __init__(self, num_length): - self.test_client = TestClient() - #self.data_array = np.zeros(num_length) - #self.data_array = np.zeros(num_length) - self.data_array = [None] * num_length - self.num_length = num_length - - - def receive_data(self, data): - #print( "Received data from client", data) - rec_num = float(data) - #self.data_array[2:] = self.data_array[1:] - - #make a running buffer - with mutex: - self.data_array.insert(0,rec_num) - self.data_array.pop() - - #print(self.data_array) - - #update data in the motion buffer - - #save data - - def start(self): - self.test_client.dataListener = self.receive_data - self.test_client.run() - print('Start the interface thread') - - def stop(self): - pass - - def get(self): - current_value = None - with mutex: - current_value = self.data_array[0] - #return the latest saved data - return current_value \ No newline at end of file diff --git a/multithread_test/test_client.py b/multithread_test/test_client.py deleted file mode 100644 index 8bec80eb..00000000 --- a/multithread_test/test_client.py +++ /dev/null @@ -1,51 +0,0 @@ -import socket -import struct -from threading import Thread - -class TestClient: - def __init__(self): - # Change this value to the IP address of the NatNet server. - self.serverIPAddress = "127.0.0.1" - - # Change this value to the IP address of your local network interface - self.localIPAddress = "127.0.0.1" - - self.dataPort = 5005 - - # similar to the rigidBodyListener - self.dataListener = None - - # Create a data socket to attach to the NatNet stream - def __createDataSocket( self, port ): - result = socket.socket( socket.AF_INET, # Internet - socket.SOCK_DGRAM, - socket.IPPROTO_UDP) # UDP - - #result.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - #result.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(self.multicastAddress) + socket.inet_aton(self.localIPAddress)) - - result.bind( (self.localIPAddress, port) ) - - return result - - def __dataThreadFunction( self, socket ): - while True: - # Block for input - data, addr = socket.recvfrom( 32768 ) # 32k byte buffer size - if( len( data ) > 0 ): - #self.__processMessage( data ) - - # Send information to any listener. - if self.dataListener is not None: - self.dataListener(data) - - def run( self ): - # Create the data socket - self.dataSocket = self.__createDataSocket( self.dataPort ) - if( self.dataSocket is None ): - print( "Could not open data channel" ) - exit - - # Create a separate thread for receiving data packets - dataThread = Thread( target = self.__dataThreadFunction, args = (self.dataSocket, )) - dataThread.start() \ No newline at end of file diff --git a/multithread_test/test_script.py b/multithread_test/test_script.py deleted file mode 100644 index cb49942c..00000000 --- a/multithread_test/test_script.py +++ /dev/null @@ -1,24 +0,0 @@ -from test_server import TestServer -from test_client import TestClient -from test_BMI3D_interface import MotionData -import time - -#fire up the server -test_server = TestServer() -test_server.run() - -num_length = 10 -motion_data = MotionData(num_length) -motion_data.start() - -while True: - print(motion_data.get()) - time.sleep(0.1) - - - - - -#fire up the client -#t_client = TestClient() -#t_client.run() \ No newline at end of file diff --git a/multithread_test/test_server.py b/multithread_test/test_server.py deleted file mode 100644 index 73dc72e9..00000000 --- a/multithread_test/test_server.py +++ /dev/null @@ -1,35 +0,0 @@ -import socket,time -from threading import Thread -import numpy as np -import string - - - -''' -print "UDP target IP:", UDP_IP -print "UDP target port:", UDP_PORT -print "message:", MESSAGE -''' -class TestServer(object): - UDP_IP = "127.0.0.1" - UDP_PORT = 5005 - MESSAGE = "Hello, World!" - SLEEP_TIME = 0.01 # seond - - def __init__(self): - self.sock = socket.socket(socket.AF_INET, # Internet - socket.SOCK_DGRAM) # UDP - - def __dataThreadFunction( self, socket,sleep_time ): - while True: - # Block for input - rand_num = np.random.rand() - - self.sock.sendto(str(rand_num).encode(), (self.UDP_IP, self.UDP_PORT)) - time.sleep(sleep_time) - - def run( self ): - # Create a separate thread for receiving data packets - dataThread = Thread( target = self.__dataThreadFunction, args = (self.sock, self.SLEEP_TIME)) - print('Server starts to broadcast data') - dataThread.start() From 0c912af910b454203c8ce35a2c62145e828082d4 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 2 Apr 2020 21:57:11 -0700 Subject: [PATCH 004/242] use pickle to serialize the data --- tests/multithread_test/notes | 0 .../multithread_test/test_BMI3D_interface.py | 4 +- tests/multithread_test/test_client.py | 7 +-- tests/multithread_test/test_script.py | 4 +- tests/multithread_test/test_server.py | 48 ++++++++++++++++++- 5 files changed, 57 insertions(+), 6 deletions(-) create mode 100644 tests/multithread_test/notes diff --git a/tests/multithread_test/notes b/tests/multithread_test/notes new file mode 100644 index 00000000..e69de29b diff --git a/tests/multithread_test/test_BMI3D_interface.py b/tests/multithread_test/test_BMI3D_interface.py index d32cfd93..f83267bd 100644 --- a/tests/multithread_test/test_BMI3D_interface.py +++ b/tests/multithread_test/test_BMI3D_interface.py @@ -1,6 +1,7 @@ from test_client import TestClient import numpy as np from multiprocessing import Process,Lock +import pickle mutex = Lock() @@ -17,7 +18,8 @@ def __init__(self, num_length): def receive_data(self, data): #print( "Received data from client", data) - rec_num = float(data) + rec_num = data + #self.data_array[2:] = self.data_array[1:] #make a running buffer diff --git a/tests/multithread_test/test_client.py b/tests/multithread_test/test_client.py index 8bec80eb..75c1d12d 100644 --- a/tests/multithread_test/test_client.py +++ b/tests/multithread_test/test_client.py @@ -1,6 +1,7 @@ import socket import struct from threading import Thread +import pickle class TestClient: def __init__(self): @@ -34,10 +35,10 @@ def __dataThreadFunction( self, socket ): data, addr = socket.recvfrom( 32768 ) # 32k byte buffer size if( len( data ) > 0 ): #self.__processMessage( data ) - - # Send information to any listener. + data_arr = pickle.loads(data) + # Send information to any listener. if self.dataListener is not None: - self.dataListener(data) + self.dataListener(data_arr) def run( self ): # Create the data socket diff --git a/tests/multithread_test/test_script.py b/tests/multithread_test/test_script.py index cb49942c..b089ea6d 100644 --- a/tests/multithread_test/test_script.py +++ b/tests/multithread_test/test_script.py @@ -1,10 +1,12 @@ from test_server import TestServer +from test_server import TestServerMouse from test_client import TestClient from test_BMI3D_interface import MotionData import time #fire up the server -test_server = TestServer() +#test_server = TestServer() +test_server = TestServerMouse() test_server.run() num_length = 10 diff --git a/tests/multithread_test/test_server.py b/tests/multithread_test/test_server.py index 73dc72e9..6379b557 100644 --- a/tests/multithread_test/test_server.py +++ b/tests/multithread_test/test_server.py @@ -2,15 +2,19 @@ from threading import Thread import numpy as np import string +import pickle + +import pygame ''' +this class generates random numbers and broadcast them via the UDP protocal to the local network print "UDP target IP:", UDP_IP print "UDP target port:", UDP_PORT print "message:", MESSAGE ''' -class TestServer(object): +class TestServer(object): UDP_IP = "127.0.0.1" UDP_PORT = 5005 MESSAGE = "Hello, World!" @@ -33,3 +37,45 @@ def run( self ): dataThread = Thread( target = self.__dataThreadFunction, args = (self.sock, self.SLEEP_TIME)) print('Server starts to broadcast data') dataThread.start() + +#this child class replaces the generator and then waits for the mouse command +class TestServerMouse(TestServer): + + def __init__(self): + + super().__init__() + + def __dataThreadFunction( self, socket,sleep_time ): + print('here ') + pygame.init() + (width, height) = (300, 200) + screen = pygame.display.set_mode((width, height)) + + while True: + + #need to poll the event before + pygame.event.get() + # Block for input + cursor_pos = pygame.mouse.get_pos() + #print(cursor_pos) + + screen.fill((0, 0, 0)) + pygame.display.flip() + + #prepare dump data + dump_data = pickle.dumps(cursor_pos) + + self.sock.sendto(dump_data, (self.UDP_IP, self.UDP_PORT)) + time.sleep(sleep_time) + + def run(self): + dataThread = Thread( target = self.__dataThreadFunction, args = (self.sock, self.SLEEP_TIME)) + print('Server starts to broadcast data') + dataThread.start() + + + +#test function +if __name__ == "__main__": + tsm = TestServerMouse() + tsm.run() \ No newline at end of file From b84234e9cbd8f85f84c01ad0a5ea17c822d490b7 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 2 Apr 2020 22:32:25 -0700 Subject: [PATCH 005/242] use pyautogui to get cursor's position --- .../multithread_test/test_BMI3D_interface.py | 3 --- tests/multithread_test/test_client.py | 3 ++- tests/multithread_test/test_script.py | 2 +- tests/multithread_test/test_server.py | 19 ++++--------------- 4 files changed, 7 insertions(+), 20 deletions(-) diff --git a/tests/multithread_test/test_BMI3D_interface.py b/tests/multithread_test/test_BMI3D_interface.py index f83267bd..82c230d4 100644 --- a/tests/multithread_test/test_BMI3D_interface.py +++ b/tests/multithread_test/test_BMI3D_interface.py @@ -29,9 +29,6 @@ def receive_data(self, data): #print(self.data_array) - #update data in the motion buffer - - #save data def start(self): self.test_client.dataListener = self.receive_data diff --git a/tests/multithread_test/test_client.py b/tests/multithread_test/test_client.py index 75c1d12d..0e5359b9 100644 --- a/tests/multithread_test/test_client.py +++ b/tests/multithread_test/test_client.py @@ -2,6 +2,7 @@ import struct from threading import Thread import pickle +import numpy class TestClient: def __init__(self): @@ -35,7 +36,7 @@ def __dataThreadFunction( self, socket ): data, addr = socket.recvfrom( 32768 ) # 32k byte buffer size if( len( data ) > 0 ): #self.__processMessage( data ) - data_arr = pickle.loads(data) + data_arr = numpy.asarray(pickle.loads(data)) # Send information to any listener. if self.dataListener is not None: self.dataListener(data_arr) diff --git a/tests/multithread_test/test_script.py b/tests/multithread_test/test_script.py index b089ea6d..64974c81 100644 --- a/tests/multithread_test/test_script.py +++ b/tests/multithread_test/test_script.py @@ -15,7 +15,7 @@ while True: print(motion_data.get()) - time.sleep(0.1) + time.sleep(1) diff --git a/tests/multithread_test/test_server.py b/tests/multithread_test/test_server.py index 6379b557..76661ddb 100644 --- a/tests/multithread_test/test_server.py +++ b/tests/multithread_test/test_server.py @@ -4,7 +4,7 @@ import string import pickle -import pygame +import pyautogui @@ -18,7 +18,7 @@ class TestServer(object): UDP_IP = "127.0.0.1" UDP_PORT = 5005 MESSAGE = "Hello, World!" - SLEEP_TIME = 0.01 # seond + SLEEP_TIME = 0.01 # seond 100 Hz per second def __init__(self): self.sock = socket.socket(socket.AF_INET, # Internet @@ -46,21 +46,10 @@ def __init__(self): super().__init__() def __dataThreadFunction( self, socket,sleep_time ): - print('here ') - pygame.init() - (width, height) = (300, 200) - screen = pygame.display.set_mode((width, height)) while True: - - #need to poll the event before - pygame.event.get() - # Block for input - cursor_pos = pygame.mouse.get_pos() - #print(cursor_pos) - - screen.fill((0, 0, 0)) - pygame.display.flip() + #get cursor position with pyautogui + cursor_pos = pyautogui.position() #prepare dump data dump_data = pickle.dumps(cursor_pos) From 140f5fa263b1635d3df49643a976d59d52b072aa Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 2 Apr 2020 22:33:21 -0700 Subject: [PATCH 006/242] use pyautogui to get the cursor position --- tests/multithread_test/test_BMI3D_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/multithread_test/test_BMI3D_interface.py b/tests/multithread_test/test_BMI3D_interface.py index 82c230d4..d3e086db 100644 --- a/tests/multithread_test/test_BMI3D_interface.py +++ b/tests/multithread_test/test_BMI3D_interface.py @@ -27,7 +27,7 @@ def receive_data(self, data): self.data_array.insert(0,rec_num) self.data_array.pop() - #print(self.data_array) + #save data to a def start(self): From a2184914c422408d9812cc23c509bb1626b4de8a Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 2 Apr 2020 22:35:55 -0700 Subject: [PATCH 007/242] move the mocap simulation to riglib --- {tests => riglib}/multithread_test/notes | 0 {tests => riglib}/multithread_test/test_BMI3D_interface.py | 0 {tests => riglib}/multithread_test/test_client.py | 0 {tests => riglib}/multithread_test/test_script.py | 0 {tests => riglib}/multithread_test/test_server.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename {tests => riglib}/multithread_test/notes (100%) rename {tests => riglib}/multithread_test/test_BMI3D_interface.py (100%) rename {tests => riglib}/multithread_test/test_client.py (100%) rename {tests => riglib}/multithread_test/test_script.py (100%) rename {tests => riglib}/multithread_test/test_server.py (100%) diff --git a/tests/multithread_test/notes b/riglib/multithread_test/notes similarity index 100% rename from tests/multithread_test/notes rename to riglib/multithread_test/notes diff --git a/tests/multithread_test/test_BMI3D_interface.py b/riglib/multithread_test/test_BMI3D_interface.py similarity index 100% rename from tests/multithread_test/test_BMI3D_interface.py rename to riglib/multithread_test/test_BMI3D_interface.py diff --git a/tests/multithread_test/test_client.py b/riglib/multithread_test/test_client.py similarity index 100% rename from tests/multithread_test/test_client.py rename to riglib/multithread_test/test_client.py diff --git a/tests/multithread_test/test_script.py b/riglib/multithread_test/test_script.py similarity index 100% rename from tests/multithread_test/test_script.py rename to riglib/multithread_test/test_script.py diff --git a/tests/multithread_test/test_server.py b/riglib/multithread_test/test_server.py similarity index 100% rename from tests/multithread_test/test_server.py rename to riglib/multithread_test/test_server.py From e2dd2a49289b4734d88fc8856a7cf6d89a296d7d Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Sun, 5 Apr 2020 10:59:07 -0700 Subject: [PATCH 008/242] fixed the stereo window size casting type error --- built_in_tasks/cursorControlTasks.py | 72 +++++++++++++++++++++++++++ riglib/stereo_opengl/render/stereo.py | 4 ++ tests/test_window_2d.py | 34 ++++++++++--- 3 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 built_in_tasks/cursorControlTasks.py diff --git a/built_in_tasks/cursorControlTasks.py b/built_in_tasks/cursorControlTasks.py new file mode 100644 index 00000000..3889d1bd --- /dev/null +++ b/built_in_tasks/cursorControlTasks.py @@ -0,0 +1,72 @@ + +from .manualcontrolmultitasks import ManualControlMulti +from riglib.stereo_opengl.window import WindowDispl2D +from .bmimultitasks import BMIControlMulti +import pygame +import numpy as np +import copy + +from riglib.bmi.extractor import DummyExtractor +from riglib.bmi.state_space_models import StateSpaceEndptVel2D +from riglib.bmi.bmi import Decoder, BMISystem, GaussianStateHMM, BMILoop, GaussianState, MachineOnlyFilter + +class CursorControl(ManualControlMulti, WindowDispl2D): + ''' + this class implements a python cursor control task for human + ''' + + def __init__(self, *args, **kwargs): + # just run the parent ManualControlMulti's initialization + self.move_step = 1 + super(CursorControl, self).__init__(*args, **kwargs) + + def init(self): + pygame.init() + self.assist_level = (0.5, 0.5) + super(CursorControl, self).init() + + # override the _cycle function + def _cycle(self): + #print(self.state) + + self.move_effector_cursor() + super(CursorControl, self)._cycle() + + # do nothing + def move_effector(self): + pass + + def move_plant(self, **kwargs): + pass + + # use keyboard to control the task + def move_effector_cursor(self): + np.array([0., 0., 0.]) + curr_pos = copy.deepcopy(self.plant.get_endpoint_pos()) + + for event in pygame.event.get(): + if event.type == pygame.KEYUP: + if event.type == pygame.K_q: + pygame.quit() + quit() + if event.key == pygame.K_LEFT: + curr_pos[0] -= self.move_step + if event.key == pygame.K_RIGHT: + curr_pos[0] += self.move_step + if event.key == pygame.K_UP: + curr_pos[2] += self.move_step + if event.key == pygame.K_DOWN: + curr_pos[2] -= self.move_step + #print('Current position: ') + #print(curr_pos) + + # set the current position + self.plant.set_endpoint_pos(curr_pos) + + def _start_wait(self): + self.wait_time = 0. + super(CursorControl, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + \ No newline at end of file diff --git a/riglib/stereo_opengl/render/stereo.py b/riglib/stereo_opengl/render/stereo.py index 630ab9ea..f28a6811 100644 --- a/riglib/stereo_opengl/render/stereo.py +++ b/riglib/stereo_opengl/render/stereo.py @@ -48,7 +48,11 @@ def draw(self, root, **kwargs): kwargs: optional keyword-arguments Optional shaders and stuff to pass to the lower-level drawing functions ''' + print('aha') + print(self.size) w, h = self.size + w = int(w) + h = int(h) # draw the portion of the screen with lower-left corner (0, 0), width 'w' and height 'h' glViewport(0, 0, w, h) diff --git a/tests/test_window_2d.py b/tests/test_window_2d.py index f8b1200f..91d27e19 100644 --- a/tests/test_window_2d.py +++ b/tests/test_window_2d.py @@ -8,8 +8,11 @@ import pygame import numpy as np +import importlib -reload(window) +#importlib.reload(window) +m_to_cm = 100 +target_pos_radius = 10 class TestGraphics(Sequence, Window): status = dict( @@ -19,12 +22,15 @@ class TestGraphics(Sequence, Window): #initial state state = "wait" target_radius = 2. + + #create targets, cursor objects, initialize def __init__(self, *args, **kwargs): # Add the target and cursor locations to the task data to be saved to # file - super(TestGraphics, self).__init__(*args, **kwargs) + #super(TestGraphics, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.dtype = [('target', 'f', (3,)), ('cursor', 'f', (3,)), (('target_index', 'i', (1,)))] self.target1 = Sphere(radius=self.target_radius, color=(1,0,0,.5)) self.add_model(self.target1) @@ -36,13 +42,13 @@ def __init__(self, *args, **kwargs): ##### HELPER AND UPDATE FUNCTIONS #### -<<<<<<< HEAD +#<<<<<<< HEAD def _get_renderer(self): return stereo.MirrorDisplay(self.window_size, self.fov, 1, 1024, self.screen_dist, self.iod) #### STATE FUNCTIONS #### def _while_wait(self): - print "_while_wait" + print("_while_wait") self.target1.translate(0, 0, 0, reset=True) self.target1.attach() self.requeue() @@ -50,8 +56,16 @@ def _while_wait(self): def target_seq_generator(n_targs, n_trials): + #generate targets + angles = np.transpose(np.arange(0,2*np.pi,2*np.pi / n_targs)) + unit_targets = targets = np.stack((np.cos(angles), np.sin(angles)),1) + targets = unit_targets * target_pos_radius + + center = np.array((0,0)) + target_inds = np.random.randint(0, n_targs, n_trials) target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + k = 0 while k < n_trials: targ = m_to_cm*targets[target_inds[k], :] @@ -59,7 +73,11 @@ def target_seq_generator(n_targs, n_trials): [targ[0], 0, targ[1]]]) k += 1 -gen = target_seq_generator(8, 1000) -w = TestGraphics(gen) -w.init() -w.run() + +if __name__ == "__main__": + gen = target_seq_generator(8, 1000) + w = TestGraphics(gen) + w.init() + w.run() + + From 1effba24e9b8c829b3578eefc2e0410d70e0d6d0 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Sun, 5 Apr 2020 12:21:28 -0700 Subject: [PATCH 009/242] 'added movement to target 1' --- riglib/stereo_opengl/render/stereo.py | 2 -- riglib/stereo_opengl/window.py | 5 ++++- tests/test_window_2d.py | 21 ++++++++++++++++----- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/riglib/stereo_opengl/render/stereo.py b/riglib/stereo_opengl/render/stereo.py index f28a6811..c397c757 100644 --- a/riglib/stereo_opengl/render/stereo.py +++ b/riglib/stereo_opengl/render/stereo.py @@ -48,8 +48,6 @@ def draw(self, root, **kwargs): kwargs: optional keyword-arguments Optional shaders and stuff to pass to the lower-level drawing functions ''' - print('aha') - print(self.size) w, h = self.size w = int(w) h = int(h) diff --git a/riglib/stereo_opengl/window.py b/riglib/stereo_opengl/window.py index bda0b5fc..a18deccc 100644 --- a/riglib/stereo_opengl/window.py +++ b/riglib/stereo_opengl/window.py @@ -39,7 +39,9 @@ class Window(LogExperiment): state = "draw" stop = False - window_size = traits.Tuple((1920*2, 1080), descr='window size, in pixels') + #window_size = traits.Tuple((1920*2, 1080), descr='window size, in pixels') + #XPS computer + window_size = traits.Tuple((1280, 360), descr='window size, in pixels') # window_size = (1920*2, 1080) background = (0,0,0,1) @@ -64,6 +66,7 @@ def __init__(self, *args, **kwargs): def set_os_params(self): os.environ['SDL_VIDEO_WINDOW_POS'] = config.display_start_pos + #print(os.environ['SDL_VIDEO_WINDOW_POS']) os.environ['SDL_VIDEO_X11_WMCLASS'] = "monkey_experiment" def screen_init(self): diff --git a/tests/test_window_2d.py b/tests/test_window_2d.py index 91d27e19..627bf47c 100644 --- a/tests/test_window_2d.py +++ b/tests/test_window_2d.py @@ -34,25 +34,34 @@ def __init__(self, *args, **kwargs): self.dtype = [('target', 'f', (3,)), ('cursor', 'f', (3,)), (('target_index', 'i', (1,)))] self.target1 = Sphere(radius=self.target_radius, color=(1,0,0,.5)) self.add_model(self.target1) - self.target2 = Sphere(radius=self.target_radius, color=(1,0,0,.5)) + self.target2 = Sphere(radius=self.target_radius, color=(1,0,0,0.5)) self.add_model(self.target2) # Initialize target location variable - self.target_location = np.array([0,0,0]) + self.target_location = np.array([0.0,0.0,0.0]) ##### HELPER AND UPDATE FUNCTIONS #### #<<<<<<< HEAD def _get_renderer(self): return stereo.MirrorDisplay(self.window_size, self.fov, 1, 1024, self.screen_dist, self.iod) + def _cycle(self): + + super()._cycle() #### STATE FUNCTIONS #### def _while_wait(self): - print("_while_wait") - self.target1.translate(0, 0, 0, reset=True) - self.target1.attach() + #print("_while_wait") + + delta_movement = np.array([0,0,0.01]) + self.target_location += delta_movement + + self.target1.translate(self.target_location[0], + self.target_location[1], + self.target_location[2], reset=True) self.requeue() self.draw_world() + print('current target 1 position ' + np.array2string(self.target_location)) def target_seq_generator(n_targs, n_trials): @@ -75,9 +84,11 @@ def target_seq_generator(n_targs, n_trials): if __name__ == "__main__": + print('Remember to set window size in stereoOpenGL class') gen = target_seq_generator(8, 1000) w = TestGraphics(gen) w.init() w.run() + From 4055383baf83acc727a03fb88ddb02e2a704ba02 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Sun, 5 Apr 2020 15:20:24 -0700 Subject: [PATCH 010/242] add generator to run cursorControlTask from itself --- built_in_tasks/bmimultitasks.py | 4 +-- built_in_tasks/cursorControlTasks.py | 37 +++++++++++++++++++++-- built_in_tasks/manualcontrolmultitasks.py | 4 +-- riglib/stereo_opengl/window.py | 2 +- 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index c4340ce9..7692cb2d 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -29,7 +29,7 @@ from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceNLinkPlanarChain -from . import manualcontrolmultitasks +from manualcontrolmultitasks import ManualControlMulti target_colors = {"blue":(0,0,1,0.5), "yellow": (1,1,0,0.5), @@ -195,7 +195,7 @@ def __init__(self, *args, **kwargs): ################# ##### Tasks ##### ################# -class BMIControlMulti(BMILoop, LinearlyDecreasingAssist, manualcontrolmultitasks.ManualControlMulti): +class BMIControlMulti(BMILoop, LinearlyDecreasingAssist, ManualControlMulti): ''' Target capture task with cursor position controlled by BMI output. Cursor movement can be assisted toward target by setting assist_level > 0. diff --git a/built_in_tasks/cursorControlTasks.py b/built_in_tasks/cursorControlTasks.py index 3889d1bd..57555753 100644 --- a/built_in_tasks/cursorControlTasks.py +++ b/built_in_tasks/cursorControlTasks.py @@ -1,7 +1,7 @@ -from .manualcontrolmultitasks import ManualControlMulti +from manualcontrolmultitasks import ManualControlMulti from riglib.stereo_opengl.window import WindowDispl2D -from .bmimultitasks import BMIControlMulti +from bmimultitasks import BMIControlMulti import pygame import numpy as np import copy @@ -69,4 +69,37 @@ def _start_wait(self): def _test_start_trial(self, ts): return ts > self.wait_time and not self.pause + +#this task can be run on its +#we will not involve database at this time +target_pos_radius = 10 + +def target_seq_generator(n_targs, n_trials): + #generate targets + angles = np.transpose(np.arange(0,2*np.pi,2*np.pi / n_targs)) + unit_targets = targets = np.stack((np.cos(angles), np.sin(angles)),1) + targets = unit_targets * target_pos_radius + + center = np.array((0,0)) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + + k = 0 + while k < n_trials: + targ = targets[target_inds[k], :] + yield np.array([[center[0], 0, center[1]], + [targ[0], 0, targ[1]]]) + k += 1 + + +if __name__ == "__main__": + print('Remember to set window size in stereoOpenGL class') + gen = target_seq_generator(8, 1000) + + w = CursorControl(gen) + w.init() + w.run() + + \ No newline at end of file diff --git a/built_in_tasks/manualcontrolmultitasks.py b/built_in_tasks/manualcontrolmultitasks.py index 563fe0b1..d5f1a3ef 100644 --- a/built_in_tasks/manualcontrolmultitasks.py +++ b/built_in_tasks/manualcontrolmultitasks.py @@ -16,7 +16,7 @@ from riglib.stereo_opengl.render import stereo, Renderer from riglib.stereo_opengl.utils import cloudy_tex -from .plantlist import plantlist +from plantlist import plantlist from riglib.stereo_opengl import ik import os @@ -31,7 +31,7 @@ GOLD = (1., 0.843, 0., 0.5) mm_per_cm = 1./10 -from .target_graphics import * +from target_graphics import * target_colors = { "yellow": (1,1,0,0.75), diff --git a/riglib/stereo_opengl/window.py b/riglib/stereo_opengl/window.py index a18deccc..43451656 100644 --- a/riglib/stereo_opengl/window.py +++ b/riglib/stereo_opengl/window.py @@ -41,7 +41,7 @@ class Window(LogExperiment): #window_size = traits.Tuple((1920*2, 1080), descr='window size, in pixels') #XPS computer - window_size = traits.Tuple((1280, 360), descr='window size, in pixels') + window_size = traits.Tuple((1280, 720), descr='window size, in pixels') # window_size = (1920*2, 1080) background = (0,0,0,1) From 8ecfc077f1f115177531fd525eb776446aa8da98 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Sun, 5 Apr 2020 15:42:22 -0700 Subject: [PATCH 011/242] created a new start from cmd line test file --- tests/start_task_from_cmd_line_sim.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/start_task_from_cmd_line_sim.py diff --git a/tests/start_task_from_cmd_line_sim.py b/tests/start_task_from_cmd_line_sim.py new file mode 100644 index 00000000..2569600e --- /dev/null +++ b/tests/start_task_from_cmd_line_sim.py @@ -0,0 +1,36 @@ +''' +Test script to run the visual feedback task from the command line +''' +from db import dbfunctions +from db.tracker import models + +from riglib import experiment +from features.generator_features import Autostart +from features.hdf_features import SaveHDF +from features.plexon_features import PlexonBMI + +from tasks import generatorfunctions as genfns +from analysis import performance + +# Tell linux to use Display 0 (the monitor physically attached to the +# machine. Otherwise, if you are connected remotely, it will try to run +# the graphics through SSH, which doesn't work for some reason. +import os +os.environ['DISPLAY'] = ':0' + +task = models.Task.objects.get(name='clda_kf_ofc_tentacle_rml_trial') +base_class = task.get() + +feats = [SaveHDF, PlexonBMI, Autostart] +Exp = experiment.make(base_class, feats=feats) + +#params.trait_norm(Exp.class_traits()) +params = dict(session_length=30, arm_visible=True, arm_class='RobotArmGen2D', + assist_level=(2., 2.), assist_time=60., rand_start=(0.,0.), max_tries=1) + +gen = genfns.sim_target_seq_generator_multi(8, 1000) +exp = Exp(gen, **params) + +exp.decoder = performance._get_te(3979).decoder + +exp.start() From 7e3968e3f1004e4294905373c5414ef522d8b21d Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Sun, 5 Apr 2020 16:58:33 -0700 Subject: [PATCH 012/242] fixed django.shortcuts.render_to_response depreciation issue --- db/tracker/views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/db/tracker/views.py b/db/tracker/views.py index eb02ece4..f9303403 100644 --- a/db/tracker/views.py +++ b/db/tracker/views.py @@ -8,7 +8,7 @@ import os from django.template import RequestContext -from django.shortcuts import render_to_response, render +from django.shortcuts import render from django.http import HttpResponse from . import exp_tracker @@ -113,7 +113,7 @@ def list_exp_history(request, **kwargs): if tracker.task_proxy is not None and "saveid" in tracker.task_kwargs: fields['running'] = tracker.task_kwargs["saveid"] - resp = render_to_response('list.html', fields, RequestContext(request)) + resp = render(none,'list.html', fields, RequestContext(request)) return resp def setup(request): From feee76131ffa8fc01b7ae72d37dea666ea3813aa Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Sun, 5 Apr 2020 17:17:28 -0700 Subject: [PATCH 013/242] fixed django.shortcuts.render_to_response depreciation issue and re setup the database settings --- db/settings.py | 31 ++++++++++++++++++++++++++----- db/tracker/views.py | 4 ++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/db/settings.py b/db/settings.py index ddf04c67..1bf0873b 100644 --- a/db/settings.py +++ b/db/settings.py @@ -12,6 +12,7 @@ djcelery.setup_loader() # Django settings for db project. + DEBUG = True TEMPLATE_DEBUG = DEBUG @@ -24,12 +25,11 @@ db_dir = cwd def get_sqlite3_databases(): dbs = dict() - db_files = glob.glob(os.path.join(db_dir, '*.sql')) for db in db_files: db_name_re = re.match('db(.*?).sql', os.path.basename(db)) db_name = db_name_re.group(1) - + print(db_name) if db_name.startswith('_'): db_name = db_name[1:] elif db_name == "": @@ -40,7 +40,7 @@ def get_sqlite3_databases(): continue dbs[db_name] = { - 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. + 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or'oracle'. 'NAME': db, # Or path to database file if using sqlite3. 'USER': '', # Not used with sqlite3. 'PASSWORD': '', # Not used with sqlite3. @@ -50,7 +50,28 @@ def get_sqlite3_databases(): return dbs -DATABASES = get_sqlite3_databases() +#DATABASES = get_sqlite3_databases() +#DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': 'mydatabase', } } + + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. + 'NAME': os.path.join(cwd, "db.sql"), # Or path to database file if using sqlite3. + 'USER': '', # Not used with sqlite3. + 'PASSWORD': '', # Not used with sqlite3. + 'HOST': '', # Set to empty string for localhost. Not used with sqlite3. + 'PORT': '', # Set to empty string for default. Not used with sqlite3. + }, + 'testing': { + 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. + 'NAME': os.path.join(cwd, "db_testing.sql"), # Or path to database file if using sqlite3. + 'USER': '', # Not used with sqlite3. + 'PASSWORD': '', # Not used with sqlite3. + 'HOST': '', # Set to empty string for localhost. Not used with sqlite3. + 'PORT': '', # Set to empty string for default. Not used with sqlite3. + }, +} # Local time zone for this installation. Choices can be found here: # http://en.wikipedia.org/wiki/List_of_tz_zones_by_name @@ -170,4 +191,4 @@ def get_sqlite3_databases(): APPEND_SLASH=False -ALLOWED_HOSTS = ['127.0.0.1', 'localhost', "testserver"] \ No newline at end of file +ALLOWED_HOSTS = ['127.0.0.1', 'localhost', "testserver"] diff --git a/db/tracker/views.py b/db/tracker/views.py index eb02ece4..6a5fc7dc 100644 --- a/db/tracker/views.py +++ b/db/tracker/views.py @@ -8,7 +8,7 @@ import os from django.template import RequestContext -from django.shortcuts import render_to_response, render +from django.shortcuts import render from django.http import HttpResponse from . import exp_tracker @@ -113,7 +113,7 @@ def list_exp_history(request, **kwargs): if tracker.task_proxy is not None and "saveid" in tracker.task_kwargs: fields['running'] = tracker.task_kwargs["saveid"] - resp = render_to_response('list.html', fields, RequestContext(request)) + resp = render(None,'list.html', fields) return resp def setup(request): From f8d233b0f4e975ce123b6e037cbc78121a33e3fc Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 6 Apr 2020 10:19:10 -0700 Subject: [PATCH 014/242] added saveHDF feature to the cursorControlTasks --- built_in_tasks/cursorControlTasks_saveHDF.py | 126 +++++++++++++++++++ db/tracker/views.py | 4 - 2 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 built_in_tasks/cursorControlTasks_saveHDF.py diff --git a/built_in_tasks/cursorControlTasks_saveHDF.py b/built_in_tasks/cursorControlTasks_saveHDF.py new file mode 100644 index 00000000..d24d9a00 --- /dev/null +++ b/built_in_tasks/cursorControlTasks_saveHDF.py @@ -0,0 +1,126 @@ + +from manualcontrolmultitasks import ManualControlMulti +from riglib.stereo_opengl.window import WindowDispl2D +#from bmimultitasks import BMIControlMulti +import pygame +import numpy as np +import copy + +#from riglib.bmi.extractor import DummyExtractor +#from riglib.bmi.state_space_models import StateSpaceEndptVel2D +#from riglib.bmi.bmi import Decoder, BMISystem, GaussianStateHMM, BMILoop, GaussianState, MachineOnlyFilter +from riglib import experiment +from features.hdf_features import SaveHDF + +class CursorControl(ManualControlMulti, WindowDispl2D): + ''' + this class implements a python cursor control task for human + ''' + + def __init__(self, *args, **kwargs): + # just run the parent ManualControlMulti's initialization + self.move_step = 1 + + # Initialize target location variable + #target location and index have been initializd + + super(CursorControl, self).__init__(*args, **kwargs) + + def init(self): + pygame.init() + + + + self.assist_level = (0, 0) + super(CursorControl, self).init() + + # override the _cycle function + def _cycle(self): + #print(self.state) + + #target and plant data have been saved in + #the parent manualcontrolmultitasks + + self.move_effector_cursor() + super(CursorControl, self)._cycle() + + # do nothing + def move_effector(self): + pass + + def move_plant(self, **kwargs): + pass + + # use keyboard to control the task + def move_effector_cursor(self): + np.array([0., 0., 0.]) + curr_pos = copy.deepcopy(self.plant.get_endpoint_pos()) + + for event in pygame.event.get(): + if event.type == pygame.KEYUP: + if event.type == pygame.K_q: + pygame.quit() + quit() + if event.key == pygame.K_LEFT: + curr_pos[0] -= self.move_step + if event.key == pygame.K_RIGHT: + curr_pos[0] += self.move_step + if event.key == pygame.K_UP: + curr_pos[2] += self.move_step + if event.key == pygame.K_DOWN: + curr_pos[2] -= self.move_step + #print('Current position: ') + #print(curr_pos) + + # set the current position + self.plant.set_endpoint_pos(curr_pos) + + def _start_wait(self): + self.wait_time = 0. + super(CursorControl, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + +#this task can be run on its +#we will not involve database at this time +target_pos_radius = 10 + +def target_seq_generator(n_targs, n_trials): + #generate targets + angles = np.transpose(np.arange(0,2*np.pi,2*np.pi / n_targs)) + unit_targets = targets = np.stack((np.cos(angles), np.sin(angles)),1) + targets = unit_targets * target_pos_radius + + center = np.array((0,0)) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + + k = 0 + while k < n_trials: + targ = targets[target_inds[k], :] + yield np.array([[center[0], 0, center[1]], + [targ[0], 0, targ[1]]]) + k += 1 + + +if __name__ == "__main__": + print('Remember to set window size in stereoOpenGL class') + gen = target_seq_generator(8, 1000) + + #incorporate the saveHDF feature by blending code + #see tests\start_From_cmd_line_sim + + base_class = CursorControl + feats = [SaveHDF] + Exp = experiment.make(base_class, feats=feats) + print(Exp) + + exp = Exp(gen) + exp.init() + exp.run() #start the task + + + + \ No newline at end of file diff --git a/db/tracker/views.py b/db/tracker/views.py index f021fc4e..6a5fc7dc 100644 --- a/db/tracker/views.py +++ b/db/tracker/views.py @@ -113,11 +113,7 @@ def list_exp_history(request, **kwargs): if tracker.task_proxy is not None and "saveid" in tracker.task_kwargs: fields['running'] = tracker.task_kwargs["saveid"] -<<<<<<< HEAD resp = render(None,'list.html', fields) -======= - resp = render(none,'list.html', fields, RequestContext(request)) ->>>>>>> 7e3968e3f1004e4294905373c5414ef522d8b21d return resp def setup(request): From 4c641532ecadaa9403ddf7f0bb12bfe62739f50b Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Fri, 10 Apr 2020 16:07:26 -0700 Subject: [PATCH 015/242] added test_db file --- tests/test_db.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_db.py diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 00000000..36725fb6 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,26 @@ +import os +import django + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'db.settings') +django.setup() + +from tracker import models + +fe = models.Feature.objects.all() +print(fe) + +ID_NUMBER = 77 +te = models.TaskEntry.objects.get(id=ID_NUMBER) +print(te.subject) + + +# try a different method +# need to take out the dot in the .tracker import +# this method gets depreciated +import db.dbfunctions as dbfn +te_1 = dbfn.TaskEntry(ID_NUMBER) +print(te_1.date) + +''' + +''' \ No newline at end of file From 02c9048c95adaef9646c0b0d3dbf8a5cd47c46d7 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Fri, 10 Apr 2020 16:08:31 -0700 Subject: [PATCH 016/242] fixed the relative import in dbfunctions --- db/dbfunctions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/db/dbfunctions.py b/db/dbfunctions.py index f0c9a4af..9257e59c 100644 --- a/db/dbfunctions.py +++ b/db/dbfunctions.py @@ -24,7 +24,7 @@ except: pass -from .tracker import models +from tracker import models # default DB, change this variable from python session to switch to other database db_name = 'default' From 549e0cfabcb06768a800b8c18330c5e16da98e33 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Fri, 10 Apr 2020 16:09:16 -0700 Subject: [PATCH 017/242] seem to fix django start sequence --- tests/test_dbfunctions.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/test_dbfunctions.py b/tests/test_dbfunctions.py index 748a4898..6602714b 100644 --- a/tests/test_dbfunctions.py +++ b/tests/test_dbfunctions.py @@ -2,26 +2,34 @@ ''' A set of tests to ensure that all the dbfunctions are still functional ''' -from db import dbfunctions as dbfn -reload(dbfn) +import os +import django -id = 4150 +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'db.settings') +django.setup() + +from tracker import models +import db.dbfunctions as dbfn + +id = 77 te = dbfn.TaskEntry(id) -print(dbfn.get_plx_file(id)) -print(dbfn.get_decoder_name(id)) -print(dbfn.get_decoder_name_full(id)) -print(dbfn.get_decoder(id)) -print(dbfn.get_params(id)) -print(dbfn.get_param(id, 'decoder')) print(dbfn.get_date(id)) print(dbfn.get_notes(id)) print(dbfn.get_subject(id)) print(dbfn.get_length(id)) print(dbfn.get_success_rate(id)) -id = 1956 -print(dbfn.get_bmiparams_file(id)) +#print(dbfn.get_plx_file(id)) +#print(dbfn.get_decoder_name(id)) +#print(dbfn.get_decoder_name_full(id)) +#print(dbfn.get_decoder(id)) +#print(dbfn.get_params(id)) +#print(dbfn.get_param(id, 'decoder')) + + +#id = 1956 +#print(dbfn.get_bmiparams_file(id)) # TODO check blackrock file fns From f1eca7b577f0d59fba97bee2a523d064b6e97e61 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 13 Apr 2020 11:00:44 -0700 Subject: [PATCH 018/242] defaut wt neuron sources to averaging --- riglib/bmi/sim_neurons.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/riglib/bmi/sim_neurons.py b/riglib/bmi/sim_neurons.py index ce887270..dc91d60c 100644 --- a/riglib/bmi/sim_neurons.py +++ b/riglib/bmi/sim_neurons.py @@ -320,7 +320,13 @@ def gen_spikes(self, next_state, mode=None): self.shar_unt = t_unt #Now weight everything together: - w = self.wt_sources + if self.wt_sources is None: # if mp wt_sources, equally weight the sources + w = np.array([1,1,1,1]) / 4 + else: + w = self.wt_sources + + + counts = np.squeeze(np.array(w[0]*self.priv_unt + w[1]*self.priv_tun + w[2]*self.shar_unt + w[3]*self.shar_tun)) #Adding back the mean FR From d5349a6e9c40616bf8c8d105ca87cfdda53b399e Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 13 Apr 2020 11:10:34 -0700 Subject: [PATCH 019/242] update file references --- tests/test_fa_decoding.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_fa_decoding.py b/tests/test_fa_decoding.py index 62fc06dd..e6cc126e 100644 --- a/tests/test_fa_decoding.py +++ b/tests/test_fa_decoding.py @@ -6,8 +6,8 @@ from riglib.bmi.feedback_controllers import LQRController from riglib import experiment -from tasks.bmimultitasks import BMIControlMulti -from tasks import manualcontrolmultitasks +from built_in_tasks.bmimultitasks import BMIControlMulti +from built_in_tasks import manualcontrolmultitasks import plantlist @@ -16,7 +16,7 @@ import pickle import time, datetime -from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceEndptVelY +from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceEndptPos1D from riglib.bmi import feedback_controllers class SuperSimpleEndPtAssister(object): @@ -206,7 +206,7 @@ def main_xz(session_length): return task def main_Y(session_length): - ssm_y = StateSpaceEndptVelY() + ssm_y = StateSpaceEndptPos1D() Task = experiment.make(SimVFB, [SaveHDF]) targets = SimVFB.centerout_Y_discrete() #targets = manualcontrolmultitasks.ManualControlMulti.centerout_2D_discrete() @@ -243,4 +243,10 @@ def save_stuff(task, suffix=''): f.close() #Return filename - return pnm \ No newline at end of file + return pnm + + +if __name__ == "__main__": + session_length = 100 + main_xz(session_length) + pass \ No newline at end of file From 188f8069198e588770e1e16ac02e537acd8894fb Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 13 Apr 2020 11:55:22 -0700 Subject: [PATCH 020/242] check if sim_c exists --- built_in_tasks/bmimultitasks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index 7692cb2d..f058c319 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -487,8 +487,12 @@ class SimBMIControlMulti(SimCosineTunedEnc, SimKFDecoderSup, BMIControlMulti): def __init__(self, *args, **kwargs): from riglib.bmi.state_space_models import StateSpaceEndptVel2D ssm = StateSpaceEndptVel2D() + + if 'sim_C' in kwargs: + self.sim_C = kwargs['sim_C'] A, B, W = ssm.get_ssm_matrices() + Q = np.mat(np.diag([1., 1, 1, 0, 0, 0, 0])) R = 10000*np.mat(np.diag([1., 1., 1.])) self.fb_ctrl = LQRController(A, B, Q, R) From d13ddbc7250fc3f2fc36f7662d3c35990d5ab76f Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 13 Apr 2020 13:41:05 -0700 Subject: [PATCH 021/242] set up start trial conditions --- built_in_tasks/bmimultitasks.py | 12 +++++++++++ built_in_tasks/test_SimBMIControlMulti.py | 25 +++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 built_in_tasks/test_SimBMIControlMulti.py diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index f058c319..a9c0de18 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -490,6 +490,10 @@ def __init__(self, *args, **kwargs): if 'sim_C' in kwargs: self.sim_C = kwargs['sim_C'] + if 'assist_level' in kwargs: + self.assist_level = kwargs['assist_level'] + + A, B, W = ssm.get_ssm_matrices() @@ -501,6 +505,14 @@ def __init__(self, *args, **kwargs): super(SimBMIControlMulti, self).__init__(*args, **kwargs) + + def _start_wait(self): + self.wait_time = 0. + super()._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + @staticmethod def sim_target_seq_generator_multi(n_targs, n_trials): ''' diff --git a/built_in_tasks/test_SimBMIControlMulti.py b/built_in_tasks/test_SimBMIControlMulti.py new file mode 100644 index 00000000..561c8471 --- /dev/null +++ b/built_in_tasks/test_SimBMIControlMulti.py @@ -0,0 +1,25 @@ +from bmimultitasks import SimBMIControlMulti +import numpy as np +#build a sequence generator +N_TARGETS = 8 +N_TRIALS = 6 +seq = SimBMIControlMulti.sim_target_seq_generator_multi(N_TARGETS, N_TRIALS) + +#build a observer matrix +N_NEURONS = 20 +N_STATES = 7 #3 positions and 3 velocities and an offset + +sim_C = np.zeros((N_NEURONS,N_STATES)) +#control x positive directions +sim_C[2,:] = np.array([0,0,0,1,0,0,0]) +sim_C[3,:] = np.array([0,0,0,-1,0,0,0]) +#control z positive directions +sim_C[5,:] = np.array([0,0,0,0,0,1,0]) +sim_C[6,:] = np.array([0,0,0,0,0,-1,0]) + +#set up assist level +assist_level = (0.1, 0.1) + +exp = SimBMIControlMulti(seq,sim_C = sim_C, assist_level = assist_level) +exp.init() +exp.run() \ No newline at end of file From e448ef331689a27f1d8b61fa09528b0fc268453f Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 13 Apr 2020 14:21:48 -0700 Subject: [PATCH 022/242] added saveHDF feature to the simulation task --- built_in_tasks/test_SimBMIControlMulti.py | 68 +++++++++++++++-------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/built_in_tasks/test_SimBMIControlMulti.py b/built_in_tasks/test_SimBMIControlMulti.py index 561c8471..a8d16ffb 100644 --- a/built_in_tasks/test_SimBMIControlMulti.py +++ b/built_in_tasks/test_SimBMIControlMulti.py @@ -1,25 +1,47 @@ from bmimultitasks import SimBMIControlMulti +from features import SaveHDF import numpy as np -#build a sequence generator -N_TARGETS = 8 -N_TRIALS = 6 -seq = SimBMIControlMulti.sim_target_seq_generator_multi(N_TARGETS, N_TRIALS) - -#build a observer matrix -N_NEURONS = 20 -N_STATES = 7 #3 positions and 3 velocities and an offset - -sim_C = np.zeros((N_NEURONS,N_STATES)) -#control x positive directions -sim_C[2,:] = np.array([0,0,0,1,0,0,0]) -sim_C[3,:] = np.array([0,0,0,-1,0,0,0]) -#control z positive directions -sim_C[5,:] = np.array([0,0,0,0,0,1,0]) -sim_C[6,:] = np.array([0,0,0,0,0,-1,0]) - -#set up assist level -assist_level = (0.1, 0.1) - -exp = SimBMIControlMulti(seq,sim_C = sim_C, assist_level = assist_level) -exp.init() -exp.run() \ No newline at end of file +from riglib import experiment + +# build a sequence generator +if __name__ == "__main__": + N_TARGETS = 8 + N_TRIALS = 6 + seq = SimBMIControlMulti.sim_target_seq_generator_multi( + N_TARGETS, N_TRIALS) + + # build a observer matrix + N_NEURONS = 20 + N_STATES = 7 # 3 positions and 3 velocities and an offset + + # build the observation matrix + sim_C = np.zeros((N_NEURONS, N_STATES)) + # control x positive directions + sim_C[2, :] = np.array([0, 0, 0, 1, 0, 0, 0]) + sim_C[3, :] = np.array([0, 0, 0, -1, 0, 0, 0]) + # control z positive directions + sim_C[5, :] = np.array([0, 0, 0, 0, 0, 1, 0]) + sim_C[6, :] = np.array([0, 0, 0, 0, 0, -1, 0]) + + # set up assist level + assist_level = (0.1, 0.1) + + #exp = SimBMIControlMulti(seq,sim_C = sim_C, assist_level = assist_level) + # exp.init() + # exp.run() + + kwargs = dict() + + kwargs['sim_C'] = sim_C + kwargs['assist_level'] = assist_level + + base_class = SimBMIControlMulti + feats = [SaveHDF] + #feats = [] + Exp = experiment.make(base_class, feats=feats) + print(Exp) + + + exp = Exp(seq, **kwargs) + exp.init() + exp.run() # start the task From 2c71f2a0acff0bd5978e4d87cdf88ae01197cf87 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 13 Apr 2020 14:26:05 -0700 Subject: [PATCH 023/242] move hdf analysis to shared analysis folder --- analysis_shared/open_hdf_tables.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 analysis_shared/open_hdf_tables.py diff --git a/analysis_shared/open_hdf_tables.py b/analysis_shared/open_hdf_tables.py new file mode 100644 index 00000000..138de956 --- /dev/null +++ b/analysis_shared/open_hdf_tables.py @@ -0,0 +1,23 @@ +import tables +import numpy +import matplotlib.pyplot as plt + +#replace this with your hdf filename +fname = 'c:\\Users\\Si Jia\\AppData\\Local\\Temp\\tmp9fswwtwp.h5' +hdffile = tables.open_file(fname,'r') #read-onl + +print(hdffile) + +#get table information +# more methods refer to this +# https://www.pytables.org/usersguide/libref/structured_storage.html#tables.Table +table = hdffile.root.task +print(table.description) + +#look at cursor trajectory +cursor_coor = table.col('cursor') +plt.plot(cursor_coor[:,0], cursor_coor[:,2]) +plt.show() + + + From 499365782bca93870da6cdbc6bf1343c5aa0339a Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Fri, 24 Apr 2020 17:13:41 -0700 Subject: [PATCH 024/242] revert to the original database finder --- db/settings.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/db/settings.py b/db/settings.py index 1bf0873b..b95fd8e7 100644 --- a/db/settings.py +++ b/db/settings.py @@ -50,10 +50,10 @@ def get_sqlite3_databases(): return dbs -#DATABASES = get_sqlite3_databases() +DATABASES = get_sqlite3_databases() #DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': 'mydatabase', } } - +''' DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. @@ -72,6 +72,7 @@ def get_sqlite3_databases(): 'PORT': '', # Set to empty string for default. Not used with sqlite3. }, } +''' # Local time zone for this installation. Choices can be found here: # http://en.wikipedia.org/wiki/List_of_tz_zones_by_name From 4090de65effb0fda7b1cc9f17a588ac1e59fca6b Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Sun, 26 Apr 2020 20:24:43 -0700 Subject: [PATCH 025/242] fixed cursorControl task's issue of not starting from django --- built_in_tasks/cursorControlTasks.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/built_in_tasks/cursorControlTasks.py b/built_in_tasks/cursorControlTasks.py index 57555753..13549d5d 100644 --- a/built_in_tasks/cursorControlTasks.py +++ b/built_in_tasks/cursorControlTasks.py @@ -18,13 +18,14 @@ class CursorControl(ManualControlMulti, WindowDispl2D): def __init__(self, *args, **kwargs): # just run the parent ManualControlMulti's initialization self.move_step = 1 + self.assist_level = (0.5, 0.5) super(CursorControl, self).__init__(*args, **kwargs) - + def init(self): - pygame.init() - self.assist_level = (0.5, 0.5) - super(CursorControl, self).init() + #pygame.init() + super(CursorControl, self).init() + # override the _cycle function def _cycle(self): #print(self.state) From b3509ca516e4d9b62c1e7454940548584709f542 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Tue, 2 Jun 2020 16:01:14 -0700 Subject: [PATCH 026/242] set up tests for reward system --- tests/test_reward_ao.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/test_reward_ao.py diff --git a/tests/test_reward_ao.py b/tests/test_reward_ao.py new file mode 100644 index 00000000..a7e7c61c --- /dev/null +++ b/tests/test_reward_ao.py @@ -0,0 +1,27 @@ +from riglib import ao_reward + +import unittest + +class TestAoReward(unittest.TestCase): + + def setUp(self): + self.reward_sys = ao_reward.Basic() + + + @unittest.skip("not sure which method to use") + def test_connection(self): + pass + + def test_flow_out(self): + reward_time = 1000 #ms + self.reward_sys(reward_time) + + @unittest.skip("not sure how to calibrate yet") + def test_calibration(self): + print('fill up bottle') + + +if __name__ == '__main__': + unittest.main() + + From 6e3bcf980e13e08d2aa2a8000c155d8e5e463629 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 15 Jun 2020 14:23:41 -0700 Subject: [PATCH 027/242] removed logging of eye_tracking --- features/eyetracker_features.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/features/eyetracker_features.py b/features/eyetracker_features.py index d4e706fc..6caa0703 100644 --- a/features/eyetracker_features.py +++ b/features/eyetracker_features.py @@ -32,12 +32,12 @@ def init(self): self.sinks = sink.sinks src, ekw = self.eye_source - f = open('/home/helene/code/bmi3d/log/eyetracker', 'a') + #f = open('/home/helene/code/bmi3d/log/eyetracker', 'a') self.eyedata = source.DataSource(src, **ekw) self.sinks.register(self.eyedata) - f.write('instantiated source\n') + #f.write('instantiated source\n') super(EyeData, self).init() - f.close() + #f.close() @property def eye_source(self): @@ -58,10 +58,10 @@ def run(self): Code to execute immediately prior to the beginning of the task FSM executing, or after the FSM has finished running. See riglib.experiment.Experiment.run(). This 'run' method starts the 'eyedata' source and stops it after the FSM has finished running ''' - f = open('/home/helene/code/bmi3d/log/eyetracker', 'a') + #f = open('/home/helene/code/bmi3d/log/eyetracker', 'a') self.eyedata.start() - f.write('started eyedata\n') - f.close() + #f.write('started eyedata\n') + #f.close() try: super(EyeData, self).run() finally: @@ -121,6 +121,7 @@ def cleanup(self, database, saveid, **kwargs): Returns ------- ''' + super(EyeData, self).cleanup(database, saveid, **kwargs) dbname = kwargs['dbname'] if 'dbname' in kwargs else 'default' if dbname == 'default': @@ -145,7 +146,7 @@ def eye_source(self): ------- ''' from riglib import eyetracker - return eyetracker.Simulate, dict(fixations=fixations, fixation_len=fixation_len) + return eyetracker.Simulate, dict(fixations= self.fixations) class CalibratedEyeData(EyeData): '''Filters eyetracking data with a calibration profile''' @@ -218,4 +219,12 @@ def _test_start_trial(self, ts): Returns ------- ''' - return ts > self.fixation_length \ No newline at end of file + return ts > self.fixation_length + +''' +if __name__ == "__main__": + sim_eye_data = SimulatedEyeData() + + sim_eye_data.init() + sim_eye_data.run() +''' From f1fc6d89b4e7a4aba185d1486181f436ac15e24d Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 15 Jun 2020 14:25:48 -0700 Subject: [PATCH 028/242] add mock retrieve function to simulate --- riglib/eyetracker.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/riglib/eyetracker.py b/riglib/eyetracker.py index a1729e78..9fdea251 100644 --- a/riglib/eyetracker.py +++ b/riglib/eyetracker.py @@ -38,6 +38,7 @@ def __init__(self, fixations=[(0,0), (-0.6,0.3), (0.6,0.3)], isi=500, slen=15): self.isi = isi def start(self): + ''' Docstring @@ -48,6 +49,18 @@ def start(self): ------- ''' self.stime = time.time() + + def retrieve(self, filename): + ''' + for sim, there is no need to retrieve an file + + Parameters + ---------- + + Returns + ------- + ''' + pass def get(self): ''' @@ -59,7 +72,7 @@ def get(self): Returns ------- ''' - time.sleep(1./update_freq) + time.sleep(1./self.update_freq) return self.interp((time.time() - self.stime) % self.mod) + np.random.randn(2)*.01 def stop(self): @@ -72,7 +85,19 @@ def stop(self): Returns ------- ''' - return + return + + def sendMsg(self, msg): + ''' + Docstring + + Parameters + ---------- + + Returns + ------- + ''' + pass class System(object): ''' From 3d552e545deaab1eadeabfabd3f0792de79c864b Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 15 Jun 2020 14:27:54 -0700 Subject: [PATCH 029/242] do not track dot graphic files --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index ab70be92..15874839 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ env/* tests/*.mat tests/*.hdf *.h5 +*.dot From d0a417eb94d069e6872283a91a17a1cbb232a963 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 25 Jun 2020 20:13:50 -0700 Subject: [PATCH 030/242] explictly add ms to reward time --- tests/test_reward_ao.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_reward_ao.py b/tests/test_reward_ao.py index a7e7c61c..f52db631 100644 --- a/tests/test_reward_ao.py +++ b/tests/test_reward_ao.py @@ -13,8 +13,8 @@ def test_connection(self): pass def test_flow_out(self): - reward_time = 1000 #ms - self.reward_sys(reward_time) + reward_time_ms = 1000 #ms + self.reward_sys(reward_time_ms) @unittest.skip("not sure how to calibrate yet") def test_calibration(self): From 12b817885ac5d7c9e2f9dfc9a0d77d43073c51d5 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Thu, 25 Jun 2020 20:14:56 -0700 Subject: [PATCH 031/242] add cursor control with sim mocap eye data --- .../cursorControlTasks_sim_mocap_eye.py | 132 ++++++++++++++++++ features/eyetracker_features.py | 26 +++- 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 built_in_tasks/cursorControlTasks_sim_mocap_eye.py diff --git a/built_in_tasks/cursorControlTasks_sim_mocap_eye.py b/built_in_tasks/cursorControlTasks_sim_mocap_eye.py new file mode 100644 index 00000000..470ae9f7 --- /dev/null +++ b/built_in_tasks/cursorControlTasks_sim_mocap_eye.py @@ -0,0 +1,132 @@ + +from manualcontrolmultitasks import ManualControlMulti +from riglib.stereo_opengl.window import WindowDispl2D +#from bmimultitasks import BMIControlMulti +import pygame +import numpy as np +import copy + +#from riglib.bmi.extractor import DummyExtractor +#from riglib.bmi.state_space_models import StateSpaceEndptVel2D +#from riglib.bmi.bmi import Decoder, BMISystem, GaussianStateHMM, BMILoop, GaussianState, MachineOnlyFilter +from riglib import experiment + + +class CursorControl(ManualControlMulti, WindowDispl2D): + ''' + this class implements a python cursor control task for human + ''' + + def __init__(self, *args, **kwargs): + # just run the parent ManualControlMulti's initialization + self.move_step = 1 + + # Initialize target location variable + #target location and index have been initializd + + super(CursorControl, self).__init__(*args, **kwargs) + + def init(self): + pygame.init() + + + + self.assist_level = (0, 0) + super(CursorControl, self).init() + + # override the _cycle function + def _cycle(self): + #print(self.state) + + #target and plant data have been saved in + #the parent manualcontrolmultitasks + + self.move_effector_cursor() + super(CursorControl, self)._cycle() + + # do nothing + def move_effector(self): + pass + + def move_plant(self, **kwargs): + pass + + # use keyboard to control the task + def move_effector_cursor(self): + np.array([0., 0., 0.]) + curr_pos = copy.deepcopy(self.plant.get_endpoint_pos()) + + for event in pygame.event.get(): + if event.type == pygame.KEYUP: + if event.type == pygame.K_q: + pygame.quit() + quit() + if event.key == pygame.K_LEFT: + curr_pos[0] -= self.move_step + if event.key == pygame.K_RIGHT: + curr_pos[0] += self.move_step + if event.key == pygame.K_UP: + curr_pos[2] += self.move_step + if event.key == pygame.K_DOWN: + curr_pos[2] -= self.move_step + #print('Current position: ') + #print(curr_pos) + + # set the current position + self.plant.set_endpoint_pos(curr_pos) + + def _start_wait(self): + self.wait_time = 0. + super(CursorControl, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + +#this task can be run on its +#we will not involve database at this time +target_pos_radius = 10 + +def target_seq_generator(n_targs, n_trials): + #generate targets + angles = np.transpose(np.arange(0,2*np.pi,2*np.pi / n_targs)) + unit_targets = targets = np.stack((np.cos(angles), np.sin(angles)),1) + targets = unit_targets * target_pos_radius + + center = np.array((0,0)) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + + k = 0 + while k < n_trials: + targ = targets[target_inds[k], :] + yield np.array([[center[0], 0, center[1]], + [targ[0], 0, targ[1]]]) + k += 1 + + +if __name__ == "__main__": + print('Remember to set window size in stereoOpenGL class') + gen = target_seq_generator(8, 1000) + + #incorporate the saveHDF feature by blending code + #see tests\start_From_cmd_line_sim + + from features.hdf_features import SaveHDF + #incorporate eyetracking and mocap simulation data + from features.eyetracker_features import SimulatedEyeData + from features.phasespace_features import MotionSimulate + + #mix features into the experiment base class + base_class = CursorControl + feats = [SimulatedEyeData, SaveHDF] + Exp = experiment.make(base_class, feats=feats) + print(Exp) + + exp = Exp(gen) + exp.init() + exp.run() #start the task + + + + \ No newline at end of file diff --git a/features/eyetracker_features.py b/features/eyetracker_features.py index 6caa0703..6da7693e 100644 --- a/features/eyetracker_features.py +++ b/features/eyetracker_features.py @@ -59,6 +59,7 @@ def run(self): See riglib.experiment.Experiment.run(). This 'run' method starts the 'eyedata' source and stops it after the FSM has finished running ''' #f = open('/home/helene/code/bmi3d/log/eyetracker', 'a') + print("before eyedata run") self.eyedata.start() #f.write('started eyedata\n') #f.close() @@ -148,6 +149,30 @@ def eye_source(self): from riglib import eyetracker return eyetracker.Simulate, dict(fixations= self.fixations) + def _cycle(self): + ''' + Docstring + basically, extract the data and do something with it + + + Parameters + ---------- + + Returns + ------- + ''' + print('cycle started') + #retrieve data + data_temp = self.eyedata.get() + + #send the data to sinks + if data_temp is not None: + self.sinks.send(self.eyedata.name, data_temp) + + super(SimulatedEyeData, self)._cycle() + + + class CalibratedEyeData(EyeData): '''Filters eyetracking data with a calibration profile''' cal_profile = traits.Instance(calibrations.EyeProfile) @@ -164,7 +189,6 @@ def __init__(self, *args, **kwargs): ''' super(CalibratedEyeData, self).__init__(*args, **kwargs) self.eyedata.set_filter(self.cal_profile) - class FixationStart(CalibratedEyeData): '''Triggers the start_trial event whenever fixation exceeds *fixation_length*''' fixation_length = traits.Float(2., desc="Length of fixation required to start the task") From 9622292b16d0ae522e5659448bf2e990130a94df Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Thu, 25 Jun 2020 23:56:31 -0700 Subject: [PATCH 032/242] fixed sim data dimensions --- built_in_tasks/cursorControlTasks_sim_mocap_eye.py | 3 ++- features/eyetracker_features.py | 2 -- riglib/eyetracker.py | 7 ++++++- riglib/hdfwriter/hdfwriter/hdfwriter.py | 1 + 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/built_in_tasks/cursorControlTasks_sim_mocap_eye.py b/built_in_tasks/cursorControlTasks_sim_mocap_eye.py index 470ae9f7..1061fd53 100644 --- a/built_in_tasks/cursorControlTasks_sim_mocap_eye.py +++ b/built_in_tasks/cursorControlTasks_sim_mocap_eye.py @@ -119,7 +119,8 @@ def target_seq_generator(n_targs, n_trials): #mix features into the experiment base class base_class = CursorControl - feats = [SimulatedEyeData, SaveHDF] + feats = [SaveHDF, SimulatedEyeData] + #feats = [SaveHDF] Exp = experiment.make(base_class, feats=feats) print(Exp) diff --git a/features/eyetracker_features.py b/features/eyetracker_features.py index 6da7693e..6ee2b1e4 100644 --- a/features/eyetracker_features.py +++ b/features/eyetracker_features.py @@ -59,7 +59,6 @@ def run(self): See riglib.experiment.Experiment.run(). This 'run' method starts the 'eyedata' source and stops it after the FSM has finished running ''' #f = open('/home/helene/code/bmi3d/log/eyetracker', 'a') - print("before eyedata run") self.eyedata.start() #f.write('started eyedata\n') #f.close() @@ -161,7 +160,6 @@ def _cycle(self): Returns ------- ''' - print('cycle started') #retrieve data data_temp = self.eyedata.get() diff --git a/riglib/eyetracker.py b/riglib/eyetracker.py index 9fdea251..d537ee55 100644 --- a/riglib/eyetracker.py +++ b/riglib/eyetracker.py @@ -48,6 +48,7 @@ def start(self): Returns ------- ''' + print("eyetracker.simulate.start()") self.stime = time.time() def retrieve(self, filename): @@ -73,7 +74,11 @@ def get(self): ------- ''' time.sleep(1./self.update_freq) - return self.interp((time.time() - self.stime) % self.mod) + np.random.randn(2)*.01 + + data = self.interp((time.time() - self.stime) % self.mod) + np.random.randn(2)*.01 + #expand dims + data_2 = np.expand_dims(data, axis = 0) + return data_2 def stop(self): ''' diff --git a/riglib/hdfwriter/hdfwriter/hdfwriter.py b/riglib/hdfwriter/hdfwriter/hdfwriter.py index 316555ee..01e6dd4f 100644 --- a/riglib/hdfwriter/hdfwriter/hdfwriter.py +++ b/riglib/hdfwriter/hdfwriter/hdfwriter.py @@ -56,6 +56,7 @@ def register(self, name, dtype, include_msgs=True): None ''' print("HDFWriter registered %r" % name) + print(dtype) if dtype.subdtype is not None: #just a simple dtype with a shape dtype, sliceshape = dtype.subdtype From ef948071d8adf1ba0905794fddeb6b0d25618e2e Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Fri, 26 Jun 2020 14:46:23 -0700 Subject: [PATCH 033/242] fixed dim expansion for mocap --- built_in_tasks/cursorControlTasks_sim_mocap_eye.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/built_in_tasks/cursorControlTasks_sim_mocap_eye.py b/built_in_tasks/cursorControlTasks_sim_mocap_eye.py index 1061fd53..656f6cf7 100644 --- a/built_in_tasks/cursorControlTasks_sim_mocap_eye.py +++ b/built_in_tasks/cursorControlTasks_sim_mocap_eye.py @@ -119,7 +119,7 @@ def target_seq_generator(n_targs, n_trials): #mix features into the experiment base class base_class = CursorControl - feats = [SaveHDF, SimulatedEyeData] + feats = [SaveHDF, SimulatedEyeData, MotionSimulate] #feats = [SaveHDF] Exp = experiment.make(base_class, feats=feats) print(Exp) From 79231f0542512e66b3a927d65309c5ef05df98bf Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Sat, 27 Jun 2020 16:10:56 -0700 Subject: [PATCH 034/242] fixed mocap sim array dim incompatibility --- analysis_shared/open_hdf_tables.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/analysis_shared/open_hdf_tables.py b/analysis_shared/open_hdf_tables.py index 138de956..ca29b86e 100644 --- a/analysis_shared/open_hdf_tables.py +++ b/analysis_shared/open_hdf_tables.py @@ -3,7 +3,8 @@ import matplotlib.pyplot as plt #replace this with your hdf filename -fname = 'c:\\Users\\Si Jia\\AppData\\Local\\Temp\\tmp9fswwtwp.h5' +#fname = 'c:\\Users\\Si Jia\\AppData\\Local\\Temp\\tmp9fswwtwp.h5' +fname = '/tmp/tmpdcbqn2zo.h5' hdffile = tables.open_file(fname,'r') #read-onl print(hdffile) From 8d5036f6fdeba5defd03cb9a5ce01cede165b4ef Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Sat, 27 Jun 2020 16:22:55 -0700 Subject: [PATCH 035/242] fixed mocap sim array dim incompatibility in eye tracker --- built_in_tasks/cursorControlTasks_sim_mocap_eye.py | 3 +-- riglib/motiontracker.py | 6 +++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/built_in_tasks/cursorControlTasks_sim_mocap_eye.py b/built_in_tasks/cursorControlTasks_sim_mocap_eye.py index 656f6cf7..47d9ab68 100644 --- a/built_in_tasks/cursorControlTasks_sim_mocap_eye.py +++ b/built_in_tasks/cursorControlTasks_sim_mocap_eye.py @@ -107,7 +107,7 @@ def target_seq_generator(n_targs, n_trials): if __name__ == "__main__": print('Remember to set window size in stereoOpenGL class') - gen = target_seq_generator(8, 1000) + gen = target_seq_generator(8, 6) #incorporate the saveHDF feature by blending code #see tests\start_From_cmd_line_sim @@ -122,7 +122,6 @@ def target_seq_generator(n_targs, n_trials): feats = [SaveHDF, SimulatedEyeData, MotionSimulate] #feats = [SaveHDF] Exp = experiment.make(base_class, feats=feats) - print(Exp) exp = Exp(gen) exp.init() diff --git a/riglib/motiontracker.py b/riglib/motiontracker.py index 792c654b..1d43afc5 100644 --- a/riglib/motiontracker.py +++ b/riglib/motiontracker.py @@ -73,7 +73,11 @@ def get(self): z = self.radius[2] * np.sin(ts / self.speed[2] * 2*np.pi + p) data[i] = x,y,z - return np.hstack([data + np.random.randn(self.n, 3)*0.1, np.ones((self.n,1))]) + #expands the dimension for HDFwriter saving + data_temp = np.hstack([data + np.random.randn(self.n, 3) * 0.1, np.ones((self.n, 1))]) + data_temp_expand = np.expand_dims(data_temp, axis = 0) + + return data_temp_expand def stop(self): ''' From a100fd2a0565a1a458998e972f29217ca74e4b00 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Sat, 27 Jun 2020 16:44:05 -0700 Subject: [PATCH 036/242] del .ipynb in gitignore to track jup notebook --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 15874839..aae994b8 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,6 @@ db/*.sql *.sql.* *.jpeg log/* -*.ipynb *.doctree tasks/debugging.txt riglib/plexon/plexfile.py From 30ecdb3c9e25510eba2d2a65a2798e4f282a1a0f Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Sat, 27 Jun 2020 16:47:03 -0700 Subject: [PATCH 037/242] example jupyter notebook file --- analysis_shared/Untitled.ipynb | 32 +++ analysis_shared/open_hdf_tables_jupyter.ipynb | 187 ++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 analysis_shared/Untitled.ipynb create mode 100644 analysis_shared/open_hdf_tables_jupyter.ipynb diff --git a/analysis_shared/Untitled.ipynb b/analysis_shared/Untitled.ipynb new file mode 100644 index 00000000..1594c89b --- /dev/null +++ b/analysis_shared/Untitled.ipynb @@ -0,0 +1,32 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/analysis_shared/open_hdf_tables_jupyter.ipynb b/analysis_shared/open_hdf_tables_jupyter.ipynb new file mode 100644 index 00000000..dff6c8a5 --- /dev/null +++ b/analysis_shared/open_hdf_tables_jupyter.ipynb @@ -0,0 +1,187 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Imports done! \n" + ] + } + ], + "source": [ + "import tables\n", + "import numpy\n", + "import matplotlib.pyplot as plt\n", + "print(\"Imports done! \")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/tmpdcbqn2zo.h5 (File) ''\n", + "Last modif.: 'Fri Jun 26 15:09:42 2020'\n", + "Object Tree: \n", + "/ (RootGroup) ''\n", + "/eyetracker (EArray(25710, 2), shuffle, zlib(5)) ''\n", + "/eyetracker_msgs (Table(54,), shuffle, zlib(5)) ''\n", + "/motiontracker (EArray(6444, 8, 4), shuffle, zlib(5)) ''\n", + "/motiontracker_msgs (Table(54,), shuffle, zlib(5)) ''\n", + "/task (Table(1994,), shuffle, zlib(5)) ''\n", + "/task_msgs (Table(54,), shuffle, zlib(5)) ''\n", + "\n" + ] + } + ], + "source": [ + "#replace this with your hdf filename\n", + "fname = '/tmp/tmpdcbqn2zo.h5'\n", + "hdffile = tables.open_file(fname, 'r')\n", + "print(hdffile)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Description([('cursor', '(3,)f8'), ('target', '(3,)f8'), ('target_index', '(1,)i4')])\n" + ] + } + ], + "source": [ + "tables = hdffile.root.task\n", + "print(tables.description)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFutJREFUeJzt3XmYVfV9x/H3RwQliiACgqxaiUat6wRIYlstikqNaB5riUuIiCQ2pjFLU1L7aJp0SdImafNopKjUHU1NjDTigiapMa3AgIAgKkswMLIMSABRhIFv/5hDn9vxznrO3DMz5/N6nvvcs/zu+X0593I/c9ariMDMzIrnoLwLMDOzfDgAzMwKygFgZlZQDgAzs4JyAJiZFZQDwMysoBwAZmYF5QAwMysoB4CZWUEdnHcBTenXr1+MGDEi7zLMzDqNhQsXbomI/i1p26EDYMSIEVRXV+ddhplZpyHpjZa29S4gM7OCcgCYmRWUA8DMrKAcAGZmBeUAMDMrKAeAmVlBOQDMzArKAWCWsfv+Zy0Pz/9t3mWYNatDXwhm1hnd8vhyACaOGpZzJWZN8xaAmVlBOQDMzArKAWBmVlAOADOzgnIAmJkVlAPAzKygHABmZgXVqgCQNFPSZknLSqb1lTRX0srk+chGXjspabNS0qS0hZuZWTqt3QK4B7iwwbRpwHMRMRJ4Lhn/fyT1BW4FRgOjgFsbCwozM6uMVgVARDwPvNVg8gTg3mT4XuDSMi+9AJgbEW9FxDZgLu8PEjMzq6AsjgEcHREbkuGNwNFl2gwG1pWMr0+mvY+kqZKqJVXX1tZmUJ6ZmZWT6UHgiAggUi5jRkRURURV//4t+mF7MzNrgywCYJOkQQDJ8+YybWqAoSXjQ5JpZmaWkywCYDZw4KyeScDjZdo8DYyTdGRy8HdcMs3MzHLS2tNAZwH/A5wgab2k64BvAedLWgmcl4wjqUrSXQAR8RbwTWBB8vhGMs3MzHLSqt8DiIhPNjJrbJm21cCUkvGZwMxWVWdmZu3GVwKbmRWUA8DMrKAcAGZmBeUAMDMrKAeAmXUpj720nm89+WreZXQKDgAz61K++MgSpv/X6rzL6BQcAGZmBeUAMDMrKAeAmVlBOQDMzArKAWBmVlAOADOzgnIAmJkVlAPAzKygHABmZgXlAOiA/urRpTyxdEPeZZhZF+cA6IAeqV7H5x5alHcZZp3K9nf38qUfLc67jE6lVb8IZmbWEb2wcgt/+egSNu98D4AhR/bMuaLOIfUWgKQTJC0ueeyQdFODNudI2l7S5pa0/ZqZvbtnH7c+voyr755Hzx7d+PENH+Xjpx1Dj27eudESqbcAIuI14HQASd2AGuCxMk1/FREXp+3PzAzgpd9u48s/WsKaLbu49mMj+OoFJ9KzR7e8y+pUst4FNBZYHRFvZLxcMzMA9tTt5wfPreSHv1zFoN49eWjKaD56fL+8y+qUsg6AicCsRuZ9RNIS4E3gKxGxPOO+zayLe23jTr74yGJe2bCDy88awi0fP4kjDu2ed1mdVmYBIKkHcAnwtTKzFwHDI+JtSeOBnwIjG1nOVGAqwLBhw7Iqz8w6sX37g7t+tYbvPvM6R/Q8mBnXnMW4kwfmXVanl+UWwEXAoojY1HBGROwoGZ4j6YeS+kXEljJtZwAzAKqqqiLD+sysE3pj6y6+8h9LWLB2GxecfDT/cNnvc9Thh+RdVpeQZQB8kkZ2/0gaCGyKiJA0ivqzj7Zm2LeZdTERwUPzf8vfP7GCbgeJ711xGpedMRhJeZfWZWQSAJIOA84HPlMy7bMAETEduBy4QVId8C4wMSL8172ZlbVpx26++uhS/uv1Ws4+vh/fufxUjunjc/uzlkkARMQu4KgG06aXDN8G3JZFX2bWsc1bs5UnXm77rUz27Q9+tnQD79Xt4xsTTubq0cM56CD/1d8efCWwmWXq7hd+w7MrNtG7Z9vPzvnQoF78w2W/z3H9D8+wMmvIAWBmmQrghIFH8OQX/iDvUqwZvl7azKygHABmZgXlADAzKygHgJlZQTkAzMwKygFgZlZQDgAzs4JyAJiZFZQDwMwytX+/b/PVWTgAzCwz/7nkTX7x2mZOHNgr71KsBRwAZpaJp5Zt5KZHFlM1vC9/f9kpeZdjLeAAMLPUnluxic/PWsRpQ3oz89oP84Eevs1YZ+AAMLNUnn+9lhseWMSHBh3BPZNHcfgh/vLvLBwAZtZm/716C9ffV83xAw7nvsmj/APtnYwDwMzaZMHat7junmqGH/UBHpgymj4f6JF3SdZKDgAza7WXfruNa/99AYP6HMqDU8bQ9zB/+XdGmQWApLWSXpa0WFJ1mfmS9ANJqyQtlXRmVn2bWeW8vH47n5o5n6MO78FDU8bQv9cheZdkbZT10ZpzI2JLI/MuAkYmj9HAHcmzmXUSr7y5g2tmzqN3z+48dP0YBvY+NO+SLIVK7gKaANwX9V4E+kgaVMH+zSyFlZt2cvXd8+jZvRuzrh/D4D498y7JUsoyAAJ4RtJCSVPLzB8MrCsZX59MM7MObk3t21x51zwOPkg8dP0Yhvb9QN4lWQay3AV0dkTUSBoAzJX0akQ839qFJOExFWDYsGEZlmdmbfHG1l1ceec8IoJZU8dwbL/D8i7JMpLZFkBE1CTPm4HHgFENmtQAQ0vGhyTTGi5nRkRURURV//79syrPzNpg/bZ3uPLOebxXt48Hpozm+AG+x09XkkkASDpMUq8Dw8A4YFmDZrOBTyVnA40BtkfEhiz6N7Psbdj+LlfeOY+du/dy/3WjOXHgEXmXZBnLahfQ0cBjkg4s86GIeErSZwEiYjowBxgPrALeAa7NqG8zy9jmHbu56s55vLVrDw9MGc0pg3vnXZK1g0wCICLWAKeVmT69ZDiAz2XRn5m1n61vv8dVd81j447d3H/dKE4f2ifvkqyd+EpgM/s/23bt4aq75rFu2zvM/PSHOWt437xLsnbk2/aZGQDb393LNTPnsWbLLmZO+jBjjjsq75KsnXkLwMzYuXsvk2bO57WNO/m3q8/i7JH98i7JKsABYFZwu96rY/I9C1hWs53brjyTc08ckHdJViEOALMC2713H1PurWbhG9v414lncMHJA/MuySrIxwDMCmr33n1MvX8hL/5mK9+/4nT+5FTfmqtovAVgVkB76vbzuQcX8fzrtXz7E6dy6Rm+LVcROQDMCmbvvv38xayXeO7VzfzdpadwxYeHNv8i65IcAGYFsm9/8KUfLeGp5Ru59eMncfWY4XmXZDlyAJgVxP79wV8+uoT/XPImX7voRK792LF5l2Q5cwCYFcD+/cFfP/YyP1lUw5fP/yCf+aPfy7sk6wAcAGZdXERw6+zlPLxgHZ//4+P5/NiReZdkHYQDwKwLiwj+7okV3P/iG3zmD4/jS+d/MO+SrANxAJh1URHBd55+jbtf+A3XfmwE0y46keSW7WaAA8Csy/qXZ1dyxy9Xc9XoYdxy8Un+8rf3cQCYdUG3/2IV//rcSq6oGsI3J5ziL38ry7eCMGvg8cU1PLtic+rlfH7WSxlU03rvvFfHc69u5rIzBvOPnziVgw7yl7+V5wAwa2Dmr9fy+sadDOp9aKrlLK/ZnlFFrXfV6GH87SUn081f/tYEB4BZGaOO7cu9k0e16bUjpj0BwM+/ck6GFZllL/UxAElDJf1C0iuSlkv6Qpk250jaLmlx8rglbb9mZpZOFlsAdcCXI2KRpF7AQklzI+KVBu1+FREXZ9CfmZllIPUWQERsiIhFyfBOYAXge8uamXVwmZ4GKmkEcAYwr8zsj0haIulJSSdn2a+ZmbVeZgeBJR0O/Bi4KSJ2NJi9CBgeEW9LGg/8FCh7QxJJU4GpAMOGDcuqPDMzayCTLQBJ3an/8n8wIn7ScH5E7IiIt5PhOUB3Sf3KLSsiZkREVURU9e/fP4vyzMysjCzOAhJwN7AiIr7XSJuBSTskjUr63Zq2bzMza7ssdgF9DLgGeFnS4mTaXwPDACJiOnA5cIOkOuBdYGJERAZ9m5lZG6UOgIh4AWjycsOIuA24LW1fZp3JDQ8szLuEQlq87nf07N4t7zI6BV8JbNZOVte+nXcJhdTr0IM594QBeZfRKTgAzNrJM1/8o7xLMGuSbwdtZlZQDgAzs4JyAJiZFZQDwMysoBwAZmYF5QAwMysoB4CZWUE5AMzMCsoXgnVg199XnXcJhbRk3e8Ar3/r+hwAHdj6be/mXUKhef1bV+cA6IAG9DqEsR8awD9+4tS8SymkCbf/mj49u3Pv5FFtev2IaU9kXJFZ+/AxADOzgnIAmJkVlAPAzKygHABmZgXlADAzK6hMAkDShZJek7RK0rQy8w+R9Egyf56kEVn0a2ZmbZc6ACR1A24HLgJOAj4p6aQGza4DtkXE8cD3gW+n7dfMzNLJYgtgFLAqItZExB7gYWBCgzYTgHuT4UeBsZKa/CF5MzNrX1lcCDYYWFcyvh4Y3VibiKiTtB04CtiSQf/v88G/eZI9dfvbY9EVM2v+OjZu3513GYV04FYQaS/o8gVh1lYDeh3C/JvPa/d+OtyVwJKmAlMBhg0b1qZldPYv/wO27tqTdwlmloPNO9+rSD9ZBEANMLRkfEgyrVyb9ZIOBnoDW8stLCJmADMAqqqqoi0Frf3Wn7TlZWaZuPyO/+aQ7gfx4JQxeZdi1qQsjgEsAEZKOlZSD2AiMLtBm9nApGT4cuDnEdGmL3czM8tG6i2AZJ/+jcDTQDdgZkQsl/QNoDoiZgN3A/dLWgW8RX1ImJlZjjI5BhARc4A5DabdUjK8G/jTLPoyM7Ns+EpgM7OCcgCYmRWUA8DMrKAcAGZmBeUAMDMrKAeAmVlBOQDMzArKAWBmVlAOADOzgnIAmJkVlAPAzKygHABmZgXlADAzKygHgJlZQTkAzMwKygFgZlZQDgAzs4JyAJiZFZQDwMysoFL9JrCkfwI+DuwBVgPXRsTvyrRbC+wE9gF1EVGVpl8zM0sv7RbAXOCUiDgVeB34WhNtz42I0/3lb2bWMaQKgIh4JiLqktEXgSHpSzIzs0rI8hjAZODJRuYF8IykhZKmNrUQSVMlVUuqrq2tzbA8MzMr1ewxAEnPAgPLzLo5Ih5P2twM1AEPNrKYsyOiRtIAYK6kVyPi+XINI2IGMAOgqqoqWvBvMDOzNmg2ACLivKbmS/o0cDEwNiLKfmFHRE3yvFnSY8AooGwAmJlZZaTaBSTpQuCrwCUR8U4jbQ6T1OvAMDAOWJamXzMzSy/tMYDbgF7U79ZZLGk6gKRjJM1J2hwNvCBpCTAfeCIinkrZr5mZpZTqOoCIOL6R6W8C45PhNcBpafoxM7Ps+UpgM7OCcgCYmRWUA8DMrKAcAGZmBeUAMDMrKAeAmVlBOQDMzArKAWBmVlAOADOzgnIAmJkVlAPAzKygHABmZgXlADAzKygHgJlZQTkAzMwKygFgZlZQDgAzs4JyAJiZFVTaH4X/uqSa5PeAF0sa30i7CyW9JmmVpGlp+jQzs2yk+k3gxPcj4p8bmympG3A7cD6wHlggaXZEvJJB32Zm1kaV2AU0ClgVEWsiYg/wMDChAv2amVkTsgiAGyUtlTRT0pFl5g8G1pWMr0+mmZlZjpoNAEnPSlpW5jEBuAP4PeB0YAPw3bQFSZoqqVpSdW1tbdrFmZlZI5o9BhAR57VkQZLuBH5WZlYNMLRkfEgyrbH+ZgAzAKqqqqIlfZuZWeulPQtoUMnoZcCyMs0WACMlHSupBzARmJ2mXzMzSy/tWUDfkXQ6EMBa4DMAko4B7oqI8RFRJ+lG4GmgGzAzIpan7NfMzFJKFQARcU0j098ExpeMzwHmpOnLzMyy5SuBzcwKygFgZlZQDgAzs4JyAJiZFZQDwMysoBwAZmYF5QAwMysoB4CZWUE5AMzMCsoBYGZWUA4AM7OCcgCYmRWUA8DMrKAcAGZmBeUAMDMrKAeAmVlBpf1FsC7p0tt/za736vIuwzqpddve4azhR+ZdhlmzHABlHNf/MHbv3Zd3GdZJjTz6cC4+9Zi8yzBrlgOgjO9dcXreJZiZtbtUASDpEeCEZLQP8LuIeN+3p6S1wE5gH1AXEVVp+jUzs/TS/ij8nx0YlvRdYHsTzc+NiC1p+jMzs+xksgtIkoArgD/OYnlmZtb+sjoN9A+ATRGxspH5ATwjaaGkqU0tSNJUSdWSqmtrazMqz8zMGmp2C0DSs8DAMrNujojHk+FPArOaWMzZEVEjaQAwV9KrEfF8uYYRMQOYAVBVVRXN1WdmZm3TbABExHlNzZd0MPAJ4KwmllGTPG+W9BgwCigbAGZmVhlZ7AI6D3g1ItaXmynpMEm9DgwD44BlGfRrZmYpZBEAE2mw+0fSMZLmJKNHAy9IWgLMB56IiKcy6NfMzFJQRMfdzS6pFngj7zoa0Q/oyKe1ur50XF86ri+dNPUNj4j+LWnYoQOgI5NU3ZEvaHN96bi+dFxfOpWqz3cDNTMrKAeAmVlBOQDabkbeBTTD9aXj+tJxfelUpD4fAzAzKyhvAZiZFZQDoIUkPSJpcfJYK2lxI+3WSno5aVddwfq+LqmmpMbxjbS7UNJrklZJmlbB+v5J0quSlkp6TFKfRtpVdP01tz4kHZK896skzZM0or1rKul7qKRfSHpF0nJJXyjT5hxJ20ve91sqVV/Sf5Pvl+r9IFl/SyWdWcHaTihZL4sl7ZB0U4M2FV1/kmZK2ixpWcm0vpLmSlqZPJf9OTlJk5I2KyVNyqSgiPCjlQ/gu8AtjcxbC/TLoaavA19ppk03YDVwHNADWAKcVKH6xgEHJ8PfBr6d9/pryfoA/hyYngxPBB6p4Hs6CDgzGe4FvF6mvnOAn1X689bS9wsYDzwJCBgDzMupzm7ARurPkc9t/QF/CJwJLCuZ9h1gWjI8rdz/DaAvsCZ5PjIZPjJtPd4CaKWSW183dfO7jmoUsCoi1kTEHuBhYEIlOo6IZyLiwA8tvwgMqUS/zWjJ+pgA3JsMPwqMTT4D7S4iNkTEomR4J7ACGFyJvjM0Abgv6r0I9JE0KIc6xgKrIyLXC0uj/iaYbzWYXPoZuxe4tMxLLwDmRsRbEbENmAtcmLYeB0DrZXbr63ZwY7KZPbORzcjBwLqS8fXk84Uymfq/Csup5Ppryfr4vzZJgG0Hjmrnut4n2fV0BjCvzOyPSFoi6UlJJ1e0sObfr47ymXvfLWtK5Ln+AI6OiA3J8Ebqb5/TULusR/8mcIlK3/o6y/qAO4BvUv8f8pvU76aanEW/LdWS9SfpZqAOeLCRxbTb+uusJB0O/Bi4KSJ2NJi9iPrdGm8nx31+CoysYHkd/v2S1AO4BPhamdl5r7//JyJCUsVOzXQAlIgOfuvr5uorqfNO4GdlZtUAQ0vGhyTTMtGC9fdp4GJgbCQ7Nssso5K3Dm/J+jjQZn3y/vcGtrZTPe8jqTv1X/4PRsRPGs4vDYSImCPph5L6RYV+frUF71e7fuZa6CJgUURsajgj7/WX2CRpUERsSHaPbS7Tpob64xUHDAF+mbZj7wJqnQ576+sG+1Uva6TfBcBISccmfxVNBGZXqL4Lga8Cl0TEO420qfT6a8n6mA0cOOPicuDnjYVX1pJjDXcDKyLie420GXjgmISkUdT/n65IQLXw/ZoNfCo5G2gMsL1kd0elNLrVnuf6K1H6GZsEPF6mzdPAOElHJrt3xyXT0qnU0e+u8ADuAT7bYNoxwJxk+DjqzyRZAiynftdHpWq7H3gZWJp8oAY1rC8ZH0/92SSrK1zfKur3YS5OHtMb1pfH+iu3PoBvUB9UAIcC/5HUPx84roLr7Gzqd+ktLVlv44HPHvgcAjcm62oJ9QfXP1rB+sq+Xw3qE3B7sn5fBqoqVV/S/2HUf6H3LpmW2/qjPog2AHup349/HfXHlJ4DVgLPAn2TtlXAXSWvnZx8DlcB12ZRj68ENjMrKO8CMjMrKAeAmVlBOQDMzArKAWBmVlAOADOzgnIAmJkVlAPAzKygHABmZgX1vxcPDOOxrBK0AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#plot out the cursor trajectories\n", + "cursor_coor = tables.col('cursor')\n", + "plt.plot(cursor_coor[:,0], cursor_coor[:,2])\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/eyetracker (EArray(25710, 2), shuffle, zlib(5)) ''\n" + ] + }, + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD8CAYAAABzTgP2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xd4VFX6wPHvm04KAZIgndA7iIQgIIIK0lbYVSzYsPe1rQW7Yi8/Xdeyri52ERR1RbEhgoooJEiTHiBAApKQEJKQPjm/P2YSJo0kM3dmUt7P8+TJ3DvnnvsmhPvOPfcUMcaglFJKlfHzdQBKKaUaFk0MSimlKtDEoJRSqgJNDEoppSrQxKCUUqoCTQxKKaUq0MSglFKqAk0MSimlKtDEoJRSqoIAXwfgiujoaBMbG+vrMJRSqlFZs2bNIWNMTG3lLEkMIjIJeBHwB/5rjHmq0vvXATcCNiAXuMYYs1lEYoEtwDZH0d+MMdfVdr7Y2FgSExOtCF0ppZoNEdlTl3JuJwYR8QdeASYAKUCCiCwyxmx2KjbPGPOao/w04HlgkuO9ncaYE92NQymllDWseMYQDyQZY3YZY4qA+cB05wLGmGynzTBAZ+5TSqkGyorE0BHY57Sd4thXgYjcKCI7gWeAm53e6iYia0XkRxEZU9NJROQaEUkUkcT09HQLwlZKKVUdr/VKMsa8YozpAdwN3O/YfQDoYowZCtwOzBORljUc/7oxJs4YExcTU+uzE6WUUi6yIjGkAp2dtjs59tVkPvBXAGNMoTEmw/F6DbAT6G1BTEoppVxkRWJIAHqJSDcRCQIuABY5FxCRXk6bU4Edjv0xjofXiEh3oBewy4KYlFJKucjtXknGmBIRuQn4Fnt31TeNMZtEZA6QaIxZBNwkIuOBYuAwMMtx+KnAHBEpBkqB64wxme7GpJRSynXSGJf2jIuLMzqOQSnVlMx6czXPnjuYthEhHjuHiKwxxsTVVk6nxHDYl5nH5Bd/9nUYSqlm6N/Ld/Lj9nTiH1/q61AATQzlxjyzjC0Hsvlu05++DkUp1czsSMvxdQgVaGKopGfbcAAaYxObUqpxysorrrXMsm1p2Eq9c13SxFBJp9ah7MvMo9s9X3HL/LW+Dkcp5QW2UkNqVr7PPhDuTM897vu97/uay99K4MWlO7wST6OcXdWTet//dfnrz9ft58ULhvowGqWUlYpKStmTcZSktFyS0nLZ4fi++cCxWXv+e2kc4/uf4NW49mTklb+++L+rmDakA+cN70xhiY0+939T/t65wzp5JR5NDLU4598r+eT6Ub4OQylVD/lFNnam57IzPZcdB8uSQA57MvIocWqO6diqBb1OCGdUj24s/D2FrLxirnrX3uNx9X1n1KuHkDEGEalXnGv3HiYqLLjCvhVJh1iRdIizT+rIPZ9urPBe5zah9arfVc2+u+r2g/aHPme+8FONZaYOas8lI7tycvcoS86plLJGdkGx/dP/wVyS0nPZcTCHpPRcUg7nU3Zp8/cTukaF0jMmnJ5tw+l1Qjg9YyLo0TaM0KBjn43TcwrZkZbDhW+sKt937rBOPP63QQQF1NzqfvhoEee8tpJd6UcBuPmMXtw+oW4TOMTOXlznnzUqLIg1D0yoc/nq1LW7arNNDGv3HubAkQJu+OD3Oh+z9B9j6RET7tZ5lVL1Y4wh42hRedPPTsen/6S0XA5mF5aXCwrwo3t0mP3i3zaiPAl0jQolOMC/2rqzC4q58+P1fLvp4HFjePWik5g8sF2FO4J9mXmc+cJP5Bfbqj1m66OTCAmsel7nO4v6JIa1D0ygdVhQnctXRxNDLerzD+Is+ampbp1XKVU9YwwHjhSUt/snOS7+O9JyK/TaCQvyp2fbcHo6Lv72RBBO5zah+PvVrSnno4R93PXJhnrFFxbkz6c3jKaktJSp/1pR4b07zuzNjaf1JONoEXGPfV++//0rRzCyRxR5RSXkF9nILihm/PM1t04cjxXXnromBn3GUA/Xj+vhkXqHPbqEjKNFXDqyK5MHtmdY19bHvXVVqjGzlRr2ZuY5PQDOYafj9dGiY5++W4UG0qttOJMHti+/+PdsG077yJB6t+UD7ErP5fK3Eyo86AW45tTuXD+2B0W2UvKKbCRnHOXuhRtIyymsUO5okY2J/6x6UY+PbcOXGw4wb9VesgtKKrx38dxVVcq7oo75zjJ6x1APfzwykfBg93Pp5v3Z3PvZRtbty6qxzKgeUUwe1J6xvWLoEuWdB06NXWmpIbugmIyjRWxIySLQ34/YqDBCAv0JDvAjONCP4AD76yB/P/y8/b+tmSgtNeQUlJCVX8SWAznkFBSTmpVPUlou2w/msP1g1a6ZXaNC6RoVRqzje7foUFqGBFJsM+QVlZBdUEx2fgk5BcVkF5SQnV9cZV+OY7vIVuqDn9qz1j04gVah7jUjgd4xVKssGbSopt2vLlxNCsYYvt10kLsWrq/yiaImK3dmsHJnRo3vtw4NrPAfKTa67D9WGK1DA136RNXQOF/oM3KLyDxa6PS6yPG6sPx15tGieg0ACvL3q5owAvwILkskAcf2Vy4XXF25QL9jdZSXO7a/fJ+jHudmjz9Sj3DNu4nsP1JQJc6QQD8Gd2xF/w4t7V/tW9IuMgRbqaHYVkqJzVBSWkpRif17sc1QYiulxPF+2XZxqWO/zVBcWkpxSVmZiu+XHZNbWMKh3EIycos4lFvIodxCim2e+SC5JyOPPRl5uNbI0nCEBwdgKzX2L2MsG5BmRVKoj2aVGMrU9LDISsW2Uv77826e/mbrccsN7NiS+6b0Z0/GUZIz8sq/Jx86etw4D+cVczgv67h3HdUJDfKv8sksNiqM2Ogw2kYEezShGGM4kl9cfhHPyLVf6DNzHRf5o46Lv2P78NGiCl0LnUWEBBAVFkSbsCA6twnlxM6taBMWRFR4MFFhQRw4UsDn61LZ+mfNUw0U2UopspXi3GLQJiyI3MISikoazqfOguJSVidnsjq5+Uw8HBEcQEiQf3nyDgo49j3IccdXloDL9uUWlLBo/f5q64uPbYOfn/13WVhSSmGxjcKSUgqcvtf0t1YfuYV1++BXX7GzF3v1+WazSgyLbz6lykMjK2UeLeKpr7fwUWJKlff6tougRZA/a/ceu5A793Ia2aP6rrDGGNJyCgkJ9GdDShY/bU/nx+3pVW7HA/yETq1bcDivmCP5NQ+vzyuyseVANlsOZNdYpib3TunLWUM6OD6h2j+NZuQWkZZTwJ9HCvgzu4C07EIOZh97bdVtfYCf0LJFIJEtAokICUCg/FPt4bxitjh9ci7/BGwziEB9WkszjxZZEq9yT05hCTkWXmSbQlJdseMQp/SK9sq5mt0zBld7I4H90/bontH4i9ibAQR2puUe91NpeHBArZ8i2kYEM7BjJLZSw5YD2VUeeqm68/ez/9sE+fsR4C8VerP4+wkdW7UgLDiAQH97uUA/P7YdzDluMq1JTEQw7VqGUFRSSmGJjbwiG/lFNvKKbV6b00Y1HzufmFLnXlc10WcM1Ug+dNSt4/OKbCzZfPz+zpXV5dYyLaeQH7amuRqWclLWvltdU1BZbxirpOcUkq5JXHnJ2yuTufKUbl45V7PqE/m3V3/xdQhKKeWSqYPae+1cliQGEZkkIttEJElEZlfz/nUislFE1onIChHp7/TePY7jtonIRCviqcmS28d6snqllPKYk59cSlLa8WdhtYrbiUFE/IFXgMlAf2Cm84XfYZ4xZpAx5kTgGeB5x7H9gQuAAcAk4FVHfR4RHR5ceyGllGqgErz0EN2KO4Z4IMkYs8sYUwTMB6Y7FzDGOHeBCQPKnsxNB+YbYwqNMbuBJEd9SimlKqk826qnWPHwuSOwz2k7BRhRuZCI3AjcDgQBpzsd+1ulYztaEJNSSjU5j04f4JXzeO3hszHmFWNMD+Bu4P76Hi8i14hIoogkpqenuxRD2RTbSinVGA3oGOmV81iRGFKBzk7bnRz7ajIf+Gt9jzXGvG6MiTPGxMXExLgU6P/WHi8spZRq2KyYq60urEgMCUAvEekmIkHYHyYvci4gIr2cNqcCZQuXLgIuEJFgEekG9AJWWxBTta45tbunqlZKKY+79r01XjmP2+nHGFMiIjcB3wL+wJvGmE0iMgdINMYsAm4SkfFAMXAYmOU4dpOIfARsBkqAG40xHpvIyNsTUSmllJV2uzlIt64suS8xxnwFfFVp34NOr285zrGPA49bEUdt/kg94o3TKOVRzjN4lpSWorNvKKs1qykxvv7jgK9DUMptnprBUzVsV4zuxoNnVR4i5hnNKjHcObEvryzb6bH6B3eKJD62DS2C7GP0nOcnNBzbqLi/6r6tf2azfFvNPa8C/IQuUaEcySvmcF6RfmJUqhl485fdnD+8M33aRXj8XM0qMXjaSV1ac/9fas/oBcU2DucVcfhoMVl5RY61FYrIyiviUG4RwQF+DOoYSWpWfrXTQJeUGnale6etsbkI8vcrXzy+1MAXNczrXx9jekUTERJQvoBOic1QZCstX0SnyLFQTtkU4fuP5NdrivDmSARCA/1pEWT/Cg0MIDBA8PfzY3091iYZ0imSEd2jKjTJ2UoNJbaybfu/T1pOAb/vrd+aJ5408Z8/eWVdBk0MFnp7ZTJvr0xmSKdIUrMKOJTbNGfe9PcTQhwrmIU4r2TmtB3oJxTZSskvslFQYp+OuqC4lPxiG3lFJRQUe3YhnE6tWzB31vBaP10dyS/mo4R9vL0ymc0HstlcwzoVJ3dvw2WjujG+X1sC/Ct25svKKyIh+TBjekUTUs/VAW2lhoJiGwXFNvKL7b8j5+38Ihu5hSVsOZDNF+sP8Ge2fYW3vwxuT25hCTmOJS3t30uafDOTMfa1l53Xhvb3EwL9hYjgAAL8hUB/PwL9/TDGVLsiHsD6lCOsT2k8zxyHdmnF+H4n0LmNd5b5bXbrMYB7azI0VP5+QnR4EG3CgmnhWDoypLrvjot4iNPF3Hk7wE/4aXs676/ae9w1BU7s3IorTunG5IHtCPS3fpxkaanhg9V7eeB/f1R5r0ubUF44fwjDurYp37d5fzZzV+zmk9/tiyTFRARz9kkdOXdYZ3q2DS+vc/n2NN76JZmfdxyq8dwxEcHMHN6Z6UM7EhLob79IF9koLLGRX1TqdBEv+yot367u4m7/XkpBhSRpL+fqQkZB/vYlQlsE+hMS6O/4bv93dN5uEeRPcID903VIgD8tgo6VqVDOeV+QPcGXHevOGgDGmKpLjNocS5KW76+6POmxu6uy/cfuroqcli6176+ufMV6yqbLr8v6KA2VFXcKdV2PQRNDAxIdHkT36HC6RIUSGxVKl6gwOrVuQXhwQMWLuWPtYG8tZr8/K593f93Dmyt2H/dCdsnJXbl8dCzdHavS1VeJrZQ3algOtX/7ljx37hD6d2h53DqKbaX8sDWNjxNTWLYtrV4L5kS2CCS/2Obysp6B/kJIgD8hQVUvtvaLsJ/Txbjsq+LFvexi3yLoeBd89y7Wzc1na1O4bcF67pvSj6urGcuUlVfEks0H6RYdxqBOkQQHHLvrK0tsldfJnvPFZhZv9F5nlkEdI/ni76e4XY8mhuPwZmIY2T2KHm3DaNcyhBNahtAuMsT+OjKEiOAAj66x7Clln7znrtjNL0kZNZbr2y6CK07pxrQhHWpsYikotvHSDzuq7RQQ360NT58zmG7RYS7FmZ5TyP/WpvL4V1sq7A8O8GNgx0hCg/yrfuIu/2R97FNz5YtylYu1o2zlJible2nZBUx44Sd6tg3no2tHWpJQ4x//vsIqi38/vScv/ZDkdr3H0yo0kHUPnul2PZoYajH1Xz+zaX/91z12xcrZp9OhVQuvnMtX0nMK+WDVHuau2E1OQc236ucO68SFI7rwv7WpvPPrnirvn9G3LY/9bSDtI5v270t5njGGq99N5Ocdh/j6ljEu38mW+Wl7Ope+eWxihnF9Yrh/aj/GP/+Tu6HWqGfb8PI1GLQpqRZWJAZvNyd9+fdTGOilCbAaAmMMv+7MYO6K3SytZdnS6Sd24MG/9CdK18tQFiprQrp/aj+uGuP6dDgZuYUMe+z7Cvv6tW/JO5cPJ/6Jpe6GeVy7npjC5gPZdI8JIzTI/b5CuuZzA/OXl1Ywd1YcZ/Q7wdeheIWIMKpnNKN6Rpfvy8orYn7CPuat2sspvaK5Z3JfIkICfRilaqrSsgt4eNFmhnVtzeWjXVsn2RjDTfPWVvssYcG1JzP44e/cDbNGzncHvvhA2WwbRa8de+wTRPJTU5kZ38Xj57zynUTe/mW3x8/TULUKDeK6sT346a7TeOJvgzQpKI8wxnDvZxspKLbx7IzBLj1X+HLDfrrd81V5Urh7Ut/y91684ESPJoWGoNkmhnsm92NEtzZ8fuNoAJ48e5BXzvvwF5t5eNEmr5xLqebos7WpfL8ljTsn9nHpucK0l1dw07y1gL0DxfbHJlNQfGzcxC3z11kWq7MAP2HnE1O8MoCtNs02MQAsuHYkQzq3qrJ/UMdIrhvbw/LztQyxt9y9vTKZma//VktppVR9Hcwu4OFFm4hzsQlpX2YeGxwD35b+Yyzf3HoqhSU2Xly6o5Yj6++zG0ax+8kp5dtJT0xpMN2Qm3ViqOzG0+zJYOuf2dw9qQ+XjYq1tP7sghIiW9ibT37dlcHgh7+lMT78V6ohMsZw76cbKSwp5RkXm5DGPLMMgOfPG0IPx93G0DlLLI0TYOF1IxnapXWD7a6uicHJnRP74ifw4F/6IyI8+Jf+nB/XufYD6+FIfjEBjj/Y7IISut3zFSUujn5VSh3z2dpUlm51vQkp2Wmtg7NP6gTAsq1plHhglsq42Da1F/IhTQyV7HpyKpeMjAXAz0944uxBTD+xg6XnqPyH1vO+r8kpKLb0HEo1J+42IQGMe245AC/NHArY70AufzvBqhDLfXPrmArbUWFB9Grr3hgLq2liqIW/n/DcuUM4s79nu5kOevg79mfle/QcSjVFVjQhlQ0iAzhriP2D4BUeSAoAfdtVnNZlzQMTWHL7WI+cy1WaGOog0N+Ply4cytjeMYhA1yhrZzgsW+B71FM/6CpzStXTp7+714QEMP75HwH4zyXDANibkcey46yJ4oorRndrED2O6sKSxCAik0Rkm4gkicjsat6/XUQ2i8gGEVkqIl2d3rOJyDrH1yIr4vGE4AB/Xrt4GPGxbUg5nM+YXtG1H1RHuYUltAq1P5T+y0sr+N4xE6RS6vgOZhfwyBfuNSFt/fPY1DgTB7QD4NRnl1kSX5mE+8Z7bfU1K7idGETEH3gFmAz0B2aKSOXfwFogzhgzGFgIPOP0Xr4x5kTH1zR34/GkFkH+zL1sOIM7RbJqVybnOB5QWSErr5gWjonmrnq3eQ+EU6oujDHc42hCevbcIS539Zz0z58BeOuy4YC10+Vce2p3kp+aSkxE45ruxYo7hnggyRizyxhTBMwHpjsXMMYsM8bkOTZ/A6y7onpZeHAAb18eT8+24Xy5YT9XuPgppTr5ToNoHv5iMw99XnUtAqWU3ae/p/KDownJ1Rl4nZtux/WJoee9X1kVHsvvGMc9U/pZVp83WZEYOgL7nLZTHPtqciXwtdN2iIgkishvIvLXmg4SkWsc5RLT061t+6uvyBaBvHdlPF3ahLIgYS83jLN+MBzAO7/u4YLXf/VI3Uo1ZlY0IYG96Rbg1YtOsncdt6hrasJ944l1MVk1BF59+CwiFwNxwLNOu7s6Zvu7EPiniFR7lTXGvG6MiTPGxMXExHgh2uOLCg/mg6tGEB0RzPu/7fFYcvhtVyaDdCCcUuXKmpCKbO41Ia3de7j89Q0f/G5VePz+wIRG13RUmRWJIRVwHgXWybGvAhEZD9wHTDPGlK9yYYxJdXzfBSwHhloQk1e0bRnCB1eNIDw4gPkJ+7jyFOualZzl6EA4pcp9Ut6E1NflJiSAv7260sKo7NY9OIE2YUGW1+ttViSGBKCXiHQTkSDgAqBC7yIRGQr8B3tSSHPa31pEgh2vo4HRwGYLYvKaTq1DmXf1yfj7CV+s32/5SGlnOhBONXd/HrE3IQ2Pbc3lbkxZk5CcaV1QDusfOpNWoY0/KYAFicEYUwLcBHwLbAE+MsZsEpE5IlLWy+hZIBz4uFK31H5AooisB5YBTxljGlViAIiNDuODq0ZQbCtlRdIhJjm6vHnCoIe/qzDTo1LNRdl02sW2Up6ZMcTlNc9tpYZzX7Pu2Z0IbHz4zPJ50JqCZruCmyf8kXqEmW/8RlRYECe0DGHVbus/lYB1678q1ZgsXJPCHR+v54G/9He52fZQbiFxlVZjc0eLQH8S7x9PWHDjWPOsriu46chnCw3sGMnbl8eTllNI5tEi+raL8Mh5svKKWVbLcplKNSVWNCFlWJwUwP6gubEkhfrQxGCxYV1bM3fWcPZm2odttGsZ4pHzXP52AjYPzPqoVENj74W0gWJbKc+62IRUUGyrsm6zu7Y+OokWQf6W1tlQaGLwgJE9ovjPJcPYmZ5LVHhQ+Yhmq53mmA1SqaZs4ZoUlm1L566JfV0aG1Baauj7wDeWxrTtsUmEeOj/dUOgicFDxvVpy8sXnsTWP3Po6aEpdfdm5rF+X5ZH6laqIfjzSAFzvtxMfGwblxfO6m7haGaA7Y9NJjig6SYF0MTgURMHtOP584bwx/4jDO4U6ZFzTH/lFx38ppok5yakZ2YMdqkJqfd9X9deqB6SHp9MUEDTv2w2/Z/Qx6af2JGnzh7EhpQjDOromeTQ7R5rPxEp1RC424QUO3sxRRYOCt35xBQC/JvHJbN5/JQ+dv7wLjx8Vn82ph5hYMeWtR/ggt1OyxIq1di524Rk5QypYE8Krk690RhpYvCSy0Z3465JffgjNdsj3Vj1QbRqKowxzHaxCamopNTypLCrmSUF0MTgVTeM68nfT+/psQfS93y60fI6lfK2j9eksHxbOndPql8TUvKho/S+39pnCruemOLyCOvGTBODl90+oTdXntKNpLRcYi1eIvTD1XvJyC2svaBSDdSBI/k86mhCmjUyts7HGWMYZ/Fdc9Ljk5tlUgBNDF4nItw/tR8XjuhCckYeHSKtHQBn9SAepbylbDptV5qQrnlvjaWxbH10UrN50Fyd5vuT+5CI8Nj0gZw9tCP7jxTQOtTaybfe+zXZ0vqU8gZXm5DScgpYYuE66esenNCkB6/VhSYGH/HzE56ZMZgpg9pxOK+YkEDr/ike+HwTeUUlltWnlKcdOJLPo19sJr5b/ZuQ4h9falkcz583pMlMne0OTQw+FODvxz/PH8rpfdtSUGztIjxDHvnO0vqU8hRjDLM/2UhJqeHZejYhWTmGp0ubUM4+qdEuR28pTQw+FhTgx6sXncSoHlGW1ltsM3zzx5+W1qmUJ3y8JoUft6dz96Q+dI2qexPS+Od/tCyGW8f34qe7TrOsvsZOE0MDEBLozxuXxjGsa2tL673u/TW6HKhq0JybkC6tRxPS5W+tJikt15IYHpk2gFvH97akrqZCE0MDERYcwFuXD7d82oyeFs8Vo5RVXG1C+sdH61m2Ld2SGJ45ZzCz3FgitKmyJDGIyCQR2SYiSSIyu5r3bxeRzSKyQUSWikhXp/dmicgOx9csK+JprFqGBPLuFfH0OcHakdF3fLze0vqUssLHifVvQrrn04188nuKJed/dsZgzhvuuTXaGzO3E4OI+AOvAJOB/sBMEelfqdhaIM4YMxhYCDzjOLYN8BAwAogHHhIRa9tTGpnWYUG8f9UIurswaVhNFq5J4csN+y2rTyl37c+yD2QbUY8mpH98tJ4PV++15PxPnj2Ic+M0KdTEijuGeCDJGLPLGFMEzAemOxcwxiwzxuQ5Nn8Dyh79TwSWGGMyjTGHgSXAJAtiatRiIoL54OoRdGrdwrI6b5q3luXbdDlQ5Xv2uZDKmpBqX5HNVmq48u0Ey+4U7pvSj5nxXSypq6myIjF0BPY5bac49tXkSqCs4bu+xzYb7SNbMO+qky1dGvSytxJISM60rD6lXPFxYgo/bU9n9uS+dKllWpicgmJmvLaSpRatcX79uB5cfWp3S+pqyrz68FlELgbigGddOPYaEUkUkcT0dGsePDV0XaJCef+qEUSFWTfg5oq3E9i0/4hl9SlVH85NSJec3PW4Zfdl5jHt5V9Yu9e+SmFXN+cWmxnfhbsm9nGrjubCisSQCjg31nVy7KtARMYD9wHTjDGF9TkWwBjzujEmzhgTFxMTY0HYjUPPtuG8f9UIIltYM21GTkEJs95cza50a7r6KVVX9WlCSkzOZPorv5SvMxIa5M+ejLway9dmyqB2PPbXgYg0z0nx6suKxJAA9BKRbiISBFwALHIuICJDgf9gTwrO94TfAmeKSGvHQ+czHfuUk37tW/LuFfGEBwdYUt+h3CIumbua/Vn5ltSnVF18lLivTk1In/6ewoVvrCLzaFH5vrwim8vnHd0zihfOP7HZrangDrcTgzGmBLgJ+wV9C/CRMWaTiMwRkWmOYs8C4cDHIrJORBY5js0EHsWeXBKAOY59qpIhnVvx5mXDLZtTKTUrn4vnrtJpupVX7M/K57Evt3By95qbkEpLDc98s5XbP1pv6ZKcr18SR3BA854Ur76kMS4kHxcXZxITE30dhk/8vCOdK99OtOw/zsCOLZl39cm0DLF2hlelyhhjmPVWAgm7M/n21lOrvVvIKyrhtgXr+HaTdbOkAqx9YAKtLXxG19iJyBpjTFxt5XTkcyMzplcMr150ElY1lf6Rms1V7yRSUOz6rbpSx1NbE9KBI/nM+PevLNl8kOhw6y7iv95zuiYFF2liaITG9z+Bl2YOtay+1bszueGD3ynWeZWUxWprQlq/L4tpL//CjrQcSo39+ZcV3rg0jvaR1o0Dam40MTRSfxncgefOHWJZfT9sTeMfH62ntLTxNS2qhqmsF5LNGJ45p2ovpC837Gf6K7+QnlNIsc26v7uBHVsyof8JltXXHFnTzUX5xIxhncgvKuGBzzcBEOTv59azh0Xr9xMeEsDj2q1PWWBBgr0Jac70ARWakIwx/OWlFWzan13lmPDgAHIL3Vtk6su/j3HreKV3DI3eJSNjuXciLEeKAAAgAElEQVRKXwCKbKVEuNmldd6qvTz1zVYrQlPNWGpWPo8ttjchXTyiYhPSv3/cWW1S6NiqhUtJITToWI+jlbNPr3+wqgpNDE3ANaf24NbxvQDIKSxx+wHef37cxcs/7LAiNNUM2afT3kCpqTqQLbugmDdXJNM2IphfnC7iXdqEkurCuJqQQL/yMQ43ndaTDq30uYIVNDE0Ebec0YtrHXPAHMotoqOb/0Ge+247c1fstiI01cwsSNjHzzsOMXtyXzq3qdgL6V/f7yDjaCH/PP9ERj/1A2C/U9ib6dqo5rIlcQd1jOQOne7CMpoYmggRYfbkvlw60n7bnpqV7/bU3Y9+uZn3f9tjRXiqmTheE9KOgzm8vTKZvw3tyIX/XQVAREiAS3cKzp48exBf/P0Ut+pQFenD5yZERHj4rAHkF9n4eE0Kuw4dpV/7lmw5ULU9t67u/98fBAX4cZ7OXa9qcbwmJGMMD3+xiRZB/nz6+7Hp0HIK3HvQvOLu0+jU2r3J9VRVesfQxPj5CU+dM5izhnQAYMuBbIZ2aeVWnXct3MCnFs2Fr5qu+Y4mpHuqaUL6+o8/+SUpw+1E4GznE1M0KXiIJoYmyN9PeP68IeV9udfuzWJUjyi36rz9o/X8b221E98qRWpWPo8v3sLI7lFcVKkJKb/IxuOLt1h6vu2PTdZJ8TxIE0MTFejvx8sXDmVMr2gAVu7M4Iy+bd2q89YF6/hsrd45qIqcm5CemTG4ykC2fy9Pcvs5Qpm2EcFsmTOJoAC9dHmS/nabsOAAf16/JI7hsfZltJduTWPywHZu1XnbgvV8skaTgzrmeE1IezPy+NcPSZacp88JESy5bSwtgnSmVE/TxNDEtQjy563L4xnQoSVgb+ste/7gqn98vJ75Fi3Krhq34zUhAZz67DJLztOlTSjvXRlPZKjOAuwNmhiagfDgAOZddTLdHN1Xv1i/n+knupccZn+6kXdWJlsQnWqsamtCuui/v1lynpiIYN6/cgRtLVz/XB2fJoZmIjI0kI+vG0l0eDAAn6/b7/adw0OLNvHajzutCE81Qh+udjQhTelXpQnpo8R9/JKU4Vb9HSJDaBkSwHtXxh93xTdlPU0MzUh0eDBf/v0Ugh0P7r5Yv5+JA1ybhbJFoL2d96mvt/J/322zLEbVOKQczuPxxZvtTUjxXSq89/OOdO5auMGt+ru0CSUzr4i3Lh9O33Yt3apL1Z8mhmamXWQI398+tnz7200HXerKml9so2WIfXzkSz8k8fCiTZbFqBo2Ywz3fLoRA1WakDamHOGSuavdqr9jqxbsz8rntYuHMaxrGzejVa6wJDGIyCQR2SYiSSIyu5r3TxWR30WkRERmVHrP5lgHunwtaOVZnduE8sM/jiWHxD2HXaonu6CEVo6HgW+vTOa2BessiU81bDU1Ie3JOMpZL69wu/79R/J5/vwTGdfHve7VynVuJwYR8QdeASYD/YGZItK/UrG9wGXAvGqqyDfGnOj4muZuPKpuuseE8+2tpwJQVFJK24hgl+rJyisuTw6frU3l0jfd+7SoGrayJqRRPSo2IdlKDWOfXW7JOeZMH8g0N59/KfdYcccQDyQZY3YZY4qA+cB05wLGmGRjzAZA145sQPq0i+CLm+yTj6XlFNLOxV4fWXnF5a9/2p7OnC82WxKfaljsvZA2AvD0ORWbkF62aKzC7RN6V7sEqPIuKxJDR2Cf03aKY19dhYhIooj8JiJ/ramQiFzjKJeYnp7uaqyqkkGdIvnk+pEA/JldYEmdb/6yG2N0idCm5sPV+1iRVLUJadWuDF5cut3t+i8fHcvfT+/pdj3KfQ3h4XNXY0wccCHwTxHpUV0hY8zrxpg4Y0xcTEyMdyNs4oZ1bcO8q0ZYWue0l3+xtD7lW85NSBc6NSEdPlrErQvWEe7myoFnD+3IA1P765KyDYQViSEVcJ6TuZNjX50YY1Id33cBy4GhFsSk6mlUz2jeuny4ZfVtTD3CF+v3W1af8p2ampCMMdy5cAOHcgvJdmPW1PH92vJ0NQPklO9YkRgSgF4i0k1EgoALgDr1LhKR1iIS7HgdDYwGtIHaR07r05bXLj7Jsvr+/uFaznvtV8vqU74xb/XeapuQ3lmZzPdbDlJsc73ZML5bG16+8CQC/RtC44Uq4/a/hjGmBLgJ+BbYAnxkjNkkInNEZBqAiAwXkRTgXOA/IlLW6b0fkCgi64FlwFPGGE0MPjRpYHueO3eIZfWtTs4kdvZi0nMKLatTec++zDyeWLyF0T2juGjEsSakP1KP8MRXW92qu88JEfx3VhwhgTopXkMjjfEhYVxcnElMTPR1GE3aOyuTecjiQWu7n5yibciNiDGGi+euYt3eLL659dTyu4WjhSWc9dIKdh066lb9ifePL5+iRXmHiKxxPNM9Lr1/U9WaNSqW28b3trTObvd8xXmv/cqR/OLaCyufm7d6L78kZVRpQnrw801uJ4Wf7zpNk0IDpolB1eiW8b1cHvhWk9XJmQx55Dvu/WyjpfUqa9XUhPTZ2hQ+cXOZ1/nXnFxl0j3VsGhiUMe1+r7xHql33qq9ZOUVeaRu5R5jDLM/tU+C9/Q5g8ub/3YfOsptC9a7VfeQzq04ubt7y8wqz9PEoGr1y+zTPVLvac8t90i9yj0frLI3Id07tR+dWts/2ReW2Cz59/r8xtFu16E8TxODqlXHVi0Y29v6QYWH84rZk+FeW7Wy1r7MPJ78agun9IyuMJCtz/3fuF334ptPcbsO5R2aGFSdvHNFPOfFdbK8XqsmXlPuKy013P2JvQnpqXMGISIUFNuInb3Y7boHdGjJgA6RbtejvEMTg6qzZ2YM4bvbTrW83sTkTMvrVPU3b/VeVu6s2ITU9wH37xRAm5AaG00Mql56nxDB9scmW1rnDB0d7XPVNSHFPfa9JXW/NHMoATqyuVHRfy1Vb0EBfmx9dJKldf7tVZ10z1fKmpBEpLwJ6b8/7+JQrjWj1d1dW1x5nyYG5ZKQQH82PTLRsvrW7s1i4Rr3+scr13xQ1oQ0xd6E9OHqvTy2eIvL9Q2PbV3+ev2DZ1oRovIyTQzKZWHBAay+9wzL6rvj4/V8u+lPy+pTtXNuQpoZ35n5q/dyz6euDz4c1SOKhGT7UrHXje1BpGN1P9W4aGJQbmnbMoR/X2TdjKw3zfudlUmHLKtP1ay01HDXwg34OZqQFiTsY7YbSWFwp0j+NvTYGl2zJ/e1IkzlA5oYlNsmD2pfofnAHcU2w1XvJrJuX5Yl9amafbB6L7/usjch/ZJ0yK2k0D06jLmzhnPnQnt31yUe6L2mvEcTg7LEx9eNsqyuvCIbl721mm1/5lhWp6qorAlpTK9o/P1g9qcbcXXi2+jwIN6/agRXvWuf8Tg+tg29ToiwMFrlbZoYlGXWP2Tdg8asvGIumbuKvRl5ltWp7JybkIZ1bc3sTzcyrEtrXJmBP8BPWHDtSI7kF7PecZc3/5qTLY5YeZsmBmWZyBaBrL7vDCLcXP+3TFpOIRfPXUVadoEl9Sm7D1bt4dddGfRpF8GLS3cwplcMiXsOu1TXZzeMpkdMOJNf/BmANy6N0yU6mwBNDMpSbSNC2PjIRJ45Z7Al9e3NzOOSuat1JlaL7MvM48mv7SuvrdlzmFN6RvPT9nSX6vrw6pMZ1CmSOV/YF130E5jQ/wTLYlW+Y0liEJFJIrJNRJJEZHY1758qIr+LSImIzKj03iwR2eH4mmVFPMr3zhvembsnWdMrZdvBHC57K4Gjha4vOK/sTUh3LlxPXpENgDG9oikqKXWprv9cMoyRPaLIPFrEm7/sBqxtSlS+5XZiEBF/4BVgMtAfmCki/SsV2wtcBsyrdGwb4CFgBBAPPCQi1nRvUT53/bgeXDG6myV1rduXxWVvraawxGZJfc3RB6v28Nsu+7xUY3pFc1KX1qzaXf95qp6dMZiJA9oBcNKjSwC4fUJvIkJ0zEJTYcUdQzyQZIzZZYwpAuYD050LGGOSjTEbgMofTyYCS4wxmcaYw8ASwNq5FpRPPXhWf6YObm9JXQnJh7n5w7WU2Fz7lNuc7cvM44HP7Wt4j+4ZxfQTO/Li0h31rufy0bGcG9cZsK/mVubmM3pZE6hqEKxIDB2BfU7bKY59nj5WNRKvXHgSw7pacyP47aaD3PPpRkpLXehC00yVlhrGPLMMsI83uGpMd+742LWV2B46awAAxbbS8tXclt8xzpI4VcPRaB4+i8g1IpIoIonp6a49LFO+s/C6kbSPDLGkro/XpPD4V1swrvSvbIbOeW1l+etnzx3M5W8luFSP8zOEKY5eSOP6xBAbHeZegKrBsSIxpAKdnbY7OfZZeqwx5nVjTJwxJi4mxvrVxJRniQgrLVwidO6K3bz0Q5Jl9TVVLy3dwdq99vEFX98yhnP+7doU5zufmEJkC/szhA0pWexIywXgrcuGWxOoalCsSAwJQC8R6SYiQcAFwKI6HvstcKaItHY8dD7TsU81QSJC0uPWreXw/JLtvLMy2bL6mprP1qbwf0u2A/aupWVjDepr48Nn4u80NmHay/Yp0t+9Ih5xdbi0atDcTgzGmBLgJuwX9C3AR8aYTSIyR0SmAYjIcBFJAc4F/iMimxzHZgKPYk8uCcAcxz7VRAX4W7uWw0OLNlV4CKrsPl+XWv4M4PYJvZn5xm8u1fPk2YMq9Da651P7XEgRwQGc6oF1wFXDII2xnTYuLs4kJib6OgzlhuyCYgY//J1l9b1xaZwOrnL4fF0qt8xfB9hnPN2QcsTlupKfmlr+Oi2ngPjHlwKwZc4kWgT5uxeo8joRWWOMiautXKN5+KyalpYhgSTcN96y+q5+N5GVO3W67s/XpXLrgnXl2+4khS1zKt7ZlSWF+6b006TQxGliUD4TExHMT3eeZll9F76xig0pzXe67kXr93PbgnUuTYZX2XPnDqlw8f9w9d7y11ef2t39E6gGTROD8qkuUaF8fcsYy+qb9vIv7DjY/Kbr/mL9fm6dv5Z2Ld3vEtypdQtmDOtUvl1YYitf1W3F3dYlctVwaWJQPtevfUs+unakZfVNeOEn9mU2n+m6v9ywn1sXrCOuaxuiI4Ldru/nuype/E9/7kcApg5uT6fWoW7Xrxo+TQyqQYjv1oY3Lq31mVidnfnCT6TlNP3pur/csJ9b5q9jWJfWnNa3rVvPFADWP3hmhS6oCcmZpGblA/DyzKFu1a0aD00MqsGY0P8Et6brvnX8sfl68ottTP3XCo7kFVsRWoO0eMOB8qTwyPQBPP3NVrfq++T6kUSGHuuaaozh3NfsA+LmX3OyjlloRrS7qmpwYmcvdvnYF84fUt5/H+xzA3158ymEBlmzeFBDsXjDAW6ev5aTurTizcuGM8iCrr8jurUhOeMoB7MLK+w/oWUwq+61rgeZ8h3trqoarV1PTHH52NsWrOfVi046Vteho8x8Y1WTmq77q432pDC0cyveujzekqQAsGp3ZpWkAPCjhT3HVOOgdwyqQSq2ldLrvq8trfPasd3xdzSHnNIrmlE9oi2t3xu+3niAmz60J4W3r4hn4EPuzyBz5Snd6B4TRreoMGKjw2jXMkSX52yi6nrHoIlBNVi2UkOPe7/yWP2/3nM67SNbeKx+q33zxwFumreWIZ1b8c4V8Zz+3HLScqp+wq+PDQ+fSUtdYKfZ0KYk1ej5+wnbH7Nu0r0y4/u1BY51w2wMvvnjzwpJ4fnvtrudFD67YZQmBVUtTQyqQQsK8GPTIxNdPr5n23AuHdm1wr6hXeyLBuUX2ziY3fC7tNqTwu8M7hTJ25cPZ9ufOeXrLLvq7kl9y38PSlWmiUE1eGHBAax9YIJLxyal5bJ2bxb/dRoj8ey32+hzQgQA4/+vYd81OCeFd66IJz2nkHP+vbL2A4+jX/uWXD+uh0URqqZIE4NqFFqHBfHrPa4t9LMx9Qhv/LyLp84eVL5vm2PajJzCkgY7EO7bTfakMMiRFI4W2jjdxUQ2oEPL8tdf3XyKVSGqJkoTg2o02ke2YOk/xrp07KrdmSzZfJCZ8V2qvHfmCz+5G5rlvtv0Jzd+8DsDO9qTgq3UcPKTS12qa2Z8Fzbtzwbgj0cm6kA1VStNDKpR6RETzqKbRrt07NKtaeQXldC20nxCWXnFpLv5INdKSzYf5MZ59qTw7pXxBPr5ceKcJS7VdfMZvcpnRl1002jCg5vWQD/lGZoYVKMzuFMrPrhqhEvH/m/dfs7oV3VBn+GPf+9uWJb4fvNBbvhgDQM62JNCi0B/+j34jUt1XTe2B/9augOwr6EwuFMrK0NVTZgliUFEJonINhFJEpHZ1bwfLCILHO+vEpFYx/5YEckXkXWOr9esiEc1faN7RvNvpxHO9fHh6r2cNaRDlf2rdmW4G5Zbvt98kOs/WEN/R1IIDwpweZBf95gwXvtxJwBDOkXqGgqqXtxODCLiD7wCTAb6AzNFpH+lYlcCh40xPYEXgKed3ttpjDnR8XWdu/Go5mPyoPY86fRAuT6+WL+fvu0iKuw7//XfSDnsm+m6l25xJIX2LXn3ingiggPo7sbgPuef7fOb9GGzqh8r7hjigSRjzC5jTBEwH5heqcx04B3H64XAGaJPwJQFZsZ34c6JfVw6duufVRf0OeXpZV5/3rB0y0Gue9+RFK4cQWSLQLemunj+vCF8tfFPALfGgKjmy4rE0BHY57Sd4thXbRljTAlwBIhyvNdNRNaKyI8iYt1SXqrZuGFcDy4bFWtZfac/t5wj+d6ZrvuHrQe5/v3f6eeUFM584UeOFrk26d8n14/k9o/ss8suvvkUwvRhs3KBrx8+HwC6GGOGArcD80SkZXUFReQaEUkUkcT09HSvBqkaNhHhwb/0Z1yfGEvqyyksYcqLP5Pv4sW5rpZtTeO6936nT7sI3rvCnhQufXM12w/mulTfbeN7c86/7esnPHxWfwZ0iLQyXNWMWJEYUoHOTtudHPuqLSMiAUAkkGGMKTTGZAAYY9YAO4He1Z3EGPO6MSbOGBMXE2PNBUA1HX5+whuXxlmy5jFAalY+5/x7JUUlpZbUV9myrWlc+94a+rSL4P0rRxAZGshdC9fz03bXP/S88P12AIbHtuay0d2sClU1Q1YkhgSgl4h0E5Eg4AJgUaUyi4BZjtczgB+MMUZEYhwPrxGR7kAvYJcFMalmKNDfj+V3jrOsvs0HsrnsrdXYSq2dgXjZtqpJ4dlvt/JRYorLdZZNDAjw8XWjrAhTNWNuJwbHM4ObgG+BLcBHxphNIjJHRKY5is0FokQkCXuTUVmX1lOBDSKyDvtD6euMMZnuxqSar5BAf9Y/dKZl9a3cmcFN837HqunplzuSQu924eVJ4c0Vu3ll2U6X63x2xmC+35IGwJY5kyyJUzVvuh6DapLScgqIf9y1KSSqc+GILjzxN9e6xpZZvi2Na95bQ6+24Xxw1QhahQbx+bpUbpm/zuU63758OJe9lQDAN7eOoW+7ah/RKQXoegyqmWsbEcLyO8ZZVt+8VXt57tttLh//4/b0Kklh+bY0t5ICUJ4UHvvrQE0KyjKaGFSTFRsdxsLrRlpW38vLkvjvz8cegcXOXswox8R2O9Nr7kn04/Z0rn43kZ4xx5LC73sPl1/U3TW6ZxQXn9y19oJK1ZEmBtWkxcW24YXzh1hW32OLt/Bx4rFhO/uPFPD44s2c8X8/Vjulxk/VJIUdB3M4+1X31lRw9sFVJ1tWl1IAOvpFNXl/G9qJzfuzeeNn91Y9K3Pnwg1EOC2JWVbv+a//xvI7xvGvH3ZQVFLK+cM7V0gKrcOCSM3KZ4KF03xvfVQfNivr6cNn1Wyc9dIKNqYe8eg5hnZpxdq9WQAEB/jRPSaceY6kkHm0iJMedW36bIDXLxnGNe+tKd9e+8AEWocFuR2zaj704bNSlXzxd89PJleWFAC6RYex5UA2Y55ZxtHCEreSwkszh1ZICruemKJJQXmMJgbVrOx8YorXzlU2SV9uYQkD3JgU77y4Tvz9w7Xl28lPTcXPT+egVJ6jiUE1K/5+wh+NaMZRP6HCiOhH/zrQh9Go5kITg2p2woMD+OnO03wdRp1Uno3jEu2WqrxAE4NqlrpEhTLvateWB/W2jq1aAHDz6T19HIlqLjQxqGZrVI9oHpk2wNdhHNfI7lGkZuUDcNuEaiceVspymhhUszZrVCxnD628rlTD8atj0Nwj0wagix4qb9EBbqrZe/bcISTsyWRfZr6vQ6niHxN6c+NpPbUXkvIqTQyq2fP3E76+5VS31lm2Uu8Twpk7azid24T6OhTVTGliUAp7T6UVd5/GKU8v82kcifePJzo82KcxKKXPGJRy6NQ61NLZWOvrjjN7a1JQDYImBqWcxMW24blzrZuNtT6uPKW7T86rVGWaGJSqZMawTlw71rsX6ekndqBFkL9Xz6lUTSxJDCIySUS2iUiSiMyu5v1gEVngeH+ViMQ6vXePY/82EWk8cxWoJu2uiX0Z36+tV8517djuvHjBUK+cS6m6cDsxiIg/8AowGegPzBSR/pWKXQkcNsb0BF4AnnYc2x+4ABgATAJeddSnlE/5+wn/vGAoPWLCPHqev57Ygbsn9vXoOZSqLyvuGOKBJGPMLmNMETAfmF6pzHTgHcfrhcAZYh+tMx2Yb4wpNMbsBpIc9Snlc+HBAbxzRTxB/p5pcR3dM4pnZgzRMQqqwbHiL74jsM9pO8Wxr9oyxpgS4AgQVcdjARCRa0QkUUQS09PTLQhbqdp1au25OZVeu3gYQQH6mE81PI3mr9IY87oxJs4YExcTE+PrcFQz4omeSqvuPaPC8qBKNSRWDHBLBTo7bXdy7KuuTIqIBACRQEYdj1XKZwpLbPS5/xtL63x0+gBOaBliaZ1KWcmKO4YEoJeIdBORIOwPkxdVKrMImOV4PQP4wdgXm14EXODotdQN6AWstiAmpSwhWNf+7+8nLLxuJJeMjLWsTqU8we07BmNMiYjcBHwL+ANvGmM2icgcINEYswiYC7wnIklAJvbkgaPcR8BmoAS40RhjczcmpawSFOBH8lNTST50lHHPLXerLlupIS62jTWBKeVBlsyVZIz5Cviq0r4HnV4XAOfWcOzjwONWxKGUp3yYsNeSep77dht3TOxjSV1KeUqjefislC9dPaY7Zw3pwOl93Rv09nPSIYsiUspzdHZVpeogOjyYl2YOJXb2Ypfr+P72sfRsG25hVEp5ht4xKOU1xtcBKFUnmhiUqoddT0zBlYHKz84YTM+2EdYHpJQHaGJQqh78/ISZ8V3qdcyuJ6Zwblzn2gsq1UBoYlCqnvq1b1mv8joXkmpsNDEoVU8Xn9yV3U9OqXP5uSt2ezAapayniUEpF9gnB66bR7/czAKLxkEo5Q2aGJRy0doHJtRaZnTPKADu/mSjp8NRyjI6jkEpF7UOC6q1zPtXjmBPRh5FtlIvRKSUNTQxKOUhb102HBEhNtqzq8ApZTVtSlLKQguuORmAnU9M4TQ3p89Qylc0MSjlhuSnpjJ5YDsAOrZqwYjuUSQ/NRV/7aKqGjFNDEq56ekZgwH7Up1KNQX6jEEpN7UMCST5qam+DkMpy+gdg1JKqQo0MSillKrArcQgIm1EZImI7HB8b11DuVmOMjtEZJbT/uUisk1E1jm+tBuHUkr5mLt3DLOBpcaYXsBSx3YFItIGeAgYAcQDD1VKIBcZY050fKW5GY9SSik3uZsYpgPvOF6/A/y1mjITgSXGmExjzGFgCTDJzfMqpZTyEHcTwwnGmAOO138CJ1RTpiOwz2k7xbGvzFuOZqQHpD4zkymllPKIWrurisj3QLtq3rrPecMYY0SkvmsXXmSMSRWRCOAT4BLg3RriuAa4BqBLl/otlKKUUqruak0MxpjxNb0nIgdFpL0x5oCItAeqe0aQCoxz2u4ELHfUner4niMi87A/g6g2MRhjXgdeB4iLi9PFc5VSykPEGNevsSLyLJBhjHlKRGYDbYwxd1Uq0wZYA5zk2PU7MAzIBloZYw6JSCDwIfC9Mea1Opw3HdjjcuDHFw0c8lDd3qDx+5bG71sa//F1NcbE1FbI3cQQBXwEdMF+oT7PGJMpInHAdcaYqxzlrgDudRz2uDHmLREJA34CAgF/4HvgdmOMzeWALCAiicaYOF/G4A6N37c0ft/S+K3h1pQYxpgM4Ixq9icCVzltvwm8WanMUex3DkoppRoQHfmslFKqAk0MVb3u6wDcpPH7lsbvWxq/Bdx6xqCUUqrp0TsGpZRSFTT7xFCPiQC7iMh3IrJFRDaLSKx3I61eXeN3lG0pIiki8rI3YzyeusQvIieKyK8isklENojI+b6ItVJMkxwTQCY5umpXfj9YRBY43l/VUP5eytQh/tsdf+cbRGSpiHT1RZw1qS1+p3LniIhx9JRsMOoSv4ic5/g32OQY5+U9xphm/QU8A8x2vJ4NPF1DueXABMfrcCDU17HXJ37H+y8C84CXfR13feIHegO9HK87AAewj4HxVcz+wE6gOxAErAf6VypzA/Ca4/UFwAJf/67rGf9pZX/jwPWNLX5HuQjsXeJ/A+J8HXc9f/+9gLVAa8d2W2/G2OzvGKjDRIAi0h8IMMYsATDG5Bpj8rwX4nHVZSJDRGQY9rmsvvNSXHVVa/zGmO3GmB2O1/uxj7CvdZCOB8UDScaYXcaYImA+9p/DmfPPtRA4owHNBVZr/MaYZU5/479hn7GgoajL7x/gUeBpoMCbwdVBXeK/GnjF2CcexXh55mlNDHWbCLA3kCUin4rIWhF5VkT8vRficdUav4j4Af8H3OHNwOqoLr//ciISj/1T1k5PB3YctU0MWaGMMaYEOAJEeSW62tUlfmdXAl97NKL6qTV+ETkJ6GyMWezNwOqoLr//3kBvEflFRH4TEa/OSN0s1ny2YCLAAGAMMBTYCywALgPmWhtp9SyI/wbgK2NMii8+tFo1EaNjPq73gFnGmFJro/ox2TsAAAH6SURBVFTVEZGLgThgrK9jqSvHB6Hnsf8fbawCsDcnjcN+t/aTiAwyxmR56+RNnnF/IsAUYJ0xZpfjmP8BJ+OlxGBB/COBMSJyA/bnI0EikmuMqfGhnZUsiB8RaQksBu4zxvzmoVDrKhXo7LTdybGvujIpIhIARAIZ3gmvVnWJHxEZjz15jzXGFHoptrqoLf4IYCCw3PFBqB2wSESmGfusDL5Wl99/CrDKGFMM7BaR7dgTRYI3AtSmJFgElC03Ogv4vJoyCUArESlr1z4d2OyF2Oqi1viNMRcZY7oYY2KxNye9662kUAe1xi8iQcBn2ONe6MXYapIA9BKRbo7YLsD+czhz/rlmAD8Yx1PEBqDW+EVkKPAfYJq327fr4LjxG2OOGGOijTGxjr/537D/HA0hKUDd/n7+h2NWahGJxt60tMtrEfr6Cb2vv7C3+y4FdmCfyK+NY38c8F+nchOADcBG4G0gyNex1yd+p/KX0bB6JdUaP3AxUAysc/o60cdxTwG2Y3/WcZ9j3xzsFyCAEOBjIAlYDXT39e+6nvF/Dxx0+n0v8nXM9Ym/UtnlNKBeSXX8/Qv25rDNjmvOBd6MT0c+K6WUqkCbkpRSSlWgiUEppVQFmhiUUkpVoIlBKaVUBZoYlFJKVaCJQSmlVAWaGJRSSlWgiUEppVQF/w/+Q/L35qnCdwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#plot out the eye tracking results\n", + "eye_array = hdffile.root.eyetracker\n", + "print(eye_array)\n", + "\n", + "plt.plot(eye_array[:,0], eye_array[:,1])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#plot out the mocap data\n", + "mocap_array = hdffile.root.motiontracker\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 1416e8ec32d5275ea74c2b66c055676d46a6cdba Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Sat, 27 Jun 2020 16:49:11 -0700 Subject: [PATCH 038/242] clean up shared analysis folder --- analysis_shared/Untitled.ipynb | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 analysis_shared/Untitled.ipynb diff --git a/analysis_shared/Untitled.ipynb b/analysis_shared/Untitled.ipynb deleted file mode 100644 index 1594c89b..00000000 --- a/analysis_shared/Untitled.ipynb +++ /dev/null @@ -1,32 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.5.2" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 9667667f459ec7c490ffe778419b6ff928f19340 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Mon, 29 Jun 2020 12:24:48 -0700 Subject: [PATCH 039/242] updated the install reqs and fixed some minor bugs --- requirements.txt | 52 +++++++++++++++++-------------- requirements_old.txt | 23 ++++++++++++++ riglib/stimulus_pulse.py | 20 ++++++------ robot/{drivebot.py => drivebot.m} | 0 tests/test_onedim_lfp.py | 2 +- 5 files changed, 63 insertions(+), 34 deletions(-) create mode 100644 requirements_old.txt rename robot/{drivebot.py => drivebot.m} (100%) diff --git a/requirements.txt b/requirements.txt index 2812921f..78146a86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,29 @@ -numpy -scipy -matplotlib -numexpr -cython -django-celery -traits -pandas -patsy -statsmodels -PyOpenGL -PyOpenGL_accelerate -Django -pylibftdi -nitime -sphinx -numpydoc -tornado -tables -pyserial -h5py -pygame -ipdb \ No newline at end of file +traits==5.2.0 +Django==3.0.2 +PyLink==0.3.3 +numpy==1.18.0 +matplotlib==3.1.2 +scipy==1.4.1 +tables==3.6.1 +celery==3.1.26.post2 +pandas==0.25.3 +h5py==2.10.0 +Cython==0.29.14 +nitime==0.8.1 +statsmodels==0.10.2 +tornado==6.0.3 +pygame==1.9.6 +django_celery==3.3.1 +namelist==0.1.0 +OWL==0.2.1 +Phidgets==2.1.8 +pyaudio==0.2.11 +pyautogui==0.9.50 +PyOpenGL==3.1.5 +pyserial==3.4 +qtpy==1.9.0 +robotframework==3.2.1 +scikit_learn==0.23.1 +tabulate==0.8.7 +tasks==2.5.0 +testing==0.0 diff --git a/requirements_old.txt b/requirements_old.txt new file mode 100644 index 00000000..2812921f --- /dev/null +++ b/requirements_old.txt @@ -0,0 +1,23 @@ +numpy +scipy +matplotlib +numexpr +cython +django-celery +traits +pandas +patsy +statsmodels +PyOpenGL +PyOpenGL_accelerate +Django +pylibftdi +nitime +sphinx +numpydoc +tornado +tables +pyserial +h5py +pygame +ipdb \ No newline at end of file diff --git a/riglib/stimulus_pulse.py b/riglib/stimulus_pulse.py index bb5b88d5..32b1fad0 100644 --- a/riglib/stimulus_pulse.py +++ b/riglib/stimulus_pulse.py @@ -4,16 +4,16 @@ class stimulus_pulse(object): - com = comedi.comedi_open('/dev/comedi0') - - def __init__(self, *args, **kwargs): - #self.com = comedi.comedi_open('/dev/comedi0') - super(stimulus_pulse, self).__init__(*args, **kwargs) - subdevice = 0 - write_mask = 0x800000 - val = 0x000000 - base_channel = 0 - comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) + com = comedi.comedi_open('/dev/comedi0') + + def __init__(self, *args, **kwargs): + #self.com = comedi.comedi_open('/dev/comedi0') + super(stimulus_pulse, self).__init__(*args, **kwargs) + subdevice = 0 + write_mask = 0x800000 + val = 0x000000 + base_channel = 0 + comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) def pulse(self,ts): #super(stimulus_pulse, self).pulse() diff --git a/robot/drivebot.py b/robot/drivebot.m similarity index 100% rename from robot/drivebot.py rename to robot/drivebot.m diff --git a/tests/test_onedim_lfp.py b/tests/test_onedim_lfp.py index 5bc1f302..65cd07b8 100644 --- a/tests/test_onedim_lfp.py +++ b/tests/test_onedim_lfp.py @@ -35,7 +35,7 @@ Exp = experiment.make(base_class, feats=feats) #params.trait_norm(Exp.class_traits()) -params = dict(session_length=10, plant_visible=True, lfp_plant_type='cursor_onedimLFP', mc_plant_type='cursor_14x14' +params = dict(session_length=10, plant_visible=True, lfp_plant_type='cursor_onedimLFP', mc_plant_type='cursor_14x14', rand_start=(0.,0.), max_tries=1) gen = SimBMIControlMulti.sim_target_seq_generator_multi(8, 1000) From 57f39062f21293931e4454c836ba29e86fa32560 Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Tue, 30 Jun 2020 15:23:13 -0700 Subject: [PATCH 040/242] add optitrack python files --- riglib/optitrack_client/NatNetClient.py | 516 ++++++++++++++++++ riglib/optitrack_client/PythonSample.py | 41 ++ .../test_NatNetClient_perframe.py | 22 + 3 files changed, 579 insertions(+) create mode 100644 riglib/optitrack_client/NatNetClient.py create mode 100644 riglib/optitrack_client/PythonSample.py create mode 100644 riglib/optitrack_client/test_NatNetClient_perframe.py diff --git a/riglib/optitrack_client/NatNetClient.py b/riglib/optitrack_client/NatNetClient.py new file mode 100644 index 00000000..2fcc28dc --- /dev/null +++ b/riglib/optitrack_client/NatNetClient.py @@ -0,0 +1,516 @@ +#Copyright © 2018 Naturalpoint +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +#http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +# OptiTrack NatNet direct depacketization library for Python 3.x + +import socket +import struct +from threading import Thread + +def trace( *args ): + pass # print( "".join(map(str,args)) ) + +# Create structs for reading various object types to speed up parsing. +Vector3 = struct.Struct( '= 2 ): + # Marker ID's + for i in markerCountRange: + id = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "\tMarker ID", i, ":", id ) + + # Marker sizes + for i in markerCountRange: + size = FloatValue.unpack( data[offset:offset+4] ) + offset += 4 + trace( "\tMarker Size", i, ":", size[0] ) + + if( self.__natNetStreamVersion[0] >= 2 ): + markerError, = FloatValue.unpack( data[offset:offset+4] ) + offset += 4 + trace( "\tMarker Error:", markerError ) + + # Version 2.6 and later + if( ( ( self.__natNetStreamVersion[0] == 2 ) and ( self.__natNetStreamVersion[1] >= 6 ) ) or self.__natNetStreamVersion[0] > 2 or self.__natNetStreamVersion[0] == 0 ): + param, = struct.unpack( 'h', data[offset:offset+2] ) + trackingValid = ( param & 0x01 ) != 0 + offset += 2 + trace( "\tTracking Valid:", 'True' if trackingValid else 'False' ) + + return offset + + # Unpack a skeleton object from a data packet + def __unpackSkeleton( self, data ): + offset = 0 + + id = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "ID:", id ) + + rigidBodyCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Rigid Body Count:", rigidBodyCount ) + for j in range( 0, rigidBodyCount ): + offset += self.__unpackRigidBody( data[offset:] ) + + return offset + + # Unpack data from a motion capture frame message + def __unpackMocapData( self, data ): + trace( "Begin MoCap Frame\n-----------------\n" ) + + data = memoryview( data ) + offset = 0 + + # Frame number (4 bytes) + frameNumber = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Frame #:", frameNumber ) + + # Marker set count (4 bytes) + markerSetCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Marker Set Count:", markerSetCount ) + + for i in range( 0, markerSetCount ): + # Model name + modelName, separator, remainder = bytes(data[offset:]).partition( b'\0' ) + offset += len( modelName ) + 1 + trace( "Model Name:", modelName.decode( 'utf-8' ) ) + + # Marker count (4 bytes) + markerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Marker Count:", markerCount ) + + for j in range( 0, markerCount ): + pos = Vector3.unpack( data[offset:offset+12] ) + offset += 12 + #trace( "\tMarker", j, ":", pos[0],",", pos[1],",", pos[2] ) + + # Unlabeled markers count (4 bytes) + unlabeledMarkersCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Unlabeled Markers Count:", unlabeledMarkersCount ) + + for i in range( 0, unlabeledMarkersCount ): + pos = Vector3.unpack( data[offset:offset+12] ) + offset += 12 + trace( "\tMarker", i, ":", pos[0],",", pos[1],",", pos[2] ) + + # Rigid body count (4 bytes) + rigidBodyCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Rigid Body Count:", rigidBodyCount ) + + for i in range( 0, rigidBodyCount ): + offset += self.__unpackRigidBody( data[offset:] ) + + # Version 2.1 and later + skeletonCount = 0 + if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] > 0 ) or self.__natNetStreamVersion[0] > 2 ): + skeletonCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Skeleton Count:", skeletonCount ) + for i in range( 0, skeletonCount ): + offset += self.__unpackSkeleton( data[offset:] ) + + # Labeled markers (Version 2.3 and later) + labeledMarkerCount = 0 + if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] > 3 ) or self.__natNetStreamVersion[0] > 2 ): + labeledMarkerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Labeled Marker Count:", labeledMarkerCount ) + for i in range( 0, labeledMarkerCount ): + id = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + pos = Vector3.unpack( data[offset:offset+12] ) + offset += 12 + size = FloatValue.unpack( data[offset:offset+4] ) + offset += 4 + + # Version 2.6 and later + if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 6 ) or self.__natNetStreamVersion[0] > 2 or major == 0 ): + param, = struct.unpack( 'h', data[offset:offset+2] ) + offset += 2 + occluded = ( param & 0x01 ) != 0 + pointCloudSolved = ( param & 0x02 ) != 0 + modelSolved = ( param & 0x04 ) != 0 + + # Version 3.0 and later + if( ( self.__natNetStreamVersion[0] >= 3 ) or major == 0 ): + residual, = FloatValue.unpack( data[offset:offset+4] ) + offset += 4 + trace( "Residual:", residual ) + + # Force Plate data (version 2.9 and later) + if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 9 ) or self.__natNetStreamVersion[0] > 2 ): + forcePlateCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Force Plate Count:", forcePlateCount ) + for i in range( 0, forcePlateCount ): + # ID + forcePlateID = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Force Plate", i, ":", forcePlateID ) + + # Channel Count + forcePlateChannelCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + # Channel Data + for j in range( 0, forcePlateChannelCount ): + trace( "\tChannel", j, ":", forcePlateID ) + forcePlateChannelFrameCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + for k in range( 0, forcePlateChannelFrameCount ): + forcePlateChannelVal = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "\t\t", forcePlateChannelVal ) + + # Device data (version 2.11 and later) + if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 11 ) or self.__natNetStreamVersion[0] > 2 ): + deviceCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Device Count:", deviceCount ) + for i in range( 0, deviceCount ): + # ID + deviceID = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "Device", i, ":", deviceID ) + + # Channel Count + deviceChannelCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + # Channel Data + for j in range( 0, deviceChannelCount ): + trace( "\tChannel", j, ":", deviceID ) + deviceChannelFrameCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + for k in range( 0, deviceChannelFrameCount ): + deviceChannelVal = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "\t\t", deviceChannelVal ) + + # Timecode + timecode = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + timecodeSub = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + # Timestamp (increased to double precision in 2.7 and later) + if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 7 ) or self.__natNetStreamVersion[0] > 2 ): + timestamp, = DoubleValue.unpack( data[offset:offset+8] ) + offset += 8 + else: + timestamp, = FloatValue.unpack( data[offset:offset+4] ) + offset += 4 + + # Hires Timestamp (Version 3.0 and later) + if( ( self.__natNetStreamVersion[0] >= 3 ) or major == 0 ): + stampCameraExposure = int.from_bytes( data[offset:offset+8], byteorder='little' ) + offset += 8 + stampDataReceived = int.from_bytes( data[offset:offset+8], byteorder='little' ) + offset += 8 + stampTransmit = int.from_bytes( data[offset:offset+8], byteorder='little' ) + offset += 8 + # Frame parameters + param, = struct.unpack( 'h', data[offset:offset+2] ) + isRecording = ( param & 0x01 ) != 0 + trackedModelsChanged = ( param & 0x02 ) != 0 + offset += 2 + + # Send information to any listener. + if self.newFrameListener is not None: + print(frameNumber) + self.newFrameListener( frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, + labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ) + + # Unpack a marker set description packet + def __unpackMarkerSetDescription( self, data ): + offset = 0 + + name, separator, remainder = bytes(data[offset:]).partition( b'\0' ) + offset += len( name ) + 1 + trace( "Markerset Name:", name.decode( 'utf-8' ) ) + + markerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + for i in range( 0, markerCount ): + name, separator, remainder = bytes(data[offset:]).partition( b'\0' ) + offset += len( name ) + 1 + trace( "\tMarker Name:", name.decode( 'utf-8' ) ) + + return offset + + # Unpack a rigid body description packet + def __unpackRigidBodyDescription( self, data ): + offset = 0 + + # Version 2.0 or higher + if( self.__natNetStreamVersion[0] >= 2 ): + name, separator, remainder = bytes(data[offset:]).partition( b'\0' ) + offset += len( name ) + 1 + trace( "\tRigidBody Name:", name.decode( 'utf-8' ) ) + + id = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + parentID = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + timestamp = Vector3.unpack( data[offset:offset+12] ) + offset += 12 + + # Version 3.0 and higher, rigid body marker information contained in description + if (self.__natNetStreamVersion[0] >= 3 or self.__natNetStreamVersion[0] == 0 ): + markerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + trace( "\tRigidBody Marker Count:", markerCount ) + + markerCountRange = range( 0, markerCount ) + for marker in markerCountRange: + markerOffset = Vector3.unpack(data[offset:offset+12]) + offset +=12 + for marker in markerCountRange: + activeLabel = int.from_bytes(data[offset:offset+4],byteorder = 'little') + offset += 4 + + return offset + + # Unpack a skeleton description packet + def __unpackSkeletonDescription( self, data ): + offset = 0 + + name, separator, remainder = bytes(data[offset:]).partition( b'\0' ) + offset += len( name ) + 1 + trace( "\tMarker Name:", name.decode( 'utf-8' ) ) + + id = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + rigidBodyCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + for i in range( 0, rigidBodyCount ): + offset += self.__unpackRigidBodyDescription( data[offset:] ) + + return offset + + # Unpack a data description packet + def __unpackDataDescriptions( self, data ): + offset = 0 + datasetCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + + for i in range( 0, datasetCount ): + type = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + if( type == 0 ): + offset += self.__unpackMarkerSetDescription( data[offset:] ) + elif( type == 1 ): + offset += self.__unpackRigidBodyDescription( data[offset:] ) + elif( type == 2 ): + offset += self.__unpackSkeletonDescription( data[offset:] ) + + def __dataThreadFunction( self, socket ): + while True: + # Block for input + data, addr = socket.recvfrom(1024) # 32k byte buffer size + if( len( data ) > 0 ): + self.__processMessage( data ) + + def __processMessage( self, data ): + trace( "Begin Packet\n------------\n" ) + + messageID = int.from_bytes( data[0:2], byteorder='little' ) + trace( "Message ID:", messageID ) + print(messageID) + print(self.NAT_FRAMEOFDATA) + + packetSize = int.from_bytes( data[2:4], byteorder='little' ) + trace( "Packet Size:", packetSize ) + + offset = 4 + if( messageID == self.NAT_FRAMEOFDATA ): + self.__unpackMocapData( data[offset:] ) + elif( messageID == self.NAT_MODELDEF ): + self.__unpackDataDescriptions( data[offset:] ) + elif( messageID == self.NAT_PINGRESPONSE ): + offset += 256 # Skip the sending app's Name field + offset += 4 # Skip the sending app's Version info + self.__natNetStreamVersion = struct.unpack( 'BBBB', data[offset:offset+4] ) + offset += 4 + elif( messageID == self.NAT_RESPONSE ): + if( packetSize == 4 ): + commandResponse = int.from_bytes( data[offset:offset+4], byteorder='little' ) + offset += 4 + else: + message, separator, remainder = bytes(data[offset:]).partition( b'\0' ) + offset += len( message ) + 1 + trace( "Command response:", message.decode( 'utf-8' ) ) + elif( messageID == self.NAT_UNRECOGNIZED_REQUEST ): + trace( "Received 'Unrecognized request' from server" ) + elif( messageID == self.NAT_MESSAGESTRING ): + message, separator, remainder = bytes(data[offset:]).partition( b'\0' ) + offset += len( message ) + 1 + trace( "Received message from server:", message.decode( 'utf-8' ) ) + else: + trace( "ERROR: Unrecognized packet type" ) + + trace( "End Packet\n----------\n" ) + print('Finished Processing') + + def sendCommand( self, command, commandStr, socket, address ): + # Compose the message in our known message format + if( command == self.NAT_REQUEST_MODELDEF or command == self.NAT_REQUEST_FRAMEOFDATA ): + packetSize = 0 + commandStr = "" + elif( command == self.NAT_REQUEST ): + packetSize = len( commandStr ) + 1 + elif( command == self.NAT_PING ): + commandStr = "Ping" + packetSize = len( commandStr ) + 1 + + data = command.to_bytes( 2, byteorder='little' ) + data += packetSize.to_bytes( 2, byteorder='little' ) + + data += commandStr.encode( 'utf-8' ) + data += b'\0' + + socket.sendto( data, address ) + + def run( self ): + # Create the data socket + self.dataSocket = self.__createDataSocket( self.dataPort ) + if( self.dataSocket is None ): + print( "Could not open data channel" ) + exit + + # Create the command socket + self.commandSocket = self.__createCommandSocket() + if( self.commandSocket is None ): + print( "Could not open command channel" ) + exit + + # Create a separate thread for receiving data packets + dataThread = Thread( target = self.__dataThreadFunction, args = (self.dataSocket, )) + dataThread.start() + + # Create a separate thread for receiving command packets + commandThread = Thread( target = self.__dataThreadFunction, args = (self.commandSocket, )) + commandThread.start() + + self.sendCommand( self.NAT_REQUEST_MODELDEF, "", self.commandSocket, (self.serverIPAddress, self.commandPort) ) diff --git a/riglib/optitrack_client/PythonSample.py b/riglib/optitrack_client/PythonSample.py new file mode 100644 index 00000000..a511a1a8 --- /dev/null +++ b/riglib/optitrack_client/PythonSample.py @@ -0,0 +1,41 @@ +#Copyright © 2018 Naturalpoint +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +#http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + + +# OptiTrack NatNet direct depacketization sample for Python 3.x +# +# Uses the Python NatNetClient.py library to establish a connection (by creating a NatNetClient), +# and receive data via a NatNet connection and decode it using the NatNetClient library. + +from NatNetClient import NatNetClient + +# This is a callback function that gets connected to the NatNet client and called once per mocap frame. +def receiveNewFrame( frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, + labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ): + print( "Received frame", frameNumber ) + +# This is a callback function that gets connected to the NatNet client. It is called once per rigid body per frame +def receiveRigidBodyFrame( id, position, rotation ): + print( "Received frame for rigid body", position ) + +# This will create a new NatNet client +streamingClient = NatNetClient() + +# Configure the streaming client to call our rigid body handler on the emulator to send data out. +streamingClient.newFrameListener = receiveNewFrame +streamingClient.rigidBodyListener = receiveRigidBodyFrame + +# Start up the streaming client now that the callbacks are set up. +# This will run perpetually, and operate on a separate thread. +streamingClient.run() diff --git a/riglib/optitrack_client/test_NatNetClient_perframe.py b/riglib/optitrack_client/test_NatNetClient_perframe.py new file mode 100644 index 00000000..c6510fb8 --- /dev/null +++ b/riglib/optitrack_client/test_NatNetClient_perframe.py @@ -0,0 +1,22 @@ +from NatNetClient import NatNetClient + +# This will create a new NatNet client +streamingClient = NatNetClient() + +streamingClient.dataSocket = streamingClient.__createDataSocket(streamingClient.dataPort) +if (streamingClient.dataSocket is None): + print("Could not open data channel") + exit + +# Create the command socket +streamingClient.commandSocket = streamingClient.__createCommandSocket() +if (streamingClient.commandSocket is None): + print("Could not open command channel") + exit + + +# receive some data + +data, addr = streamingClient.dataSocket.recvfrom(1024) # 32k byte buffer size +if (len(data) > 0): + streamingClient__processMessage(data) From 9422b6783b8068a6ee42ad7bed6810fc3f6346e2 Mon Sep 17 00:00:00 2001 From: leoscholl Date: Tue, 7 Jul 2020 12:58:48 -0700 Subject: [PATCH 041/242] fix relative imports --- built_in_tasks/bmimultitasks.py | 2 +- built_in_tasks/manualcontrolmultitasks.py | 4 ++-- built_in_tasks/passivetasks.py | 4 +++- db/tracker/ajax.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index a9c0de18..01781c99 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -29,7 +29,7 @@ from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceNLinkPlanarChain -from manualcontrolmultitasks import ManualControlMulti +from built_in_tasks.manualcontrolmultitasks import ManualControlMulti target_colors = {"blue":(0,0,1,0.5), "yellow": (1,1,0,0.5), diff --git a/built_in_tasks/manualcontrolmultitasks.py b/built_in_tasks/manualcontrolmultitasks.py index d5f1a3ef..a8c81779 100644 --- a/built_in_tasks/manualcontrolmultitasks.py +++ b/built_in_tasks/manualcontrolmultitasks.py @@ -16,7 +16,7 @@ from riglib.stereo_opengl.render import stereo, Renderer from riglib.stereo_opengl.utils import cloudy_tex -from plantlist import plantlist +from built_in_tasks.plantlist import plantlist from riglib.stereo_opengl import ik import os @@ -31,7 +31,7 @@ GOLD = (1., 0.843, 0., 0.5) mm_per_cm = 1./10 -from target_graphics import * +from built_in_tasks.target_graphics import * target_colors = { "yellow": (1,1,0,0.75), diff --git a/built_in_tasks/passivetasks.py b/built_in_tasks/passivetasks.py index a38f17fd..d8e8777d 100644 --- a/built_in_tasks/passivetasks.py +++ b/built_in_tasks/passivetasks.py @@ -24,7 +24,7 @@ from riglib.stereo_opengl.window import WindowDispl2D, FakeWindow -from .bmimultitasks import BMIControlMulti +from built_in_tasks.bmimultitasks import BMIControlMulti bmi_ssm_options = ['Endpt2D', 'Tentacle', 'Joint2L'] @@ -69,6 +69,8 @@ def _test_start_trial(self, ts): @classmethod def get_desc(cls, params, report): + if report == None: + return "Broken report" duration = report[-1][-1] - report[0][-1] reward_count = 0 for item in report: diff --git a/db/tracker/ajax.py b/db/tracker/ajax.py index 2121b2eb..41448542 100644 --- a/db/tracker/ajax.py +++ b/db/tracker/ajax.py @@ -14,7 +14,7 @@ from .models import TaskEntry, Feature, Sequence, Task, Generator, Subject, DataFile, System, Decoder -import trainbmi +import db.trainbmi as trainbmi import logging import io, traceback From 6acf6277edb3a1641314b26fa3f4ccf021ee16ed Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Sun, 12 Jul 2020 18:48:43 -0700 Subject: [PATCH 042/242] add a few more dependencies, take pyaudio for now --- requirements.txt | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 78146a86..b3d0139e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ django_celery==3.3.1 namelist==0.1.0 OWL==0.2.1 Phidgets==2.1.8 -pyaudio==0.2.11 +#pyaudio==0.2.11 pyautogui==0.9.50 PyOpenGL==3.1.5 pyserial==3.4 @@ -26,4 +26,8 @@ robotframework==3.2.1 scikit_learn==0.23.1 tabulate==0.8.7 tasks==2.5.0 -testing==0.0 +Sphinx == 0.0.0 +ipdb == 0.0.0 +numpydoc == 0.0.0 +pylibftdi == 0.0.0 + From 402e5480e6144e97731a354b2b06b1ce1ef0fc89 Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Mon, 13 Jul 2020 18:00:20 -0700 Subject: [PATCH 043/242] correct the window size to 2D monitor --- riglib/stereo_opengl/window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/riglib/stereo_opengl/window.py b/riglib/stereo_opengl/window.py index 43451656..c77f16c9 100644 --- a/riglib/stereo_opengl/window.py +++ b/riglib/stereo_opengl/window.py @@ -41,7 +41,7 @@ class Window(LogExperiment): #window_size = traits.Tuple((1920*2, 1080), descr='window size, in pixels') #XPS computer - window_size = traits.Tuple((1280, 720), descr='window size, in pixels') + window_size = traits.Tuple((1280, 1080), descr='window size, in pixels') # window_size = (1920*2, 1080) background = (0,0,0,1) From c8a6ef4a4f62dc0582d38474300dbc69766732dc Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Mon, 13 Jul 2020 18:00:54 -0700 Subject: [PATCH 044/242] specify oS and python versions --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index c3ba4b23..a53c88d2 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,10 @@ Getting started # Dependencies ## Linux/OS X (none at this time) +16.04 64bit + +#python version +3.7.8 ## Windows Visual C++ Build tools (for the 'traits' package) From 6e4df9c26f24b4cf5667713d9c0c7f248906800e Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Tue, 14 Jul 2020 14:57:03 -0700 Subject: [PATCH 045/242] the interface bet. mocap NatNet interface and BMI3 --- riglib/optitrack_client/NatNetClient.py | 8 +-- riglib/optitrack_client/__init__.py | 0 .../optitrack_client/optitrack_interface.py | 50 +++++++++++++++++++ riglib/optitrack_client/test_optitrack.py | 10 ++++ 4 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 riglib/optitrack_client/__init__.py create mode 100644 riglib/optitrack_client/optitrack_interface.py create mode 100644 riglib/optitrack_client/test_optitrack.py diff --git a/riglib/optitrack_client/NatNetClient.py b/riglib/optitrack_client/NatNetClient.py index 2fcc28dc..f23ba9cc 100644 --- a/riglib/optitrack_client/NatNetClient.py +++ b/riglib/optitrack_client/NatNetClient.py @@ -333,7 +333,7 @@ def __unpackMocapData( self, data ): # Send information to any listener. if self.newFrameListener is not None: - print(frameNumber) + #print(frameNumber) self.newFrameListener( frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ) @@ -437,8 +437,8 @@ def __processMessage( self, data ): messageID = int.from_bytes( data[0:2], byteorder='little' ) trace( "Message ID:", messageID ) - print(messageID) - print(self.NAT_FRAMEOFDATA) + #print(messageID) + #print(self.NAT_FRAMEOFDATA) packetSize = int.from_bytes( data[2:4], byteorder='little' ) trace( "Packet Size:", packetSize ) @@ -471,7 +471,7 @@ def __processMessage( self, data ): trace( "ERROR: Unrecognized packet type" ) trace( "End Packet\n----------\n" ) - print('Finished Processing') + #print('Finished Processing') def sendCommand( self, command, commandStr, socket, address ): # Compose the message in our known message format diff --git a/riglib/optitrack_client/__init__.py b/riglib/optitrack_client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/riglib/optitrack_client/optitrack_interface.py b/riglib/optitrack_client/optitrack_interface.py new file mode 100644 index 00000000..4651bcc3 --- /dev/null +++ b/riglib/optitrack_client/optitrack_interface.py @@ -0,0 +1,50 @@ +from .NatNetClient import NatNetClient as TestClient +import numpy as np +from multiprocessing import Process,Lock +import pickle + +mutex = Lock() + +class MotionData(object): + """ + this is is the dataSource interface for getting the mocap at BMI3D's reqeust + """ + + def __init__(self, num_length): + self.test_client = TestClient() + #self.data_array = np.zeros(num_length) + #self.data_array = np.zeros(num_length) + self.data_array = [None] * num_length + self.num_length = num_length + + # This is a callback function that gets connected to the NatNet client and called once per mocap frame. + def receiveNewFrame(self, frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, + labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ): + #print( "Received frame", frameNumber ) + pass + + # This is a callback function that gets connected to the NatNet client. It is called once per rigid body per frame + def receiveRigidBodyFrame(self, id, position, rotation ): + #print( "Received frame for rigid body", position ) + + #save to the running buffer with a lock + with mutex: + self.data_array.insert(0,position) + self.data_array.pop() + + + def start(self): + self.test_client.newFrameListener = self.receiveNewFrame + self.test_client.rigidBodyListener =self.receiveRigidBodyFrame + self.test_client.run() + print('Started the interface thread') + + def stop(self): + pass + + def get(self): + current_value = None + with mutex: + current_value = self.data_array[0] + #return the latest saved data + return current_value \ No newline at end of file diff --git a/riglib/optitrack_client/test_optitrack.py b/riglib/optitrack_client/test_optitrack.py new file mode 100644 index 00000000..635e7892 --- /dev/null +++ b/riglib/optitrack_client/test_optitrack.py @@ -0,0 +1,10 @@ +from optitrack_interface import MotionData +import time + +num_length = 10 +motion_data = MotionData(num_length) +motion_data.start() + +while True: + print(motion_data.get()) + time.sleep(1) \ No newline at end of file From d1d581dfb31838c2555dfd0dc2be7f042b4bf015 Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Tue, 14 Jul 2020 21:44:26 -0700 Subject: [PATCH 046/242] test optitrack interface --- .../cursorControlTasks_optitrack.py | 124 ++++++++++++++++++ features/optitrack_feature.py | 116 ++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 built_in_tasks/cursorControlTasks_optitrack.py create mode 100644 features/optitrack_feature.py diff --git a/built_in_tasks/cursorControlTasks_optitrack.py b/built_in_tasks/cursorControlTasks_optitrack.py new file mode 100644 index 00000000..c37d9046 --- /dev/null +++ b/built_in_tasks/cursorControlTasks_optitrack.py @@ -0,0 +1,124 @@ + +from manualcontrolmultitasks import ManualControlMulti +from riglib.stereo_opengl.window import WindowDispl2D +#from bmimultitasks import BMIControlMulti +import pygame +import numpy as np +import copy + +#from riglib.bmi.extractor import DummyExtractor +#from riglib.bmi.state_space_models import StateSpaceEndptVel2D +#from riglib.bmi.bmi import Decoder, BMISystem, GaussianStateHMM, BMILoop, GaussianState, MachineOnlyFilter +from riglib import experiment + + +class CursorControl(ManualControlMulti, WindowDispl2D): + ''' + this class implements a python cursor control task for human + ''' + + def __init__(self, *args, **kwargs): + # just run the parent ManualControlMulti's initialization + self.move_step = 1 + + # Initialize target location variable + #target location and index have been initializd + + super(CursorControl, self).__init__(*args, **kwargs) + + def init(self): + pygame.init() + + + + self.assist_level = (0, 0) + super(CursorControl, self).init() + + # override the _cycle function + def _cycle(self): + #print(self.state) + + #target and plant data have been saved in + #the parent manualcontrolmultitasks + + self.move_effector_cursor() + super(CursorControl, self)._cycle() + + # do nothing + def move_effector(self): + pass + + def move_plant(self, **kwargs): + pass + + # use keyboard to control the task + def move_effector_cursor(self): + np.array([0., 0., 0.]) + curr_pos = copy.deepcopy(self.plant.get_endpoint_pos()) + + for event in pygame.event.get(): + if event.type == pygame.KEYUP: + if event.type == pygame.K_q: + pygame.quit() + quit() + if event.key == pygame.K_LEFT: + curr_pos[0] -= self.move_step + if event.key == pygame.K_RIGHT: + curr_pos[0] += self.move_step + if event.key == pygame.K_UP: + curr_pos[2] += self.move_step + if event.key == pygame.K_DOWN: + curr_pos[2] -= self.move_step + #print('Current position: ') + #print(curr_pos) + + # set the current position + self.plant.set_endpoint_pos(curr_pos) + + def _start_wait(self): + self.wait_time = 0. + super(CursorControl, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + +#this task can be run on its +#we will not involve database at this time +target_pos_radius = 10 + +def target_seq_generator(n_targs, n_trials): + #generate targets + angles = np.transpose(np.arange(0,2*np.pi,2*np.pi / n_targs)) + unit_targets = targets = np.stack((np.cos(angles), np.sin(angles)),1) + targets = unit_targets * target_pos_radius + + center = np.array((0,0)) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + + k = 0 + while k < n_trials: + targ = targets[target_inds[k], :] + yield np.array([[center[0], 0, center[1]], + [targ[0], 0, targ[1]]]) + k += 1 + + +if __name__ == "__main__": + print('Remember to set window size in stereoOpenGL class') + gen = target_seq_generator(8, 2) + + #incorporate the saveHDF feature by blending code + #see tests\start_From_cmd_line_sim + from features.optitrack_feature import MotionData + + base_class = CursorControl + feats = [MotionData] + Exp = experiment.make(base_class, feats=feats) + print(Exp) + + exp = Exp(gen) + exp.init() + exp.run() #start the task + diff --git a/features/optitrack_feature.py b/features/optitrack_feature.py new file mode 100644 index 00000000..8b5df0f3 --- /dev/null +++ b/features/optitrack_feature.py @@ -0,0 +1,116 @@ +''' +Features for the phasephase motiontracker +''' + +import time +import tempfile +import random +import traceback +import numpy as np +#import fnmatch +import os + +from riglib import calibrations + +import os +import subprocess + +import time +from riglib.experiment import traits + +######################################################################################################## +# Phasespace datasources +######################################################################################################## +class MotionData(traits.HasTraits): + ''' + Enable reading of raw motiontracker data from Phasespace system + ''' + marker_count = traits.Int(8, desc="Number of markers to return") + + def init(self): + ''' + Secondary init function. See riglib.experiment.Experiment.init() + Prior to starting the task, this 'init' sets up the DataSource for interacting with the + motion tracker system and registers the source with the SinkRegister so that the data gets saved to file as it is collected. + ''' + from riglib import source + src, mkw = self.source_class + #will make available to self.motiondata that other features can access + self.motiondata = source.DataSource(src, **mkw) + + #save to the sink + from riglib import sink + self.sinks = sink.sinks + self.sinks.register(self.motiondata) + + super(MotionData, self).init() + + @property + def source_class(self): + ''' + Specify the source class as a function in case future descendant classes want to use a different type of source + ''' + from riglib.optitrack_client.optitrack_interface import MotionData + return MotionData, dict() + + def run(self): + ''' + Code to execute immediately prior to the beginning of the task FSM executing, or after the FSM has finished running. + See riglib.experiment.Experiment.run(). This 'run' method starts the motiontracker source prior to starting the experiment's + main thread/process, and handle any errors by stopping the source + ''' + self.motiondata.start() + try: + super(MotionData, self).run() + finally: + self.motiondata.stop() + + def join(self): + ''' + See riglib.experiment.Experiment.join(). Re-join the 'motiondata' source process before cleaning up the experiment thread + ''' + self.motiondata.join() + super(MotionData, self).join() + + def _start_None(self): + ''' + Code to run before the 'None' state starts (i.e., the task stops) + ''' + self.motiondata.stop() + super(MotionData, self)._start_None() + + +class MotionSimulate(MotionData): + ''' + Simulate presence of raw motiontracking system using a randomized spatial function + ''' + @property + def source_class(self): + ''' + Specify the source class as a function in case future descendant classes want to use a different type of source + ''' + from riglib import motiontracker + cls = motiontracker.make(self.marker_count, cls=motiontracker.Simulate) + return cls, dict(radius=(100,100,50), offset=(-150,0,0)) + + +class MotionAutoAlign(MotionData): + '''Creates an auto-aligning motion tracker, for use with the 6-point alignment system''' + autoalign = traits.Instance(calibrations.AutoAlign) + + def init(self): + ''' + Secondary init function. See riglib.experiment.Experiment.init() + Prior to starting the task, this 'init' adds a filter onto the motiondata source. See MotionData for further details. + ''' + super(MotionAutoAlign, self).init() + self.motiondata.filter = self.autoalign + + @property + def source_class(self): + ''' + Specify the source class as a function in case future descendant classes want to use a different type of source + ''' + from riglib import motiontracker + cls = motiontracker.make(self.marker_count, cls=motiontracker.AligningSystem) + return cls, dict() From a821be51a6b474a89eaa49a2979e4dce3fa8c185 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Wed, 15 Jul 2020 11:40:00 -0700 Subject: [PATCH 047/242] add update_freq and dtype to optitrack client --- features/optitrack_feature.py | 6 +++--- riglib/optitrack_client/optitrack_interface.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/features/optitrack_feature.py b/features/optitrack_feature.py index 8b5df0f3..a97bbedf 100644 --- a/features/optitrack_feature.py +++ b/features/optitrack_feature.py @@ -19,7 +19,7 @@ from riglib.experiment import traits ######################################################################################################## -# Phasespace datasources +# Optitrack datasources ######################################################################################################## class MotionData(traits.HasTraits): ''' @@ -50,8 +50,8 @@ def source_class(self): ''' Specify the source class as a function in case future descendant classes want to use a different type of source ''' - from riglib.optitrack_client.optitrack_interface import MotionData - return MotionData, dict() + from riglib.optitrack_client.optitrack_interface import MotionData as op_client + return op_client, dict() def run(self): ''' diff --git a/riglib/optitrack_client/optitrack_interface.py b/riglib/optitrack_client/optitrack_interface.py index 4651bcc3..5e31c527 100644 --- a/riglib/optitrack_client/optitrack_interface.py +++ b/riglib/optitrack_client/optitrack_interface.py @@ -8,14 +8,18 @@ class MotionData(object): """ this is is the dataSource interface for getting the mocap at BMI3D's reqeust + compatible with DataSourceSystem """ + update_freq = 120 # Hz + rigid_body_count = 1 #for now,only one rigid body - def __init__(self, num_length): + dtype = np.dtype((np.float, (rigid_body_count, 4))) + + + def __init__(self): self.test_client = TestClient() - #self.data_array = np.zeros(num_length) - #self.data_array = np.zeros(num_length) - self.data_array = [None] * num_length - self.num_length = num_length + self.num_length = 10 # slots for buffer + self.data_array = [None] * self.num_length # This is a callback function that gets connected to the NatNet client and called once per mocap frame. def receiveNewFrame(self, frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, From 082522f720019327d52353ba01397e28a25be881 Mon Sep 17 00:00:00 2001 From: Sijia Li Date: Wed, 15 Jul 2020 12:22:05 -0700 Subject: [PATCH 048/242] added simulation feature to the optitrack client --- features/optitrack_feature.py | 9 +++--- .../optitrack_client/optitrack_interface.py | 28 +++++++++++++------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/features/optitrack_feature.py b/features/optitrack_feature.py index a97bbedf..9e887e81 100644 --- a/features/optitrack_feature.py +++ b/features/optitrack_feature.py @@ -50,8 +50,8 @@ def source_class(self): ''' Specify the source class as a function in case future descendant classes want to use a different type of source ''' - from riglib.optitrack_client.optitrack_interface import MotionData as op_client - return op_client, dict() + from riglib.optitrack_client.optitrack_interface import System + return System, dict() def run(self): ''' @@ -89,9 +89,8 @@ def source_class(self): ''' Specify the source class as a function in case future descendant classes want to use a different type of source ''' - from riglib import motiontracker - cls = motiontracker.make(self.marker_count, cls=motiontracker.Simulate) - return cls, dict(radius=(100,100,50), offset=(-150,0,0)) + from riglib.optitrack_client.optitrack_interface import Simulation + return Simulation, dict() class MotionAutoAlign(MotionData): diff --git a/riglib/optitrack_client/optitrack_interface.py b/riglib/optitrack_client/optitrack_interface.py index 5e31c527..d3322199 100644 --- a/riglib/optitrack_client/optitrack_interface.py +++ b/riglib/optitrack_client/optitrack_interface.py @@ -5,18 +5,17 @@ mutex = Lock() -class MotionData(object): +class System(object): """ this is is the dataSource interface for getting the mocap at BMI3D's reqeust compatible with DataSourceSystem """ - update_freq = 120 # Hz - rigid_body_count = 1 #for now,only one rigid body - - dtype = np.dtype((np.float, (rigid_body_count, 4))) - - + rigidBodyCount = 1 + dtype = np.dtype((np.float, (rigidBodyCount, 6))) #6 degress of freedo def __init__(self): + self.update_freq = 120 #Hz + self.rigid_body_count = 1 #for now,only one rigid body + self.test_client = TestClient() self.num_length = 10 # slots for buffer self.data_array = [None] * self.num_length @@ -51,4 +50,17 @@ def get(self): with mutex: current_value = self.data_array[0] #return the latest saved data - return current_value \ No newline at end of file + return current_value + +class Simulation(System): + ''' + this class does all the things except when the optitrack is not broadcasting data + the get function starts to return random numbers + ''' + update_freq = 10 #Hz + + def get(self): + mag_fac = 10 + current_value = np.random.rand(self.rigidBodyCount, 6) * mag_fac + return current_value + \ No newline at end of file From 5e84e1d10cbee05db8273e12736993f724e4b400 Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Thu, 16 Jul 2020 14:04:50 -0700 Subject: [PATCH 049/242] inital mocap and BMI3D integration --- .../cursorControlTasks_optitrack.py | 43 ++++++++++++++++--- .../optitrack_client/optitrack_interface.py | 20 +++++++-- riglib/optitrack_client/test_optitrack.py | 6 +-- 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/built_in_tasks/cursorControlTasks_optitrack.py b/built_in_tasks/cursorControlTasks_optitrack.py index c37d9046..da203040 100644 --- a/built_in_tasks/cursorControlTasks_optitrack.py +++ b/built_in_tasks/cursorControlTasks_optitrack.py @@ -28,9 +28,6 @@ def __init__(self, *args, **kwargs): def init(self): pygame.init() - - - self.assist_level = (0, 0) super(CursorControl, self).init() @@ -41,18 +38,54 @@ def _cycle(self): #target and plant data have been saved in #the parent manualcontrolmultitasks - self.move_effector_cursor() + self.move_effector() super(CursorControl, self)._cycle() # do nothing def move_effector(self): - pass + self.scale_factor = 50 + + #get data from motion tracker- take average of all data points since last poll + #the default regid body yields a 6 degree of freedom + #so, its 1 by 6 vector for now + pt = self.motiondata.get() + + + if len(pt) > 0:#check if there is avaialble data + + #does some transformation + #centering + #transformation + pt = pt[:,0,:] + #average the data in the buffer + + pt = pt.mean(0) + pt = pt[:3] * self.scale_factor + + #limited to 2D, set the y direction to + if self.limit2d: + pt[1] = 0 + self.no_data_count = 0 + + + else: #if no new data + self.no_data_count +=1 + pt = None + + # Set the plant's endpoint to the position + # determined by the motiontracker, + # unless there is no data available + if pt is not None: + self.plant.set_endpoint_pos(pt) + + def move_plant(self, **kwargs): pass # use keyboard to control the task def move_effector_cursor(self): + #incremental adding np.array([0., 0., 0.]) curr_pos = copy.deepcopy(self.plant.get_endpoint_pos()) diff --git a/riglib/optitrack_client/optitrack_interface.py b/riglib/optitrack_client/optitrack_interface.py index d3322199..c5e5de3f 100644 --- a/riglib/optitrack_client/optitrack_interface.py +++ b/riglib/optitrack_client/optitrack_interface.py @@ -9,16 +9,18 @@ class System(object): """ this is is the dataSource interface for getting the mocap at BMI3D's reqeust compatible with DataSourceSystem + uses data_array to keep track of the lastest buffer """ rigidBodyCount = 1 + update_freq = 120 dtype = np.dtype((np.float, (rigidBodyCount, 6))) #6 degress of freedo def __init__(self): - self.update_freq = 120 #Hz self.rigid_body_count = 1 #for now,only one rigid body self.test_client = TestClient() self.num_length = 10 # slots for buffer self.data_array = [None] * self.num_length + self.rotation_buffer = [None] * self.num_length # This is a callback function that gets connected to the NatNet client and called once per mocap frame. def receiveNewFrame(self, frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, @@ -34,6 +36,9 @@ def receiveRigidBodyFrame(self, id, position, rotation ): with mutex: self.data_array.insert(0,position) self.data_array.pop() + #save to rotation buffer list + self.rotation_buffer.insert(0,position) + self.rotation_buffer.pop() def start(self): @@ -47,17 +52,26 @@ def stop(self): def get(self): current_value = None + rotation_value = None + pos_rot = None + with mutex: current_value = self.data_array[0] + rotation_value = self.rotation_buffer[0] + #return the latest saved data - return current_value + if (not current_value is None) and (not rotation_value is None): + pos_rot = np.concatenate((np.asarray(current_value),np.asarray(rotation_value))) + + + return pos_rot #return that (x,y,z, rotation matrix) class Simulation(System): ''' this class does all the things except when the optitrack is not broadcasting data the get function starts to return random numbers ''' - update_freq = 10 #Hz + update_freq = 60 #Hz def get(self): mag_fac = 10 diff --git a/riglib/optitrack_client/test_optitrack.py b/riglib/optitrack_client/test_optitrack.py index 635e7892..dc7ee8aa 100644 --- a/riglib/optitrack_client/test_optitrack.py +++ b/riglib/optitrack_client/test_optitrack.py @@ -1,10 +1,10 @@ -from optitrack_interface import MotionData +from riglib.optitrack_client.optitrack_interface import System import time num_length = 10 -motion_data = MotionData(num_length) +motion_data = System() motion_data.start() while True: print(motion_data.get()) - time.sleep(1) \ No newline at end of file + time.sleep(0.05) \ No newline at end of file From bac4a97a75ab7a29015e87254a33188297159239 Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Mon, 20 Jul 2020 17:51:29 -0700 Subject: [PATCH 050/242] fixed erray compatibility in saving hdf data --- riglib/optitrack_client/optitrack_interface.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/riglib/optitrack_client/optitrack_interface.py b/riglib/optitrack_client/optitrack_interface.py index c5e5de3f..8c92da27 100644 --- a/riglib/optitrack_client/optitrack_interface.py +++ b/riglib/optitrack_client/optitrack_interface.py @@ -63,7 +63,7 @@ def get(self): if (not current_value is None) and (not rotation_value is None): pos_rot = np.concatenate((np.asarray(current_value),np.asarray(rotation_value))) - + pos_rot = np.expand_dims(pos_rot, axis = 0) return pos_rot #return that (x,y,z, rotation matrix) class Simulation(System): @@ -76,5 +76,6 @@ class Simulation(System): def get(self): mag_fac = 10 current_value = np.random.rand(self.rigidBodyCount, 6) * mag_fac + current_value = np.expand_dims(current_value, axis = 0) return current_value \ No newline at end of file From 894cacdfeba24f5c5720acdf619dbcedc2efb73c Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Tue, 28 Jul 2020 17:34:58 -0700 Subject: [PATCH 051/242] new optitrack interface with c# streamer --- .../cursorControlTasks_optitrack.py | 12 +- .../optitrack_client/optitrack_direct_pack.py | 81 +++++++++++ .../optitrack_client/optitrack_interface.py | 133 +++++++++++------- 3 files changed, 175 insertions(+), 51 deletions(-) create mode 100644 riglib/optitrack_client/optitrack_direct_pack.py diff --git a/built_in_tasks/cursorControlTasks_optitrack.py b/built_in_tasks/cursorControlTasks_optitrack.py index da203040..2a6eb996 100644 --- a/built_in_tasks/cursorControlTasks_optitrack.py +++ b/built_in_tasks/cursorControlTasks_optitrack.py @@ -10,6 +10,7 @@ #from riglib.bmi.state_space_models import StateSpaceEndptVel2D #from riglib.bmi.bmi import Decoder, BMISystem, GaussianStateHMM, BMILoop, GaussianState, MachineOnlyFilter from riglib import experiment +from features.optitrack_feature import MotionSimulate class CursorControl(ManualControlMulti, WindowDispl2D): @@ -41,10 +42,12 @@ def _cycle(self): self.move_effector() super(CursorControl, self)._cycle() - # do nothing def move_effector(self): self.scale_factor = 50 + if isinstance(self, MotionSimulate): + self.scale_factor = 1 + #get data from motion tracker- take average of all data points since last poll #the default regid body yields a 6 degree of freedom #so, its 1 by 6 vector for now @@ -137,17 +140,18 @@ def target_seq_generator(n_targs, n_trials): [targ[0], 0, targ[1]]]) k += 1 - if __name__ == "__main__": print('Remember to set window size in stereoOpenGL class') gen = target_seq_generator(8, 2) #incorporate the saveHDF feature by blending code #see tests\start_From_cmd_line_sim - from features.optitrack_feature import MotionData + from features.optitrack_feature import MotionSimulate + from features.hdf_features import SaveHDF + base_class = CursorControl - feats = [MotionData] + feats = [MotionSimulate, SaveHDF] Exp = experiment.make(base_class, feats=feats) print(Exp) diff --git a/riglib/optitrack_client/optitrack_direct_pack.py b/riglib/optitrack_client/optitrack_direct_pack.py new file mode 100644 index 00000000..8c92da27 --- /dev/null +++ b/riglib/optitrack_client/optitrack_direct_pack.py @@ -0,0 +1,81 @@ +from .NatNetClient import NatNetClient as TestClient +import numpy as np +from multiprocessing import Process,Lock +import pickle + +mutex = Lock() + +class System(object): + """ + this is is the dataSource interface for getting the mocap at BMI3D's reqeust + compatible with DataSourceSystem + uses data_array to keep track of the lastest buffer + """ + rigidBodyCount = 1 + update_freq = 120 + dtype = np.dtype((np.float, (rigidBodyCount, 6))) #6 degress of freedo + def __init__(self): + self.rigid_body_count = 1 #for now,only one rigid body + + self.test_client = TestClient() + self.num_length = 10 # slots for buffer + self.data_array = [None] * self.num_length + self.rotation_buffer = [None] * self.num_length + + # This is a callback function that gets connected to the NatNet client and called once per mocap frame. + def receiveNewFrame(self, frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, + labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ): + #print( "Received frame", frameNumber ) + pass + + # This is a callback function that gets connected to the NatNet client. It is called once per rigid body per frame + def receiveRigidBodyFrame(self, id, position, rotation ): + #print( "Received frame for rigid body", position ) + + #save to the running buffer with a lock + with mutex: + self.data_array.insert(0,position) + self.data_array.pop() + #save to rotation buffer list + self.rotation_buffer.insert(0,position) + self.rotation_buffer.pop() + + + def start(self): + self.test_client.newFrameListener = self.receiveNewFrame + self.test_client.rigidBodyListener =self.receiveRigidBodyFrame + self.test_client.run() + print('Started the interface thread') + + def stop(self): + pass + + def get(self): + current_value = None + rotation_value = None + pos_rot = None + + with mutex: + current_value = self.data_array[0] + rotation_value = self.rotation_buffer[0] + + #return the latest saved data + if (not current_value is None) and (not rotation_value is None): + pos_rot = np.concatenate((np.asarray(current_value),np.asarray(rotation_value))) + + pos_rot = np.expand_dims(pos_rot, axis = 0) + return pos_rot #return that (x,y,z, rotation matrix) + +class Simulation(System): + ''' + this class does all the things except when the optitrack is not broadcasting data + the get function starts to return random numbers + ''' + update_freq = 60 #Hz + + def get(self): + mag_fac = 10 + current_value = np.random.rand(self.rigidBodyCount, 6) * mag_fac + current_value = np.expand_dims(current_value, axis = 0) + return current_value + \ No newline at end of file diff --git a/riglib/optitrack_client/optitrack_interface.py b/riglib/optitrack_client/optitrack_interface.py index 8c92da27..1c169e58 100644 --- a/riglib/optitrack_client/optitrack_interface.py +++ b/riglib/optitrack_client/optitrack_interface.py @@ -1,70 +1,104 @@ -from .NatNetClient import NatNetClient as TestClient import numpy as np -from multiprocessing import Process,Lock -import pickle +import sys, time +import socket -mutex = Lock() +N_TEST_FRAMES = 1 #number of testing frames during start class System(object): """ this is is the dataSource interface for getting the mocap at BMI3D's reqeust compatible with DataSourceSystem uses data_array to keep track of the lastest buffer """ + port_num = 1230 #same as the optitrack #default to 1230 + HEADERSIZE = 10 + rece_byte_size = 512 + debug = True + optitrack_ip_addr = "10.155.206.1" + + rigidBodyCount = 1 update_freq = 120 - dtype = np.dtype((np.float, (rigidBodyCount, 6))) #6 degress of freedo + dtype = np.dtype((np.float, (rigidBodyCount, 6))) #6 degress of freedom + def __init__(self): self.rigid_body_count = 1 #for now,only one rigid body - self.test_client = TestClient() - self.num_length = 10 # slots for buffer - self.data_array = [None] * self.num_length - self.rotation_buffer = [None] * self.num_length - - # This is a callback function that gets connected to the NatNet client and called once per mocap frame. - def receiveNewFrame(self, frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, - labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ): - #print( "Received frame", frameNumber ) - pass - # This is a callback function that gets connected to the NatNet client. It is called once per rigid body per frame - def receiveRigidBodyFrame(self, id, position, rotation ): - #print( "Received frame for rigid body", position ) + - #save to the running buffer with a lock - with mutex: - self.data_array.insert(0,position) - self.data_array.pop() - #save to rotation buffer list - self.rotation_buffer.insert(0,position) - self.rotation_buffer.pop() - - def start(self): - self.test_client.newFrameListener = self.receiveNewFrame - self.test_client.rigidBodyListener =self.receiveRigidBodyFrame - self.test_client.run() - print('Started the interface thread') - + #start to connect to the client + #set up the socket + self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + + print("connecting to the c# server") + ''' + self.s.bind(('', 1230)) #bind to all incoming request + self.s.listen() #listen to one clinet + ''' + try: + #clientsocket, address = self.s.accept() + self.s.connect((self.optitrack_ip_addr, self.port_num)) + except: + print("cannot connect to Motive") + print("Is the c# server running?") + + #otherwise it works as expected and set client to be a + #class property + #self.clientsocket = clientsocket + print(f"Connection to c# client \ + {self.optitrack_ip_addr} has been established.") + + #automatically pull 10 frames + # and cal the mean round trip time + t1 = time.perf_counter() + for i in range(N_TEST_FRAMES): self.get() + t2 = time.perf_counter() + print(f'time to grab {N_TEST_FRAMES} frames : \ + {(t2 - t1)} s ') + + def stop(self): - pass + msg = "stop" + self.send_command(msg) + #close the socket + #self.s.close() + print("socket closed!") def get(self): - current_value = None - rotation_value = None - pos_rot = None + #the property that gets one frame of data + # 3 positions and 3 angles + #the last element is frame number + msg = "get" + result_string = self.send_and_receive(msg) + motive_frame = np.fromstring(result_string, sep=',') + current_value = motive_frame[:6] #only using the motion data - with mutex: - current_value = self.data_array[0] - rotation_value = self.rotation_buffer[0] + #for some weird reason, the string needs to be expanded.. + #just send the motion data for now + current_value = np.expand_dims(current_value, axis = 0) + return current_value - #return the latest saved data - if (not current_value is None) and (not rotation_value is None): - pos_rot = np.concatenate((np.asarray(current_value),np.asarray(rotation_value))) - - pos_rot = np.expand_dims(pos_rot, axis = 0) - return pos_rot #return that (x,y,z, rotation matrix) + + + def send_command(self, msg): + #get the message in string and encode in bytes and send to the socket + msg = f"{len(msg):<{self.HEADERSIZE}}"+msg + msg_ascii = msg.encode("ascii") + self.s.send(msg_ascii) + + def send_and_receive(self, msg): + #this function sends a command + #and then wait for a response + msg = f"{len(msg):<{self.HEADERSIZE}}"+msg + msg_ascii = msg.encode("ascii") + self.s.send(msg_ascii) + result_in_bytes = self.s.recv(self.rece_byte_size) + return str(result_in_bytes,encoding="ASCII") + class Simulation(System): ''' @@ -72,10 +106,15 @@ class Simulation(System): the get function starts to return random numbers ''' update_freq = 60 #Hz - def get(self): mag_fac = 10 current_value = np.random.rand(self.rigidBodyCount, 6) * mag_fac current_value = np.expand_dims(current_value, axis = 0) return current_value - \ No newline at end of file + + +if __name__ == "__main__": + s = System() + s.start() + print(s.get()) + s.stop() From 93c570ace6da1aaff0cce2d920dff48ca698bf56 Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Mon, 10 Aug 2020 18:10:35 -0700 Subject: [PATCH 052/242] tests on optitrack --- .../cursorControlTasks_optitrack.py | 8 +++---- .../optitrack_client/optitrack_direct_pack.py | 12 +++++++--- .../optitrack_client/optitrack_interface.py | 10 ++++++-- riglib/optitrack_client/test_control.py | 23 +++++++++++++++++++ riglib/optitrack_client/test_optitrack.py | 2 +- 5 files changed, 45 insertions(+), 10 deletions(-) create mode 100644 riglib/optitrack_client/test_control.py diff --git a/built_in_tasks/cursorControlTasks_optitrack.py b/built_in_tasks/cursorControlTasks_optitrack.py index 2a6eb996..fca59a50 100644 --- a/built_in_tasks/cursorControlTasks_optitrack.py +++ b/built_in_tasks/cursorControlTasks_optitrack.py @@ -43,7 +43,7 @@ def _cycle(self): super(CursorControl, self)._cycle() def move_effector(self): - self.scale_factor = 50 + self.scale_factor = 0.1 if isinstance(self, MotionSimulate): self.scale_factor = 1 @@ -146,16 +146,16 @@ def target_seq_generator(n_targs, n_trials): #incorporate the saveHDF feature by blending code #see tests\start_From_cmd_line_sim - from features.optitrack_feature import MotionSimulate + from features.optitrack_feature import MotionData from features.hdf_features import SaveHDF base_class = CursorControl - feats = [MotionSimulate, SaveHDF] + feats = [MotionData, SaveHDF] Exp = experiment.make(base_class, feats=feats) print(Exp) exp = Exp(gen) exp.init() exp.run() #start the task - + \ No newline at end of file diff --git a/riglib/optitrack_client/optitrack_direct_pack.py b/riglib/optitrack_client/optitrack_direct_pack.py index 8c92da27..c9006af9 100644 --- a/riglib/optitrack_client/optitrack_direct_pack.py +++ b/riglib/optitrack_client/optitrack_direct_pack.py @@ -1,4 +1,4 @@ -from .NatNetClient import NatNetClient as TestClient +from riglib.optitrack_client.NatNetClient import NatNetClient as TestClient import numpy as np from multiprocessing import Process,Lock import pickle @@ -62,8 +62,9 @@ def get(self): #return the latest saved data if (not current_value is None) and (not rotation_value is None): pos_rot = np.concatenate((np.asarray(current_value),np.asarray(rotation_value))) - + pos_rot = np.expand_dims(pos_rot, axis = 0) + print(pos_rot.shape) return pos_rot #return that (x,y,z, rotation matrix) class Simulation(System): @@ -78,4 +79,9 @@ def get(self): current_value = np.random.rand(self.rigidBodyCount, 6) * mag_fac current_value = np.expand_dims(current_value, axis = 0) return current_value - \ No newline at end of file + + +if __name__ == "__main__": + s = System() + s.start() + s.get() \ No newline at end of file diff --git a/riglib/optitrack_client/optitrack_interface.py b/riglib/optitrack_client/optitrack_interface.py index 1c169e58..54d59184 100644 --- a/riglib/optitrack_client/optitrack_interface.py +++ b/riglib/optitrack_client/optitrack_interface.py @@ -15,6 +15,7 @@ class System(object): rece_byte_size = 512 debug = True optitrack_ip_addr = "10.155.206.1" + TIME_OUT_TIME = 2 rigidBodyCount = 1 @@ -32,7 +33,7 @@ def start(self): #set up the socket self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - + self.s.settimeout(self.TIME_OUT_TIME) print("connecting to the c# server") ''' @@ -76,10 +77,13 @@ def get(self): result_string = self.send_and_receive(msg) motive_frame = np.fromstring(result_string, sep=',') current_value = motive_frame[:6] #only using the motion data + current_value.transpose() + #for some weird reason, the string needs to be expanded.. #just send the motion data for now current_value = np.expand_dims(current_value, axis = 0) + current_value = np.expand_dims(current_value, axis = 0) return current_value @@ -116,5 +120,7 @@ def get(self): if __name__ == "__main__": s = System() s.start() - print(s.get()) + s.send_command("start_rec") + time.sleep(5) s.stop() + print("finished") diff --git a/riglib/optitrack_client/test_control.py b/riglib/optitrack_client/test_control.py new file mode 100644 index 00000000..6b5879dd --- /dev/null +++ b/riglib/optitrack_client/test_control.py @@ -0,0 +1,23 @@ +from NatNetClient import NatNetClient + + +# This is a callback function that gets connected to the NatNet client and called once per mocap frame. +def receiveNewFrame( frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, + labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ): + #print( "Received frame", frameNumber ) + pass + +# This is a callback function that gets connected to the NatNet client. It is called once per rigid body per frame +def receiveRigidBodyFrame( id, position, rotation ): + #print( "Received frame for rigid body", position ) + pass + +# This will create a new NatNet client +test_client = NatNetClient() + +# Configure the streaming client to call our rigid body handler on the emulator to send data out. +test_client.newFrameListener = receiveNewFrame +test_client.rigidBodyListener = receiveRigidBodyFrame + +test_client.sendCommand( test_client.NAT_REQUEST_MODELDEF, "", test_client.commandSocket, + (test_client.serverIPAddress, test_client.commandPort) ) \ No newline at end of file diff --git a/riglib/optitrack_client/test_optitrack.py b/riglib/optitrack_client/test_optitrack.py index dc7ee8aa..5a23eac5 100644 --- a/riglib/optitrack_client/test_optitrack.py +++ b/riglib/optitrack_client/test_optitrack.py @@ -1,4 +1,4 @@ -from riglib.optitrack_client.optitrack_interface import System +from riglib.optitrack_client.optitrack_direct_pack import System import time num_length = 10 From dbe2501839ebb4b81deb5fa2fd6825ce81525668 Mon Sep 17 00:00:00 2001 From: Si Jia Li Date: Mon, 10 Aug 2020 18:23:50 -0700 Subject: [PATCH 053/242] Including package for reward system --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b3d0139e..949d4571 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,4 @@ Sphinx == 0.0.0 ipdb == 0.0.0 numpydoc == 0.0.0 pylibftdi == 0.0.0 - +pyfirmata == 1.1.0 From 374049a1814b07c7514dbab2d08be7f9245c3b8f Mon Sep 17 00:00:00 2001 From: Pavi Raj Date: Mon, 10 Aug 2020 18:31:39 -0700 Subject: [PATCH 054/242] Reward system calibration and test file --- tests/test_reward_python2arduino.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/test_reward_python2arduino.py diff --git a/tests/test_reward_python2arduino.py b/tests/test_reward_python2arduino.py new file mode 100644 index 00000000..794a744a --- /dev/null +++ b/tests/test_reward_python2arduino.py @@ -0,0 +1,29 @@ +import pyfirmata +import time +#not sure about the protocal +com_port = '/dev/ttyACM0'#specify whihch port, can find it on IDE +board = pyfirmata.Arduino(com_port) + +def test_reward_system(): + while True: + board.digital[13].write(1) + time.sleep(1) #in second + print('ON') + board.digital[13].write(0) + time.sleep(2) #in secondS + print('OFF') + +def cailbrate_reward(): + board.digital[13].write(1) + time.sleep(72) # it takes around 126 seconds to drain 200 ml of fluid + board.digital[13].write(0) + print('Check the breaker for calibration. You should notice 200 ml of fluid') + +user_input = input("Enter 1 - to test Arduino connection; 2 - to calibrate:") + +if user_input == '1': + print('Testing Reward System') + test_reward_system() +elif user_input == '2': + print('Calibrating Reward System') + cailbrate_reward() From 57c89e8524409d280703c908103d60fd9ae3ce78 Mon Sep 17 00:00:00 2001 From: Pavi Raj Date: Mon, 10 Aug 2020 18:49:24 -0700 Subject: [PATCH 055/242] reward system test --- tests/test_reward_python2arduino.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_reward_python2arduino.py b/tests/test_reward_python2arduino.py index 794a744a..bcf20513 100644 --- a/tests/test_reward_python2arduino.py +++ b/tests/test_reward_python2arduino.py @@ -15,7 +15,7 @@ def test_reward_system(): def cailbrate_reward(): board.digital[13].write(1) - time.sleep(72) # it takes around 126 seconds to drain 200 ml of fluid + time.sleep(72) # it takes around 72 seconds to drain 200 ml of fluid - Flow rate: 2.8 mL/s board.digital[13].write(0) print('Check the breaker for calibration. You should notice 200 ml of fluid') From e7e862237a4788f41cc44a1fa47bc463b88a32f2 Mon Sep 17 00:00:00 2001 From: leoscholl Date: Tue, 11 Aug 2020 13:55:44 -0700 Subject: [PATCH 056/242] update mrograph --- docs/MROgraph.py | 44 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/docs/MROgraph.py b/docs/MROgraph.py index 727d04b7..e4f87113 100644 --- a/docs/MROgraph.py +++ b/docs/MROgraph.py @@ -9,7 +9,7 @@ Requires: Python 2.3, dot, standard Unix tools """ -import os,itertools +import os,itertools,argparse PSVIEWER='gv' # you may change these with PNGVIEWER='evince' # your preferred viewers @@ -39,6 +39,7 @@ def __init__(self,*classes,**options): filename=options.get('filename',"MRO_of_%s.ps" % classes[0].__name__) self.labels=options.get('labels',2) caption=options.get('caption',False) + self.nometa=options.get('nometa',False) setup=options.get('setup','') name,dotformat=os.path.splitext(filename) format=dotformat[1:] @@ -54,7 +55,15 @@ def __init__(self,*classes,**options): codeiter=itertools.chain(*[self.genMROcode(cls) for cls in classes]) self.dotcode='digraph %s{\n%s%s}' % ( name,setupcode,'\n'.join(codeiter)) - os.system("echo '%s' | dot -T%s > %s; %s %s&" % + if os.name == 'nt': + with open('graph.gv', mode='w') as f: + f.write(self.dotcode) + os.system("dot -T%s graph.gv -o %s" % + (format,filename)) + os.remove('graph.gv') + os.startfile(filename) + else: + os.system("echo '%s' | dot -T%s > %s; %s %s&" % (self.dotcode,format,filename,viewer,filename)) def genMROcode(self,cls): "Generates the dot code for the MRO of a given class" @@ -76,7 +85,7 @@ def genMROcode(self,cls): '[%s]' % label, '[shape=box,%s]' % label) yield(' %s %s;\n' % (name,option)) - if type(c) is not type: # c has a custom metaclass + if not self.nometa and type(c) is not type: # c has a custom metaclass metaname=type(c).__name__ yield ' edge [style=dashed]; %s -> %s;' % (metaname,name) def __repr__(self): @@ -98,6 +107,33 @@ class A(B,C): pass return MROgraph(A,M,**options) if __name__=="__main__": - testHierarchy() # generates a postscript diagram of A and M hierarchies + parser = argparse.ArgumentParser(description='Draw an MRO graph.') + parser.add_argument('classname', type=str, nargs='+', + help='name of the class to draw') + parser.add_argument('filename', type=str, + help='resulting name and filetype') + parser.add_argument('-l', '--no-labels', dest='labels', + action='store_true', default=False, + help='remove numerical labels from each node and edge') + parser.add_argument('-c', '--caption', dest='caption', + action='store_true', default=False, + help='include caption listing each class') + parser.add_argument('-m', '--hide-meta', dest='meta', + action='store_true', default=False, + help='do not include metaclasses in the graph') + parser.add_argument('-s', '--size', type=str, default="16,12", + help='width,height size of graph in inches (default: 16,12)') + parser.add_argument('-r', '--ratio', type=float, default=0.7) + parser.add_argument('-e', '--edge-color', type=str, default='blue') + parser.add_argument('-n', '--node-color', type=str, default='red') + + args = parser.parse_args() + cls = [] + for c in args.classname: + pack = c.split('.') + exec('import ' + '.'.join(pack[0:-1])) + cls.append(eval(c)) + opt = 'size="%s"; ratio=%f; edge [color=%s]; node [color=%s];' % (args.size, args.ratio, args.edge_color, args.node_color) + MROgraph(*cls, filename=args.filename, labels=0 if args.labels else 2, caption=args.caption, nometa=args.meta, setup=opt) # From ec9e04a115fd7c6a2c962f16361e4d23eb30b526 Mon Sep 17 00:00:00 2001 From: leoscholl Date: Tue, 11 Aug 2020 14:00:39 -0700 Subject: [PATCH 057/242] init linear decoder --- built_in_tasks/bmimultitasks.py | 87 +++++++++++++++++++++------- built_in_tasks/cursorControlTasks.py | 4 +- db/dbfunctions.py | 2 +- db/tracker/json_param.py | 2 +- features/simulation_features.py | 14 +++++ riglib/bmi/extractor.py | 3 +- riglib/bmi/feedback_controllers.py | 15 +++++ riglib/bmi/goal_calculators.py | 4 +- riglib/bmi/lindecoder.py | 65 +++++++++++++++++++++ riglib/bmi/sim_neurons.py | 28 +++++++++ riglib/bmi/state_space_models.py | 20 +++++++ riglib/plants.py | 5 ++ 12 files changed, 222 insertions(+), 27 deletions(-) create mode 100644 riglib/bmi/lindecoder.py diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index 01781c99..b72f16ae 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -26,7 +26,7 @@ from riglib.stereo_opengl.primitives import Line -from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceNLinkPlanarChain +from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceNLinkPlanarChain, StateSpaceEndptPos3D from built_in_tasks.manualcontrolmultitasks import ManualControlMulti @@ -238,7 +238,7 @@ def create_assister(self): if isinstance(self.decoder.ssm, StateSpaceEndptVel2D) and isinstance(self.decoder, ppfdecoder.PPFDecoder): self.assister = OFCEndpointAssister() - elif isinstance(self.decoder.ssm, StateSpaceEndptVel2D): + elif isinstance(self.decoder.ssm, StateSpaceEndptVel2D) or isinstance(self.decoder.ssm, StateSpaceEndptPos3D): self.assister = SimpleEndpointAssister(**kwargs) ## elif (self.decoder.ssm == namelist.tentacle_2D_state_space) or (self.decoder.ssm == namelist.joint_2D_state_space): ## # kin_chain = self.plant.kin_chain @@ -257,6 +257,8 @@ def create_assister(self): def create_goal_calculator(self): if isinstance(self.decoder.ssm, StateSpaceEndptVel2D): self.goal_calculator = goal_calculators.ZeroVelocityGoal(self.decoder.ssm) + elif isinstance(self.decoder.ssm, StateSpaceEndptPos3D): + self.goal_calculator = goal_calculators.ZeroVelocityGoal(self.decoder.ssm) elif isinstance(self.decoder.ssm, StateSpaceNLinkPlanarChain) and self.decoder.ssm.n_links == 2: self.goal_calculator = goal_calculators.PlanarMultiLinkJointGoal(self.decoder.ssm, self.plant.base_loc, self.plant.kin_chain, multiproc=False, init_resp=None) elif isinstance(self.decoder.ssm, StateSpaceNLinkPlanarChain) and self.decoder.ssm.n_links == 4: @@ -479,33 +481,25 @@ def _start_wait(self, *args, **kwargs): ######################### ######## Simulation tasks ######################### -from features.simulation_features import SimKalmanEnc, SimKFDecoderSup, SimCosineTunedEnc -from riglib.bmi.feedback_controllers import LQRController -class SimBMIControlMulti(SimCosineTunedEnc, SimKFDecoderSup, BMIControlMulti): +from features.simulation_features import SimKalmanEnc, SimKFDecoderSup, SimCosineTunedEnc, SimNormCosineTunedEnc +from riglib.bmi.feedback_controllers import LQRController, PosFeedbackController +class SimBMIControlMulti(BMIControlMulti2DWindow): win_res = (250, 140) - sequence_generators = ['sim_target_seq_generator_multi'] + sequence_generators = ManualControlMulti.sequence_generators + ['sim_target_seq_generator_multi'] def __init__(self, *args, **kwargs): from riglib.bmi.state_space_models import StateSpaceEndptVel2D - ssm = StateSpaceEndptVel2D() if 'sim_C' in kwargs: self.sim_C = kwargs['sim_C'] + else: + raise Exception("Need sim_C") if 'assist_level' in kwargs: self.assist_level = kwargs['assist_level'] - - - - A, B, W = ssm.get_ssm_matrices() - - Q = np.mat(np.diag([1., 1, 1, 0, 0, 0, 0])) - R = 10000*np.mat(np.diag([1., 1., 1.])) - self.fb_ctrl = LQRController(A, B, Q, R) - - self.ssm = ssm + else: + self.assist_level = (0, 0) super(SimBMIControlMulti, self).__init__(*args, **kwargs) - def _start_wait(self): self.wait_time = 0. super()._start_wait() @@ -514,7 +508,7 @@ def _test_start_trial(self, ts): return ts > self.wait_time and not self.pause @staticmethod - def sim_target_seq_generator_multi(n_targs, n_trials): + def sim_target_seq_generator_multi(n_targs=8, n_trials=8): ''' Simulated generator for simulations of the BMIControlMulti and CLDAControlMulti tasks ''' @@ -527,4 +521,57 @@ def sim_target_seq_generator_multi(n_targs, n_trials): for k in range(n_trials): targ = targets[target_inds[k], :] yield np.array([[center[0], 0, center[1]], - [targ[0], 0, targ[1]]]) + [targ[0], 0, targ[1]]]) + +class SimBMICosEncKFDec(SimCosineTunedEnc, SimKFDecoderSup, SimBMIControlMulti): + def __init__(self, *args, **kwargs): + N_NEURONS = 4 + N_STATES = 7 # 3 positions and 3 velocities and an offset + + # build the observation matrix + sim_C = np.zeros((N_NEURONS, N_STATES)) + # control x positive directions + sim_C[0, :] = np.array([0, 0, 0, 1, 0, 0, 0]) + sim_C[1, :] = np.array([0, 0, 0, -1, 0, 0, 0]) + # control z positive directions + sim_C[2, :] = np.array([0, 0, 0, 0, 0, 1, 0]) + sim_C[3, :] = np.array([0, 0, 0, 0, 0, -1, 0]) + + kwargs['sim_C'] = sim_C + + ssm = StateSpaceEndptVel2D() + A, B, W = ssm.get_ssm_matrices() + Q = np.mat(np.diag([1., 1, 1, 0, 0, 0, 0])) + R = 10000*np.mat(np.diag([1., 1., 1.])) + self.fb_ctrl = LQRController(A, B, Q, R) + self.ssm = ssm + + super(SimBMICosEncKFDec, self).__init__(*args, **kwargs) + +from riglib.bmi.lindecoder import LinearScaleFilter +from riglib.bmi.bmi import Decoder +class SimBMICosEncLinDec(SimNormCosineTunedEnc, SimBMIControlMulti): + def __init__(self, *args, **kwargs): + + # build the observation matrix + sim_C = np.zeros((2, 3)) + # control x positive directions + sim_C[0, :] = np.array([1, 0, 0]) + sim_C[1, :] = np.array([0, 1, 1]) + + kwargs['sim_C'] = sim_C + kwargs['assist_level'] = (0, 0) # TODO: implement assister for 3D Pos ssm + + ssm = StateSpaceEndptPos3D() + self.fb_ctrl = PosFeedbackController() + self.ssm = ssm + + super(SimBMICosEncLinDec, self).__init__(*args, **kwargs) + + def load_decoder(self): + units = self.encoder.get_units() + ssm = StateSpaceEndptPos3D() + filt = LinearScaleFilter(10000, 6, ssm.n_states, len(units)) + self.decoder = Decoder(filt, units, ssm, binlen=0.1, subbins=1) + self.decoder.n_features = len(units) + self.decoder.binlen = 0.1 \ No newline at end of file diff --git a/built_in_tasks/cursorControlTasks.py b/built_in_tasks/cursorControlTasks.py index 13549d5d..7c3a69b8 100644 --- a/built_in_tasks/cursorControlTasks.py +++ b/built_in_tasks/cursorControlTasks.py @@ -1,7 +1,7 @@ -from manualcontrolmultitasks import ManualControlMulti +from built_in_tasks.manualcontrolmultitasks import ManualControlMulti from riglib.stereo_opengl.window import WindowDispl2D -from bmimultitasks import BMIControlMulti +from built_in_tasks.bmimultitasks import BMIControlMulti import pygame import numpy as np import copy diff --git a/db/dbfunctions.py b/db/dbfunctions.py index 9257e59c..c607a7e0 100644 --- a/db/dbfunctions.py +++ b/db/dbfunctions.py @@ -24,7 +24,7 @@ except: pass -from tracker import models +from db.tracker import models # default DB, change this variable from python session to switch to other database db_name = 'default' diff --git a/db/tracker/json_param.py b/db/tracker/json_param.py index b38306c0..92010a2d 100644 --- a/db/tracker/json_param.py +++ b/db/tracker/json_param.py @@ -111,7 +111,7 @@ def norm_trait(trait, value): #use Cast to validate the value try: - return trait.cast(value) + return trait.validate(trait.name, '', value) except: f = open(os.path.join(config.log_path, "trait_log"), 'w') f.write('Error with type for trait %s, %s, value %s' % (str(trait), str(ttype), str(value))) diff --git a/features/simulation_features.py b/features/simulation_features.py index 216e3011..37e9e338 100644 --- a/features/simulation_features.py +++ b/features/simulation_features.py @@ -223,6 +223,20 @@ def create_feature_extractor(self): n_subbins=self.decoder.n_subbins, units=self.decoder.units, task=self) self._add_feature_extractor_dtype() +class SimNormCosineTunedEnc(SimNeuralEnc): + + def _init_neural_encoder(self): + from riglib.bmi.sim_neurons import NormalizedCosEnc + self.encoder = NormalizedCosEnc(self.plant.endpt_bounds, self.sim_C, self.ssm, return_ts=False, DT=0.1, call_ds_rate=1) + + def create_feature_extractor(self): + ''' + Create the feature extractor object + ''' + self.extractor = extractor.SimDirectObsExtractor(self.fb_ctrl, self.encoder, + n_subbins=self.decoder.n_subbins, units=self.decoder.units, task=self) + self._add_feature_extractor_dtype() + class SimFAEnc(SimCosineTunedEnc): def __init__(self, *args, **kwargs): self.FACosEnc_kwargs = kwargs.pop('SimFAEnc_kwargs', dict()) diff --git a/riglib/bmi/extractor.py b/riglib/bmi/extractor.py index f15de59d..22a8e0f5 100644 --- a/riglib/bmi/extractor.py +++ b/riglib/bmi/extractor.py @@ -744,8 +744,7 @@ class SimDirectObsExtractor(SimBinnedSpikeCountsExtractor): ''' def __call__(self, start_time, *args, **kwargs): y_t = self.get_spike_ts(*args, **kwargs) - return dict(spike_counts=y_t) - + return dict(spike_counts=np.reshape(y_t, (len(self.units),self.n_subbins))) ############################################# diff --git a/riglib/bmi/feedback_controllers.py b/riglib/bmi/feedback_controllers.py index 7c73e803..6afa5caa 100644 --- a/riglib/bmi/feedback_controllers.py +++ b/riglib/bmi/feedback_controllers.py @@ -381,3 +381,18 @@ def get(self, cur_target, cur_pos, keys_pressed=None): joint_pos, joint_vel = ik.inv_kin_2D(pos, self.link_lengths[0], self.link_lengths[1], vel) return joint_vel[0]['sh_vabd'], joint_vel[0]['el_vflex'] +class PosFeedbackController(FeedbackController): + ''' + Dumb controller that just spits back the target + ''' + def __init__(self, *args, **kwargs): + pass + + def calc_next_state(self, current_state, target_state, mode=None): + return target_state + + def __call__(self, current_state, target_state, mode=None): + return target_state + + def get(self, current_state, target_state, mode=None): + return target_state \ No newline at end of file diff --git a/riglib/bmi/goal_calculators.py b/riglib/bmi/goal_calculators.py index 3e9fb1cc..023e14f7 100644 --- a/riglib/bmi/goal_calculators.py +++ b/riglib/bmi/goal_calculators.py @@ -200,8 +200,10 @@ def __call__(self, target_pos, **kwargs): target_vel = np.zeros_like(target_pos) offset_val = 1 target_state = np.hstack([target_pos, target_vel, 1]).reshape(-1, 1) - else: + elif len(target_pos) == n_pos_vel_states: target_state = np.hstack([target_pos, 1]).reshape(-1, 1) + else: + target_state = np.hstack(target_pos).reshape(-1, 1) # don't add offset error = 0 return (target_state, error), True diff --git a/riglib/bmi/lindecoder.py b/riglib/bmi/lindecoder.py new file mode 100644 index 00000000..7cb38a04 --- /dev/null +++ b/riglib/bmi/lindecoder.py @@ -0,0 +1,65 @@ +''' +Classes for BMI decoding using linear scaling. +''' +import numpy as np + +class State(object): + '''For compatibility with other BMI decoding implementations, literally just holds the state''' + + def __init__(self, mean, *args, **kwargs): + self.mean = mean + +class LinearScaleFilter(object): + + def __init__(self, n_counts, window, n_states, n_units): + ''' + Parameters: + + n_counts How many observations to hold + window How many observations to average + n_states How many state space variables are there + n_units Number of neural units + ''' + self.state = State(np.zeros([n_states,1])) + self.obs = np.zeros((n_counts, n_units)) + self.n_states = n_states + self.window = window + self.n_units = n_units + self.count = 0 + + def _init_state(self): + pass + + def get_mean(self): + return np.array(self.state.mean).ravel() + + def __call__(self, obs, **kwargs): + self.state = self._normalize(obs, **kwargs) + + def _normalize(self, obs,**kwargs): + ''' Function to compute normalized scaling of new observations''' + + self.obs[:-1, :] = self.obs[1:, :] + self.obs[-1, :] = np.squeeze(obs) + if self.count < len(self.obs): + self.count += 1 + + m_win = np.squeeze(np.mean(self.obs[-self.window:, :], axis=0)) + m = np.median(self.obs[-self.count:, :], axis=0) + # range = max(1, np.amax(self.obs[-self.count:, :]) - np.amin(self.obs[-self.count:, :])) + range = np.std(self.obs[-self.count:, :], axis=0)*3 + range[range < 1] = 1 + x = (m_win - m) / range + x = np.squeeze(np.asarray(x)) * 20 # hack for 14x14 cursor + + # Arrange output + if self.n_states == self.n_units: + return State(x) + elif self.n_states == 3 and self.n_units == 2: + mean = np.zeros([self.n_states,1]) + mean[0] = x[0] + mean[2] = x[1] + return State(mean) + else: + raise NotImplementedError() + \ No newline at end of file diff --git a/riglib/bmi/sim_neurons.py b/riglib/bmi/sim_neurons.py index dc91d60c..bd047978 100644 --- a/riglib/bmi/sim_neurons.py +++ b/riglib/bmi/sim_neurons.py @@ -363,6 +363,34 @@ def y2_eq_r2_min_x2(self, x_arr, r2): y.append(-1*np.sqrt(r2 - x**2)) return np.array(y) +class NormalizedCosEnc(GenericCosEnc): + + def __init__(self, bounds, *args, **kwargs): + self.min = np.array([bounds[0], bounds[2], bounds[4]]) + self.range = np.array([bounds[1] - bounds[0], bounds[3] - bounds[2], bounds[5] - bounds[4]]) + self.range[self.range == 0] = 1 + self.gain = 100 + super(NormalizedCosEnc, self).__init__(*args, **kwargs) + + def gen_spikes(self, next_state, mode=None): + """ + Simulate the spikes + + Parameters + ---------- + next_state : np.array of shape (N, 1) + The "next state" to be encoded by this population of neurons + + Returns + ------- + time stamps or counts + Either spike time stamps or a vector of unit spike counts is returned, depending on whether the 'return_ts' attribute is True + + """ + norm_state = np.divide(np.subtract(np.squeeze(next_state), self.min), self.range) + rates = np.dot(self.C, norm_state) * self.gain + return self.return_spikes(rates, mode=mode) + def from_file_to_FACosEnc(plot=False): from riglib.bmi import state_space_models as ssm import pickle diff --git a/riglib/bmi/state_space_models.py b/riglib/bmi/state_space_models.py index b8d70591..c0d204f4 100644 --- a/riglib/bmi/state_space_models.py +++ b/riglib/bmi/state_space_models.py @@ -381,6 +381,26 @@ def __setstate__(self, state): if not hasattr(self, 'w'): self.w = 7 +class StateSpaceEndptPos3D(StateSpace): + ''' StateSpace for 3D pos control''' + def __init__(self, **kwargs): + self.states = [ + State('hand_px', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0), + State('hand_py', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0), + State('hand_pz', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0) + ] + + def __setstate__(self, state): + self.__dict__ = state + if not hasattr(self, 'Delta'): + self.Delta = 0.1 + + if not hasattr(self, 'vel_decay'): + self.vel_decay = 0.8 + + if not hasattr(self, 'w'): + self.w = 7 + ############################ ##### Helper functions ##### ############################ diff --git a/riglib/plants.py b/riglib/plants.py index 0be35669..58ba76e6 100644 --- a/riglib/plants.py +++ b/riglib/plants.py @@ -240,6 +240,9 @@ def set_visibility(self, visible): def _bound(self, pos, vel): pos = pos.copy() vel = vel.copy() + if len(vel) == 0: + vel_wall = self.vel_wall # don't worry about vel if it's empty + self.vel_wall = False if self.endpt_bounds is not None: if pos[0] < self.endpt_bounds[0]: pos[0] = self.endpt_bounds[0] @@ -261,6 +264,8 @@ def _bound(self, pos, vel): if pos[2] > self.endpt_bounds[5]: pos[2] = self.endpt_bounds[5] if self.vel_wall: vel[2] = 0 + if len(vel) == 0: + self.vel_wall = vel_wall # restore previous value return pos, vel def drive(self, decoder): From 41817af631ac7dc360946c14b7509c759f5f3ff7 Mon Sep 17 00:00:00 2001 From: leoscholl Date: Tue, 11 Aug 2020 15:07:00 -0700 Subject: [PATCH 058/242] add windows batch testing script --- .gitignore | 1 + tests/unit_tests/coverage.bat | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 tests/unit_tests/coverage.bat diff --git a/.gitignore b/.gitignore index aae994b8..7085146f 100644 --- a/.gitignore +++ b/.gitignore @@ -43,5 +43,6 @@ riglib/fsm/fsm.egg-info/* env/* tests/*.mat tests/*.hdf +tests/unit_tests/htmlcov *.h5 *.dot diff --git a/tests/unit_tests/coverage.bat b/tests/unit_tests/coverage.bat new file mode 100644 index 00000000..288a359d --- /dev/null +++ b/tests/unit_tests/coverage.bat @@ -0,0 +1,3 @@ +coverage3 run --source ../../riglib run_unit_tests.py +coverage3 report -m +coverage3 html \ No newline at end of file From 01253eeb6bdc2cfbb3aa83107d34135960e688d4 Mon Sep 17 00:00:00 2001 From: leoscholl Date: Tue, 11 Aug 2020 15:17:42 -0700 Subject: [PATCH 059/242] small changes to db and sim bmi test --- db/settings.py | 4 ++-- requirements.txt | 1 - {built_in_tasks => tests}/test_SimBMIControlMulti.py | 0 3 files changed, 2 insertions(+), 3 deletions(-) rename {built_in_tasks => tests}/test_SimBMIControlMulti.py (100%) diff --git a/db/settings.py b/db/settings.py index b95fd8e7..6143731b 100644 --- a/db/settings.py +++ b/db/settings.py @@ -50,8 +50,8 @@ def get_sqlite3_databases(): return dbs -DATABASES = get_sqlite3_databases() -#DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': 'mydatabase', } } +#DATABASES = get_sqlite3_databases() +DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': 'db.sql', } } ''' DATABASES = { diff --git a/requirements.txt b/requirements.txt index 78146a86..2352c450 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,4 +26,3 @@ robotframework==3.2.1 scikit_learn==0.23.1 tabulate==0.8.7 tasks==2.5.0 -testing==0.0 diff --git a/built_in_tasks/test_SimBMIControlMulti.py b/tests/test_SimBMIControlMulti.py similarity index 100% rename from built_in_tasks/test_SimBMIControlMulti.py rename to tests/test_SimBMIControlMulti.py From 4c6c5496870b8415bce1d1d5ba9c760e42e684e1 Mon Sep 17 00:00:00 2001 From: Pavi Raj Date: Tue, 11 Aug 2020 18:55:39 -0700 Subject: [PATCH 060/242] initial upload for ecube neural streaming lib --- riglib/ecube/__init__.py | 98 +++++++++++++++++++++++++++++++++++ riglib/ecube/pyeCubeStream.py | 43 +++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 riglib/ecube/__init__.py create mode 100644 riglib/ecube/pyeCubeStream.py diff --git a/riglib/ecube/__init__.py b/riglib/ecube/__init__.py new file mode 100644 index 00000000..68bd2868 --- /dev/null +++ b/riglib/ecube/__init__.py @@ -0,0 +1,98 @@ +from pyeCubeStream import eCubeStream +import numpy as np +import time + +''' +#to do list +#a method to check connection instead of pulling data once +# modify the return type of the get function +''' + +from riglib.source import DataSourceSystem +class LFP(DataSourceSystem): + ''' + wrapper class for pyecubestream + pyecube already implemented, start, stop, and get + here, we just wrap it under DataSourceSystem + ''' + #as required by a DataSourceSystem + update_freq = 1000. + dtype = np.dtype('float') + + #need to decide if we do processing here? + #for future work, we need to decide + # which channels + # chan_offset if there needs to be + + def __init__(self): + ''' + Constructor for ecube.LFP + + Parameters + ---------- + + + Returns + ------- + ecube.LFP instance + ''' + + #ecubeStream by defaults stream from HS, but just to make it super clear + self.conn = eCubeStream(source='Headstages') + + + + #then we can select channels if channel selection is available + + def check_conn_by_pull_data(self): + ''' + quickly pull data twice + print out the process time and the + ''' + t1 = time.perf_counter() + + test_conn_data = self.get() + ecube_timestamp_1 = test_conn_data[0] + + test_conn_data = self.get() + ecube_timestamp_2 = test_conn_data[0] + + t2 = time.perf_counter() + ecube_ts_in_ms = (ecube_timestamp_2 - ecube_timestamp_1)/1e6 + + print(f'takes {(t2 - t1)*1000} ms to pull two frames') + print(f'the delta time stamp difference is {ecube_ts_in_ms} ms') + print(f'data has the shape {test_conn_data[1].shape}') + + #wrapper functions to pyecubestream + + def start(self): + self.conn.start() + #quickly to pull data twice to check connections + try: + self.check_conn_by_pull_data() + except: + raise Exception('Connection to ecube failed') + + def stop(self): + self.conn.stop() + + def get(self): + ''' + do some data attenuation before returning + ''' + data_block = self.conn.get() #in the form of (time_stamp, data) + return data_block + + + +#quick test +if __name__ == "__main__": + #this will pull data twice and pull bunch of things + lfp = LFP() + lfp.start() + lfp.stop() + + + + diff --git a/riglib/ecube/pyeCubeStream.py b/riglib/ecube/pyeCubeStream.py new file mode 100644 index 00000000..782a69f7 --- /dev/null +++ b/riglib/ecube/pyeCubeStream.py @@ -0,0 +1,43 @@ +import zmq +import numpy as np + +class eCubeStream: + + dtypedict = { + 'Headstages': np.int16, + 'AnalogPanel': np.int16, + 'DigitalPanel': np.uint64 + } + + def __init__(self, source='Headstages', address='127.0.0.1', port=49676): + self.source = source + self.addrstr = 'tcp://'+address+':'+str(port) + self.ctx = zmq.Context.instance() + self.sock = self.ctx.socket(zmq.SUB) + self.dtype = self.dtypedict[source] + + def start(self): + self.sock.connect(self.addrstr) + self.sock.setsockopt_string(zmq.SUBSCRIBE, self.source) + + def get(self): + datamsg = self.sock.recv_multipart() + + samples = int.from_bytes(datamsg[2], byteorder='little', signed=True) + if self.source == 'DigitalPanel': + channels = 1 + else: + channels = int.from_bytes(datamsg[3], byteorder='little', signed=False) + + if len(datamsg[4]) != channels * samples * np.dtype(self.dtype).itemsize: + raise ValueError("Error: data packet size {} did not match {} samples of {} channels".format( + len(datamsg[4]), samples, channels)) + + timestamp = int.from_bytes(datamsg[1], byteorder='little', signed=False) + data = np.frombuffer(datamsg[4], dtype=self.dtype).reshape((samples, channels)) + + return (timestamp, data) + + def stop(self): + self.sock.setsockopt_string(zmq.UNSUBSCRIBE, self.source) + self.sock.disconnect(self.addrstr) \ No newline at end of file From d1cc436ce43039459fa16e5ce92d47d088d596de Mon Sep 17 00:00:00 2001 From: leoscholl Date: Fri, 21 Aug 2020 12:32:42 -0700 Subject: [PATCH 061/242] update linear decoder --- built_in_tasks/bmimultitasks.py | 57 ++++++++++++++----- features/simulation_features.py | 20 ++++++- riglib/bmi/bmi.py | 3 +- riglib/bmi/extractor.py | 63 +++++++++++++++++++++ riglib/bmi/lindecoder.py | 77 +++++++++++++++++--------- riglib/bmi/sim_neurons.py | 98 ++++++++++++++++++++++++++++++--- tests/run_unit_tests.py | 14 +++++ tests/test_decoders.py | 66 ++++++++++++++++++++++ 8 files changed, 350 insertions(+), 48 deletions(-) create mode 100644 tests/run_unit_tests.py create mode 100644 tests/test_decoders.py diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index b72f16ae..96eb5a26 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -161,6 +161,33 @@ def endpoint_assist_simple(cursor_pos, target_pos, decoder_binlen=0.1, speed=0.5 x_assist = np.mat(x_assist.reshape(-1,1)) return x_assist +class SimplePosAssister(SimpleEndpointAssister): + + @staticmethod + def endpoint_assist_simple(cursor_pos, target_pos, decoder_binlen=0.1, speed=0.5, target_radius=2., assist_level=0.): + ''' + Estimate the next state using a constant velocity estimate moving toward the specified target + + Parameters + ---------- + see SimpleEndtpointAssister for docs + + Returns + ------- + x_assist : np.ndarray of shape (7, 1) + Control vector to add onto the state vector to assist control. + ''' + diff_vec = target_pos - cursor_pos + dist_to_target = np.linalg.norm(diff_vec) + dir_to_target = diff_vec / (np.spacing(1) + dist_to_target) + + if dist_to_target > target_radius: + assist_cursor_pos = cursor_pos + speed*dir_to_target + else: + assist_cursor_pos = cursor_pos + speed*diff_vec/2 + + return assist_cursor_pos.ravel() + class SimpleEndpointAssisterLFC(feedback_controllers.MultiModalLFC): ''' Docstring @@ -238,8 +265,10 @@ def create_assister(self): if isinstance(self.decoder.ssm, StateSpaceEndptVel2D) and isinstance(self.decoder, ppfdecoder.PPFDecoder): self.assister = OFCEndpointAssister() - elif isinstance(self.decoder.ssm, StateSpaceEndptVel2D) or isinstance(self.decoder.ssm, StateSpaceEndptPos3D): + elif isinstance(self.decoder.ssm, StateSpaceEndptVel2D): self.assister = SimpleEndpointAssister(**kwargs) + elif isinstance(self.decoder.ssm, StateSpaceEndptPos3D): + self.assister = SimplePosAssister(**kwargs) ## elif (self.decoder.ssm == namelist.tentacle_2D_state_space) or (self.decoder.ssm == namelist.joint_2D_state_space): ## # kin_chain = self.plant.kin_chain ## # A, B, W = self.decoder.ssm.get_ssm_matrices(update_rate=self.decoder.binlen) @@ -320,12 +349,6 @@ class BMIControlMulti2DWindow(BMIControlMulti, WindowDispl2D): def __init__(self,*args, **kwargs): super(BMIControlMulti2DWindow, self).__init__(*args, **kwargs) - def create_assister(self): - kwargs = dict(decoder_binlen=self.decoder.binlen, target_radius=self.target_radius) - if hasattr(self, 'assist_speed'): - kwargs['assist_speed'] = self.assist_speed - self.assister = SimpleEndpointAssister(**kwargs) - def create_goal_calculator(self): self.goal_calculator = goal_calculators.ZeroVelocityGoal(self.decoder.ssm) @@ -481,7 +504,7 @@ def _start_wait(self, *args, **kwargs): ######################### ######## Simulation tasks ######################### -from features.simulation_features import SimKalmanEnc, SimKFDecoderSup, SimCosineTunedEnc, SimNormCosineTunedEnc +from features.simulation_features import SimKalmanEnc, SimKFDecoderSup, SimCosineTunedEnc from riglib.bmi.feedback_controllers import LQRController, PosFeedbackController class SimBMIControlMulti(BMIControlMulti2DWindow): win_res = (250, 140) @@ -548,9 +571,10 @@ def __init__(self, *args, **kwargs): super(SimBMICosEncKFDec, self).__init__(*args, **kwargs) +from features.simulation_features import SimLFPCosineTunedEnc, SimNormCosineTunedEnc from riglib.bmi.lindecoder import LinearScaleFilter from riglib.bmi.bmi import Decoder -class SimBMICosEncLinDec(SimNormCosineTunedEnc, SimBMIControlMulti): +class SimBMICosEncLinDec(SimLFPCosineTunedEnc, SimBMIControlMulti): def __init__(self, *args, **kwargs): # build the observation matrix @@ -560,18 +584,23 @@ def __init__(self, *args, **kwargs): sim_C[1, :] = np.array([0, 1, 1]) kwargs['sim_C'] = sim_C - kwargs['assist_level'] = (0, 0) # TODO: implement assister for 3D Pos ssm + kwargs['assist_level'] = (0.1, 0.1) + + # make the decoder map + self.decoder_map = np.array([[1, 0], [0, 0], [0, 1]]) ssm = StateSpaceEndptPos3D() self.fb_ctrl = PosFeedbackController() self.ssm = ssm super(SimBMICosEncLinDec, self).__init__(*args, **kwargs) - + def load_decoder(self): units = self.encoder.get_units() ssm = StateSpaceEndptPos3D() - filt = LinearScaleFilter(10000, 6, ssm.n_states, len(units)) + filt_counts = 10000 # number of observations to calculate range + filt_window = 1 # number of observations to average for each tick + filt_map = self.decoder_map # map from states to units + filt = LinearScaleFilter(filt_counts, ssm.n_states, len(units), map=filt_map, window=filt_window, gain=self.fov) self.decoder = Decoder(filt, units, ssm, binlen=0.1, subbins=1) - self.decoder.n_features = len(units) - self.decoder.binlen = 0.1 \ No newline at end of file + self.decoder.n_features = len(units) \ No newline at end of file diff --git a/features/simulation_features.py b/features/simulation_features.py index 37e9e338..9bba08c7 100644 --- a/features/simulation_features.py +++ b/features/simulation_features.py @@ -227,7 +227,8 @@ class SimNormCosineTunedEnc(SimNeuralEnc): def _init_neural_encoder(self): from riglib.bmi.sim_neurons import NormalizedCosEnc - self.encoder = NormalizedCosEnc(self.plant.endpt_bounds, self.sim_C, self.ssm, return_ts=False, DT=0.1, call_ds_rate=1) + self.encoder = NormalizedCosEnc(self.plant.endpt_bounds, self.sim_C, self.ssm, spike=True, return_ts=False, + DT=self.update_rate, tick=self.update_rate) def create_feature_extractor(self): ''' @@ -237,6 +238,23 @@ def create_feature_extractor(self): n_subbins=self.decoder.n_subbins, units=self.decoder.units, task=self) self._add_feature_extractor_dtype() +class SimLFPCosineTunedEnc(SimNeuralEnc): + + bands = [(51, 100)] + + def _init_neural_encoder(self): + from riglib.bmi.sim_neurons import NormalizedCosEnc + self.encoder = NormalizedCosEnc(self.plant.endpt_bounds, self.sim_C, self.ssm, spike=False, return_ts=False, + DT=self.update_rate, tick=self.update_rate, n_bands=len(self.bands)) + + def create_feature_extractor(self): + ''' + Create the feature extractor object + ''' + self.extractor = extractor.SimPowerExtractor(self.fb_ctrl, self.encoder, + channels=self.decoder.channels, bands=self.bands, task=self) + self._add_feature_extractor_dtype() + class SimFAEnc(SimCosineTunedEnc): def __init__(self, *args, **kwargs): self.FACosEnc_kwargs = kwargs.pop('SimFAEnc_kwargs', dict()) diff --git a/riglib/bmi/bmi.py b/riglib/bmi/bmi.py index 2d8113be..ba1380dd 100644 --- a/riglib/bmi/bmi.py +++ b/riglib/bmi/bmi.py @@ -467,6 +467,7 @@ def __init__(self, filt, units, ssm, binlen=0.1, n_subbins=1, tslice=[-1,-1], ca self.ssm = ssm self.units = np.array(units, dtype=np.int32) + self.channels = np.unique(self.units[:,0]) self.binlen = binlen self.bounding_box = ssm.bounding_box self.states = ssm.state_names @@ -774,7 +775,7 @@ def predict(self, neural_obs, assist_level=0.0, weighted_avg_lfc=False, **kwargs self.filt.state.mean = np.mat(tmp).T else: - self.filt.state.mean = (1-assist_level)*self.filt.state.mean + assist_level * x_assist + self.filt.state.mean = (1-assist_level)*self.filt.state.mean + assist_level * x_assist.reshape(-1,1) # Bound cursor, if any hard bounds for states are applied if hasattr(self, 'bounder'): diff --git a/riglib/bmi/extractor.py b/riglib/bmi/extractor.py index 22a8e0f5..58f1afba 100644 --- a/riglib/bmi/extractor.py +++ b/riglib/bmi/extractor.py @@ -738,6 +738,69 @@ def get_spike_ts(self): ts_data = self.encoder(ctrl) return ts_data +class SimPowerExtractor(LFPMTMPowerExtractor): + ''' + This extractor pretends to generate MTMPower estimates but really it just + fetches the powers from a given encoder at the appropriate freq bands + ''' + def __init__(self, input_device, encoder, channels=[], bands=default_bands, task=None, **kwargs): + ''' + Constructor for SimPowerExtractor + + Parameters + ---------- + input_device: object with a "calc_next_state" method + Generate the "intended" next state, e.g., by feedback control policy + encoder: callable with 1 argument + Maps the "control" input into the spike timestamps of a set of neurons + channels : list + LFP electrode indices to use for feature extraction + bands : list of tuples + Each tuple defines a frequency band of interest as (start frequency, end frequency) + + Returns + ------- + SimPowerExtractor instance + ''' + self.input_device = input_device + self.encoder = encoder + self.channels = channels + self.bands = bands + self.feature_dtype = ('lfp_power', 'f8', (len(channels)*len(bands), 1)) + self.task = task + extractor_kwargs = dict() + extractor_kwargs['channels'] = channels + extractor_kwargs['bands'] = bands + self.extractor_kwargs = extractor_kwargs + + def get_cont_samples(self, *args, **kwargs): + ''' + see LFPMTMPowerExtractor.get_cont_samples for docs + ''' + current_state = self.task.get_current_state() + target_state = self.task.get_target_BMI_state() + ctrl = self.input_device.calc_next_state(current_state, target_state) + cont_data = self.encoder(ctrl) + return cont_data + + def __call__(self, start_time, *args, **kwargs): + ''' + Parameters + ---------- + start_time : float + Absolute time from the task event loop. This is unused by LFP extractors in their current implementation + and only passed in to ensure that function signatures are the same across extractors. + *args, **kwargs : optional positional/keyword arguments + These are passed to the source, or ignored (not needed for this extractor). + + Returns + ------- + dict + Extracted features to be saved in the task. + ''' + lfp_power = self.get_cont_samples(*args, **kwargs) # dims of channels x bands + return dict(lfp_power=lfp_power) + class SimDirectObsExtractor(SimBinnedSpikeCountsExtractor): ''' This extractor just passes back the observation vector generated by the encoder diff --git a/riglib/bmi/lindecoder.py b/riglib/bmi/lindecoder.py index 7cb38a04..55c835a8 100644 --- a/riglib/bmi/lindecoder.py +++ b/riglib/bmi/lindecoder.py @@ -11,21 +11,43 @@ def __init__(self, mean, *args, **kwargs): class LinearScaleFilter(object): - def __init__(self, n_counts, window, n_states, n_units): + def __init__(self, n_counts, n_states, n_units, map=None, window=1, gain=20): ''' - Parameters: + Constructor for LinearScaleFilter - n_counts How many observations to hold - window How many observations to average - n_states How many state space variables are there - n_units Number of neural units + Parameters + ---------- + n_counts : Number of observations to hold + Range is computed over the whole observation matrix size (N, D) + where N is the number of observations and D is the number of units + n_states : How many state variables are there + For example, a one-dim decoder has one state variable + n_units : Number of neural units + Can be number of isolated spiking units or number of channels for lfp + map : Which units to assign to which states (default = None) + Floating point matrix of size (S, D) where S is the number of + states and D is the number of units, assigning a weight to each pair + Sum along each row must equal 1.0 + window : How many observations to average to smooth output (default = 1) + gain : How far to move the plant for a normalized output of 1.0 (default = 20) + + Returns + ------- + LinearScaleFilter instance ''' self.state = State(np.zeros([n_states,1])) self.obs = np.zeros((n_counts, n_units)) self.n_states = n_states - self.window = window self.n_units = n_units + self.window = window + self.map = map + if map is None: + # Generate a default map where one unit controls one state + self.map = np.identity(max(n_states, n_units)) + self.map = np.resize(self.map, (n_states, n_units)) + self.gain = gain self.count = 0 + self.fixed = False def _init_state(self): pass @@ -39,27 +61,32 @@ def __call__(self, obs, **kwargs): def _normalize(self, obs,**kwargs): ''' Function to compute normalized scaling of new observations''' - self.obs[:-1, :] = self.obs[1:, :] - self.obs[-1, :] = np.squeeze(obs) - if self.count < len(self.obs): - self.count += 1 + # Update observation matrix, unless it has been fixed + if not self.fixed: + self.obs[:-1, :] = self.obs[1:, :] + self.obs[-1, :] = np.squeeze(obs) + if self.count < len(self.obs): + self.count += 1 + # Normalize latest observation(s) m_win = np.squeeze(np.mean(self.obs[-self.window:, :], axis=0)) m = np.median(self.obs[-self.count:, :], axis=0) # range = max(1, np.amax(self.obs[-self.count:, :]) - np.amin(self.obs[-self.count:, :])) - range = np.std(self.obs[-self.count:, :], axis=0)*3 + range = 3 * np.std(self.obs[-self.count:, :], axis=0) range[range < 1] = 1 - x = (m_win - m) / range - x = np.squeeze(np.asarray(x)) * 20 # hack for 14x14 cursor + x = (m_win - m) / range * self.gain - # Arrange output - if self.n_states == self.n_units: - return State(x) - elif self.n_states == 3 and self.n_units == 2: - mean = np.zeros([self.n_states,1]) - mean[0] = x[0] - mean[2] = x[1] - return State(mean) - else: - raise NotImplementedError() - \ No newline at end of file + # Arrange output according to map + out = np.matmul(self.map, x).reshape(-1,1) + return State(out) + + def save_obs(self): + raise NotImplementedError() + + def fix_obs(self): + self.fixed = True + + def load_and_fix_obs(self, file): + raise NotImplementedError() + self.count = len(self.obs) + self.fix_obs() \ No newline at end of file diff --git a/riglib/bmi/sim_neurons.py b/riglib/bmi/sim_neurons.py index bd047978..87e5d0d9 100644 --- a/riglib/bmi/sim_neurons.py +++ b/riglib/bmi/sim_neurons.py @@ -1,6 +1,6 @@ #!/usr/bin/python """ -Classes to simulate neural activity (spike firing rates) by various methods. +Classes to simulate neural activity (spike firing rates and lfp) by various methods. """ import os @@ -9,7 +9,6 @@ from scipy.io import loadmat import numpy as np -from numpy.random import poisson, rand from scipy.io import loadmat, savemat @@ -133,7 +132,7 @@ def gen_spikes(self, next_state, mode=None): def return_spikes(self, rates, mode=None): rates[rates < 0] = 0 # Floor firing rates at 0 Hz - counts = poisson(rates * self.DT) + counts = np.random.poisson(rates * self.DT) if np.logical_or(mode=='ts', np.logical_and(mode is None, self.return_ts)): ts = [] @@ -352,7 +351,7 @@ def gen_spikes(self, next_state, mode=None): def mod_poisson(self, x, dt=0.1): x[x<0] = 0 - return poisson(x*dt) + return np.random.poisson(x*dt) def y2_eq_r2_min_x2(self, x_arr, r2): y = [] @@ -364,13 +363,48 @@ def y2_eq_r2_min_x2(self, x_arr, r2): return np.array(y) class NormalizedCosEnc(GenericCosEnc): + ''' + Generates neural observations (spikes or LFP) based on normalized scaling within the bounds of + the task, instead of DC rectifying the output. Generates simulated spiking or LFP power data + ''' + + def __init__(self, bounds, C, ssm, spike=True, return_ts=False, DT=0.1, tick=1/60, n_bands=1): + ''' + Constructor for NormalizedCosEnc - def __init__(self, bounds, *args, **kwargs): + Parameters + ---------- + bounds : array of [min_x, max_x, min_y, max_y, min_z, max_z] + Extreme plant coordinates + C : np.ndarray of shape (N, K) + N is the number of simulated neurons, K is the number of covariates driving neuronal activity. + The product of C and the hidden state vector x should give the intended spike rates in Hz + ssm : state_space_models.StateSpace instance + ARG_DESCR + spike : bool, optional, default=True + Determines whether simultated output is spike or LFP data + return_ts : bool, optional, default=False + If True, fake timestamps are returned for each spike event in the same format + as real spike data would be delivered over the network during a real experiment. + If False, a vector of counts is returned instead. Specify True or False depending on + which type of feature extractor you're using for your simulated task. + DT : float, optional, default=0.1 + Sampling interval to come up with new spike processes + tick : float, optional, default=1/60 + Refresh rate of main experiment tick + + Returns + ------- + GenericCosEnc instance + ''' self.min = np.array([bounds[0], bounds[2], bounds[4]]) self.range = np.array([bounds[1] - bounds[0], bounds[3] - bounds[2], bounds[5] - bounds[4]]) self.range[self.range == 0] = 1 - self.gain = 100 - super(NormalizedCosEnc, self).__init__(*args, **kwargs) + self.gain = 10 + self.spike = spike + self.n_bands = n_bands + call_ds_rate = DT / tick + super(NormalizedCosEnc, self).__init__(C, ssm, return_ts, DT, call_ds_rate) def gen_spikes(self, next_state, mode=None): """ @@ -391,6 +425,56 @@ def gen_spikes(self, next_state, mode=None): rates = np.dot(self.C, norm_state) * self.gain return self.return_spikes(rates, mode=mode) + def gen_power(self, next_state, mode=None): + """ + Simulate the LFP powers + + Parameters + ---------- + next_state : np.array of shape (N, 1) + The "next state" to be encoded by this population of neurons + + Returns + ------- + powers : np.array of shape (N, P) -> flattened + N number of neurons, P number of power bands, determined by n_bands + """ + norm_state = np.divide(np.subtract(np.squeeze(next_state), self.min), self.range) + ideal = np.dot(self.C, norm_state) * self.gain + # Generate gaussian noise + noise = np.random.normal(0, 0, size=(len(ideal), self.n_bands)) + # Replicate across frequency bands + power = np.tile(ideal.reshape((-1,1)), (1, self.n_bands)) + noise + return power.reshape(-1,1) + + def __call__(self, next_state, mode=None): + ''' + See CosEnc.__call__ for docs + ''' + if self.spike: + if self.call_count % self.call_ds_rate == 0: + ts_data = self.gen_spikes(next_state, mode=mode) + + else: + if self.return_ts: + # return an empty list of time stamps + ts_data = np.array([]) + else: + # return a vector of 0's + ts_data = np.zeros(self.n_neurons) + + self.call_count += 1 + return ts_data + else: + if self.call_count % self.call_ds_rate == 0: + lfp_data = self.gen_power(next_state, mode=mode) + + else: + lfp_data = np.zeros((self.n_neurons*self.n_bands, 1)) + + self.call_count += 1 + return lfp_data + def from_file_to_FACosEnc(plot=False): from riglib.bmi import state_space_models as ssm import pickle diff --git a/tests/run_unit_tests.py b/tests/run_unit_tests.py new file mode 100644 index 00000000..1246c0f1 --- /dev/null +++ b/tests/run_unit_tests.py @@ -0,0 +1,14 @@ +import unittest, os + +from tests.test_decoders import TestLinDec +test_classes = [ + TestLinDec +] + +suite = unittest.TestSuite() + +for cls in test_classes: + suite.addTest(unittest.makeSuite(cls)) + +runner = unittest.TextTestRunner() +runner_output = runner.run(suite) diff --git a/tests/test_decoders.py b/tests/test_decoders.py new file mode 100644 index 00000000..ea814346 --- /dev/null +++ b/tests/test_decoders.py @@ -0,0 +1,66 @@ +from riglib.bmi import lindecoder +from built_in_tasks.bmimultitasks import SimBMICosEncLinDec +from riglib import experiment +import numpy as np + +import unittest + +class TestLinDec(unittest.TestCase): + + def setUp(self): + pass + + def test_sanity(self): + simple_filt = lindecoder.LinearScaleFilter(100, 1, 1) + self.assertEqual(0, simple_filt.get_mean()) + + for i in range(50): + simple_filt([1]) + + self.assertEqual(0.5, np.mean(simple_filt.obs)) + self.assertEqual(0, simple_filt.get_mean()) + + for i in range(250): + simple_filt(i) + + self.assertTrue(simple_filt.get_mean() > 0.5) # 0.9 not working because of normalization by std instead of range + + def test_filter(self): + filt = lindecoder.LinearScaleFilter(100, 3, 2) + self.assertListEqual([0,0,0], filt.get_mean().tolist()) + for i in range(100): + filt([0, 0]) + self.assertEqual(0, filt.state.mean[0, 0]) + self.assertEqual(0, filt.state.mean[1, 0]) + self.assertEqual(0, filt.state.mean[2, 0]) + + #@unittest.skip('msg') + def test_experiment(self): + N_TARGETS = 8 + N_TRIALS = 2 + seq = SimBMICosEncLinDec.sim_target_seq_generator_multi( + N_TARGETS, N_TRIALS) + base_class = SimBMICosEncLinDec + feats = [] + Exp = experiment.make(base_class, feats=feats) + exp = Exp(seq) + exp.init() + exp.run() + + rewards = 0 + time_penalties = 0 + hold_penalties = 0 + for s in exp.event_log: + if s[0] == 'reward': + rewards += 1 + elif s[0] == 'hold_penalty': + hold_penalties += 1 + elif s[0] == 'timeout_penalty': + time_penalties += 1 + self.assertTrue(rewards <= rewards + time_penalties + hold_penalties) + self.assertTrue(rewards > 0) + +if __name__ == '__main__': + unittest.main() + + From 4d9b1a50d446d158bdf16024f8c82cfaf8bc540f Mon Sep 17 00:00:00 2001 From: leoscholl Date: Mon, 24 Aug 2020 11:58:49 -0700 Subject: [PATCH 062/242] linear decoder working pos --- built_in_tasks/bmimultitasks.py | 11 ++++-- riglib/bmi/lindecoder.py | 65 ++++++++++++++++++++------------- riglib/bmi/sim_neurons.py | 8 ++-- tests/test_decoders.py | 2 +- 4 files changed, 53 insertions(+), 33 deletions(-) diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index 96eb5a26..46f5d9a9 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -581,10 +581,10 @@ def __init__(self, *args, **kwargs): sim_C = np.zeros((2, 3)) # control x positive directions sim_C[0, :] = np.array([1, 0, 0]) - sim_C[1, :] = np.array([0, 1, 1]) + sim_C[1, :] = np.array([0, 0, 1]) kwargs['sim_C'] = sim_C - kwargs['assist_level'] = (0.1, 0.1) + kwargs['assist_level'] = (0, 0) # make the decoder map self.decoder_map = np.array([[1, 0], [0, 0], [0, 1]]) @@ -599,8 +599,11 @@ def load_decoder(self): units = self.encoder.get_units() ssm = StateSpaceEndptPos3D() filt_counts = 10000 # number of observations to calculate range - filt_window = 1 # number of observations to average for each tick + filt_window = 3 # number of observations to average for each tick filt_map = self.decoder_map # map from states to units - filt = LinearScaleFilter(filt_counts, ssm.n_states, len(units), map=filt_map, window=filt_window, gain=self.fov) + filt = LinearScaleFilter(filt_counts, ssm.n_states, len(units), map=filt_map, window=filt_window) + gain = 2 * np.max(self.plant.endpt_bounds) + filt.update_norm_param(neural_mean=[5, 5], neural_range=[10,10], scaling_mean=[0,0], scaling_range=[gain,gain]) + filt.fix_norm_param() self.decoder = Decoder(filt, units, ssm, binlen=0.1, subbins=1) self.decoder.n_features = len(units) \ No newline at end of file diff --git a/riglib/bmi/lindecoder.py b/riglib/bmi/lindecoder.py index 55c835a8..fdb33c96 100644 --- a/riglib/bmi/lindecoder.py +++ b/riglib/bmi/lindecoder.py @@ -11,7 +11,7 @@ def __init__(self, mean, *args, **kwargs): class LinearScaleFilter(object): - def __init__(self, n_counts, n_states, n_units, map=None, window=1, gain=20): + def __init__(self, n_counts, n_states, n_units, map=None, window=1): ''' Constructor for LinearScaleFilter @@ -27,9 +27,7 @@ def __init__(self, n_counts, n_states, n_units, map=None, window=1, gain=20): map : Which units to assign to which states (default = None) Floating point matrix of size (S, D) where S is the number of states and D is the number of units, assigning a weight to each pair - Sum along each row must equal 1.0 window : How many observations to average to smooth output (default = 1) - gain : How far to move the plant for a normalized output of 1.0 (default = 20) Returns ------- @@ -42,11 +40,16 @@ def __init__(self, n_counts, n_states, n_units, map=None, window=1, gain=20): self.window = window self.map = map if map is None: - # Generate a default map where one unit controls one state + # Generate a default mapping where one unit controls one state self.map = np.identity(max(n_states, n_units)) self.map = np.resize(self.map, (n_states, n_units)) - self.gain = gain self.count = 0 + self.params = dict( + neural_mean = np.zeros(n_units), + neural_range = np.ones(n_units), + scaling_mean = np.zeros(n_units), + scaling_range = np.ones(n_units), + ) self.fixed = False def _init_state(self): @@ -55,38 +58,50 @@ def _init_state(self): def get_mean(self): return np.array(self.state.mean).ravel() - def __call__(self, obs, **kwargs): + def __call__(self, obs, **kwargs): # TODO need to pick single frequency band if given more than one self.state = self._normalize(obs, **kwargs) def _normalize(self, obs,**kwargs): ''' Function to compute normalized scaling of new observations''' - # Update observation matrix, unless it has been fixed - if not self.fixed: - self.obs[:-1, :] = self.obs[1:, :] - self.obs[-1, :] = np.squeeze(obs) - if self.count < len(self.obs): - self.count += 1 + # Update observation matrix + norm_obs = (obs.ravel() - self.params['neural_mean']) / self.params['neural_range'] # center on zero + self.obs[:-1, :] = self.obs[1:, :] + self.obs[-1, :] = norm_obs + if self.count < len(self.obs): + self.count += 1 - # Normalize latest observation(s) + if not self.fixed: + self._update_scale_param(obs) m_win = np.squeeze(np.mean(self.obs[-self.window:, :], axis=0)) - m = np.median(self.obs[-self.count:, :], axis=0) - # range = max(1, np.amax(self.obs[-self.count:, :]) - np.amin(self.obs[-self.count:, :])) - range = 3 * np.std(self.obs[-self.count:, :], axis=0) - range[range < 1] = 1 - x = (m_win - m) / range * self.gain + x = (m_win - self.params['scaling_mean']) * self.params['scaling_range'] # Arrange output according to map out = np.matmul(self.map, x).reshape(-1,1) return State(out) - def save_obs(self): - raise NotImplementedError() + def _update_scale_param(self, obs): + ''' Function to update the normalization parameters''' + + # Normalize latest observation(s) + mean = np.median(self.obs[-self.count:, :], axis=0) + # range = max(1, np.amax(self.obs[-self.count:, :]) - np.amin(self.obs[-self.count:, :])) + range = 3 * np.std(self.obs[-self.count:, :], axis=0) + range[range < 1] = 1 + self.update_norm_param(scaling_mean=mean, scaling_range=range) + + def update_norm_param(self, neural_mean=None, neural_range=None, scaling_mean=None, scaling_range=None): + if neural_mean is not None: + self.params.update(neural_mean = neural_mean) + if neural_range is not None: + self.params.update(neural_range = neural_range) + if scaling_mean is not None: + self.params.update(scaling_mean = scaling_mean) + if scaling_range is not None: + self.params.update(scaling_range = scaling_range) - def fix_obs(self): + def fix_norm_param(self): self.fixed = True - def load_and_fix_obs(self, file): - raise NotImplementedError() - self.count = len(self.obs) - self.fix_obs() \ No newline at end of file + def get_norm_param(self): + return self.params \ No newline at end of file diff --git a/riglib/bmi/sim_neurons.py b/riglib/bmi/sim_neurons.py index 87e5d0d9..57c882b7 100644 --- a/riglib/bmi/sim_neurons.py +++ b/riglib/bmi/sim_neurons.py @@ -368,7 +368,7 @@ class NormalizedCosEnc(GenericCosEnc): the task, instead of DC rectifying the output. Generates simulated spiking or LFP power data ''' - def __init__(self, bounds, C, ssm, spike=True, return_ts=False, DT=0.1, tick=1/60, n_bands=1): + def __init__(self, bounds, C, ssm, spike=True, return_ts=False, DT=0.1, tick=1/60, n_bands=1, gain=10): ''' Constructor for NormalizedCosEnc @@ -400,9 +400,9 @@ def __init__(self, bounds, C, ssm, spike=True, return_ts=False, DT=0.1, tick=1/6 self.min = np.array([bounds[0], bounds[2], bounds[4]]) self.range = np.array([bounds[1] - bounds[0], bounds[3] - bounds[2], bounds[5] - bounds[4]]) self.range[self.range == 0] = 1 - self.gain = 10 self.spike = spike self.n_bands = n_bands + self.gain = gain call_ds_rate = DT / tick super(NormalizedCosEnc, self).__init__(C, ssm, return_ts, DT, call_ds_rate) @@ -441,8 +441,10 @@ def gen_power(self, next_state, mode=None): """ norm_state = np.divide(np.subtract(np.squeeze(next_state), self.min), self.range) ideal = np.dot(self.C, norm_state) * self.gain + # Generate gaussian noise - noise = np.random.normal(0, 0, size=(len(ideal), self.n_bands)) + noise = np.random.normal(0, 0.05 * self.gain, size=(len(ideal), self.n_bands)) + # Replicate across frequency bands power = np.tile(ideal.reshape((-1,1)), (1, self.n_bands)) + noise return power.reshape(-1,1) diff --git a/tests/test_decoders.py b/tests/test_decoders.py index ea814346..d1de650f 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -37,7 +37,7 @@ def test_filter(self): #@unittest.skip('msg') def test_experiment(self): N_TARGETS = 8 - N_TRIALS = 2 + N_TRIALS = 100 seq = SimBMICosEncLinDec.sim_target_seq_generator_multi( N_TARGETS, N_TRIALS) base_class = SimBMICosEncLinDec From 9004459a3590c5d19549f1de9e844386c78c3d6a Mon Sep 17 00:00:00 2001 From: leoscholl Date: Mon, 24 Aug 2020 12:59:44 -0700 Subject: [PATCH 063/242] init vel scale decoder --- built_in_tasks/bmimultitasks.py | 10 +++++----- riglib/bmi/lindecoder.py | 29 ++++++++++++++++++++++++++--- riglib/bmi/sim_neurons.py | 4 ++-- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index 46f5d9a9..6959aeff 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -572,7 +572,7 @@ def __init__(self, *args, **kwargs): super(SimBMICosEncKFDec, self).__init__(*args, **kwargs) from features.simulation_features import SimLFPCosineTunedEnc, SimNormCosineTunedEnc -from riglib.bmi.lindecoder import LinearScaleFilter +from riglib.bmi.lindecoder import PosVelScaleFilter from riglib.bmi.bmi import Decoder class SimBMICosEncLinDec(SimLFPCosineTunedEnc, SimBMIControlMulti): def __init__(self, *args, **kwargs): @@ -589,7 +589,7 @@ def __init__(self, *args, **kwargs): # make the decoder map self.decoder_map = np.array([[1, 0], [0, 0], [0, 1]]) - ssm = StateSpaceEndptPos3D() + ssm = StateSpaceEndptVel2D() self.fb_ctrl = PosFeedbackController() self.ssm = ssm @@ -597,13 +597,13 @@ def __init__(self, *args, **kwargs): def load_decoder(self): units = self.encoder.get_units() - ssm = StateSpaceEndptPos3D() filt_counts = 10000 # number of observations to calculate range filt_window = 3 # number of observations to average for each tick filt_map = self.decoder_map # map from states to units - filt = LinearScaleFilter(filt_counts, ssm.n_states, len(units), map=filt_map, window=filt_window) + vel_control = False + filt = PosVelScaleFilter(vel_control, filt_counts, self.ssm.n_states, len(units), map=filt_map, window=filt_window) gain = 2 * np.max(self.plant.endpt_bounds) filt.update_norm_param(neural_mean=[5, 5], neural_range=[10,10], scaling_mean=[0,0], scaling_range=[gain,gain]) filt.fix_norm_param() - self.decoder = Decoder(filt, units, ssm, binlen=0.1, subbins=1) + self.decoder = Decoder(filt, units, self.ssm, binlen=0.1, subbins=1) self.decoder.n_features = len(units) \ No newline at end of file diff --git a/riglib/bmi/lindecoder.py b/riglib/bmi/lindecoder.py index fdb33c96..fd0271f2 100644 --- a/riglib/bmi/lindecoder.py +++ b/riglib/bmi/lindecoder.py @@ -4,7 +4,7 @@ import numpy as np class State(object): - '''For compatibility with other BMI decoding implementations, literally just holds the state''' + '''For compatibility with other BMI decoding implementations''' def __init__(self, mean, *args, **kwargs): self.mean = mean @@ -65,7 +65,7 @@ def _normalize(self, obs,**kwargs): ''' Function to compute normalized scaling of new observations''' # Update observation matrix - norm_obs = (obs.ravel() - self.params['neural_mean']) / self.params['neural_range'] # center on zero + norm_obs = (np.squeeze(obs) - self.params['neural_mean']) / self.params['neural_range'] # center on zero self.obs[:-1, :] = self.obs[1:, :] self.obs[-1, :] = norm_obs if self.count < len(self.obs): @@ -104,4 +104,27 @@ def fix_norm_param(self): self.fixed = True def get_norm_param(self): - return self.params \ No newline at end of file + return self.params + +class PosVelState(State): + + def __init__(self, vel_control, *args, **kwargs): + self.vel_control = vel_control + self.mean = np.zeros((7,1)) + + def update(self, mean): + if self.vel_control: + self.mean[3:6] = mean + self.mean[0:3] = self.mean[3:6] + self.mean[0:3] + else: + self.mean = np.zeros((7,1)) + self.mean[0:3] = mean + +class PosVelScaleFilter(LinearScaleFilter): + def __init__(self, vel_control, *args, **kwargs): + super(PosVelScaleFilter, self).__init__(*args, **kwargs) + self.state = PosVelState(vel_control) + + def __call__(self, obs, **kwargs): + state = self._normalize(obs, **kwargs) + self.state.update(state.mean) \ No newline at end of file diff --git a/riglib/bmi/sim_neurons.py b/riglib/bmi/sim_neurons.py index 57c882b7..eaa0688f 100644 --- a/riglib/bmi/sim_neurons.py +++ b/riglib/bmi/sim_neurons.py @@ -421,7 +421,7 @@ def gen_spikes(self, next_state, mode=None): Either spike time stamps or a vector of unit spike counts is returned, depending on whether the 'return_ts' attribute is True """ - norm_state = np.divide(np.subtract(np.squeeze(next_state), self.min), self.range) + norm_state = np.divide(np.subtract(np.squeeze(next_state[0:3]), self.min), self.range) rates = np.dot(self.C, norm_state) * self.gain return self.return_spikes(rates, mode=mode) @@ -439,7 +439,7 @@ def gen_power(self, next_state, mode=None): powers : np.array of shape (N, P) -> flattened N number of neurons, P number of power bands, determined by n_bands """ - norm_state = np.divide(np.subtract(np.squeeze(next_state), self.min), self.range) + norm_state = np.divide(np.subtract(np.squeeze(next_state[0:3]), self.min), self.range) ideal = np.dot(self.C, norm_state) * self.gain # Generate gaussian noise From b0dcd206dc14521caa998777858d3ccf75accfda Mon Sep 17 00:00:00 2001 From: Pavi Raj Date: Mon, 24 Aug 2020 14:30:20 -0700 Subject: [PATCH 064/242] aolab reward system interface with bmi3D --- built_in_tasks/manualcontrolmultitasks.py | 4 +- features/reward_features.py | 35 +- riglib/reward.py | 415 +++------------------- tests/test_reward.py | 7 +- tests/test_reward_ao.py | 20 +- tests/test_reward_python2arduino.py | 10 +- 6 files changed, 119 insertions(+), 372 deletions(-) diff --git a/built_in_tasks/manualcontrolmultitasks.py b/built_in_tasks/manualcontrolmultitasks.py index a8c81779..37023d2f 100644 --- a/built_in_tasks/manualcontrolmultitasks.py +++ b/built_in_tasks/manualcontrolmultitasks.py @@ -6,7 +6,7 @@ from collections import OrderedDict import time -from riglib import reward +from riglib import reward # This import file corresponds to the Orsborn lab reward system now from riglib.experiment import traits, Sequence from riglib.stereo_opengl.window import Window, FPScontrol, WindowDispl2D @@ -97,7 +97,7 @@ class ManualControlMulti(Sequence, Window): # Runtime settable traits - reward_time = traits.Float(.5, desc="Length of juice reward") + reward_time = traits.Float(.2, desc="Length of juice reward") target_radius = traits.Float(2, desc="Radius of targets in cm") hold_time = traits.Float(.2, desc="Length of hold required at targets") diff --git a/features/reward_features.py b/features/reward_features.py index a3f2d335..3e81136a 100644 --- a/features/reward_features.py +++ b/features/reward_features.py @@ -16,16 +16,45 @@ ###### CONSTANTS sec_per_min = 60 - class RewardSystem(traits.HasTraits): ''' - Feature for the Crist solenoid reward system + Feature for the current reward system in Amy Orsborn Lab - Aug 2020 ''' trials_per_reward = traits.Float(1, desc='Number of successful trials before solenoid is opened') + def __init__(self, *args, **kwargs): from riglib import reward super(RewardSystem, self).__init__(*args, **kwargs) - self.reward = reward.open() + self.reward = reward.open() # There is no open function in our reward.py + + def _start_reward(self): + self.reward_start = self.get_time() + if self.reward is not None: + self.reportstats['Reward #'] += 1 + if self.reportstats['Reward #'] % self.trials_per_reward == 0: + self.reward.reward(self.reward_time) + #super(RewardSystem, self).reward(self.reward_time) + + def _test_reward_end(self, ts): + if self.reportstats['Reward #'] % self.trials_per_reward == 0: + return ts > self.reward_time + else: + return True + + +""""" BELOW THIS IS ALL THE OLD CODE ASSOCIATED WITH REWARD FEATURES""" + + +class RewardSystem_Crist(traits.HasTraits): + ''' + Feature for the Crist solenoid reward system + ''' + trials_per_reward = traits.Float(1, desc='Number of successful trials before solenoid is opened') + + def __init__(self, *args, **kwargs): + from riglib import reward_crist + super(RewardSystem, self).__init__(*args, **kwargs) + self.reward = reward_crist.open() def _start_reward(self): self.reward_start = self.get_time() diff --git a/riglib/reward.py b/riglib/reward.py index 2300ccaa..4e8a4613 100644 --- a/riglib/reward.py +++ b/riglib/reward.py @@ -1,373 +1,72 @@ -''' -Code for interacting with the Crist reward system(s). Consult the Crist manual for the command protocol -''' - - -import glob +""" +Written by Pavi - Aug 2020 for reward system integration +Code for reward system used in Amy Orsborn lab +Functions list: +-------------- +Class Basic -> reward(reward_time_s), test, calibrate, drain(drain_time_s) +""" + +# import functions +import pyfirmata import time -import struct -import binascii -import threading -import io -import traceback - - -import serial -import time - -import numpy - -try: - import traits.api as traits -except: - import enthought.traits.api as traits - -def _xsum(msg): - ''' - Compute the checksums for the messages which must be sent as part of the packet - - Parameters - ---------- - msg : string - Message to be sent over the serial port - - Returns - ------- - char - The 8-bit checksum of the entire message - ''' - chrval = [int(''.join(x), 16) for x in zip(*[iter(binascii.b2a_hex(msg))]*2)] - return chr(sum(chrval) % 256) class Basic(object): - ''' - Bare-bones interface for the Crist reward system. Can give timed reward and drain on/off. - This class is sufficient for all the tasks implemented as of Aug. 2014. - ''' - response_message_length = 7 - def __init__(self): - ''' - Constructor for basic reward system interface - - Parameters - ---------- - None - - Returns - ------- - Basic instance - ''' - self.port = serial.Serial('/dev/crist_reward', baudrate=38400) - from config import config - self.version = int(config.reward_sys['version']) - if self.version==1: self.set_beeper_volume(128) - time.sleep(.5) - self.reset() - - def _write(self, msg): - ''' - Send an arbitrary message over the serial port - - Parameters - ---------- - msg : string - Message to be sent over the serial port - - Returns - ------- - msg_out : string - Response from crist system after sending command - ''' - fmsg = msg+_xsum(msg) - self.port.flushOutput() - self.port.flushInput() - self.port.write(fmsg) - msg_out = self.port.read(self.port.inWaiting()) - return msg_out - - def reward(self, length): - ''' - Open the solenoid for some length of time - - Parameters - ---------- - length : float - Duration of time the solenoid should be open, in seconds. NOTE: in some versions of the system, there appears to be max of ~5s + # com_port = '/dev/ttyACM0' # specify the port, based on windows/Unix, can find it on IDE or terminal + # board = pyfirmata.Arduino(com_port) - Returns - ------- - None - ''' - length /= .1 - length = int(length) - if self.version==0: - self._write(struct.pack('= 0 and volume <= 255): - raise ValueError("Invalid beeper volume: %g" % volume) - return self._write('@CS' + '%c' % volume + 'E' + struct.pack('xxx')) - - def reset(self): - ''' - Send the system reset command - ''' - if self.version==0: - self._write("@CPSNNN") - elif self.version==1: - cmd = ['@', 'C', '1', 'P', '%c' % 0b10000000, '%c' % 0, '%c' % 0, 'D'] - stuff = ''.join(cmd) - self._write(stuff) - else: - raise Exception("Unrecognized reward system version!") - self.last_response = self.port.read(self.port.inWaiting()) - - def drain(self, drain_time=1200): - ''' - Turns on the reward system drain for specified amount of time (in seconds) - - Parameters - ---------- - drain_time : float - Time to drain the system, in seconds. - - Returns - ------- - None - ''' - assert drain_time > 0 - assert drain_time < 9999 - if self.version == 0: #have to wait and manually tell it to turn off - self._write("@CNSENN") - time.sleep(drain_time) - self._write("@CNSDNN") - elif self.version == 1: - self._write('@M1' + struct.pack('H', drain_time) + 'D' + struct.pack('xx')) - else: - raise Exception("Unrecognized reward system version!") - - def drain_off(self): - ''' - Turns off drain if currently on - ''' - if self.version==0: - self._write("@CNSDNN") - elif self.version==1: - self._write('@M1' + struct.pack('H', 0) + 'A' + struct.pack('xx')) - else: - raise Exception("Unrecognized reward system version!") - - -########################################## -##### Code below this line is unused ##### -########################################## -class _parse_num(object): - types = {1:' 2: - self._messages[header][-1](msg) - else: - print(self._messages[header], repr(msg)) - except: - traceback.print_exc() - time.sleep(10) - print(repr(msg),repr(self.port.read(self.port.inWaiting()))) - - def reward(self, time=500, volume=None): - '''Returns the string used to output a time or volume reward. - - Parameters - ---------- - time : int - Time in milliseconds to turn on the reward - volume: int - volume in microliters - ''' - assert (volume is None and time is not None) or \ - (volume is not None and time is None) - time /= .1 - self._write(struct.pack(' Date: Mon, 24 Aug 2020 14:30:51 -0700 Subject: [PATCH 065/242] ao lab reward system interface --- built_in_tasks/cursorControlTasks_reward.py | 129 +++++++ db/tracker/ajax.py | 1 + riglib/reward_crist.py | 373 ++++++++++++++++++++ 3 files changed, 503 insertions(+) create mode 100644 built_in_tasks/cursorControlTasks_reward.py create mode 100644 riglib/reward_crist.py diff --git a/built_in_tasks/cursorControlTasks_reward.py b/built_in_tasks/cursorControlTasks_reward.py new file mode 100644 index 00000000..e6312720 --- /dev/null +++ b/built_in_tasks/cursorControlTasks_reward.py @@ -0,0 +1,129 @@ +""" +This script is a modified version of cursorControlTasks_optitrack to include reward feature in cursor control task + +modified by Pavi Aug 2020 +""" + +from manualcontrolmultitasks import ManualControlMulti +from riglib.stereo_opengl.window import WindowDispl2D +# from bmimultitasks import BMIControlMulti +import pygame +import numpy as np +import copy + +# from riglib.bmi.extractor import DummyExtractor +# from riglib.bmi.state_space_models import StateSpaceEndptVel2D +# from riglib.bmi.bmi import Decoder, BMISystem, GaussianStateHMM, BMILoop, GaussianState, MachineOnlyFilter +from riglib import experiment +from features.hdf_features import SaveHDF +from features.reward_features import RewardSystem + +class CursorControl(ManualControlMulti, WindowDispl2D): + ''' + this class implements a python cursor control task for human + ''' + + def __init__(self, *args, **kwargs): + # just run the parent ManualControlMulti's initialization + self.move_step = 1 + + # Initialize target location variable + #target location and index have been initializd + + super(CursorControl, self).__init__(*args, **kwargs) + + def init(self): + pygame.init() + self.assist_level = (0, 0) + super(CursorControl, self).init() + + # override the _cycle function + def _cycle(self): + #print(self.state) + + #target and plant data have been saved in + #the parent manualcontrolmultitasks + + self.move_effector_cursor() + super(CursorControl, self)._cycle() + + # do nothing + def move_effector(self): + pass + + def move_plant(self, **kwargs): + pass + + # use keyboard to control the task + def move_effector_cursor(self): + np.array([0., 0., 0.]) + curr_pos = copy.deepcopy(self.plant.get_endpoint_pos()) + + for event in pygame.event.get(): + if event.type == pygame.KEYUP: + if event.type == pygame.K_q: + pygame.quit() + quit() + if event.key == pygame.K_LEFT: + curr_pos[0] -= self.move_step + if event.key == pygame.K_RIGHT: + curr_pos[0] += self.move_step + if event.key == pygame.K_UP: + curr_pos[2] += self.move_step + if event.key == pygame.K_DOWN: + curr_pos[2] -= self.move_step + #print('Current position: ') + #print(curr_pos) + + # set the current position + self.plant.set_endpoint_pos(curr_pos) + + def _start_wait(self): + self.wait_time = 0. + super(CursorControl, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + +#this task can be run on its +#we will not involve database at this time +target_pos_radius = 10 + +def target_seq_generator(n_targs, n_trials): + #generate targets + angles = np.transpose(np.arange(0,2*np.pi,2*np.pi / n_targs)) + unit_targets = targets = np.stack((np.cos(angles), np.sin(angles)),1) + targets = unit_targets * target_pos_radius + + center = np.array((0,0)) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + + k = 0 + while k < n_trials: + targ = targets[target_inds[k], :] + yield np.array([[center[0], 0, center[1]], + [targ[0], 0, targ[1]]]) + k += 1 + + + +if __name__ == "__main__": + print('Remember to set window size in stereoOpenGL class') + gen = target_seq_generator(8, 2) + + # incorporate the saveHDF feature by blending code + # see tests\start_From_cmd_line_sim + from features.reward_features import RewardSystem + #from features.optitrack_feature import MotionData + from features.hdf_features import SaveHDF + + base_class = CursorControl + feats = [RewardSystem, SaveHDF] + Exp = experiment.make(base_class, feats=feats) + print(Exp) + + exp = Exp(gen) + exp.init() + exp.run() # start the task \ No newline at end of file diff --git a/db/tracker/ajax.py b/db/tracker/ajax.py index 41448542..2581ee05 100644 --- a/db/tracker/ajax.py +++ b/db/tracker/ajax.py @@ -373,6 +373,7 @@ def save_notes(request, idx): def reward_drain(request, onoff): ''' Start/stop the "drain" of a solenoid reward remotely + This function is modified to use the reward system in Orsborn lab - check reward.py for functions ''' from riglib import reward r = reward.Basic() diff --git a/riglib/reward_crist.py b/riglib/reward_crist.py new file mode 100644 index 00000000..2300ccaa --- /dev/null +++ b/riglib/reward_crist.py @@ -0,0 +1,373 @@ +''' +Code for interacting with the Crist reward system(s). Consult the Crist manual for the command protocol +''' + + +import glob +import time + +import struct +import binascii +import threading +import io +import traceback + + +import serial +import time + +import numpy + +try: + import traits.api as traits +except: + import enthought.traits.api as traits + +def _xsum(msg): + ''' + Compute the checksums for the messages which must be sent as part of the packet + + Parameters + ---------- + msg : string + Message to be sent over the serial port + + Returns + ------- + char + The 8-bit checksum of the entire message + ''' + chrval = [int(''.join(x), 16) for x in zip(*[iter(binascii.b2a_hex(msg))]*2)] + return chr(sum(chrval) % 256) + +class Basic(object): + ''' + Bare-bones interface for the Crist reward system. Can give timed reward and drain on/off. + This class is sufficient for all the tasks implemented as of Aug. 2014. + ''' + response_message_length = 7 + def __init__(self): + ''' + Constructor for basic reward system interface + + Parameters + ---------- + None + + Returns + ------- + Basic instance + ''' + self.port = serial.Serial('/dev/crist_reward', baudrate=38400) + from config import config + self.version = int(config.reward_sys['version']) + if self.version==1: self.set_beeper_volume(128) + time.sleep(.5) + self.reset() + + def _write(self, msg): + ''' + Send an arbitrary message over the serial port + + Parameters + ---------- + msg : string + Message to be sent over the serial port + + Returns + ------- + msg_out : string + Response from crist system after sending command + ''' + fmsg = msg+_xsum(msg) + self.port.flushOutput() + self.port.flushInput() + self.port.write(fmsg) + msg_out = self.port.read(self.port.inWaiting()) + return msg_out + + def reward(self, length): + ''' + Open the solenoid for some length of time + + Parameters + ---------- + length : float + Duration of time the solenoid should be open, in seconds. NOTE: in some versions of the system, there appears to be max of ~5s + + Returns + ------- + None + ''' + length /= .1 + length = int(length) + if self.version==0: + self._write(struct.pack('= 0 and volume <= 255): + raise ValueError("Invalid beeper volume: %g" % volume) + return self._write('@CS' + '%c' % volume + 'E' + struct.pack('xxx')) + + def reset(self): + ''' + Send the system reset command + ''' + if self.version==0: + self._write("@CPSNNN") + elif self.version==1: + cmd = ['@', 'C', '1', 'P', '%c' % 0b10000000, '%c' % 0, '%c' % 0, 'D'] + stuff = ''.join(cmd) + self._write(stuff) + else: + raise Exception("Unrecognized reward system version!") + self.last_response = self.port.read(self.port.inWaiting()) + + def drain(self, drain_time=1200): + ''' + Turns on the reward system drain for specified amount of time (in seconds) + + Parameters + ---------- + drain_time : float + Time to drain the system, in seconds. + + Returns + ------- + None + ''' + assert drain_time > 0 + assert drain_time < 9999 + if self.version == 0: #have to wait and manually tell it to turn off + self._write("@CNSENN") + time.sleep(drain_time) + self._write("@CNSDNN") + elif self.version == 1: + self._write('@M1' + struct.pack('H', drain_time) + 'D' + struct.pack('xx')) + else: + raise Exception("Unrecognized reward system version!") + + def drain_off(self): + ''' + Turns off drain if currently on + ''' + if self.version==0: + self._write("@CNSDNN") + elif self.version==1: + self._write('@M1' + struct.pack('H', 0) + 'A' + struct.pack('xx')) + else: + raise Exception("Unrecognized reward system version!") + + +########################################## +##### Code below this line is unused ##### +########################################## +class _parse_num(object): + types = {1:' 2: + self._messages[header][-1](msg) + else: + print(self._messages[header], repr(msg)) + except: + traceback.print_exc() + time.sleep(10) + print(repr(msg),repr(self.port.read(self.port.inWaiting()))) + + def reward(self, time=500, volume=None): + '''Returns the string used to output a time or volume reward. + + Parameters + ---------- + time : int + Time in milliseconds to turn on the reward + volume: int + volume in microliters + ''' + assert (volume is None and time is not None) or \ + (volume is not None and time is None) + time /= .1 + self._write(struct.pack(' Date: Mon, 24 Aug 2020 16:19:17 -0700 Subject: [PATCH 066/242] update in reward timing variable names --- features/reward_features.py | 2 +- riglib/reward.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/features/reward_features.py b/features/reward_features.py index 3e81136a..b2ce6d9d 100644 --- a/features/reward_features.py +++ b/features/reward_features.py @@ -32,7 +32,7 @@ def _start_reward(self): if self.reward is not None: self.reportstats['Reward #'] += 1 if self.reportstats['Reward #'] % self.trials_per_reward == 0: - self.reward.reward(self.reward_time) + self.reward.reward(self.reward_time*10) #super(RewardSystem, self).reward(self.reward_time) def _test_reward_end(self, ts): diff --git a/riglib/reward.py b/riglib/reward.py index 4e8a4613..5ea980f8 100644 --- a/riglib/reward.py +++ b/riglib/reward.py @@ -19,10 +19,10 @@ def __init__(self): com_port = '/dev/ttyACM0' # specify the port, based on windows/Unix, can find it on IDE or terminal self.board = pyfirmata.Arduino(com_port) - def reward(self, reward_time_s): + def reward(self, reward_time_s=0.2): """Open the solenoid for some length of time. This function does not run the loop infinitely""" self.board.digital[13].write(1) # send a high signal to Pin 13 on the arduino which should be connected to the reward system - time.sleep(rewardtime) # in second + time.sleep(reward_time_s) # in second print('ON') self.board.digital[13].write(0) print('OFF') From 2575e2a705939556f2089a100c8bfd19393198db Mon Sep 17 00:00:00 2001 From: leoscholl Date: Mon, 24 Aug 2020 21:57:15 -0700 Subject: [PATCH 067/242] mouse and keyboard control --- features/input_device_features.py | 77 +++++++++++++++++++++++++++++++ tests/run_unit_tests.py | 5 +- tests/test_decoders.py | 2 +- tests/test_features.py | 36 +++++++++++++++ 4 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 features/input_device_features.py create mode 100644 tests/test_features.py diff --git a/features/input_device_features.py b/features/input_device_features.py new file mode 100644 index 00000000..7b3cfa5a --- /dev/null +++ b/features/input_device_features.py @@ -0,0 +1,77 @@ + +import pygame +import numpy as np +import copy + +class KeyboardControl(object): + ''' + this class implements a python cursor control task for human + ''' + + def __init__(self, *args, **kwargs): + self.move_step = 1 + self.assist_level = (0.5, 0.5) + super(KeyboardControl, self).__init__(*args, **kwargs) + + # override the _cycle function + def _cycle(self): + self.move_effector_cursor() + super(KeyboardControl, self)._cycle() + + def move_effector(self): + pass + + def move_plant(self, **kwargs): + pass + + # use keyboard to control the task + def move_effector_cursor(self): + curr_pos = copy.deepcopy(self.plant.get_endpoint_pos()) + + for event in pygame.event.get(): + if event.type == pygame.KEYUP: + if event.type == pygame.K_q: + pygame.quit() + quit() + if event.key == pygame.K_LEFT: + curr_pos[0] -= self.move_step + if event.key == pygame.K_RIGHT: + curr_pos[0] += self.move_step + if event.key == pygame.K_UP: + curr_pos[2] += self.move_step + if event.key == pygame.K_DOWN: + curr_pos[2] -= self.move_step + #print('Current position: ') + #print(curr_pos) + + # set the current position + self.plant.set_endpoint_pos(curr_pos) + + def _start_wait(self): + self.wait_time = 0. + super(KeyboardControl, self)._start_wait() + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + +class MouseControl(KeyboardControl): + + def init(self, *args, **kwargs): + self.center = [self.window_size[0] / 2, self.window_size[1] / 2] + self.pos = self.plant.get_endpoint_pos() + super(MouseControl, self).init(*args, **kwargs) + + def move_effector_cursor(self): + + # Update position but keep mouse in center + pygame.mouse.set_visible(False) + pygame.event.set_grab(True) + rel = pygame.mouse.get_rel() + self.pos[0] += rel[0] / self.window_size[0] * self.fov + self.pos[2] -= rel[1] / self.window_size[1] * self.fov + pos, _ = self.plant._bound(self.pos, []) + self.plant.set_endpoint_pos(pos) + + def cleanup(self, *args, **kwargs): + pygame.mouse.set_visible(True) + super(MouseControl, self).cleanup(*args, **kwargs) diff --git a/tests/run_unit_tests.py b/tests/run_unit_tests.py index 1246c0f1..b1e66f72 100644 --- a/tests/run_unit_tests.py +++ b/tests/run_unit_tests.py @@ -1,8 +1,11 @@ import unittest, os from tests.test_decoders import TestLinDec +from tests.test_features import TestKeyboardControl, TestMouseControl test_classes = [ - TestLinDec + #TestLinDec, + #TestKeyboardControl, + TestMouseControl, ] suite = unittest.TestSuite() diff --git a/tests/test_decoders.py b/tests/test_decoders.py index d1de650f..86ba4eee 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -37,7 +37,7 @@ def test_filter(self): #@unittest.skip('msg') def test_experiment(self): N_TARGETS = 8 - N_TRIALS = 100 + N_TRIALS = 3 seq = SimBMICosEncLinDec.sim_target_seq_generator_multi( N_TARGETS, N_TRIALS) base_class = SimBMICosEncLinDec diff --git a/tests/test_features.py b/tests/test_features.py new file mode 100644 index 00000000..3df99f9c --- /dev/null +++ b/tests/test_features.py @@ -0,0 +1,36 @@ +from riglib import experiment +from built_in_tasks.manualcontrolmultitasks import ManualControlMulti +from riglib.stereo_opengl.window import WindowDispl2D +from features.input_device_features import KeyboardControl, MouseControl +import numpy as np + +import unittest + +def init_exp(base_class, feats): + blocks = 1 + targets = 3 + seq = ManualControlMulti.centerout_2D_discrete(blocks, targets) + Exp = experiment.make(base_class, feats=feats) + exp = Exp(seq) + exp.init() + return exp + +class TestKeyboardControl(unittest.TestCase): + + def setUp(self): + pass + + def test_exp(self): + exp = init_exp(ManualControlMulti, [KeyboardControl, WindowDispl2D]) + exp.run() + +class TestMouseControl(unittest.TestCase): + + def test_exp(self): + exp = init_exp(ManualControlMulti, [MouseControl, WindowDispl2D]) + exp.run() + +if __name__ == '__main__': + unittest.main() + + From a8a5816142a81047c398c690b412d56c9a056745 Mon Sep 17 00:00:00 2001 From: leoscholl Date: Fri, 4 Sep 2020 18:24:48 -0700 Subject: [PATCH 068/242] cleanup for code review --- built_in_tasks/bmimultitasks.py | 100 +- features/simulation_features.py | 4 +- riglib/bmi/extractor.py | 6 +- riglib/bmi/lindecoder.py | 134 +- riglib/bmi/sim_neurons.py | 39 +- riglib/bmi/state_space_models.py | 20 - .../blackrock_lfp_train_test.py | 0 .../ex_redundant_ik_grad_descent.py | 0 tests/{ => old_tests}/graphics_demo.py | 0 tests/{ => old_tests}/hdfwrite.py | 0 tests/{ => old_tests}/lfp_daq_test.py | 0 tests/{ => old_tests}/lfp_train_test.py | 0 tests/{ => old_tests}/plexnet.py | 0 tests/{ => old_tests}/plexstream.py | 0 tests/{ => old_tests}/plexstream_bmi.py | 0 tests/{ => old_tests}/psth.py | 0 tests/{ => old_tests}/spikebin.py | 0 .../start_task_from_cmd_line.py | 0 .../start_task_from_cmd_line_sim.py | 0 .../test_SimBMIControlMulti.py | 0 .../test_adapting_ppf_with_assist.py | 0 .../test_arduino_dio_omniplex.py | 0 .../test_blackrock_cbpy_streaming.py | 0 .../{ => old_tests}/test_blackrock_rstart.py | 0 tests/{ => old_tests}/test_bmi_bin_edges.py | 0 tests/{ => old_tests}/test_bmi_recon.py | 0 .../{ => old_tests}/test_constrained_sskf.py | 0 tests/{ => old_tests}/test_data.py | 0 tests/{ => old_tests}/test_db.py | 0 tests/{ => old_tests}/test_dbfunctions.py | 0 .../{ => old_tests}/test_decoder_training.py | 0 tests/{ => old_tests}/test_fa_decoding.py | 0 tests/{ => old_tests}/test_fa_tasks.py | 0 tests/{ => old_tests}/test_fixed_ppf.py | 0 tests/{ => old_tests}/test_kinarm_stream.py | 0 .../test_lfp_cursor_fixed_decoder.py | 0 tests/{ => old_tests}/test_map_arduino_dio.py | 0 tests/{ => old_tests}/test_npz_to_hdf.py | 0 tests/{ => old_tests}/test_onedim_lfp.py | 0 .../{ => old_tests}/test_phidgets_joystick.py | 0 .../test_phidgets_joystick_acqu.py | 0 tests/{ => old_tests}/test_plx_rstart.py | 0 tests/{ => old_tests}/test_rat_bmi.py | 0 .../test_reconst_bmi_trajectory.py | 0 tests/{ => old_tests}/test_reward.py | 0 tests/{ => old_tests}/test_reward_ao.py | 0 tests/{ => old_tests}/test_rml_cg.py | 0 .../test_rml_cg_param_updates.py | 0 .../{ => old_tests}/test_rstart_barebones.py | 0 .../test_tentacle_attractor.py | 0 .../test_train_kfdecoder_brain_control.py | 0 .../test_train_kfdecoder_manual_control.py | 0 tests/{ => old_tests}/test_window_2d.py | 0 tests/{ => old_tests}/vfb_joint_space.py | 0 tests/run_unit_tests.py | 4 +- tests/test_decoders.py | 59 +- .../coverage.bat | 0 .../coverage.sh | 0 ...n-python-interface_riglib___init___py.html | 64 + ...ython-interface_riglib_arduino_imu_py.html | 196 +++ ...-interface_riglib_arduino_joystick_py.html | 194 +++ ...nterface_riglib_blackrock___init___py.html | 130 ++ ...erface_riglib_blackrock_brMiscFxns_py.html | 139 ++ ...interface_riglib_blackrock_brpylib_py.html | 1325 +++++++++++++++ ...nterface_riglib_blackrock_cerelink_py.html | 248 +++ ...thon-interface_riglib_bmi___init___py.html | 69 + ...n-interface_riglib_bmi_accumulator_py.html | 183 ++ ...python-interface_riglib_bmi_assist_py.html | 201 +++ ...in-python-interface_riglib_bmi_bmi_py.html | 1496 +++++++++++++++++ ...n-python-interface_riglib_bmi_clda_py.html | 1182 +++++++++++++ ...hon-interface_riglib_bmi_extractor_py.html | 1424 ++++++++++++++++ ...ce_riglib_bmi_feedback_controllers_py.html | 462 +++++ ...erface_riglib_bmi_goal_calculators_py.html | 422 +++++ ...nterface_riglib_bmi_kfdecoder_fcns_py.html | 476 ++++++ ...hon-interface_riglib_bmi_kfdecoder_py.html | 979 +++++++++++ ...on-interface_riglib_bmi_lindecoder_py.html | 129 ++ ...face_riglib_bmi_onedim_lfp_decoder_py.html | 306 ++++ ...on-interface_riglib_bmi_ppfdecoder_py.html | 574 +++++++ ...terface_riglib_bmi_rat_bmi_decoder_py.html | 646 +++++++ ...on-interface_riglib_bmi_robot_arms_py.html | 785 +++++++++ ...n-interface_riglib_bmi_sim_neurons_py.html | 921 ++++++++++ ...n-interface_riglib_bmi_sskfdecoder_py.html | 241 +++ ...face_riglib_bmi_state_space_models_py.html | 594 +++++++ ...-python-interface_riglib_bmi_train_py.html | 1476 ++++++++++++++++ ...ain-python-interface_riglib_button_py.html | 166 ++ ...thon-interface_riglib_calibrations_py.html | 371 ++++ ...thon-interface_riglib_dio___init___py.html | 65 + ...nterface_riglib_dio_nidaq___init___py.html | 230 +++ ...-python-interface_riglib_dio_parse_py.html | 235 +++ ...interface_riglib_experiment_Pygame_py.html | 158 ++ ...terface_riglib_experiment___init___py.html | 162 ++ ...rface_riglib_experiment_experiment_py.html | 707 ++++++++ ...terface_riglib_experiment_generate_py.html | 203 +++ ...-interface_riglib_experiment_mocks_py.html | 225 +++ ...interface_riglib_experiment_report_py.html | 152 ++ ...python-interface_riglib_eyetracker_py.html | 298 ++++ ...ain-python-interface_riglib_filter_py.html | 117 ++ ...thon-interface_riglib_fsm___init___py.html | 65 + ...-interface_riglib_fsm_fsm___init___py.html | 65 + ...ython-interface_riglib_fsm_fsm_fsm_py.html | 428 +++++ ...-python-interface_riglib_fsm_setup_py.html | 82 + ...nterface_riglib_hdfwriter___init___py.html | 68 + ...iglib_hdfwriter_hdfwriter___init___py.html | 66 + ...glib_hdfwriter_hdfwriter_hdfwriter_py.html | 202 +++ ...n-interface_riglib_hdfwriter_setup_py.html | 82 + ...python-interface_riglib_kinarmdata_py.html | 136 ++ ...thon-interface_riglib_kinarmsocket_py.html | 146 ++ ...nterface_riglib_master8stimulation_py.html | 174 ++ ...hon-interface_riglib_motiontracker_py.html | 364 ++++ ...in-python-interface_riglib_mp_calc_py.html | 284 ++++ ...n-python-interface_riglib_mp_proxy_py.html | 118 ++ ...glib_optitrack_client_NatNetClient_py.html | 580 +++++++ ...glib_optitrack_client_PythonSample_py.html | 105 ++ ...e_riglib_optitrack_client___init___py.html | 64 + ...track_client_optitrack_direct_pack_py.html | 151 ++ ...titrack_client_optitrack_interface_py.html | 190 +++ ..._client_test_NatNetClient_perframe_py.html | 86 + ...glib_optitrack_client_test_control_py.html | 87 + ...ib_optitrack_client_test_optitrack_py.html | 74 + ...n-python-interface_riglib_phidgets_py.html | 191 +++ ...ain-python-interface_riglib_plants_py.html | 624 +++++++ ...n-interface_riglib_plexon___init___py.html | 287 ++++ ...n-interface_riglib_plexon_checkbin_py.html | 77 + ...on-interface_riglib_plexon_plexnet_py.html | 479 ++++++ ...plexon_plexnet_softserver_oldfiles_py.html | 437 +++++ ...hon-interface_riglib_plexon_source_py.html | 97 ++ ...erface_riglib_plexon_test_plexfile_py.html | 80 + ...terface_riglib_positioner___init___py.html | 652 +++++++ ...-interface_riglib_positioner_calib_py.html | 140 ++ ...ain-python-interface_riglib_reward_py.html | 437 +++++ ...python-interface_riglib_serial_dio_py.html | 208 +++ ...rain-python-interface_riglib_setup_py.html | 108 ++ ...brain-python-interface_riglib_sink_py.html | 335 ++++ ...ain-python-interface_riglib_source_py.html | 768 +++++++++ ...face_riglib_stereo_opengl___init___py.html | 74 + ...e_riglib_stereo_opengl_environment_py.html | 111 ++ ...-interface_riglib_stereo_opengl_ik_py.html | 746 ++++++++ ...erface_riglib_stereo_opengl_models_py.html | 367 ++++ ...ce_riglib_stereo_opengl_primitives_py.html | 381 +++++ ...glib_stereo_opengl_render___init___py.html | 70 + ...ce_riglib_stereo_opengl_render_fbo_py.html | 173 ++ ...riglib_stereo_opengl_render_render_py.html | 227 +++ ...riglib_stereo_opengl_render_shader_py.html | 181 ++ ...e_riglib_stereo_opengl_render_ssao_py.html | 139 ++ ...riglib_stereo_opengl_render_stereo_py.html | 182 ++ ...face_riglib_stereo_opengl_textures_py.html | 161 ++ ...terface_riglib_stereo_opengl_utils_py.html | 150 ++ ...erface_riglib_stereo_opengl_window_py.html | 497 ++++++ ...interface_riglib_stereo_opengl_xfm_py.html | 310 ++++ ...on-interface_riglib_stimulus_pulse_py.html | 93 + ...python-interface_riglib_touch_data_py.html | 163 ++ .../htmlcov/coverage_html.js | 589 +++++++ .../htmlcov/index.html | 728 ++++++++ .../jquery.ba-throttle-debounce.min.js | 9 + .../htmlcov/jquery.hotkeys.js | 99 ++ .../htmlcov/jquery.isonscreen.js | 53 + .../htmlcov/jquery.min.js | 4 + .../htmlcov/jquery.tablesorter.min.js | 2 + .../htmlcov/keybd_closed.png | Bin 0 -> 112 bytes .../htmlcov/keybd_open.png | Bin 0 -> 112 bytes .../htmlcov/status.json | 1 + .../unit_tests (deprecated)/htmlcov/style.css | 291 ++++ .../mocks.py | 0 .../reqlib.py | 0 .../requirements.py | 0 .../run_unit_tests.py | 0 .../test_built_in_vfb_task.py | 0 .../test_feature_savehdf.py | 0 .../test_mixin_features.py | 0 .../test_riglib_bmi.py | 0 .../test_riglib_experiment.py | 0 .../test_riglib_hdfwriter.py | 0 .../test_riglib_source.py | 0 .../test_riglib_traits.py | 0 174 files changed, 32510 insertions(+), 145 deletions(-) rename tests/{ => old_tests}/blackrock_lfp_train_test.py (100%) rename tests/{ => old_tests}/ex_redundant_ik_grad_descent.py (100%) rename tests/{ => old_tests}/graphics_demo.py (100%) rename tests/{ => old_tests}/hdfwrite.py (100%) rename tests/{ => old_tests}/lfp_daq_test.py (100%) rename tests/{ => old_tests}/lfp_train_test.py (100%) rename tests/{ => old_tests}/plexnet.py (100%) rename tests/{ => old_tests}/plexstream.py (100%) rename tests/{ => old_tests}/plexstream_bmi.py (100%) rename tests/{ => old_tests}/psth.py (100%) rename tests/{ => old_tests}/spikebin.py (100%) rename tests/{ => old_tests}/start_task_from_cmd_line.py (100%) rename tests/{ => old_tests}/start_task_from_cmd_line_sim.py (100%) rename tests/{ => old_tests}/test_SimBMIControlMulti.py (100%) rename tests/{ => old_tests}/test_adapting_ppf_with_assist.py (100%) rename tests/{ => old_tests}/test_arduino_dio_omniplex.py (100%) rename tests/{ => old_tests}/test_blackrock_cbpy_streaming.py (100%) rename tests/{ => old_tests}/test_blackrock_rstart.py (100%) rename tests/{ => old_tests}/test_bmi_bin_edges.py (100%) rename tests/{ => old_tests}/test_bmi_recon.py (100%) rename tests/{ => old_tests}/test_constrained_sskf.py (100%) rename tests/{ => old_tests}/test_data.py (100%) rename tests/{ => old_tests}/test_db.py (100%) rename tests/{ => old_tests}/test_dbfunctions.py (100%) rename tests/{ => old_tests}/test_decoder_training.py (100%) rename tests/{ => old_tests}/test_fa_decoding.py (100%) rename tests/{ => old_tests}/test_fa_tasks.py (100%) rename tests/{ => old_tests}/test_fixed_ppf.py (100%) rename tests/{ => old_tests}/test_kinarm_stream.py (100%) rename tests/{ => old_tests}/test_lfp_cursor_fixed_decoder.py (100%) rename tests/{ => old_tests}/test_map_arduino_dio.py (100%) rename tests/{ => old_tests}/test_npz_to_hdf.py (100%) rename tests/{ => old_tests}/test_onedim_lfp.py (100%) rename tests/{ => old_tests}/test_phidgets_joystick.py (100%) rename tests/{ => old_tests}/test_phidgets_joystick_acqu.py (100%) rename tests/{ => old_tests}/test_plx_rstart.py (100%) rename tests/{ => old_tests}/test_rat_bmi.py (100%) rename tests/{ => old_tests}/test_reconst_bmi_trajectory.py (100%) rename tests/{ => old_tests}/test_reward.py (100%) rename tests/{ => old_tests}/test_reward_ao.py (100%) rename tests/{ => old_tests}/test_rml_cg.py (100%) rename tests/{ => old_tests}/test_rml_cg_param_updates.py (100%) rename tests/{ => old_tests}/test_rstart_barebones.py (100%) rename tests/{ => old_tests}/test_tentacle_attractor.py (100%) rename tests/{ => old_tests}/test_train_kfdecoder_brain_control.py (100%) rename tests/{ => old_tests}/test_train_kfdecoder_manual_control.py (100%) rename tests/{ => old_tests}/test_window_2d.py (100%) rename tests/{ => old_tests}/vfb_joint_space.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/coverage.bat (100%) rename tests/{unit_tests => unit_tests (deprecated)}/coverage.sh (100%) mode change 100755 => 100644 create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_arduino_imu_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_arduino_joystick_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_brMiscFxns_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_brpylib_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_cerelink_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_accumulator_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_assist_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_bmi_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_clda_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_extractor_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_feedback_controllers_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_goal_calculators_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_kfdecoder_fcns_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_kfdecoder_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_lindecoder_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_onedim_lfp_decoder_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_ppfdecoder_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_rat_bmi_decoder_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_robot_arms_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_sim_neurons_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_sskfdecoder_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_state_space_models_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_train_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_button_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_calibrations_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio_nidaq___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio_parse_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_Pygame_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_experiment_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_generate_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_mocks_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_report_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_eyetracker_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_filter_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_fsm___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_fsm_fsm_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_setup_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_hdfwriter___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_hdfwriter_hdfwriter_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_setup_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_kinarmdata_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_kinarmsocket_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_master8stimulation_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_motiontracker_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_mp_calc_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_mp_proxy_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_NatNetClient_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_PythonSample_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_optitrack_direct_pack_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_optitrack_interface_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_NatNetClient_perframe_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_control_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_optitrack_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_phidgets_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plants_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_checkbin_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_plexnet_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_plexnet_softserver_oldfiles_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_source_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_test_plexfile_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_positioner___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_positioner_calib_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_reward_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_serial_dio_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_setup_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_sink_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_source_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_environment_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_ik_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_models_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_primitives_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render___init___py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_fbo_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_render_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_shader_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_ssao_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_stereo_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_textures_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_utils_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_window_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_xfm_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stimulus_pulse_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_touch_data_py.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/coverage_html.js create mode 100644 tests/unit_tests (deprecated)/htmlcov/index.html create mode 100644 tests/unit_tests (deprecated)/htmlcov/jquery.ba-throttle-debounce.min.js create mode 100644 tests/unit_tests (deprecated)/htmlcov/jquery.hotkeys.js create mode 100644 tests/unit_tests (deprecated)/htmlcov/jquery.isonscreen.js create mode 100644 tests/unit_tests (deprecated)/htmlcov/jquery.min.js create mode 100644 tests/unit_tests (deprecated)/htmlcov/jquery.tablesorter.min.js create mode 100644 tests/unit_tests (deprecated)/htmlcov/keybd_closed.png create mode 100644 tests/unit_tests (deprecated)/htmlcov/keybd_open.png create mode 100644 tests/unit_tests (deprecated)/htmlcov/status.json create mode 100644 tests/unit_tests (deprecated)/htmlcov/style.css rename tests/{unit_tests => unit_tests (deprecated)}/mocks.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/reqlib.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/requirements.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/run_unit_tests.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/test_built_in_vfb_task.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/test_feature_savehdf.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/test_mixin_features.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/test_riglib_bmi.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/test_riglib_experiment.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/test_riglib_hdfwriter.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/test_riglib_source.py (100%) rename tests/{unit_tests => unit_tests (deprecated)}/test_riglib_traits.py (100%) diff --git a/built_in_tasks/bmimultitasks.py b/built_in_tasks/bmimultitasks.py index 6959aeff..32d6554e 100644 --- a/built_in_tasks/bmimultitasks.py +++ b/built_in_tasks/bmimultitasks.py @@ -26,7 +26,7 @@ from riglib.stereo_opengl.primitives import Line -from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceNLinkPlanarChain, StateSpaceEndptPos3D +from riglib.bmi.state_space_models import StateSpaceEndptVel2D, StateSpaceNLinkPlanarChain from built_in_tasks.manualcontrolmultitasks import ManualControlMulti @@ -267,8 +267,6 @@ def create_assister(self): self.assister = OFCEndpointAssister() elif isinstance(self.decoder.ssm, StateSpaceEndptVel2D): self.assister = SimpleEndpointAssister(**kwargs) - elif isinstance(self.decoder.ssm, StateSpaceEndptPos3D): - self.assister = SimplePosAssister(**kwargs) ## elif (self.decoder.ssm == namelist.tentacle_2D_state_space) or (self.decoder.ssm == namelist.joint_2D_state_space): ## # kin_chain = self.plant.kin_chain ## # A, B, W = self.decoder.ssm.get_ssm_matrices(update_rate=self.decoder.binlen) @@ -286,8 +284,6 @@ def create_assister(self): def create_goal_calculator(self): if isinstance(self.decoder.ssm, StateSpaceEndptVel2D): self.goal_calculator = goal_calculators.ZeroVelocityGoal(self.decoder.ssm) - elif isinstance(self.decoder.ssm, StateSpaceEndptPos3D): - self.goal_calculator = goal_calculators.ZeroVelocityGoal(self.decoder.ssm) elif isinstance(self.decoder.ssm, StateSpaceNLinkPlanarChain) and self.decoder.ssm.n_links == 2: self.goal_calculator = goal_calculators.PlanarMultiLinkJointGoal(self.decoder.ssm, self.plant.base_loc, self.plant.kin_chain, multiproc=False, init_resp=None) elif isinstance(self.decoder.ssm, StateSpaceNLinkPlanarChain) and self.decoder.ssm.n_links == 4: @@ -508,7 +504,7 @@ def _start_wait(self, *args, **kwargs): from riglib.bmi.feedback_controllers import LQRController, PosFeedbackController class SimBMIControlMulti(BMIControlMulti2DWindow): win_res = (250, 140) - sequence_generators = ManualControlMulti.sequence_generators + ['sim_target_seq_generator_multi'] + sequence_generators = ManualControlMulti.sequence_generators + ['sim_target_seq_generator_multi', 'sim_target_no_center'] def __init__(self, *args, **kwargs): from riglib.bmi.state_space_models import StateSpaceEndptVel2D @@ -546,6 +542,20 @@ def sim_target_seq_generator_multi(n_targs=8, n_trials=8): yield np.array([[center[0], 0, center[1]], [targ[0], 0, targ[1]]]) + @staticmethod + def sim_target_no_center(n_targs=8, n_trials=8): + ''' + Simulated generator for simulations of the BMIControlMulti and CLDAControlMulti tasks + ''' + pi = np.pi + targets = 8*np.vstack([[np.cos(pi/4*k), np.sin(pi/4*k)] for k in range(8)]) + + target_inds = np.random.randint(0, n_targs, n_trials) + target_inds[0:n_targs] = np.arange(min(n_targs, n_trials)) + for k in range(n_trials): + targ = targets[target_inds[k], :] + yield np.array([[targ[0], 0, targ[1]]]) + class SimBMICosEncKFDec(SimCosineTunedEnc, SimKFDecoderSup, SimBMIControlMulti): def __init__(self, *args, **kwargs): N_NEURONS = 4 @@ -577,33 +587,73 @@ def __init__(self, *args, **kwargs): class SimBMICosEncLinDec(SimLFPCosineTunedEnc, SimBMIControlMulti): def __init__(self, *args, **kwargs): - # build the observation matrix - sim_C = np.zeros((2, 3)) - # control x positive directions - sim_C[0, :] = np.array([1, 0, 0]) - sim_C[1, :] = np.array([0, 0, 1]) - - kwargs['sim_C'] = sim_C - kwargs['assist_level'] = (0, 0) + ssm = StateSpaceEndptVel2D() - # make the decoder map - self.decoder_map = np.array([[1, 0], [0, 0], [0, 1]]) + # build the observation matrix + sim_C = np.zeros((2, 7)) - ssm = StateSpaceEndptVel2D() + # control x and z position + sim_C[0, :] = np.array([1, 0, 0, 0, 0, 0, 0]) + sim_C[1, :] = np.array([0, 0, 1, 0, 0, 0, 0]) + self.vel_control = False self.fb_ctrl = PosFeedbackController() + + # map neurons (2) to states (7) using C + self.decoder_map = sim_C.T self.ssm = ssm + kwargs['sim_C'] = sim_C + kwargs['assist_level'] = (0, 0) super(SimBMICosEncLinDec, self).__init__(*args, **kwargs) + def init(self, *args, **kwargs): + self.max_attempts = 1 + self.timeout_time = 1 + super(SimBMICosEncLinDec, self).init(*args, **kwargs) + def load_decoder(self): units = self.encoder.get_units() filt_counts = 10000 # number of observations to calculate range - filt_window = 3 # number of observations to average for each tick + filt_window = 1 # number of observations to average for each tick filt_map = self.decoder_map # map from states to units - vel_control = False - filt = PosVelScaleFilter(vel_control, filt_counts, self.ssm.n_states, len(units), map=filt_map, window=filt_window) - gain = 2 * np.max(self.plant.endpt_bounds) - filt.update_norm_param(neural_mean=[5, 5], neural_range=[10,10], scaling_mean=[0,0], scaling_range=[gain,gain]) - filt.fix_norm_param() - self.decoder = Decoder(filt, units, self.ssm, binlen=0.1, subbins=1) - self.decoder.n_features = len(units) \ No newline at end of file + filt = PosVelScaleFilter(self.vel_control, filt_counts, self.ssm.n_states, \ + len(units), map=filt_map, window=filt_window, call_rate=self.fps, + plant_gain=2*np.max(self.plant.endpt_bounds)) + + # supply some known good attributes + neural_gain = self.fov + scaling_gain = 1 + filt.update_norm_attr(neural_mean=[neural_gain/2, neural_gain/2], neural_std=[neural_gain,neural_gain], \ + scaling_mean=[0,0], scaling_std=[scaling_gain,scaling_gain]) + filt.fix_norm_attr() + + # or allow decoder to figure it out + # neural_gain = self.fov * 1.1 + # filt.update_norm_attr(neural_mean=[neural_gain/2, neural_gain/2], neural_std=[neural_gain,neural_gain]) + + self.decoder = Decoder(filt, units, self.ssm, binlen=0.1, subbins=1, call_rate=self.fps) + self.decoder.n_features = len(units) + +class SimBMIVelocityLinDec(SimBMICosEncLinDec): + def __init__(self, *args, **kwargs): + + ssm = StateSpaceEndptVel2D() + + # control x and z velocity + sim_C = np.zeros((2, 7)) + sim_C[0, :] = np.array([0, 0, 0, 1, 0, 0, 0]) + sim_C[1, :] = np.array([0, 0, 0, 0, 0, 1, 0]) + self.vel_control = True + A, B, W = ssm.get_ssm_matrices() + Q = np.mat(np.diag([1., 1, 1, 0, 0, 0, 0])) + R = 10000*np.mat(np.diag([1., 1., 1.])) + self.fb_ctrl = LQRController(A, B, Q, R) + + # map neurons (2) to states (7) using C + self.decoder_map = sim_C.T + self.ssm = ssm + kwargs['sim_C'] = sim_C + kwargs['assist_level'] = (0, 0) + + super(SimBMICosEncLinDec, self).__init__(*args, **kwargs) + \ No newline at end of file diff --git a/features/simulation_features.py b/features/simulation_features.py index 9bba08c7..f24ed484 100644 --- a/features/simulation_features.py +++ b/features/simulation_features.py @@ -228,7 +228,7 @@ class SimNormCosineTunedEnc(SimNeuralEnc): def _init_neural_encoder(self): from riglib.bmi.sim_neurons import NormalizedCosEnc self.encoder = NormalizedCosEnc(self.plant.endpt_bounds, self.sim_C, self.ssm, spike=True, return_ts=False, - DT=self.update_rate, tick=self.update_rate) + DT=self.update_rate, tick=self.update_rate, gain=self.fov) def create_feature_extractor(self): ''' @@ -245,7 +245,7 @@ class SimLFPCosineTunedEnc(SimNeuralEnc): def _init_neural_encoder(self): from riglib.bmi.sim_neurons import NormalizedCosEnc self.encoder = NormalizedCosEnc(self.plant.endpt_bounds, self.sim_C, self.ssm, spike=False, return_ts=False, - DT=self.update_rate, tick=self.update_rate, n_bands=len(self.bands)) + DT=self.update_rate, tick=self.update_rate, n_bands=len(self.bands), gain=self.fov) def create_feature_extractor(self): ''' diff --git a/riglib/bmi/extractor.py b/riglib/bmi/extractor.py index 58f1afba..2a13d279 100644 --- a/riglib/bmi/extractor.py +++ b/riglib/bmi/extractor.py @@ -651,11 +651,11 @@ def __call__(self, *args, **kwargs): self.idx += 1 return output -class ReplayLFPPowerExtractor(BinnedSpikeCountsExtractor): +class ReplayLFPPowerExtractor(LFPMTMPowerExtractor): ''' A "feature extractor" that replays LFP power estimates from an HDF file ''' - feature_type = 'lfp_power' + def __init__(self, hdf_table, source='lfp_power'): ''' Constructor for ReplayLFPPowerExtractor @@ -682,7 +682,7 @@ def __init__(self, hdf_table, source='lfp_power'): def __call__(self, *args, **kwargs): ''' - See BinnedSpikeCountsExtractor.__call__ for documentation + See LFPMTMPowerExtractor.__call__ for documentation ''' output = self.hdf_table[self.idx][self.source] self.idx += 1 diff --git a/riglib/bmi/lindecoder.py b/riglib/bmi/lindecoder.py index fd0271f2..f7f5746c 100644 --- a/riglib/bmi/lindecoder.py +++ b/riglib/bmi/lindecoder.py @@ -11,7 +11,10 @@ def __init__(self, mean, *args, **kwargs): class LinearScaleFilter(object): - def __init__(self, n_counts, n_states, n_units, map=None, window=1): + model_attrs = ['attr'] + attrs_to_pickle = ['attr', 'obs', 'map'] + + def __init__(self, n_counts, n_states, n_units, map=None, window=1, plant_gain=20): ''' Constructor for LinearScaleFilter @@ -28,103 +31,136 @@ def __init__(self, n_counts, n_states, n_units, map=None, window=1): Floating point matrix of size (S, D) where S is the number of states and D is the number of units, assigning a weight to each pair window : How many observations to average to smooth output (default = 1) + plant_gain : how big is the screen, basically (default = 20) + Maps from normalized output (0,1) to plant coordinates Returns ------- LinearScaleFilter instance ''' - self.state = State(np.zeros([n_states,1])) self.obs = np.zeros((n_counts, n_units)) self.n_states = n_states self.n_units = n_units self.window = window self.map = map + self.plant_gain = plant_gain if map is None: # Generate a default mapping where one unit controls one state self.map = np.identity(max(n_states, n_units)) self.map = np.resize(self.map, (n_states, n_units)) self.count = 0 - self.params = dict( + self.attr = dict( neural_mean = np.zeros(n_units), - neural_range = np.ones(n_units), + neural_std = np.ones(n_units), scaling_mean = np.zeros(n_units), - scaling_range = np.ones(n_units), + scaling_std = np.ones(n_units), ) self.fixed = False - - def _init_state(self): - pass + self._init_state() def get_mean(self): return np.array(self.state.mean).ravel() def __call__(self, obs, **kwargs): # TODO need to pick single frequency band if given more than one - self.state = self._normalize(obs, **kwargs) + self._add_obs(obs, **kwargs) + if not self.fixed: + self._update_scale_attr() + self._init_state() - def _normalize(self, obs,**kwargs): - ''' Function to compute normalized scaling of new observations''' + def update_norm_attr(self, neural_mean=None, neural_std=None, scaling_mean=None, scaling_std=None): + ''' Public method to set mean and std attributes''' + if neural_mean is not None: + self.attr.update(neural_mean = neural_mean) + if neural_std is not None: + self.attr.update(neural_std = neural_std) + if scaling_mean is not None: + self.attr.update(scaling_mean = scaling_mean) + if scaling_std is not None: + self.attr.update(scaling_std = scaling_std) + def fix_norm_attr(self): + ''' Stop fliter from self updating its attributes''' + self.fixed = True + + def _pickle_init(self): + self.fixed = True + + def _init_state(self): + out = self._scale() + self.state = State(out) + + def _add_obs(self, obs,**kwargs): + ''' Normalize new observations and add them to the observation matrix''' + + # Z-score neural data + norm_obs = (np.squeeze(obs) - self.attr['neural_mean']) / self.attr['neural_std'] + # Update observation matrix - norm_obs = (np.squeeze(obs) - self.params['neural_mean']) / self.params['neural_range'] # center on zero - self.obs[:-1, :] = self.obs[1:, :] - self.obs[-1, :] = norm_obs if self.count < len(self.obs): self.count += 1 - - if not self.fixed: - self._update_scale_param(obs) - m_win = np.squeeze(np.mean(self.obs[-self.window:, :], axis=0)) - x = (m_win - self.params['scaling_mean']) * self.params['scaling_range'] + self.obs[:-1, :] = self.obs[1:, :] + self.obs[-1, :] = norm_obs + + def _scale(self): + ''' Scale the (normalized) observations within the window''' + + # Normalize windowed average to 'scaling' mean and range + if self.count == 0: + m_win = np.zeros(np.size(self.obs, axis=1)) + elif self.count < self.window: + m_win = np.squeeze(np.mean(self.obs[-self.count:, :], axis=0)) + else: + m_win = np.squeeze(np.mean(self.obs[-self.window:, :], axis=0)) + x = (m_win - self.attr['scaling_mean']) / self.attr['scaling_std'] # Arrange output according to map - out = np.matmul(self.map, x).reshape(-1,1) - return State(out) + out = np.matmul(self.map, x).reshape(-1,1) * self.plant_gain + return out - def _update_scale_param(self, obs): - ''' Function to update the normalization parameters''' + def _update_scale_attr(self): + ''' Update the normalization parameters''' # Normalize latest observation(s) mean = np.median(self.obs[-self.count:, :], axis=0) # range = max(1, np.amax(self.obs[-self.count:, :]) - np.amin(self.obs[-self.count:, :])) - range = 3 * np.std(self.obs[-self.count:, :], axis=0) - range[range < 1] = 1 - self.update_norm_param(scaling_mean=mean, scaling_range=range) + std = np.std(self.obs[-self.count:, :], axis=0) + std[std == 0] = 1 # Hopefully this never happens + self.update_norm_attr(scaling_mean=mean, scaling_std=std) - def update_norm_param(self, neural_mean=None, neural_range=None, scaling_mean=None, scaling_range=None): - if neural_mean is not None: - self.params.update(neural_mean = neural_mean) - if neural_range is not None: - self.params.update(neural_range = neural_range) - if scaling_mean is not None: - self.params.update(scaling_mean = scaling_mean) - if scaling_range is not None: - self.params.update(scaling_range = scaling_range) - - def fix_norm_param(self): - self.fixed = True - - def get_norm_param(self): - return self.params class PosVelState(State): + ''' Simple state with the ability to integrate velocity over time''' - def __init__(self, vel_control, *args, **kwargs): + def __init__(self, vel_control, call_rate=60): self.vel_control = vel_control + self.call_rate = call_rate self.mean = np.zeros((7,1)) def update(self, mean): if self.vel_control: - self.mean[3:6] = mean - self.mean[0:3] = self.mean[3:6] + self.mean[0:3] + self.mean[3:6] = mean[3:6] + + # Add the velocity (units/s) to the position (units) + self.mean[0:3] = self.mean[3:6] / self.call_rate + self.mean[0:3] else: - self.mean = np.zeros((7,1)) - self.mean[0:3] = mean + self.mean = mean class PosVelScaleFilter(LinearScaleFilter): + ''' Linear filter that holds a position and velocity state''' + def __init__(self, vel_control, *args, **kwargs): + self.call_rate = kwargs.pop('call_rate') + self.vel_control = vel_control super(PosVelScaleFilter, self).__init__(*args, **kwargs) - self.state = PosVelState(vel_control) + + def _init_state(self): + self.state = PosVelState(self.vel_control, self.call_rate) + out = self._scale() + self.state.update(out) def __call__(self, obs, **kwargs): - state = self._normalize(obs, **kwargs) - self.state.update(state.mean) \ No newline at end of file + self._add_obs(obs, **kwargs) + if not self.fixed: + self._update_scale_attr() + out = self._scale() + self.state.update(out) \ No newline at end of file diff --git a/riglib/bmi/sim_neurons.py b/riglib/bmi/sim_neurons.py index eaa0688f..6d63a6d1 100644 --- a/riglib/bmi/sim_neurons.py +++ b/riglib/bmi/sim_neurons.py @@ -362,6 +362,7 @@ def y2_eq_r2_min_x2(self, x_arr, r2): y.append(-1*np.sqrt(r2 - x**2)) return np.array(y) +from riglib.bmi.state_space_models import StateSpaceEndptVel2D class NormalizedCosEnc(GenericCosEnc): ''' Generates neural observations (spikes or LFP) based on normalized scaling within the bounds of @@ -406,56 +407,60 @@ def __init__(self, bounds, C, ssm, spike=True, return_ts=False, DT=0.1, tick=1/6 call_ds_rate = DT / tick super(NormalizedCosEnc, self).__init__(C, ssm, return_ts, DT, call_ds_rate) - def gen_spikes(self, next_state, mode=None): + def gen_spikes(self, rates, mode=None): """ Simulate the spikes Parameters ---------- - next_state : np.array of shape (N, 1) - The "next state" to be encoded by this population of neurons - + rates : np.array of shape (N, 1) + Returns ------- time stamps or counts Either spike time stamps or a vector of unit spike counts is returned, depending on whether the 'return_ts' attribute is True """ - norm_state = np.divide(np.subtract(np.squeeze(next_state[0:3]), self.min), self.range) - rates = np.dot(self.C, norm_state) * self.gain return self.return_spikes(rates, mode=mode) - def gen_power(self, next_state, mode=None): + def gen_power(self, ideal, mode=None): """ - Simulate the LFP powers + Simulate the LFP powers by adding gaussian noise to ideal powers Parameters ---------- - next_state : np.array of shape (N, 1) - The "next state" to be encoded by this population of neurons + ideal : np.array of shape (N, 1) Returns ------- powers : np.array of shape (N, P) -> flattened N number of neurons, P number of power bands, determined by n_bands """ - norm_state = np.divide(np.subtract(np.squeeze(next_state[0:3]), self.min), self.range) - ideal = np.dot(self.C, norm_state) * self.gain # Generate gaussian noise - noise = np.random.normal(0, 0.05 * self.gain, size=(len(ideal), self.n_bands)) + noise = np.random.normal(0, 0.02 * self.gain, size=(len(ideal), self.n_bands)) # Replicate across frequency bands power = np.tile(ideal.reshape((-1,1)), (1, self.n_bands)) + noise return power.reshape(-1,1) - + def __call__(self, next_state, mode=None): ''' See CosEnc.__call__ for docs - ''' + ''' + next_state = np.squeeze(np.asarray(next_state)) + + if isinstance(self.ssm, StateSpaceEndptVel2D): + norm_pos = np.divide(np.subtract(next_state[0:3], self.min), self.range) + norm_vel = np.divide(np.subtract(next_state[3:6], self.min), self.range) + out = np.dot(self.C, [norm_pos[0], norm_pos[1], norm_pos[2], norm_vel[0], norm_vel[1], norm_vel[2], 0]) * self.gain + else: + raise NotImplementedError() + + if self.spike: if self.call_count % self.call_ds_rate == 0: - ts_data = self.gen_spikes(next_state, mode=mode) + ts_data = self.gen_spikes(out, mode=mode) else: if self.return_ts: @@ -469,7 +474,7 @@ def __call__(self, next_state, mode=None): return ts_data else: if self.call_count % self.call_ds_rate == 0: - lfp_data = self.gen_power(next_state, mode=mode) + lfp_data = self.gen_power(out, mode=mode) else: lfp_data = np.zeros((self.n_neurons*self.n_bands, 1)) diff --git a/riglib/bmi/state_space_models.py b/riglib/bmi/state_space_models.py index c0d204f4..b8d70591 100644 --- a/riglib/bmi/state_space_models.py +++ b/riglib/bmi/state_space_models.py @@ -381,26 +381,6 @@ def __setstate__(self, state): if not hasattr(self, 'w'): self.w = 7 -class StateSpaceEndptPos3D(StateSpace): - ''' StateSpace for 3D pos control''' - def __init__(self, **kwargs): - self.states = [ - State('hand_px', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0), - State('hand_py', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0), - State('hand_pz', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0) - ] - - def __setstate__(self, state): - self.__dict__ = state - if not hasattr(self, 'Delta'): - self.Delta = 0.1 - - if not hasattr(self, 'vel_decay'): - self.vel_decay = 0.8 - - if not hasattr(self, 'w'): - self.w = 7 - ############################ ##### Helper functions ##### ############################ diff --git a/tests/blackrock_lfp_train_test.py b/tests/old_tests/blackrock_lfp_train_test.py similarity index 100% rename from tests/blackrock_lfp_train_test.py rename to tests/old_tests/blackrock_lfp_train_test.py diff --git a/tests/ex_redundant_ik_grad_descent.py b/tests/old_tests/ex_redundant_ik_grad_descent.py similarity index 100% rename from tests/ex_redundant_ik_grad_descent.py rename to tests/old_tests/ex_redundant_ik_grad_descent.py diff --git a/tests/graphics_demo.py b/tests/old_tests/graphics_demo.py similarity index 100% rename from tests/graphics_demo.py rename to tests/old_tests/graphics_demo.py diff --git a/tests/hdfwrite.py b/tests/old_tests/hdfwrite.py similarity index 100% rename from tests/hdfwrite.py rename to tests/old_tests/hdfwrite.py diff --git a/tests/lfp_daq_test.py b/tests/old_tests/lfp_daq_test.py similarity index 100% rename from tests/lfp_daq_test.py rename to tests/old_tests/lfp_daq_test.py diff --git a/tests/lfp_train_test.py b/tests/old_tests/lfp_train_test.py similarity index 100% rename from tests/lfp_train_test.py rename to tests/old_tests/lfp_train_test.py diff --git a/tests/plexnet.py b/tests/old_tests/plexnet.py similarity index 100% rename from tests/plexnet.py rename to tests/old_tests/plexnet.py diff --git a/tests/plexstream.py b/tests/old_tests/plexstream.py similarity index 100% rename from tests/plexstream.py rename to tests/old_tests/plexstream.py diff --git a/tests/plexstream_bmi.py b/tests/old_tests/plexstream_bmi.py similarity index 100% rename from tests/plexstream_bmi.py rename to tests/old_tests/plexstream_bmi.py diff --git a/tests/psth.py b/tests/old_tests/psth.py similarity index 100% rename from tests/psth.py rename to tests/old_tests/psth.py diff --git a/tests/spikebin.py b/tests/old_tests/spikebin.py similarity index 100% rename from tests/spikebin.py rename to tests/old_tests/spikebin.py diff --git a/tests/start_task_from_cmd_line.py b/tests/old_tests/start_task_from_cmd_line.py similarity index 100% rename from tests/start_task_from_cmd_line.py rename to tests/old_tests/start_task_from_cmd_line.py diff --git a/tests/start_task_from_cmd_line_sim.py b/tests/old_tests/start_task_from_cmd_line_sim.py similarity index 100% rename from tests/start_task_from_cmd_line_sim.py rename to tests/old_tests/start_task_from_cmd_line_sim.py diff --git a/tests/test_SimBMIControlMulti.py b/tests/old_tests/test_SimBMIControlMulti.py similarity index 100% rename from tests/test_SimBMIControlMulti.py rename to tests/old_tests/test_SimBMIControlMulti.py diff --git a/tests/test_adapting_ppf_with_assist.py b/tests/old_tests/test_adapting_ppf_with_assist.py similarity index 100% rename from tests/test_adapting_ppf_with_assist.py rename to tests/old_tests/test_adapting_ppf_with_assist.py diff --git a/tests/test_arduino_dio_omniplex.py b/tests/old_tests/test_arduino_dio_omniplex.py similarity index 100% rename from tests/test_arduino_dio_omniplex.py rename to tests/old_tests/test_arduino_dio_omniplex.py diff --git a/tests/test_blackrock_cbpy_streaming.py b/tests/old_tests/test_blackrock_cbpy_streaming.py similarity index 100% rename from tests/test_blackrock_cbpy_streaming.py rename to tests/old_tests/test_blackrock_cbpy_streaming.py diff --git a/tests/test_blackrock_rstart.py b/tests/old_tests/test_blackrock_rstart.py similarity index 100% rename from tests/test_blackrock_rstart.py rename to tests/old_tests/test_blackrock_rstart.py diff --git a/tests/test_bmi_bin_edges.py b/tests/old_tests/test_bmi_bin_edges.py similarity index 100% rename from tests/test_bmi_bin_edges.py rename to tests/old_tests/test_bmi_bin_edges.py diff --git a/tests/test_bmi_recon.py b/tests/old_tests/test_bmi_recon.py similarity index 100% rename from tests/test_bmi_recon.py rename to tests/old_tests/test_bmi_recon.py diff --git a/tests/test_constrained_sskf.py b/tests/old_tests/test_constrained_sskf.py similarity index 100% rename from tests/test_constrained_sskf.py rename to tests/old_tests/test_constrained_sskf.py diff --git a/tests/test_data.py b/tests/old_tests/test_data.py similarity index 100% rename from tests/test_data.py rename to tests/old_tests/test_data.py diff --git a/tests/test_db.py b/tests/old_tests/test_db.py similarity index 100% rename from tests/test_db.py rename to tests/old_tests/test_db.py diff --git a/tests/test_dbfunctions.py b/tests/old_tests/test_dbfunctions.py similarity index 100% rename from tests/test_dbfunctions.py rename to tests/old_tests/test_dbfunctions.py diff --git a/tests/test_decoder_training.py b/tests/old_tests/test_decoder_training.py similarity index 100% rename from tests/test_decoder_training.py rename to tests/old_tests/test_decoder_training.py diff --git a/tests/test_fa_decoding.py b/tests/old_tests/test_fa_decoding.py similarity index 100% rename from tests/test_fa_decoding.py rename to tests/old_tests/test_fa_decoding.py diff --git a/tests/test_fa_tasks.py b/tests/old_tests/test_fa_tasks.py similarity index 100% rename from tests/test_fa_tasks.py rename to tests/old_tests/test_fa_tasks.py diff --git a/tests/test_fixed_ppf.py b/tests/old_tests/test_fixed_ppf.py similarity index 100% rename from tests/test_fixed_ppf.py rename to tests/old_tests/test_fixed_ppf.py diff --git a/tests/test_kinarm_stream.py b/tests/old_tests/test_kinarm_stream.py similarity index 100% rename from tests/test_kinarm_stream.py rename to tests/old_tests/test_kinarm_stream.py diff --git a/tests/test_lfp_cursor_fixed_decoder.py b/tests/old_tests/test_lfp_cursor_fixed_decoder.py similarity index 100% rename from tests/test_lfp_cursor_fixed_decoder.py rename to tests/old_tests/test_lfp_cursor_fixed_decoder.py diff --git a/tests/test_map_arduino_dio.py b/tests/old_tests/test_map_arduino_dio.py similarity index 100% rename from tests/test_map_arduino_dio.py rename to tests/old_tests/test_map_arduino_dio.py diff --git a/tests/test_npz_to_hdf.py b/tests/old_tests/test_npz_to_hdf.py similarity index 100% rename from tests/test_npz_to_hdf.py rename to tests/old_tests/test_npz_to_hdf.py diff --git a/tests/test_onedim_lfp.py b/tests/old_tests/test_onedim_lfp.py similarity index 100% rename from tests/test_onedim_lfp.py rename to tests/old_tests/test_onedim_lfp.py diff --git a/tests/test_phidgets_joystick.py b/tests/old_tests/test_phidgets_joystick.py similarity index 100% rename from tests/test_phidgets_joystick.py rename to tests/old_tests/test_phidgets_joystick.py diff --git a/tests/test_phidgets_joystick_acqu.py b/tests/old_tests/test_phidgets_joystick_acqu.py similarity index 100% rename from tests/test_phidgets_joystick_acqu.py rename to tests/old_tests/test_phidgets_joystick_acqu.py diff --git a/tests/test_plx_rstart.py b/tests/old_tests/test_plx_rstart.py similarity index 100% rename from tests/test_plx_rstart.py rename to tests/old_tests/test_plx_rstart.py diff --git a/tests/test_rat_bmi.py b/tests/old_tests/test_rat_bmi.py similarity index 100% rename from tests/test_rat_bmi.py rename to tests/old_tests/test_rat_bmi.py diff --git a/tests/test_reconst_bmi_trajectory.py b/tests/old_tests/test_reconst_bmi_trajectory.py similarity index 100% rename from tests/test_reconst_bmi_trajectory.py rename to tests/old_tests/test_reconst_bmi_trajectory.py diff --git a/tests/test_reward.py b/tests/old_tests/test_reward.py similarity index 100% rename from tests/test_reward.py rename to tests/old_tests/test_reward.py diff --git a/tests/test_reward_ao.py b/tests/old_tests/test_reward_ao.py similarity index 100% rename from tests/test_reward_ao.py rename to tests/old_tests/test_reward_ao.py diff --git a/tests/test_rml_cg.py b/tests/old_tests/test_rml_cg.py similarity index 100% rename from tests/test_rml_cg.py rename to tests/old_tests/test_rml_cg.py diff --git a/tests/test_rml_cg_param_updates.py b/tests/old_tests/test_rml_cg_param_updates.py similarity index 100% rename from tests/test_rml_cg_param_updates.py rename to tests/old_tests/test_rml_cg_param_updates.py diff --git a/tests/test_rstart_barebones.py b/tests/old_tests/test_rstart_barebones.py similarity index 100% rename from tests/test_rstart_barebones.py rename to tests/old_tests/test_rstart_barebones.py diff --git a/tests/test_tentacle_attractor.py b/tests/old_tests/test_tentacle_attractor.py similarity index 100% rename from tests/test_tentacle_attractor.py rename to tests/old_tests/test_tentacle_attractor.py diff --git a/tests/test_train_kfdecoder_brain_control.py b/tests/old_tests/test_train_kfdecoder_brain_control.py similarity index 100% rename from tests/test_train_kfdecoder_brain_control.py rename to tests/old_tests/test_train_kfdecoder_brain_control.py diff --git a/tests/test_train_kfdecoder_manual_control.py b/tests/old_tests/test_train_kfdecoder_manual_control.py similarity index 100% rename from tests/test_train_kfdecoder_manual_control.py rename to tests/old_tests/test_train_kfdecoder_manual_control.py diff --git a/tests/test_window_2d.py b/tests/old_tests/test_window_2d.py similarity index 100% rename from tests/test_window_2d.py rename to tests/old_tests/test_window_2d.py diff --git a/tests/vfb_joint_space.py b/tests/old_tests/vfb_joint_space.py similarity index 100% rename from tests/vfb_joint_space.py rename to tests/old_tests/vfb_joint_space.py diff --git a/tests/run_unit_tests.py b/tests/run_unit_tests.py index b1e66f72..b0ef16f2 100644 --- a/tests/run_unit_tests.py +++ b/tests/run_unit_tests.py @@ -3,8 +3,8 @@ from tests.test_decoders import TestLinDec from tests.test_features import TestKeyboardControl, TestMouseControl test_classes = [ - #TestLinDec, - #TestKeyboardControl, + TestLinDec, + TestKeyboardControl, TestMouseControl, ] diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 86ba4eee..583b2f73 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -1,5 +1,5 @@ from riglib.bmi import lindecoder -from built_in_tasks.bmimultitasks import SimBMICosEncLinDec +from built_in_tasks.bmimultitasks import SimBMICosEncLinDec, SimBMIVelocityLinDec from riglib import experiment import numpy as np @@ -7,9 +7,6 @@ class TestLinDec(unittest.TestCase): - def setUp(self): - pass - def test_sanity(self): simple_filt = lindecoder.LinearScaleFilter(100, 1, 1) self.assertEqual(0, simple_filt.get_mean()) @@ -36,29 +33,37 @@ def test_filter(self): #@unittest.skip('msg') def test_experiment(self): - N_TARGETS = 8 - N_TRIALS = 3 - seq = SimBMICosEncLinDec.sim_target_seq_generator_multi( - N_TARGETS, N_TRIALS) - base_class = SimBMICosEncLinDec - feats = [] - Exp = experiment.make(base_class, feats=feats) - exp = Exp(seq) - exp.init() - exp.run() - - rewards = 0 - time_penalties = 0 - hold_penalties = 0 - for s in exp.event_log: - if s[0] == 'reward': - rewards += 1 - elif s[0] == 'hold_penalty': - hold_penalties += 1 - elif s[0] == 'timeout_penalty': - time_penalties += 1 - self.assertTrue(rewards <= rewards + time_penalties + hold_penalties) - self.assertTrue(rewards > 0) + for cls in [SimBMICosEncLinDec, SimBMIVelocityLinDec]: + N_TARGETS = 8 + N_TRIALS = 16 + seq = cls.sim_target_no_center( + N_TARGETS, N_TRIALS) + base_class = cls + feats = [] + Exp = experiment.make(base_class, feats=feats) + exp = Exp(seq) + exp.init() + + exp.run() + + rewards, time_penalties, hold_penalties = calculate_rewards(exp) + self.assertTrue(rewards <= rewards + time_penalties + hold_penalties) + self.assertTrue(rewards > 0) + + +def calculate_rewards(exp): + rewards = 0 + time_penalties = 0 + hold_penalties = 0 + for s in exp.event_log: + if s[0] == 'reward': + rewards += 1 + elif s[0] == 'hold_penalty': + hold_penalties += 1 + elif s[0] == 'timeout_penalty': + time_penalties += 1 + return rewards, time_penalties, hold_penalties + if __name__ == '__main__': unittest.main() diff --git a/tests/unit_tests/coverage.bat b/tests/unit_tests (deprecated)/coverage.bat similarity index 100% rename from tests/unit_tests/coverage.bat rename to tests/unit_tests (deprecated)/coverage.bat diff --git a/tests/unit_tests/coverage.sh b/tests/unit_tests (deprecated)/coverage.sh old mode 100755 new mode 100644 similarity index 100% rename from tests/unit_tests/coverage.sh rename to tests/unit_tests (deprecated)/coverage.sh diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib___init___py.html new file mode 100644 index 00000000..9053a915 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib___init___py.html @@ -0,0 +1,64 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_arduino_imu_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_arduino_imu_py.html new file mode 100644 index 00000000..29cbd0aa --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_arduino_imu_py.html @@ -0,0 +1,196 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\arduino_imu.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1 

+

2import time 

+

3import itertools 

+

4import numpy as np 

+

5 

+

6import serial 

+

7from .source import DataSourceSystem 

+

8 

+

9 

+

10class System(DataSourceSystem): 

+

11 ''' 

+

12 Generic DataSourceSystem interface for the Phidgets board: http://www.phidgets.com/products.php?category=0&product_id=1018_2 

+

13 ''' 

+

14 update_freq = 20 

+

15 

+

16 def __init__(self, n_sensors=10, n_inputs=0): 

+

17 ''' 

+

18 Docstring 

+

19 

+

20 Parameters 

+

21 ---------- 

+

22 

+

23 Returns 

+

24 ------- 

+

25 ''' 

+

26 self.n_sensors = n_sensors 

+

27 self.n_inputs = n_inputs 

+

28 self.interval = 1. / self.update_freq 

+

29 

+

30 self.sensordat = np.zeros((n_sensors + 1,)) 

+

31 self.inputdat = np.zeros((n_inputs,), dtype=np.bool) 

+

32 self.data = np.zeros((1,), dtype=self.dtype) 

+

33 self.port = serial.Serial('/dev/arduino_joystick', baudrate=115200) 

+

34 self.port.flush() 

+

35 time.sleep(3.) 

+

36 self.t0 = time.time() 

+

37 def start(self): 

+

38 ''' 

+

39 Docstring 

+

40 

+

41 Parameters 

+

42 ---------- 

+

43 

+

44 Returns 

+

45 ------- 

+

46 ''' 

+

47 self.tic = time.time() 

+

48 

+

49 def stop(self): 

+

50 ''' 

+

51 Docstring 

+

52 

+

53 Parameters 

+

54 ---------- 

+

55 

+

56 Returns 

+

57 ------- 

+

58 ''' 

+

59 pass 

+

60 

+

61 def get(self): 

+

62 ''' 

+

63 Docstring 

+

64 

+

65 Parameters 

+

66 ---------- 

+

67 

+

68 Returns 

+

69 ------- 

+

70 ''' 

+

71 

+

72 toc = time.time() - self.tic 

+

73 if 0 < toc < self.interval: 

+

74 time.sleep(self.interval - toc) 

+

75 try: 

+

76 self.port.write('d') 

+

77 #time.sleep(self.interval/20.) 

+

78 for i in range(self.n_sensors): 

+

79 s = float(self.port.readline()) 

+

80 self.sensordat[i] = s 

+

81 self.sensordat[self.n_sensors] = time.time() 

+

82 except: 

+

83 print('sensor_error', self.tic - self.t0, self.sensordat, self.n_sensors) 

+

84 #self.port.flush() 

+

85 

+

86 self.data['sensors'] = self.sensordat 

+

87 self.data['inputs'] = self.inputdat 

+

88 self.tic = time.time() 

+

89 return self.data 

+

90 

+

91 def sendMsg(self, msg): 

+

92 ''' 

+

93 Docstring 

+

94 

+

95 Parameters 

+

96 ---------- 

+

97 

+

98 Returns 

+

99 ------- 

+

100 ''' 

+

101 pass 

+

102 

+

103 def __del__(self): 

+

104 ''' 

+

105 Docstring 

+

106 

+

107 Parameters 

+

108 ---------- 

+

109 

+

110 Returns 

+

111 ------- 

+

112 ''' 

+

113 self.port.close() 

+

114 

+

115def make(sensors, inputs, cls=System, **kwargs): 

+

116 ''' 

+

117 Docstring 

+

118 This ridiculous function dynamically creates a class with a new init function 

+

119 

+

120 Parameters 

+

121 ---------- 

+

122 

+

123 Returns 

+

124 ------- 

+

125 ''' 

+

126 def init(self, **kwargs): 

+

127 print('making arduino imu') 

+

128 super(self.__class__, self).__init__(n_sensors=sensors, n_inputs=inputs, **kwargs) 

+

129 print('making arduino imu2') 

+

130 

+

131 dtype = np.dtype([('sensors', np.float, (sensors+1,)), ('inputs', np.bool, (inputs,))]) 

+

132 return type(cls.__name__, (cls,), dict(dtype=dtype, __init__=init)) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_arduino_joystick_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_arduino_joystick_py.html new file mode 100644 index 00000000..ec497467 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_arduino_joystick_py.html @@ -0,0 +1,194 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\arduino_joystick.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1 

+

2import time 

+

3import itertools 

+

4import numpy as np 

+

5 

+

6import serial 

+

7from .source import DataSourceSystem 

+

8 

+

9 

+

10class System(DataSourceSystem): 

+

11 ''' 

+

12 Generic DataSourceSystem interface for the Phidgets board: http://www.phidgets.com/products.php?category=0&product_id=1018_2 

+

13 ''' 

+

14 update_freq = 1000 

+

15 

+

16 def __init__(self, n_sensors=2, n_inputs=1): 

+

17 ''' 

+

18 Docstring 

+

19 

+

20 Parameters 

+

21 ---------- 

+

22 

+

23 Returns 

+

24 ------- 

+

25 ''' 

+

26 self.n_sensors = n_sensors 

+

27 self.n_inputs = n_inputs 

+

28 self.interval = 1. / self.update_freq 

+

29 

+

30 self.sensordat = np.zeros((n_sensors,)) 

+

31 self.inputdat = np.zeros((n_inputs,), dtype=np.bool) 

+

32 self.data = np.zeros((1,), dtype=self.dtype) 

+

33 self.port = serial.Serial('/dev/arduino_joystick', baudrate=9600) 

+

34 self.port.flush() 

+

35 def start(self): 

+

36 ''' 

+

37 Docstring 

+

38 

+

39 Parameters 

+

40 ---------- 

+

41 

+

42 Returns 

+

43 ------- 

+

44 ''' 

+

45 self.tic = time.time() 

+

46 

+

47 def stop(self): 

+

48 ''' 

+

49 Docstring 

+

50 

+

51 Parameters 

+

52 ---------- 

+

53 

+

54 Returns 

+

55 ------- 

+

56 ''' 

+

57 pass 

+

58 

+

59 def get(self): 

+

60 ''' 

+

61 Docstring 

+

62 

+

63 Parameters 

+

64 ---------- 

+

65 

+

66 Returns 

+

67 ------- 

+

68 ''' 

+

69 

+

70 toc = time.time() - self.tic 

+

71 if 0 < toc < self.interval: 

+

72 time.sleep(self.interval - toc) 

+

73 try: 

+

74 for i in range(self.n_sensors): 

+

75 s = float(int(self.port.readline())) 

+

76 self.sensordat[i] = np.min([s, 1023.]) 

+

77 self.sensordat = self.sensordat/1023. 

+

78 x = self.sensordat[1] 

+

79 y = self.sensordat[0] 

+

80 self.sensordat = np.array([x, y]) 

+

81 except: 

+

82 print('sensor_error') 

+

83 

+

84 self.data['sensors'] = self.sensordat 

+

85 self.data['inputs'] = self.inputdat 

+

86 self.tic = time.time() 

+

87 return self.data 

+

88 

+

89 def sendMsg(self, msg): 

+

90 ''' 

+

91 Docstring 

+

92 

+

93 Parameters 

+

94 ---------- 

+

95 

+

96 Returns 

+

97 ------- 

+

98 ''' 

+

99 pass 

+

100 

+

101 def __del__(self): 

+

102 ''' 

+

103 Docstring 

+

104 

+

105 Parameters 

+

106 ---------- 

+

107 

+

108 Returns 

+

109 ------- 

+

110 ''' 

+

111 self.port.close() 

+

112 

+

113def make(sensors, inputs, cls=System, **kwargs): 

+

114 ''' 

+

115 Docstring 

+

116 This ridiculous function dynamically creates a class with a new init function 

+

117 

+

118 Parameters 

+

119 ---------- 

+

120 

+

121 Returns 

+

122 ------- 

+

123 ''' 

+

124 def init(self, **kwargs): 

+

125 print('making arduino joystick') 

+

126 super(self.__class__, self).__init__(n_sensors=sensors, n_inputs=inputs, **kwargs) 

+

127 print('making arduino joystick2') 

+

128 

+

129 dtype = np.dtype([('sensors', np.float, (sensors,)), ('inputs', np.bool, (inputs,))]) 

+

130 return type(cls.__name__, (cls,), dict(dtype=dtype, __init__=init)) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock___init___py.html new file mode 100644 index 00000000..1288a8fa --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock___init___py.html @@ -0,0 +1,130 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\blackrock\__init__.py: 52% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Extensions of the generic riglib.source.DataSourceSystem for getting Spikes/LFP data from the Blackrock NeuroPort system over the rig's internal network (UDP) 

+

3''' 

+

4import numpy as np 

+

5 

+

6from . import cerelink 

+

7from riglib.source import DataSourceSystem 

+

8 

+

9 

+

10class Spikes(DataSourceSystem): 

+

11 ''' 

+

12 For use with a DataSource in order to acquire streaming spike data from  

+

13 the Blackrock Neural Signal Processor (NSP). 

+

14 ''' 

+

15 

+

16 update_freq = 30000. 

+

17 dtype = np.dtype([("ts", np.float), 

+

18 ("chan", np.int32), 

+

19 ("unit", np.int32), 

+

20 ("arrival_ts", np.float64)]) 

+

21 

+

22 def __init__(self, channels): 

+

23 self.conn = cerelink.Connection() 

+

24 self.conn.connect() 

+

25 self.conn.select_channels(channels) 

+

26 

+

27 def start(self): 

+

28 self.conn.start_data() 

+

29 self.data = self.conn.get_event_data() 

+

30 

+

31 def stop(self): 

+

32 self.conn.stop_data() 

+

33 

+

34 def get(self): 

+

35 d = next(self.data) 

+

36 return np.array([(d.ts / self.update_freq, 

+

37 d.chan, 

+

38 d.unit, 

+

39 d.arrival_ts)], 

+

40 dtype=self.dtype) 

+

41 

+

42 

+

43class LFP(DataSourceSystem): 

+

44 ''' 

+

45 For use with a MultiChanDataSource in order to acquire streaming LFP  

+

46 data from the Blackrock Neural Signal Processor (NSP). 

+

47 ''' 

+

48 

+

49 update_freq = 1000 

+

50 dtype = np.dtype('float') 

+

51 

+

52 def __init__(self, channels): 

+

53 self.conn = cerelink.Connection() 

+

54 self.conn.connect() 

+

55 self.conn.select_channels(channels) 

+

56 

+

57 def start(self): 

+

58 self.conn.start_data() 

+

59 self.data = self.conn.get_continuous_data() 

+

60 

+

61 def stop(self): 

+

62 self.conn.stop_data() 

+

63 

+

64 def get(self): 

+

65 d = next(self.data) 

+

66 return (d.chan, d.samples) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_brMiscFxns_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_brMiscFxns_py.html new file mode 100644 index 00000000..3c27c466 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_brMiscFxns_py.html @@ -0,0 +1,139 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\blackrock\brMiscFxns.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1# -*- coding: utf-8 -*- 

+

2""" 

+

3Random functions that may be useful elsewhere (or necessary) 

+

4current version: 1.2.0 --- 08/04/2016 

+

5 

+

6@author: Mitch Frankel - Blackrock Microsystems 

+

7""" 

+

8 

+

9""" 

+

10Version History: 

+

11v1.0.0 - 07/05/2016 - initial release 

+

12v1.1.0 - 07/12/2016 - minor editing changes to print statements and addition of version control 

+

13v1.2.0 - 08/04/2016 - minor modifications to allow use of Python 2.6+ 

+

14""" 

+

15 

+

16from os import getcwd, path 

+

17 

+

18# Version control 

+

19brmiscfxns_ver = "1.2.0" 

+

20 

+

21# Patch for use with Python 2.6+ 

+

22try: input = raw_input 

+

23except NameError: pass 

+

24 

+

25def openfilecheck(open_mode, file_name='', file_ext='', file_type=''): 

+

26 """ 

+

27 :param open_mode: {str} method to open the file (e.g., 'rb' for binary read only) 

+

28 :param file_name: [optional] {str} full path of file to open 

+

29 :param file_ext: [optional] {str} file extension (e.g., '.nev') 

+

30 :param file_type: [optional] {str} file type for use when browsing for file (e.g., 'Blackrock NEV Files') 

+

31 :return: {file} opened file 

+

32 """ 

+

33 

+

34 while True: 

+

35 if not file_name: # no file name passed 

+

36 from qtpy.QtWidgets import QFileDialog, QApplication 

+

37 # Ask user to specify a file path or browse 

+

38 file_name = eval(input("Enter complete " + file_ext + " file path or hit enter to browse: ")) 

+

39 

+

40 if not file_name: 

+

41 if 'app' not in locals(): app = QApplication([]) 

+

42 if not file_ext: file_type = 'All Files' 

+

43 file_name = QFileDialog.getOpenFileName(QFileDialog(), "Select File", getcwd(), 

+

44 file_type + " (*" + file_ext + ")") 

+

45 

+

46 # Ensure file exists (really needed for users type entering) 

+

47 if path.isfile(file_name): 

+

48 # Ensure given file matches file_ext 

+

49 if file_ext: 

+

50 _, fext = path.splitext(file_name) 

+

51 

+

52 # check for * in extension 

+

53 if file_ext[-1] == '*': test_extension = file_ext[:-1] 

+

54 else: test_extension = file_ext 

+

55 

+

56 if fext[0:len(test_extension)] != test_extension: 

+

57 file_name = '' 

+

58 print(("\n*** File given is not a " + file_ext + " file, try again ***\n")) 

+

59 continue 

+

60 break 

+

61 else: 

+

62 file_name = '' 

+

63 print("\n*** File given does exist, try again ***\n") 

+

64 

+

65 print(('\n' + file_name.split('/')[-1] + ' opened')) 

+

66 return open(file_name, open_mode) 

+

67 

+

68 

+

69def checkequal(iterator): 

+

70 try: 

+

71 iterator = iter(iterator) 

+

72 first = next(iterator) 

+

73 return all(first == rest for rest in iterator) 

+

74 except StopIteration: 

+

75 return True 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_brpylib_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_brpylib_py.html new file mode 100644 index 00000000..1b0ec0a1 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_brpylib_py.html @@ -0,0 +1,1325 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\blackrock\brpylib.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1# -*- coding: utf-8 -*- 

+

2""" 

+

3Collection of classes used for reading headers and data from Blackrock files 

+

4current version: 1.3.2 --- 08/12/2016 

+

5 

+

6@author: Mitch Frankel - Blackrock Microsystems 

+

7 

+

8Version History: 

+

9v1.0.0 - 07/05/2016 - initial release - requires brMiscFxns v1.0.0 

+

10v1.1.0 - 07/08/2016 - inclusion of NsxFile.savesubsetnsx() for saving subset of Nsx data to disk4 

+

11v1.1.1 - 07/09/2016 - update to NsxFile.savesubsetnsx() for option (not)overwriting subset files if already exist 

+

12 bug fixes in NsxFile class as reported from beta user 

+

13v1.2.0 - 07/12/2016 - bug fixes in NsxFile.savesubsetnsx() 

+

14 added version control and checking for brMiscFxns 

+

15 requires brMiscFxns v1.1.0 

+

16v1.3.0 - 07/22/2016 - added 'samp_per_s' to NsxFile.getdata() output 

+

17 added close() method to NsxFile and NevFile objects 

+

18 NsxFile.getdata() now pre-allocates output['data'] as zeros - speed and safety 

+

19v1.3.1 - 08/02/2016 - bug fixes to NsxFile.getdata() for usability with Python 2.7 as reported from beta user 

+

20 patch for use with multiple NSP sync (overwriting of initial null data from initial data packet) 

+

21 __future__ import for use with Python 2.7 (division) 

+

22 minor modifications to allow use of Python 2.6+ 

+

23v1.3.2 - 08/12/2016 - bug fixes to NsXFile.getdata() 

+

24""" 

+

25 

+

26 

+

27 # for those using Python 2.6+ 

+

28import numpy as np 

+

29from collections import namedtuple 

+

30from datetime import datetime 

+

31from math import ceil 

+

32from os import path as ospath 

+

33from struct import calcsize, pack, unpack, unpack_from 

+

34from .brMiscFxns import openfilecheck, brmiscfxns_ver 

+

35 

+

36# Version control set/check 

+

37brpylib_ver = "1.3.2" 

+

38brmiscfxns_ver_req = "1.2.0" 

+

39if brmiscfxns_ver.split('.') < brmiscfxns_ver_req.split('.'): 

+

40 raise Exception("brpylib requires brMiscFxns " + brmiscfxns_ver_req + " or higher, please use latest version") 

+

41 

+

42# Patch for use with Python 2.6+ 

+

43try: input = raw_input 

+

44except NameError: pass 

+

45 

+

46# Define global variables to remove magic numbers 

+

47# <editor-fold desc="Globals"> 

+

48WARNING_SLEEP_TIME = 5 

+

49DATA_PAGING_SIZE = 1024**3 

+

50DATA_FILE_SIZE_MIN = 1024**2 * 10 

+

51STRING_TERMINUS = '\x00' 

+

52UNDEFINED = 0 

+

53ELEC_ID_DEF = 'all' 

+

54START_TIME_DEF = 0 

+

55DATA_TIME_DEF = 'all' 

+

56DOWNSAMPLE_DEF = 1 

+

57START_OFFSET_MIN = 0 

+

58STOP_OFFSET_MIN = 0 

+

59 

+

60UV_PER_BIT_21 = 0.25 

+

61WAVEFORM_SAMPLES_21 = 48 

+

62NSX_BASIC_HEADER_BYTES_22 = 314 

+

63NSX_EXT_HEADER_BYTES_22 = 66 

+

64DATA_BYTE_SIZE = 2 

+

65TIMESTAMP_NULL_21 = 0 

+

66 

+

67NO_FILTER = 0 

+

68BUTTER_FILTER = 1 

+

69SERIAL_MODE = 0 

+

70 

+

71RB2D_MARKER = 1 

+

72RB2D_BLOB = 2 

+

73RB3D_MARKER = 3 

+

74BOUNDARY_2D = 4 

+

75MARKER_SIZE = 5 

+

76 

+

77DIGITAL_PACKET_ID = 0 

+

78NEURAL_PACKET_ID_MIN = 1 

+

79NEURAL_PACKET_ID_MAX = 2048 

+

80COMMENT_PACKET_ID = 65535 

+

81VIDEO_SYNC_PACKET_ID = 65534 

+

82TRACKING_PACKET_ID = 65533 

+

83BUTTON_PACKET_ID = 65532 

+

84CONFIGURATION_PACKET_ID = 65531 

+

85 

+

86PARALLEL_REASON = 1 

+

87PERIODIC_REASON = 64 

+

88SERIAL_REASON = 129 

+

89LOWER_BYTE_MASK = 255 

+

90FIRST_BIT_MASK = 1 

+

91SECOND_BIT_MASK = 2 

+

92 

+

93CLASSIFIER_MIN = 1 

+

94CLASSIFIER_MAX = 16 

+

95CLASSIFIER_NOISE = 255 

+

96 

+

97CHARSET_ANSI = 0 

+

98CHARSET_UTF = 1 

+

99CHARSET_ROI = 255 

+

100 

+

101COMM_RGBA = 0 

+

102COMM_TIME = 1 

+

103 

+

104BUTTON_PRESS = 1 

+

105BUTTON_RESET = 2 

+

106 

+

107CHG_NORMAL = 0 

+

108CHG_CRITICAL = 1 

+

109 

+

110ENTER_EVENT = 1 

+

111EXIT_EVENT = 2 

+

112# </editor-fold> 

+

113 

+

114# Define a named tuple that has information about header/packet fields 

+

115FieldDef = namedtuple('FieldDef', ['name', 'formatStr', 'formatFnc']) 

+

116 

+

117 

+

118# <editor-fold desc="Header processing functions"> 

+

119def processheaders(curr_file, packet_fields): 

+

120 """ 

+

121 :param curr_file: {file} the current BR datafile to be processed 

+

122 :param packet_fields : {named tuple} the specific binary fields for the given header 

+

123 :return: a fully unpacked and formatted tuple set of header information 

+

124 

+

125 Read a packet from a binary data file and return a list of fields 

+

126 The amount and format of data read will be specified by the 

+

127 packet_fields container 

+

128 """ 

+

129 

+

130 # This is a lot in one line. First I pull out all the format strings from 

+

131 # the basic_header_fields named tuple, then concatenate them into a string 

+

132 # with '<' at the front (for little endian format) 

+

133 packet_format_str = '<' + ''.join([fmt for name, fmt, fun in packet_fields]) 

+

134 

+

135 # Calculate how many bytes to read based on the format strings of the header fields 

+

136 bytes_in_packet = calcsize(packet_format_str) 

+

137 packet_binary = curr_file.read(bytes_in_packet) 

+

138 

+

139 # unpack the binary data from the header based on the format strings of each field. 

+

140 # This returns a list of data, but it's not always correctly formatted (eg, FileSpec 

+

141 # is read as ints 2 and 3 but I want it as '2.3' 

+

142 packet_unpacked = unpack(packet_format_str, packet_binary) 

+

143 

+

144 # Create a iterator from the data list. This allows a formatting function 

+

145 # to use more than one item from the list if needed, and the next formatting 

+

146 # function can pick up on the correct item in the list 

+

147 data_iter = iter(packet_unpacked) 

+

148 

+

149 # create an empty dictionary from the name field of the packet_fields. 

+

150 # The loop below will fill in the values with formatted data by calling 

+

151 # each field's formatting function 

+

152 packet_formatted = dict.fromkeys([name for name, fmt, fun in packet_fields]) 

+

153 for name, fmt, fun in packet_fields: 

+

154 packet_formatted[name] = fun(data_iter) 

+

155 

+

156 return packet_formatted 

+

157 

+

158 

+

159def format_filespec(header_list): 

+

160 return str(next(header_list)) + '.' + str(next(header_list)) # eg 2.3 

+

161 

+

162 

+

163def format_timeorigin(header_list): 

+

164 year = next(header_list) 

+

165 month = next(header_list) 

+

166 _ = next(header_list) 

+

167 day = next(header_list) 

+

168 hour = next(header_list) 

+

169 minute = next(header_list) 

+

170 second = next(header_list) 

+

171 millisecond = next(header_list) 

+

172 return datetime(year, month, day, hour, minute, second, millisecond * 1000) 

+

173 

+

174 

+

175def format_stripstring(header_list): 

+

176 string = bytes.decode(next(header_list), 'latin-1') 

+

177 return string.split(STRING_TERMINUS, 1)[0] 

+

178 

+

179 

+

180def format_none(header_list): 

+

181 return next(header_list) 

+

182 

+

183 

+

184def format_freq(header_list): 

+

185 return str(float(next(header_list)) / 1000) + ' Hz' 

+

186 

+

187 

+

188def format_filter(header_list): 

+

189 filter_type = next(header_list) 

+

190 if filter_type == NO_FILTER: return "none" 

+

191 elif filter_type == BUTTER_FILTER: return "butterworth" 

+

192 

+

193 

+

194def format_charstring(header_list): 

+

195 return int(next(header_list)) 

+

196 

+

197 

+

198def format_digconfig(header_list): 

+

199 config = next(header_list) & FIRST_BIT_MASK 

+

200 if config: return 'active' 

+

201 else: return 'ignored' 

+

202 

+

203 

+

204def format_anaconfig(header_list): 

+

205 config = next(header_list) 

+

206 if config & FIRST_BIT_MASK: return 'low_to_high' 

+

207 if config & SECOND_BIT_MASK: return 'high_to_low' 

+

208 else: return 'none' 

+

209 

+

210 

+

211def format_digmode(header_list): 

+

212 dig_mode = next(header_list) 

+

213 if dig_mode == SERIAL_MODE: return 'serial' 

+

214 else: return 'parallel' 

+

215 

+

216 

+

217def format_trackobjtype(header_list): 

+

218 trackobj_type = next(header_list) 

+

219 if trackobj_type == UNDEFINED: return 'undefined' 

+

220 elif trackobj_type == RB2D_MARKER: return '2D RB markers' 

+

221 elif trackobj_type == RB2D_BLOB: return '2D RB blob' 

+

222 elif trackobj_type == RB3D_MARKER: return '3D RB markers' 

+

223 elif trackobj_type == BOUNDARY_2D: return '2D boundary' 

+

224 elif trackobj_type == MARKER_SIZE: return 'marker size' 

+

225 else: return 'error' 

+

226 

+

227 

+

228def getdigfactor(ext_headers, idx): 

+

229 max_analog = ext_headers[idx]['MaxAnalogValue'] 

+

230 min_analog = ext_headers[idx]['MinAnalogValue'] 

+

231 max_digital = ext_headers[idx]['MaxDigitalValue'] 

+

232 min_digital = ext_headers[idx]['MinDigitalValue'] 

+

233 return float(max_analog - min_analog) / float(max_digital - min_digital) 

+

234# </editor-fold> 

+

235 

+

236 

+

237# <editor-fold desc="Header dictionaries"> 

+

238nev_header_dict = { 

+

239 'basic': [FieldDef('FileTypeID', '8s', format_stripstring), # 8 bytes - 8 char array 

+

240 FieldDef('FileSpec', '2B', format_filespec), # 2 bytes - 2 unsigned char 

+

241 FieldDef('AddFlags', 'H', format_none), # 2 bytes - uint16 

+

242 FieldDef('BytesInHeader', 'I', format_none), # 4 bytes - uint32 

+

243 FieldDef('BytesInDataPackets', 'I', format_none), # 4 bytes - uint32 

+

244 FieldDef('TimeStampResolution', 'I', format_none), # 4 bytes - uint32 

+

245 FieldDef('SampleTimeResolution', 'I', format_none), # 4 bytes - uint32 

+

246 FieldDef('TimeOrigin', '8H', format_timeorigin), # 16 bytes - 8 x uint16 

+

247 FieldDef('CreatingApplication', '32s', format_stripstring), # 32 bytes - 32 char array 

+

248 FieldDef('Comment', '256s', format_stripstring), # 256 bytes - 256 char array 

+

249 FieldDef('NumExtendedHeaders', 'I', format_none)], # 4 bytes - uint32 

+

250 

+

251 'ARRAYNME': FieldDef('ArrayName', '24s', format_stripstring), # 24 bytes - 24 char array 

+

252 'ECOMMENT': FieldDef('ExtraComment', '24s', format_stripstring), # 24 bytes - 24 char array 

+

253 'CCOMMENT': FieldDef('ContComment', '24s', format_stripstring), # 24 bytes - 24 char array 

+

254 'MAPFILE': FieldDef('MapFile', '24s', format_stripstring), # 24 bytes - 24 char array 

+

255 

+

256 'NEUEVWAV': [FieldDef('ElectrodeID', 'H', format_none), # 2 bytes - uint16 

+

257 FieldDef('PhysicalConnector', 'B', format_charstring), # 1 byte - 1 unsigned char 

+

258 FieldDef('ConnectorPin', 'B', format_charstring), # 1 byte - 1 unsigned char 

+

259 FieldDef('DigitizationFactor', 'H', format_none), # 2 bytes - uint16 

+

260 FieldDef('EnergyThreshold', 'H', format_none), # 2 bytes - uint16 

+

261 FieldDef('HighThreshold', 'h', format_none), # 2 bytes - int16 

+

262 FieldDef('LowThreshold', 'h', format_none), # 2 bytes - int16 

+

263 FieldDef('NumSortedUnits', 'B', format_charstring), # 1 byte - 1 unsigned char 

+

264 FieldDef('BytesPerWaveform', 'B', format_charstring), # 1 byte - 1 unsigned char 

+

265 FieldDef('SpikeWidthSamples', 'H', format_none), # 2 bytes - uint16 

+

266 FieldDef('EmptyBytes', '8s', format_none)], # 8 bytes - empty 

+

267 

+

268 'NEUEVLBL': [FieldDef('ElectrodeID', 'H', format_none), # 2 bytes - uint16 

+

269 FieldDef('Label', '16s', format_stripstring), # 16 bytes - 16 char array 

+

270 FieldDef('EmptyBytes', '6s', format_none)], # 6 bytes - empty 

+

271 

+

272 'NEUEVFLT': [FieldDef('ElectrodeID', 'H', format_none), # 2 bytes - uint16 

+

273 FieldDef('HighFreqCorner', 'I', format_freq), # 4 bytes - uint32 

+

274 FieldDef('HighFreqOrder', 'I', format_none), # 4 bytes - uint32 

+

275 FieldDef('HighFreqType', 'H', format_filter), # 2 bytes - uint16 

+

276 FieldDef('LowFreqCorner', 'I', format_freq), # 4 bytes - uint32 

+

277 FieldDef('LowFreqOrder', 'I', format_none), # 4 bytes - uint32 

+

278 FieldDef('LowFreqType', 'H', format_filter), # 2 bytes - uint16 

+

279 FieldDef('EmptyBytes', '2s', format_none)], # 2 bytes - empty 

+

280 

+

281 'DIGLABEL': [FieldDef('Label', '16s', format_stripstring), # 16 bytes - 16 char array 

+

282 FieldDef('Mode', '?', format_digmode), # 1 byte - boolean 

+

283 FieldDef('EmptyBytes', '7s', format_none)], # 7 bytes - empty 

+

284 

+

285 'NSASEXEV': [FieldDef('Frequency', 'H', format_none), # 2 bytes - uint16 

+

286 FieldDef('DigitalInputConfig', 'B', format_digconfig), # 1 byte - 1 unsigned char 

+

287 FieldDef('AnalogCh1Config', 'B', format_anaconfig), # 1 byte - 1 unsigned char 

+

288 FieldDef('AnalogCh1DetectVal', 'h', format_none), # 2 bytes - int16 

+

289 FieldDef('AnalogCh2Config', 'B', format_anaconfig), # 1 byte - 1 unsigned char 

+

290 FieldDef('AnalogCh2DetectVal', 'h', format_none), # 2 bytes - int16 

+

291 FieldDef('AnalogCh3Config', 'B', format_anaconfig), # 1 byte - 1 unsigned char 

+

292 FieldDef('AnalogCh3DetectVal', 'h', format_none), # 2 bytes - int16 

+

293 FieldDef('AnalogCh4Config', 'B', format_anaconfig), # 1 byte - 1 unsigned char 

+

294 FieldDef('AnalogCh4DetectVal', 'h', format_none), # 2 bytes - int16 

+

295 FieldDef('AnalogCh5Config', 'B', format_anaconfig), # 1 byte - 1 unsigned char 

+

296 FieldDef('AnalogCh5DetectVal', 'h', format_none), # 2 bytes - int16 

+

297 FieldDef('EmptyBytes', '6s', format_none)], # 2 bytes - empty 

+

298 

+

299 'VIDEOSYN': [FieldDef('VideoSourceID', 'H', format_none), # 2 bytes - uint16 

+

300 FieldDef('VideoSource', '16s', format_stripstring), # 16 bytes - 16 char array 

+

301 FieldDef('FrameRate', 'f', format_none), # 4 bytes - single float 

+

302 FieldDef('EmptyBytes', '2s', format_none)], # 2 bytes - empty 

+

303 

+

304 'TRACKOBJ': [FieldDef('TrackableType', 'H', format_trackobjtype), # 2 bytes - uint16 

+

305 FieldDef('TrackableID', 'H', format_none), # 2 bytes - uint16 

+

306 FieldDef('PointCount', 'H', format_none), # 2 bytes - uint16 

+

307 FieldDef('VideoSource', '16s', format_stripstring), # 16 bytes - 16 char array 

+

308 FieldDef('EmptyBytes', '2s', format_none)] # 2 bytes - empty 

+

309} 

+

310 

+

311nsx_header_dict = { 

+

312 'basic_21': [FieldDef('Label', '16s', format_stripstring), # 16 bytes - 16 char array 

+

313 FieldDef('Period', 'I', format_none), # 4 bytes - uint32 

+

314 FieldDef('ChannelCount', 'I', format_none)], # 4 bytes - uint32 

+

315 

+

316 'basic': [FieldDef('FileSpec', '2B', format_filespec), # 2 bytes - 2 unsigned char 

+

317 FieldDef('BytesInHeader', 'I', format_none), # 4 bytes - uint32 

+

318 FieldDef('Label', '16s', format_stripstring), # 16 bytes - 16 char array 

+

319 FieldDef('Comment', '256s', format_stripstring), # 256 bytes - 256 char array 

+

320 FieldDef('Period', 'I', format_none), # 4 bytes - uint32 

+

321 FieldDef('TimeStampResolution', 'I', format_none), # 4 bytes - uint32 

+

322 FieldDef('TimeOrigin', '8H', format_timeorigin), # 16 bytes - 8 uint16 

+

323 FieldDef('ChannelCount', 'I', format_none)], # 4 bytes - uint32 

+

324 

+

325 'extended': [FieldDef('Type', '2s', format_stripstring), # 2 bytes - 2 char array 

+

326 FieldDef('ElectrodeID', 'H', format_none), # 2 bytes - uint16 

+

327 FieldDef('ElectrodeLabel', '16s', format_stripstring), # 16 bytes - 16 char array 

+

328 FieldDef('PhysicalConnector', 'B', format_none), # 1 byte - uint8 

+

329 FieldDef('ConnectorPin', 'B', format_none), # 1 byte - uint8 

+

330 FieldDef('MinDigitalValue', 'h', format_none), # 2 bytes - int16 

+

331 FieldDef('MaxDigitalValue', 'h', format_none), # 2 bytes - int16 

+

332 FieldDef('MinAnalogValue', 'h', format_none), # 2 bytes - int16 

+

333 FieldDef('MaxAnalogValue', 'h', format_none), # 2 bytes - int16 

+

334 FieldDef('Units', '16s', format_stripstring), # 16 bytes - 16 char array 

+

335 FieldDef('HighFreqCorner', 'I', format_freq), # 4 bytes - uint32 

+

336 FieldDef('HighFreqOrder', 'I', format_none), # 4 bytes - uint32 

+

337 FieldDef('HighFreqType', 'H', format_filter), # 2 bytes - uint16 

+

338 FieldDef('LowFreqCorner', 'I', format_freq), # 4 bytes - uint32 

+

339 FieldDef('LowFreqOrder', 'I', format_none), # 4 bytes - uint32 

+

340 FieldDef('LowFreqType', 'H', format_filter)], # 2 bytes - uint16 

+

341 

+

342 'data': [FieldDef('Header', 'B', format_none), # 1 byte - uint8 

+

343 FieldDef('Timestamp', 'I', format_none), # 4 bytes - uint32 

+

344 FieldDef('NumDataPoints', 'I', format_none)] # 4 bytes - uint32] 

+

345} 

+

346# </editor-fold> 

+

347 

+

348 

+

349# <editor-fold desc="Safety check functions"> 

+

350def check_elecid(elec_ids): 

+

351 if type(elec_ids) is str and elec_ids != ELEC_ID_DEF: 

+

352 print("\n*** WARNING: Electrode IDs must be 'all', a single integer, or a list of integers.") 

+

353 print(" Setting elec_ids to 'all'") 

+

354 elec_ids = ELEC_ID_DEF 

+

355 if elec_ids != ELEC_ID_DEF and type(elec_ids) is not list: 

+

356 if type(elec_ids) == range: elec_ids = list(elec_ids) 

+

357 elif type(elec_ids) == int: elec_ids = [elec_ids] 

+

358 return elec_ids 

+

359 

+

360 

+

361def check_starttime(start_time_s): 

+

362 if not isinstance(start_time_s, (int, float)) or \ 

+

363 (isinstance(start_time_s, (int, float)) and start_time_s < START_TIME_DEF): 

+

364 print("\n*** WARNING: Start time is not valid, setting start_time_s to 0") 

+

365 start_time_s = START_TIME_DEF 

+

366 return start_time_s 

+

367 

+

368 

+

369def check_datatime(data_time_s): 

+

370 if (type(data_time_s) is str and data_time_s != DATA_TIME_DEF) or \ 

+

371 (isinstance(data_time_s, (int, float)) and data_time_s < 0): 

+

372 print("\n*** WARNING: Data time is not valid, setting data_time_s to 'all'") 

+

373 data_time_s = DATA_TIME_DEF 

+

374 return data_time_s 

+

375 

+

376 

+

377def check_downsample(downsample): 

+

378 if not isinstance(downsample, int) or downsample < DOWNSAMPLE_DEF: 

+

379 print("\n*** WARNING: Downsample must be an integer value greater than 0. " 

+

380 " Setting downsample to 1 (no downsampling)") 

+

381 downsample = DOWNSAMPLE_DEF 

+

382 return downsample 

+

383 

+

384 

+

385def check_dataelecid(elec_ids, all_elec_ids): 

+

386 unique_elec_ids = set(elec_ids) 

+

387 all_elec_ids = set(all_elec_ids) 

+

388 

+

389 # if some electrodes asked for don't exist, reset list with those that do, or throw error and return 

+

390 if not unique_elec_ids.issubset(all_elec_ids): 

+

391 if not unique_elec_ids & all_elec_ids: 

+

392 print('\nNone of the elec_ids passed exist in the data, returning None') 

+

393 return None 

+

394 else: 

+

395 print(("\n*** WARNING: Channels " + str(sorted(list(unique_elec_ids - all_elec_ids))) + 

+

396 " do not exist in the data")) 

+

397 unique_elec_ids = unique_elec_ids & all_elec_ids 

+

398 

+

399 return sorted(list(unique_elec_ids)) 

+

400 

+

401 

+

402def check_filesize(file_size): 

+

403 if file_size < DATA_FILE_SIZE_MIN: 

+

404 print('\n file_size must be larger than 10 Mb, setting file_size=10 Mb') 

+

405 return DATA_FILE_SIZE_MIN 

+

406 else: 

+

407 return int(file_size) 

+

408# </editor-fold> 

+

409 

+

410 

+

411class NevFile: 

+

412 """ 

+

413 attributes and methods for all BR event data files. Initialization opens the file and extracts the 

+

414 basic header information. 

+

415 """ 

+

416 

+

417 def __init__(self, datafile=''): 

+

418 self.datafile = datafile 

+

419 self.basic_header = {} 

+

420 self.extended_headers = [] 

+

421 

+

422 # Run openfilecheck and open the file passed or allow user to browse to one 

+

423 self.datafile = openfilecheck('rb', file_name=self.datafile, file_ext='.nev', file_type='Blackrock NEV Files') 

+

424 

+

425 # extract basic header information 

+

426 self.basic_header = processheaders(self.datafile, nev_header_dict['basic']) 

+

427 

+

428 # Extract extended headers 

+

429 for i in range(self.basic_header['NumExtendedHeaders']): 

+

430 self.extended_headers.append({}) 

+

431 header_string = bytes.decode(unpack('<8s', self.datafile.read(8))[0], 'latin-1') 

+

432 self.extended_headers[i]['PacketID'] = header_string.split(STRING_TERMINUS, 1)[0] 

+

433 self.extended_headers[i].update( 

+

434 processheaders(self.datafile, nev_header_dict[self.extended_headers[i]['PacketID']])) 

+

435 

+

436 # Must set this for file spec 2.1 and 2.2 

+

437 if header_string == 'NEUEVWAV' and float(self.basic_header['FileSpec']) < 2.3: 

+

438 self.extended_headers[i]['SpikeWidthSamples'] = WAVEFORM_SAMPLES_21 

+

439 

+

440 def getdata(self, elec_ids='all'): 

+

441 """ 

+

442 This function is used to return a set of data from the NSx datafile. 

+

443 

+

444 :param elec_ids: [optional] {list} User selection of elec_ids to extract specific spike waveforms (e.g., [13]) 

+

445 :return: output: {Dictionary} with one or more of the following dictionaries (all include TimeStamps) 

+

446 dig_events: Reason, Data, [for file spec 2.2 and below, AnalogData and AnalogDataUnits] 

+

447 spike_events: Units='nV', ChannelID, NEUEVWAV_HeaderIndices, Classification, Waveforms 

+

448 comments: CharSet, Flag, Data, Comment 

+

449 video_sync_events: VideoFileNum, VideoFrameNum, VideoElapsedTime_ms, VideoSourceID 

+

450 tracking_events: ParentID, NodeID, NodeCount, PointCount, TrackingPoints 

+

451 button_trigger_events: TriggerType 

+

452 configuration_events: ConfigChangeType, ConfigChanged 

+

453 

+

454 Note: For digital and neural data - TimeStamps, Classification, and Data can be lists of lists when more 

+

455 than one digital type or spike event exists for a channel 

+

456 """ 

+

457 

+

458 # Initialize output dictionary and reset position in file (if read before, may not be here anymore) 

+

459 output = dict() 

+

460 self.datafile.seek(self.basic_header['BytesInHeader'], 0) 

+

461 

+

462 # Safety checks 

+

463 elec_ids = check_elecid(elec_ids) 

+

464 

+

465 # Must go through each data packet and process separately until end of file 

+

466 while self.datafile.tell() != ospath.getsize(self.datafile.name): 

+

467 

+

468 time_stamp = unpack('<I', self.datafile.read(4))[0] 

+

469 packet_id = unpack('<H', self.datafile.read(2))[0] 

+

470 

+

471 # skip unwanted neural data packets if only asking for certain channels 

+

472 if not (elec_ids == 'all' or ( (packet_id in elec_ids) and 

+

473 NEURAL_PACKET_ID_MIN <= packet_id <= NEURAL_PACKET_ID_MAX )): 

+

474 self.datafile.seek(self.basic_header['BytesInDataPackets'] - 6, 1) 

+

475 continue 

+

476 

+

477 # For digital event data, read reason, skip one byte (reserved), read digital value, 

+

478 # and skip X bytes (reserved) 

+

479 if packet_id == DIGITAL_PACKET_ID: 

+

480 

+

481 # See if the dictionary exists in output 

+

482 if 'dig_events' not in output: 

+

483 output['dig_events'] = {'Reason': [], 'TimeStamps': [], 'Data': []} 

+

484 

+

485 reason = unpack('B', self.datafile.read(1))[0] 

+

486 if reason == PARALLEL_REASON: reason = 'parallel' 

+

487 elif reason == PERIODIC_REASON: reason = 'periodic' 

+

488 elif reason == SERIAL_REASON: reason = 'serial' 

+

489 else: reason = 'unknown' 

+

490 self.datafile.seek(1, 1) 

+

491 

+

492 # Check if this type of data already exists, if not, create an empty list, and then append data 

+

493 if reason in output['dig_events']['Reason']: 

+

494 idx = output['dig_events']['Reason'].index(reason) 

+

495 else: 

+

496 idx = -1 

+

497 output['dig_events']['Reason'].append(reason) 

+

498 output['dig_events']['TimeStamps'].append([]) 

+

499 output['dig_events']['Data'].append([]) 

+

500 

+

501 output['dig_events']['TimeStamps'][idx].append(time_stamp) 

+

502 output['dig_events']['Data'][idx].append(unpack('<H', self.datafile.read(2))[0]) 

+

503 

+

504 # For serial data, strip off upper byte 

+

505 if reason == 'serial': 

+

506 output['dig_events']['Data'][idx][-1] &= LOWER_BYTE_MASK 

+

507 

+

508 # For File Spec < 2.3, also capture analog Data, otherwise skip remaining packet bytes 

+

509 if float(self.basic_header['FileSpec']) < 2.3: 

+

510 if 'AnalogDataUnits' not in output['dig_events']: 

+

511 output['dig_events']['AnalogDataUnits'] = 'mv' 

+

512 

+

513 output['dig_events']['AnalogData'].append([]) 

+

514 for j in range(5): 

+

515 output['dig_events']['AnalogData'][-1].append(unpack('<h', self.datafile.read(2))[0]) 

+

516 else: 

+

517 self.datafile.seek(self.basic_header['BytesInDataPackets'] - 10, 1) 

+

518 

+

519 # For neural waveforms, read classifier, skip one byte (reserved), and read waveform data 

+

520 elif NEURAL_PACKET_ID_MIN <= packet_id <= NEURAL_PACKET_ID_MAX: 

+

521 

+

522 # See if the dictionary exists in output, if not, create it 

+

523 if 'spike_events' not in output: 

+

524 output['spike_events'] = {'Units': 'nV', 'ChannelID': [], 'TimeStamps': [], 

+

525 'NEUEVWAV_HeaderIndices': [], 'Classification': [], 'Waveforms': []} 

+

526 

+

527 classifier = unpack('B', self.datafile.read(1))[0] 

+

528 if classifier == UNDEFINED: classifier = 'none' 

+

529 elif CLASSIFIER_MIN <= classifier <= CLASSIFIER_MAX: classifier = classifier 

+

530 elif classifier == CLASSIFIER_NOISE: classifier = 'noise' 

+

531 else: classifier = 'error' 

+

532 self.datafile.seek(1, 1) 

+

533 

+

534 # Check if data for this electrode exists and update parameters accordingly 

+

535 if packet_id in output['spike_events']['ChannelID']: 

+

536 idx = output['spike_events']['ChannelID'].index(packet_id) 

+

537 else: 

+

538 idx = -1 

+

539 output['spike_events']['ChannelID'].append(packet_id) 

+

540 output['spike_events']['TimeStamps'].append([]) 

+

541 output['spike_events']['Classification'].append([]) 

+

542 

+

543 # Find neuevwav extended header for this electrode for use in calculating data info 

+

544 output['spike_events']['NEUEVWAV_HeaderIndices'].append( 

+

545 next(item for (item, d) in enumerate(self.extended_headers) 

+

546 if d["ElectrodeID"] == packet_id and d["PacketID"] == 'NEUEVWAV')) 

+

547 

+

548 output['spike_events']['TimeStamps'][idx].append(time_stamp) 

+

549 output['spike_events']['Classification'][idx].append(classifier) 

+

550 

+

551 # Use extended header idx to get specific data information 

+

552 ext_header_idx = output['spike_events']['NEUEVWAV_HeaderIndices'][idx] 

+

553 samples = self.extended_headers[ext_header_idx]['SpikeWidthSamples'] 

+

554 dig_factor = self.extended_headers[ext_header_idx]['DigitizationFactor'] 

+

555 num_bytes = self.extended_headers[ext_header_idx]['BytesPerWaveform'] 

+

556 if num_bytes <= 1: data_type = np.int8 

+

557 elif num_bytes == 2: data_type = np.int16 

+

558 

+

559 # Extract and scale the data 

+

560 if idx == -1: 

+

561 output['spike_events']['Waveforms'].append( 

+

562 [np.fromfile(file=self.datafile, dtype=data_type, count=samples).astype(np.int32) * dig_factor]) 

+

563 else: 

+

564 try: 

+

565 output['spike_events']['Waveforms'][idx] = \ 

+

566 np.append(output['spike_events']['Waveforms'][idx], 

+

567 [np.fromfile(file=self.datafile, dtype=data_type, count=samples).astype(np.int32) * 

+

568 dig_factor], axis=0) 

+

569 except: 

+

570 output['spike_events']['Waveforms'][idx] = \ 

+

571 np.append(output['spike_events']['Waveforms'][idx], 

+

572 [np.zeros((samples,)).astype(np.int32) * dig_factor], axis=0) 

+

573 print('adding zero waveform') 

+

574 

+

575 # For comment events 

+

576 elif packet_id == COMMENT_PACKET_ID: 

+

577 

+

578 # See if the dictionary exists in output, if not, create it 

+

579 if 'comments' not in output: 

+

580 output['comments'] = {'TimeStamps': [], 'CharSet': [], 'Flag': [], 'Data': [], 'Comment': []} 

+

581 

+

582 output['comments']['TimeStamps'].append(time_stamp) 

+

583 

+

584 char_set = unpack('B', self.datafile.read(1))[0] 

+

585 if char_set == CHARSET_ANSI: output['comments']['CharSet'].append('ANSI') 

+

586 elif char_set == CHARSET_UTF: output['comments']['CharSet'].append('UTF-16') 

+

587 elif char_set == CHARSET_ROI: output['comments']['CharSet'].append('NeuroMotive ROI') 

+

588 else: output['comments']['CharSet'].append('error') 

+

589 

+

590 comm_flag = unpack('B', self.datafile.read(1))[0] 

+

591 if comm_flag == COMM_RGBA: output['comments']['Flag'].append('RGBA color code') 

+

592 elif comm_flag == COMM_TIME: output['comments']['Flag'].append('timestamp') 

+

593 else: output['comments']['Flag'].append('error') 

+

594 

+

595 output['comments']['Data'].append(unpack('<I', self.datafile.read(4))[0]) 

+

596 

+

597 samples = self.basic_header['BytesInDataPackets'] - 12 

+

598 comm_string = bytes.decode(self.datafile.read(samples), 'latin-1') 

+

599 output['comments']['Comment'].append(comm_string.split(STRING_TERMINUS, 1)[0]) 

+

600 

+

601 # For video sync event 

+

602 elif packet_id == VIDEO_SYNC_PACKET_ID: 

+

603 

+

604 # See if the dictionary exists in output, if not, create it 

+

605 if 'video_sync_events' not in output: 

+

606 output['video_sync_events'] = {'TimeStamps': [], 'VideoFileNum': [], 'VideoFrameNum': [], 

+

607 'VideoElapsedTime_ms': [], 'VideoSourceID': []} 

+

608 

+

609 output['video_sync_events']['TimeStamps'].append( time_stamp) 

+

610 output['video_sync_events']['VideoFileNum'].append( unpack('<H', self.datafile.read(2))[0]) 

+

611 output['video_sync_events']['VideoFrameNum'].append( unpack('<I', self.datafile.read(4))[0]) 

+

612 output['video_sync_events']['VideoElapsedTime_ms'].append( unpack('<I', self.datafile.read(4))[0]) 

+

613 output['video_sync_events']['VideoSourceID'].append( unpack('<I', self.datafile.read(4))[0]) 

+

614 self.datafile.seek((self.basic_header['BytesInDataPackets'] - 20), 1) 

+

615 

+

616 # For tracking event 

+

617 elif packet_id == TRACKING_PACKET_ID: 

+

618 

+

619 # See if the dictionary exists in output, if not, create it 

+

620 if 'tracking_events' not in output: 

+

621 output['tracking_events'] = {'TimeStamps': [], 'ParentID': [], 'NodeID': [], 'NodeCount': [], 

+

622 'PointCount': [], 'TrackingPoints': []} 

+

623 

+

624 output['tracking_events']['TimeStamps'].append( time_stamp) 

+

625 output['tracking_events']['ParentID'].append( unpack('<H', self.datafile.read(2))[0]) 

+

626 output['tracking_events']['NodeID'].append( unpack('<H', self.datafile.read(2))[0]) 

+

627 output['tracking_events']['NodeCount'].append( unpack('<H', self.datafile.read(2))[0]) 

+

628 output['tracking_events']['PointCount'].append( unpack('<H', self.datafile.read(2))[0]) 

+

629 samples = (self.basic_header['BytesInDataPackets'] - 14) // 2 

+

630 output['tracking_events']['TrackingPoints'].append( 

+

631 np.fromfile(file=self.datafile, dtype=np.uint16, count=samples)) 

+

632 

+

633 # For button trigger event 

+

634 elif packet_id == BUTTON_PACKET_ID: 

+

635 

+

636 # See if the dictionary exists in output, if not, create it 

+

637 if 'button_trigger_events' not in output: 

+

638 output['button_trigger_events'] = {'TimeStamps': [], 'TriggerType': []} 

+

639 

+

640 output['button_trigger_events']['TimeStamps'].append(time_stamp) 

+

641 trigger_type = unpack('<H', self.datafile.read(2))[0] 

+

642 if trigger_type == UNDEFINED: output['button_trigger_events']['TriggerType'].append('undefined') 

+

643 elif trigger_type == BUTTON_PRESS: output['button_trigger_events']['TriggerType'].append('button press') 

+

644 elif trigger_type == BUTTON_RESET: output['button_trigger_events']['TriggerType'].append('event reset') 

+

645 else: output['button_trigger_events']['TriggerType'].append('error') 

+

646 self.datafile.seek((self.basic_header['BytesInDataPackets'] - 8), 1) 

+

647 

+

648 # For configuration log event 

+

649 elif packet_id == CONFIGURATION_PACKET_ID: 

+

650 

+

651 # See if the dictionary exists in output, if not, create it 

+

652 if 'configuration_events' not in output: 

+

653 output['configuration_events'] = {'TimeStamps': [], 'ConfigChangeType': [], 'ConfigChanged': []} 

+

654 

+

655 output['configuration_events']['TimeStamps'].append(time_stamp) 

+

656 change_type = unpack('<H', self.datafile.read(2))[0] 

+

657 if change_type == CHG_NORMAL: output['configuration_events']['ConfigChangeType'].append('normal') 

+

658 elif change_type == CHG_CRITICAL: output['configuration_events']['ConfigChangeType'].append('critical') 

+

659 else: output['configuration_events']['ConfigChangeType'].append('error') 

+

660 

+

661 samples = self.basic_header['BytesInDataPackets'] - 8 

+

662 output['configuration_events']['ConfigChanged'].append(unpack(('<' + str(samples) + 's'), 

+

663 self.datafile.read(samples))[0]) 

+

664 

+

665 # Otherwise, packet unknown, skip to next packet 

+

666 else: self.datafile.seek((self.basic_header['BytesInDataPackets'] - 6), 1) 

+

667 

+

668 return output 

+

669 

+

670 def processroicomments(self, comments): 

+

671 """ 

+

672 used to process the comment data packets associated with NeuroMotive region of interest enter/exit events. 

+

673 requires that read_data() has already been run. 

+

674 :return: roi_events: a dictionary of regions, enter timestamps, and exit timestamps for each region 

+

675 """ 

+

676 

+

677 roi_events = {'Regions': [], 'EnterTimeStamps': [], 'ExitTimeStamps': []} 

+

678 

+

679 for i in range(len(comments['TimeStamps'])): 

+

680 if comments['CharSet'][i] == 'NeuroMotive ROI': 

+

681 

+

682 temp_data = pack('<I', comments['Data'][i]) 

+

683 roi = unpack_from('<B', temp_data)[0] 

+

684 event = unpack_from('<B', temp_data, 1)[0] 

+

685 

+

686 # Determine the label of the region source 

+

687 source_label = next(d['VideoSource'] for d in self.extended_headers if d["TrackableID"] == roi) 

+

688 

+

689 # update the timestamps for events 

+

690 if source_label in roi_events['Regions']: 

+

691 idx = roi_events['Regions'].index(source_label) 

+

692 else: 

+

693 idx = -1 

+

694 roi_events['Regions'].append(source_label) 

+

695 roi_events['EnterTimeStamps'].append([]) 

+

696 roi_events['ExitTimeStamps'].append([]) 

+

697 

+

698 if event == ENTER_EVENT: roi_events['EnterTimeStamps'][idx].append(comments['TimeStamp'][i]) 

+

699 elif event == EXIT_EVENT: roi_events['ExitTimeStamps'][idx].append(comments['TimeStamp'][i]) 

+

700 

+

701 return roi_events 

+

702 

+

703 def close(self): 

+

704 name = self.datafile.name 

+

705 self.datafile.close() 

+

706 print(('\n' + name.split('/')[-1] + ' closed')) 

+

707 

+

708 

+

709class NsxFile: 

+

710 """ 

+

711 attributes and methods for all BR continuous data files. Initialization opens the file and extracts the 

+

712 basic header information. 

+

713 """ 

+

714 

+

715 def __init__(self, datafile=''): 

+

716 

+

717 self.datafile = datafile 

+

718 self.basic_header = {} 

+

719 self.extended_headers = [] 

+

720 

+

721 # Run openfilecheck and open the file passed or allow user to browse to one 

+

722 self.datafile = openfilecheck('rb', file_name=self.datafile, file_ext='.ns*', file_type='Blackrock NSx Files') 

+

723 

+

724 # Determine File ID to determine if File Spec 2.1 

+

725 self.basic_header['FileTypeID'] = bytes.decode(self.datafile.read(8), 'latin-1') 

+

726 

+

727 # Extract basic and extended header information based on File Spec 

+

728 if self.basic_header['FileTypeID'] == 'NEURALSG': 

+

729 self.basic_header.update(processheaders(self.datafile, nsx_header_dict['basic_21'])) 

+

730 self.basic_header['FileSpec'] = '2.1' 

+

731 self.basic_header['TimeStampResolution'] = 30000 

+

732 self.basic_header['BytesInHeader'] = 32 + 4 * self.basic_header['ChannelCount'] 

+

733 shape = (1, self.basic_header['ChannelCount']) 

+

734 self.basic_header['ChannelID'] = \ 

+

735 list(np.fromfile(file=self.datafile, dtype=np.uint32, 

+

736 count=self.basic_header['ChannelCount']).reshape(shape)[0]) 

+

737 else: 

+

738 self.basic_header.update(processheaders(self.datafile, nsx_header_dict['basic'])) 

+

739 for i in range(self.basic_header['ChannelCount']): 

+

740 self.extended_headers.append(processheaders(self.datafile, nsx_header_dict['extended'])) 

+

741 

+

742 def getdata(self, elec_ids='all', start_time_s=0, data_time_s='all', downsample=1): 

+

743 """ 

+

744 This function is used to return a set of data from the NSx datafile. 

+

745 

+

746 :param elec_ids: [optional] {list} List of elec_ids to extract (e.g., [13]) 

+

747 :param start_time_s: [optional] {float} Starting time for data extraction (e.g., 1.0) 

+

748 :param data_time_s: [optional] {float} Length of time of data to return (e.g., 30.0) 

+

749 :param downsample: [optional] {int} Downsampling factor (e.g., 2) 

+

750 :return: output: {Dictionary} of: data_headers: {list} dictionaries of all data headers 

+

751 elec_ids: {list} elec_ids that were extracted (sorted) 

+

752 start_time_s: {float} starting time for data extraction 

+

753 data_time_s: {float} length of time of data returned 

+

754 downsample: {int} data downsampling factor 

+

755 samp_per_s: {float} output data samples per second 

+

756 data: {numpy array} continuous data in a 2D numpy array 

+

757 

+

758 Parameters: elec_ids, start_time_s, data_time_s, and downsample are not mandatory. Defaults will assume all 

+

759 electrodes and all data points starting at time(0) are to be read. Data is returned as a numpy 2d array 

+

760 with each row being the data set for each electrode (e.g. output['data'][0] for output['elec_ids'][0]). 

+

761 """ 

+

762 

+

763 # Safety checks 

+

764 start_time_s = check_starttime(start_time_s) 

+

765 data_time_s = check_datatime(data_time_s) 

+

766 downsample = check_downsample(downsample) 

+

767 elec_ids = check_elecid(elec_ids) 

+

768 

+

769 # initialize parameters 

+

770 output = dict() 

+

771 output['elec_ids'] = elec_ids 

+

772 output['start_time_s'] = float(start_time_s) 

+

773 output['data_time_s'] = data_time_s 

+

774 output['downsample'] = downsample 

+

775 output['data'] = [] 

+

776 output['data_headers'] = [] 

+

777 output['ExtendedHeaderIndices'] = [] 

+

778 

+

779 datafile_samp_per_sec = self.basic_header['TimeStampResolution'] / self.basic_header['Period'] 

+

780 data_pt_size = self.basic_header['ChannelCount'] * DATA_BYTE_SIZE 

+

781 elec_id_indices = [] 

+

782 front_end_idxs = [] 

+

783 analog_input_idxs = [] 

+

784 front_end_idx_cont = True 

+

785 analog_input_idx_cont = True 

+

786 hit_start = False 

+

787 hit_stop = False 

+

788 d_ptr = 0 

+

789 

+

790 # Move file position to start of datafile (if read before, may not be here anymore) 

+

791 self.datafile.seek(self.basic_header['BytesInHeader'], 0) 

+

792 

+

793 # Based on FileSpec set other parameters 

+

794 if self.basic_header['FileSpec'] == '2.1': 

+

795 output['elec_ids'] = self.basic_header['ChannelID'] 

+

796 output['data_headers'].append({}) 

+

797 output['data_headers'][0]['Timestamp'] = TIMESTAMP_NULL_21 

+

798 output['data_headers'][0]['NumDataPoints'] = (ospath.getsize(self.datafile.name) - self.datafile.tell()) \ 

+

799 // (DATA_BYTE_SIZE * self.basic_header['ChannelCount']) 

+

800 else: 

+

801 output['elec_ids'] = [d['ElectrodeID'] for d in self.extended_headers] 

+

802 

+

803 # Determine start and stop index for data 

+

804 if start_time_s == START_TIME_DEF: start_idx = START_OFFSET_MIN 

+

805 else: start_idx = int(round(start_time_s * datafile_samp_per_sec)) 

+

806 if data_time_s == DATA_TIME_DEF: stop_idx = STOP_OFFSET_MIN 

+

807 else: stop_idx = int(round((start_time_s + data_time_s) * datafile_samp_per_sec)) 

+

808 

+

809 # If a subset of electrodes is requested, error check, determine elec indices, and reduce headers 

+

810 if elec_ids != ELEC_ID_DEF: 

+

811 elec_ids = check_dataelecid(elec_ids, output['elec_ids']) 

+

812 if not elec_ids: return output 

+

813 else: 

+

814 elec_id_indices = [output['elec_ids'].index(e) for e in elec_ids] 

+

815 output['elec_ids'] = elec_ids 

+

816 num_elecs = len(output['elec_ids']) 

+

817 

+

818 # Determine extended header indices and idx for Front End vs. Analog Input channels 

+

819 if self.basic_header['FileSpec'] != '2.1': 

+

820 for i in range(num_elecs): 

+

821 idx = next(item for (item, d) in enumerate(self.extended_headers) 

+

822 if d["ElectrodeID"] == output['elec_ids'][i]) 

+

823 output['ExtendedHeaderIndices'].append(idx) 

+

824 

+

825 if self.extended_headers[idx]['PhysicalConnector'] < 5: front_end_idxs.append(i) 

+

826 else: analog_input_idxs.append(i) 

+

827 

+

828 # Determine if front_end_idxs and analog_idxs are contiguous (default = False) 

+

829 if any(np.diff(np.array(front_end_idxs)) != 1): front_end_idx_cont = False 

+

830 if any(np.diff(np.array(analog_input_idxs)) != 1): analog_input_idx_cont = False 

+

831 

+

832 # Pre-allocate output data based on data packet info (timestamp + num pts) and/or data_time_s 

+

833 # 1) Determine number of samples in all data packets to set possible number of output pts 

+

834 # 1a) For file spec > 2.1, get to last data packet quickly to determine total possible output length 

+

835 # 2) If possible output length is bigger than requested, set output based on requested 

+

836 if self.basic_header['FileSpec'] == '2.1': 

+

837 timestamp = TIMESTAMP_NULL_21 

+

838 num_data_pts = output['data_headers'][0]['NumDataPoints'] 

+

839 else: 

+

840 while self.datafile.tell() != ospath.getsize(self.datafile.name): 

+

841 self.datafile.seek(1, 1) # skip header byte value 

+

842 timestamp = unpack('<I', self.datafile.read(4))[0] 

+

843 num_data_pts = unpack('<I', self.datafile.read(4))[0] 

+

844 self.datafile.seek(num_data_pts * self.basic_header['ChannelCount'] * DATA_BYTE_SIZE, 1) 

+

845 

+

846 stop_idx_output = ceil(timestamp / self.basic_header['Period']) + num_data_pts 

+

847 if data_time_s != DATA_TIME_DEF and stop_idx < stop_idx_output: stop_idx_output = stop_idx 

+

848 total_samps = int(ceil((stop_idx_output - start_idx) / downsample)) 

+

849 

+

850 if (total_samps * self.basic_header['ChannelCount'] * DATA_BYTE_SIZE) > DATA_PAGING_SIZE: 

+

851 print("\nOutput data requested is larger than 1 GB, attempting to preallocate output now") 

+

852 

+

853 # If data output is bigger than available, let user know this is too big and they must request at least one of: 

+

854 # subset of electrodes, subset of data, or use savensxsubset to smaller file sizes, otherwise, pre-allocate data 

+

855 try: output['data'] = np.zeros((total_samps, num_elecs), dtype=np.float32) 

+

856 except MemoryError as err: 

+

857 err.args += (" Output data size requested is larger than available memory. Use the parameters\n" 

+

858 " for getdata(), e.g., 'elec_ids', to request a subset of the data or use\n" 

+

859 " NsxFile.savesubsetnsx() to create subsets of the main nsx file\n", ) 

+

860 raise 

+

861 

+

862 # Reset file position to start of data header #1, loop through all data packets, process header, and add data 

+

863 self.datafile.seek(self.basic_header['BytesInHeader'], 0) 

+

864 while not hit_stop: 

+

865 

+

866 # Read header, check to make sure the header is valid (ie Header field != 0). There is currently a 

+

867 # bug with the NSP where pausing creates a 0 sample packet before the next real data packet, these need to 

+

868 # be skipped, including any tiny packets that have less samples than downsample 

+

869 if self.basic_header['FileSpec'] != '2.1': 

+

870 output['data_headers'].append(processheaders(self.datafile, nsx_header_dict['data'])) 

+

871 if output['data_headers'][-1]['Header'] == 0: print('Invalid Header. File may be corrupt') 

+

872 if output['data_headers'][-1]['NumDataPoints'] < downsample: 

+

873 self.datafile.seek(self.basic_header['ChannelCount'] * output['data_headers'][-1]['NumDataPoints'] 

+

874 * DATA_BYTE_SIZE, 1) 

+

875 continue 

+

876 

+

877 # Determine sample value for current packet timestamp 

+

878 timestamp_sample = int(round(output['data_headers'][-1]['Timestamp'] / self.basic_header['Period'])) 

+

879 

+

880 # For now, we need a patch for file sync which syncs 2 NSP clocks, starting a new data packet which 

+

881 # may be backwards in time wrt the end of data packet 1. Thus, when this happens, we need to treat 

+

882 # data packet 2 as if it was 1, and start this process over. 

+

883 if timestamp_sample < d_ptr: 

+

884 d_ptr = 0 

+

885 hit_start = False 

+

886 output['data_headers'] = [] 

+

887 self.datafile.seek(-9, 1) 

+

888 continue 

+

889 

+

890 # Check to see if stop index is before the first data packet 

+

891 if len(output['data_headers']) == 1 and (STOP_OFFSET_MIN < stop_idx < timestamp_sample): 

+

892 print(("\nData requested is before any data was saved, which starts at t = {0:.6f} s".format( 

+

893 output['data_headers'][0]['Timestamp'] / self.basic_header['TimeStampResolution']))) 

+

894 return 

+

895 

+

896 # For the first data packet to be read 

+

897 if not hit_start: 

+

898 

+

899 # Check for starting point of data request 

+

900 start_offset = start_idx - timestamp_sample 

+

901 

+

902 # If start_offset is outside of this packet, skip the current packet 

+

903 # if we've reached the end of file, break, otherwise continue to next packet 

+

904 if start_offset > output['data_headers'][-1]['NumDataPoints']: 

+

905 self.datafile.seek(output['data_headers'][-1]['NumDataPoints'] * data_pt_size, 1) 

+

906 if self.datafile.tell() == ospath.getsize(self.datafile.name): break 

+

907 else: continue 

+

908 

+

909 else: 

+

910 # If the start_offset is before the current packet, check to ensure that stop_index 

+

911 # is not also in the paused area, then create padded data for during pause time 

+

912 if start_offset < 0: 

+

913 if STOP_OFFSET_MIN < stop_idx < timestamp_sample: 

+

914 print("\nBecause of pausing, data section requested is during pause period") 

+

915 return 

+

916 else: 

+

917 print(("\nFirst data packet requested begins at t = {0:.6f} s, " 

+

918 "initial section padded with zeros".format( 

+

919 output['data_headers'][-1]['Timestamp'] / self.basic_header['TimeStampResolution']))) 

+

920 start_offset = START_OFFSET_MIN 

+

921 d_ptr = (timestamp_sample - start_idx) // downsample 

+

922 hit_start = True 

+

923 

+

924 # for all other packets 

+

925 else: 

+

926 # check to see if padded data is needed, including hitting the stop index 

+

927 if STOP_OFFSET_MIN < stop_idx < timestamp_sample: 

+

928 print("\nSection padded with zeros due to file pausing") 

+

929 hit_stop = True; break 

+

930 

+

931 elif (timestamp_sample - start_idx) > d_ptr: 

+

932 print("\nSection padded with zeros due to file pausing") 

+

933 start_offset = START_OFFSET_MIN 

+

934 d_ptr = (timestamp_sample - start_idx) // downsample 

+

935 

+

936 # Set number of samples to be read based on if start/stop sample is during data packet 

+

937 if STOP_OFFSET_MIN < stop_idx <= (timestamp_sample + output['data_headers'][-1]['NumDataPoints']): 

+

938 total_pts = stop_idx - timestamp_sample - start_offset 

+

939 hit_stop = True 

+

940 else: 

+

941 total_pts = output['data_headers'][-1]['NumDataPoints'] - start_offset 

+

942 

+

943 # Need current file position because memory map will reset file position 

+

944 curr_file_pos = self.datafile.tell() 

+

945 

+

946 # Determine starting position to read from memory map 

+

947 file_offset = int(curr_file_pos + start_offset * data_pt_size) 

+

948 

+

949 # Extract data no more than 1 GB at a time (or based on DATA_PAGING_SIZE) 

+

950 # Determine shape of data to map based on file sizing and position, then map it 

+

951 downsample_data_size = data_pt_size * downsample 

+

952 max_length = (DATA_PAGING_SIZE // downsample_data_size) * downsample_data_size 

+

953 num_loops = int(ceil(total_pts * data_pt_size / max_length)) 

+

954 

+

955 for loop in range(num_loops): 

+

956 if loop == 0: 

+

957 if num_loops == 1: num_pts = total_pts 

+

958 else: num_pts = max_length // data_pt_size 

+

959 

+

960 else: 

+

961 file_offset += max_length 

+

962 if loop == (num_loops - 1): num_pts = ((total_pts * data_pt_size) % max_length) // data_pt_size 

+

963 else: num_pts = max_length // data_pt_size 

+

964 

+

965 if num_loops != 1: print(('Data extraction requires paging: {0} of {1}'.format(loop + 1, num_loops))) 

+

966 

+

967 num_pts = int(num_pts) 

+

968 shape = (num_pts, self.basic_header['ChannelCount']) 

+

969 mm = np.memmap(self.datafile, dtype=np.int16, mode='r', offset=file_offset, shape=shape) 

+

970 

+

971 # append data based on downsample slice and elec_ids indexing, then clear memory map 

+

972 if downsample != 1: mm = mm[::downsample] 

+

973 if elec_id_indices: 

+

974 output['data'][d_ptr:d_ptr + mm.shape[0]] = np.array(mm[:, elec_id_indices]).astype(np.float32) 

+

975 else: 

+

976 output['data'][d_ptr:d_ptr + mm.shape[0]] = np.array(mm).astype(np.float32) 

+

977 d_ptr += mm.shape[0] 

+

978 del mm 

+

979 

+

980 # Reset current file position for file position checking and possibly next header 

+

981 curr_file_pos += self.basic_header['ChannelCount'] * output['data_headers'][-1]['NumDataPoints'] \ 

+

982 * DATA_BYTE_SIZE 

+

983 self.datafile.seek(curr_file_pos, 0) 

+

984 if curr_file_pos == ospath.getsize(self.datafile.name): hit_stop = True 

+

985 

+

986 # Safety checks for start and stop times 

+

987 if not hit_stop and start_idx > START_OFFSET_MIN: 

+

988 raise Exception('Error: End of file found before start_time_s') 

+

989 elif not hit_stop and stop_idx: 

+

990 print("\n*** WARNING: End of file found before stop_time_s, returning all data in file") 

+

991 

+

992 # Transpose the data so that it has entries based on each electrode, not each sample time 

+

993 output['data'] = output['data'].transpose() 

+

994 

+

995 # All data must be scaled based on scaling factors from extended header 

+

996 if self.basic_header['FileSpec'] == '2.1': output['data'] *= UV_PER_BIT_21 

+

997 else: 

+

998 if front_end_idxs: 

+

999 if front_end_idx_cont: 

+

1000 output['data'][front_end_idxs[0]:front_end_idxs[-1] + 1] *= \ 

+

1001 getdigfactor(self.extended_headers, output['ExtendedHeaderIndices'][front_end_idxs[0]]) 

+

1002 else: 

+

1003 for i in front_end_idxs: 

+

1004 output['data'][i] *= getdigfactor(self.extended_headers, output['ExtendedHeaderIndices'][i]) 

+

1005 

+

1006 if analog_input_idxs: 

+

1007 if analog_input_idx_cont: 

+

1008 output['data'][analog_input_idxs[0]:analog_input_idxs[-1] + 1] *= \ 

+

1009 getdigfactor(self.extended_headers, output['ExtendedHeaderIndices'][analog_input_idxs[0]]) 

+

1010 else: 

+

1011 for i in analog_input_idxs: 

+

1012 output['data'][i] *= getdigfactor(self.extended_headers, output['ExtendedHeaderIndices'][i]) 

+

1013 

+

1014 # Update parameters based on data extracted 

+

1015 output['samp_per_s'] = float(datafile_samp_per_sec / downsample) 

+

1016 output['data_time_s'] = len(output['data'][0]) / output['samp_per_s'] 

+

1017 

+

1018 return output 

+

1019 

+

1020 def savesubsetnsx(self, elec_ids='all', file_size=None, file_time_s=None, file_suffix=''): 

+

1021 """ 

+

1022 This function is used to save a subset of data based on electrode IDs, file sizing, or file data time. If 

+

1023 both file_time_s and file_size are passed, it will default to file_time_s and determine sizing accordingly. 

+

1024 

+

1025 :param elec_ids: [optional] {list} List of elec_ids to extract (e.g., [13]) 

+

1026 :param file_size: [optional] {int} Byte size of each subset file to save (e.g., 1024**3 = 1 Gb). If nothing 

+

1027 is passed, file_size will be all data points. 

+

1028 :param file_time_s: [optional] {float} Time length of data for each subset file, in seconds (e.g. 60.0). If 

+

1029 nothing is passed, file_size will be used as default. 

+

1030 :param file_suffix: [optional] {str} Suffix to append to NSx datafile name for subset files. If nothing is 

+

1031 passed, default will be "_subset". 

+

1032 :return: None - None of the electrodes requested exist in the data 

+

1033 SUCCESS - All file subsets extracted and saved 

+

1034 """ 

+

1035 

+

1036 # Initializations 

+

1037 elec_id_indices = [] 

+

1038 file_num = 1 

+

1039 pausing = False 

+

1040 datafile_datapt_size = self.basic_header['ChannelCount'] * DATA_BYTE_SIZE 

+

1041 self.datafile.seek(0, 0) 

+

1042 

+

1043 # Run electrode id checks and set num_elecs 

+

1044 elec_ids = check_elecid(elec_ids) 

+

1045 if self.basic_header['FileSpec'] == '2.1': all_elec_ids = self.basic_header['ChannelID'] 

+

1046 else: all_elec_ids = [x['ElectrodeID'] for x in self.extended_headers] 

+

1047 

+

1048 if elec_ids == ELEC_ID_DEF: 

+

1049 elec_ids = all_elec_ids 

+

1050 else: 

+

1051 elec_ids = check_dataelecid(elec_ids, all_elec_ids) 

+

1052 if not elec_ids: return None 

+

1053 else: elec_id_indices = [all_elec_ids.index(x) for x in elec_ids] 

+

1054 

+

1055 num_elecs = len(elec_ids) 

+

1056 

+

1057 # If file_size or file_time_s passed, check it and set file_sizing accordingly 

+

1058 if file_time_s: 

+

1059 if file_time_s and file_size: 

+

1060 print("\nWARNING: Only one of file_size or file_time_s can be passed, defaulting to file_time_s.") 

+

1061 file_size = int(num_elecs * DATA_BYTE_SIZE * file_time_s * 

+

1062 self.basic_header['TimeStampResolution'] / self.basic_header['Period']) 

+

1063 if self.basic_header['FileSpec'] == '2.1': 

+

1064 file_size += 32 + 4 * num_elecs 

+

1065 else: 

+

1066 file_size += NSX_BASIC_HEADER_BYTES_22 + NSX_EXT_HEADER_BYTES_22 * num_elecs + 5 

+

1067 print(("\nBased on timing request, file size will be {0:d} Mb".format(int(file_size / 1024**2)))) 

+

1068 elif file_size: 

+

1069 file_size = check_filesize(file_size) 

+

1070 

+

1071 # Create and open subset file as writable binary, if it already exists ask user for overwrite permission 

+

1072 file_name, file_ext = ospath.splitext(self.datafile.name) 

+

1073 if file_suffix: file_name += '_' + file_suffix 

+

1074 else: file_name += '_subset' 

+

1075 

+

1076 if ospath.isfile(file_name + "_000" + file_ext): 

+

1077 if 'y' != eval(input("\nFile '" + file_name.split('/')[-1] + "_xxx" + file_ext + 

+

1078 "' already exists, overwrite [y/n]: ")): 

+

1079 print("\nExiting, no overwrite, returning None"); return None 

+

1080 else: 

+

1081 print("\n*** Overwriting existing subset files ***") 

+

1082 

+

1083 subset_file = open(file_name + "_000" + file_ext, 'wb') 

+

1084 print(("\nWriting subset file: " + ospath.split(subset_file.name)[1])) 

+

1085 

+

1086 # For file spec 2.1: 

+

1087 # 1) copy the first 28 bytes from the datafile (these are unchanged) 

+

1088 # 2) write subset channel count and channel ID to file 

+

1089 # 3) skip ahead in datafile the number of bytes in datafile ChannelCount(4) plus ChannelID (4*ChannelCount) 

+

1090 if self.basic_header['FileSpec'] == '2.1': 

+

1091 subset_file.write(self.datafile.read(28)) 

+

1092 subset_file.write(np.array(num_elecs).astype(np.uint32).tobytes()) 

+

1093 subset_file.write(np.array(elec_ids).astype(np.uint32).tobytes()) 

+

1094 self.datafile.seek(4 + 4 * self.basic_header['ChannelCount'], 1) 

+

1095 

+

1096 # For file spec 2.2 and above 

+

1097 # 1) copy the first 10 bytes from the datafile (unchanged) 

+

1098 # 2) write subset bytes-in-headers and skip 4 bytes in datafile, noting position of this for update later 

+

1099 # 3) copy the next 296 bytes from datafile (unchanged) 

+

1100 # 4) write subset channel-count value and skip 4 bytes in datafile 

+

1101 # 5) append extended headers based on the channel ID. Must read the first 4 bytes, determine if correct 

+

1102 # Channel ID, repack first 4 bytes, write to disk, then copy remaining 62 (66-4) bytes 

+

1103 else: 

+

1104 subset_file.write(self.datafile.read(10)) 

+

1105 bytes_in_headers = NSX_BASIC_HEADER_BYTES_22 + NSX_EXT_HEADER_BYTES_22 * num_elecs 

+

1106 num_pts_header_pos = bytes_in_headers + 5 

+

1107 subset_file.write(np.array(bytes_in_headers).astype(np.uint32).tobytes()) 

+

1108 self.datafile.seek(4, 1) 

+

1109 subset_file.write(self.datafile.read(296)) 

+

1110 subset_file.write(np.array(num_elecs).astype(np.uint32).tobytes()) 

+

1111 self.datafile.seek(4, 1) 

+

1112 

+

1113 for i in range(len(self.extended_headers)): 

+

1114 h_type = self.datafile.read(2) 

+

1115 chan_id = self.datafile.read(2) 

+

1116 if unpack('<H', chan_id)[0] in elec_ids: 

+

1117 subset_file.write(h_type) 

+

1118 subset_file.write(chan_id) 

+

1119 subset_file.write(self.datafile.read(62)) 

+

1120 else: 

+

1121 self.datafile.seek(62, 1) 

+

1122 

+

1123 # For all file types, loop through all data packets, extracting data based on page sizing 

+

1124 while self.datafile.tell() != ospath.getsize(self.datafile.name): 

+

1125 

+

1126 # pull and set data packet header info 

+

1127 if self.basic_header['FileSpec'] == '2.1': 

+

1128 packet_pts = (ospath.getsize(self.datafile.name) - self.datafile.tell()) \ 

+

1129 / (DATA_BYTE_SIZE * self.basic_header['ChannelCount']) 

+

1130 else: 

+

1131 header_binary = self.datafile.read(1) 

+

1132 timestamp_binary = self.datafile.read(4) 

+

1133 packet_pts_binary = self.datafile.read(4) 

+

1134 packet_pts = unpack('<I', packet_pts_binary)[0] 

+

1135 if packet_pts == 0: continue 

+

1136 

+

1137 subset_file.write(header_binary) 

+

1138 subset_file.write(timestamp_binary) 

+

1139 subset_file.write(packet_pts_binary) 

+

1140 

+

1141 # get current file position and set loop parameters 

+

1142 datafile_pos = self.datafile.tell() 

+

1143 file_offset = datafile_pos 

+

1144 mm_length = (DATA_PAGING_SIZE // datafile_datapt_size) * datafile_datapt_size 

+

1145 num_loops = int(ceil(packet_pts * datafile_datapt_size / mm_length)) 

+

1146 packet_read_pts = 0 

+

1147 subset_file_pkt_pts = 0 

+

1148 

+

1149 # Determine shape of data to map based on file sizing and position, map it, then append to file 

+

1150 for loop in range(num_loops): 

+

1151 if loop == 0: 

+

1152 if num_loops == 1: num_pts = packet_pts 

+

1153 else: num_pts = mm_length // datafile_datapt_size 

+

1154 

+

1155 else: 

+

1156 file_offset += mm_length 

+

1157 if loop == (num_loops - 1): 

+

1158 num_pts = ((packet_pts * datafile_datapt_size) % mm_length) // datafile_datapt_size 

+

1159 else: 

+

1160 num_pts = mm_length // datafile_datapt_size 

+

1161 

+

1162 shape = (int(num_pts), self.basic_header['ChannelCount']) 

+

1163 mm = np.memmap(self.datafile, dtype=np.int16, mode='r', offset=file_offset, shape=shape) 

+

1164 if elec_id_indices: mm = mm[:, elec_id_indices] 

+

1165 start_idx = 0 

+

1166 

+

1167 # Determine if we need to start an additional file 

+

1168 if file_size and (file_size - subset_file.tell()) < DATA_PAGING_SIZE: 

+

1169 

+

1170 # number of points we can possibly write to current subset file 

+

1171 pts_can_add = int((file_size - subset_file.tell()) // (num_elecs * DATA_BYTE_SIZE)) + 1 

+

1172 stop_idx = start_idx + pts_can_add 

+

1173 

+

1174 # If the pts remaining are less than exist in the data, we'll need an additional subset file 

+

1175 while pts_can_add < num_pts: 

+

1176 

+

1177 # Write pts to disk, set old file name, update pts in packet, and close last subset file 

+

1178 if elec_id_indices: subset_file.write(np.array(mm[start_idx:stop_idx]).tobytes()) 

+

1179 else: subset_file.write(mm[start_idx:stop_idx]) 

+

1180 prior_file_name = subset_file.name 

+

1181 prior_file_pkt_pts = subset_file_pkt_pts + pts_can_add 

+

1182 subset_file.close() 

+

1183 

+

1184 # We need to copy header information from last subset file and adjust some headers. 

+

1185 # For file spec 2.1, this is just the basic header. 

+

1186 # For file spec 2.2 and above: 

+

1187 # 1) copy basic and extended headers 

+

1188 # 2) create data packet header with new timestamp and num data points (dummy numpts value) 

+

1189 # 3) overwrite the number of data points in the old file last header packet with true value 

+

1190 prior_file = open(prior_file_name, 'rb+') 

+

1191 if file_num < 10: numstr = "_00" + str(file_num) 

+

1192 elif 10 <= file_num < 100: numstr = "_0" + str(file_num) 

+

1193 else: numstr = "_" + str(file_num) 

+

1194 subset_file = open(file_name + numstr + file_ext, 'wb') 

+

1195 print(("Writing subset file: " + ospath.split(subset_file.name)[1])) 

+

1196 

+

1197 if self.basic_header['FileSpec'] == '2.1': 

+

1198 subset_file.write(prior_file.read(32 + 4 * num_elecs)) 

+

1199 else: 

+

1200 subset_file.write(prior_file.read(bytes_in_headers)) 

+

1201 subset_file.write(header_binary) 

+

1202 timestamp_new = unpack('<I', timestamp_binary)[0] \ 

+

1203 + (packet_read_pts + pts_can_add) * self.basic_header['Period'] 

+

1204 subset_file.write(np.array(timestamp_new).astype(np.uint32).tobytes()) 

+

1205 subset_file.write(np.array(num_pts - pts_can_add).astype(np.uint32).tobytes()) 

+

1206 

+

1207 prior_file.seek(num_pts_header_pos, 0) 

+

1208 prior_file.write(np.array(prior_file_pkt_pts).astype(np.uint32).tobytes()) 

+

1209 

+

1210 num_pts_header_pos = bytes_in_headers + 5 

+

1211 

+

1212 # Close old file and update parameters 

+

1213 prior_file.close() 

+

1214 packet_read_pts += pts_can_add 

+

1215 start_idx += pts_can_add 

+

1216 num_pts -= pts_can_add 

+

1217 file_num += 1 

+

1218 subset_file_pkt_pts = 0 

+

1219 pausing = False 

+

1220 

+

1221 pts_can_add = int((file_size - subset_file.tell()) // (num_elecs * DATA_BYTE_SIZE)) + 1 

+

1222 stop_idx = start_idx + pts_can_add 

+

1223 

+

1224 # If no additional file needed, write remaining data to disk, update parameters, and clear memory map 

+

1225 if elec_id_indices: subset_file.write(np.array(mm[start_idx:]).tobytes()) 

+

1226 else: subset_file.write(mm[start_idx:]) 

+

1227 packet_read_pts += num_pts 

+

1228 subset_file_pkt_pts += num_pts 

+

1229 del mm 

+

1230 

+

1231 # Update num_pts header position for each packet, while saving last packet num_pts_header_pos for later 

+

1232 if self.basic_header['FileSpec'] != '2.1': 

+

1233 curr_hdr_num_pts_pos = num_pts_header_pos 

+

1234 num_pts_header_pos += 4 + subset_file_pkt_pts * num_elecs * DATA_BYTE_SIZE + 5 

+

1235 

+

1236 # Because memory map resets the file position, reset position in datafile 

+

1237 datafile_pos += self.basic_header['ChannelCount'] * packet_pts * DATA_BYTE_SIZE 

+

1238 self.datafile.seek(datafile_pos, 0) 

+

1239 

+

1240 # If using file_timing and there is pausing in data (multiple packets), let user know 

+

1241 if file_time_s and not pausing and (self.datafile.tell() != ospath.getsize(self.datafile.name)): 

+

1242 pausing = True 

+

1243 print("\n*** Because of pausing in original datafile, this file may be slightly time shorter\n" 

+

1244 " than others, and will contain multiple data packets offset in time\n") 

+

1245 

+

1246 # Update last data header packet num data points accordingly (spec != 2.1) 

+

1247 if self.basic_header['FileSpec'] != '2.1': 

+

1248 subset_file_pos = subset_file.tell() 

+

1249 subset_file.seek(curr_hdr_num_pts_pos, 0) 

+

1250 subset_file.write(np.array(subset_file_pkt_pts).astype(np.uint32).tobytes()) 

+

1251 subset_file.seek(subset_file_pos, 0) 

+

1252 

+

1253 # Close subset file and return success 

+

1254 subset_file.close() 

+

1255 print("\n *** All subset files written to disk and closed ***") 

+

1256 return "SUCCESS" 

+

1257 

+

1258 def close(self): 

+

1259 name = self.datafile.name 

+

1260 self.datafile.close() 

+

1261 print(('\n' + name.split('/')[-1] + ' closed')) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_cerelink_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_cerelink_py.html new file mode 100644 index 00000000..2bc8ccd3 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_blackrock_cerelink_py.html @@ -0,0 +1,248 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\blackrock\cerelink.py: 21% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Client-side code that uses cbpy to configure and receive neural data from the  

+

3Blackrock Neural Signal Processor (NSP) (or nPlay). 

+

4''' 

+

5 

+

6import sys 

+

7import time 

+

8from collections import namedtuple 

+

9try: 

+

10 from cerebus import cbpy 

+

11except ImportError: 

+

12 import warnings 

+

13 warnings.warn("Unable to import cerebus library. Check if is installed if using the Blackrock NeuroPort system") 

+

14 

+

15SpikeEventData = namedtuple("SpikeEventData", 

+

16 ["chan", "unit", "ts", "arrival_ts"]) 

+

17ContinuousData = namedtuple("ContinuousData", 

+

18 ["chan", "samples", "arrival_ts"]) 

+

19 

+

20class Connection(object): 

+

21 ''' 

+

22 A wrapper around a UDP socket which sends the Blackrock NeuroPort system commands and  

+

23 receives data. Must run in a separte process (e.g., through `riglib.source`)  

+

24 if you want to use it as part of a task (e.g., BMI control) 

+

25 ''' 

+

26 debug = False 

+

27 def __init__(self): 

+

28 self.parameters = dict() 

+

29 self.parameters['inst-addr'] = '192.168.137.128' 

+

30 self.parameters['inst-port'] = 51001 

+

31 self.parameters['client-port'] = 51002 

+

32 

+

33 self.channel_offset = 0 # used to be 4 -- some old bug with nPlay 

+

34 print('Using cbpy channel offset of:', self.channel_offset) 

+

35 

+

36 if sys.platform == 'darwin': # OS X 

+

37 print('Using OS X settings for cbpy') 

+

38 self.parameters['client-addr'] = '255.255.255.255' 

+

39 else: # linux 

+

40 print('Using linux settings for cbpy') 

+

41 self.parameters['client-addr'] = '192.168.137.255' 

+

42 self.parameters['receive-buffer-size'] = 8388608 

+

43 

+

44 self._init = False 

+

45 

+

46 if self.debug: 

+

47 self.nsamp_recv = 0 

+

48 self.nsamp_last_print = 0 

+

49 

+

50 def connect(self): 

+

51 '''Open the interface to the NSP (or nPlay).''' 

+

52 

+

53 print('calling cbpy.open in cerelink.connect()') 

+

54 # try: 

+

55 # result, return_dict = cbpy.open(connection='default') 

+

56 # time.sleep(3) 

+

57 # except: 

+

58 result, return_dict = cbpy.open(connection='default', parameter=self.parameters) 

+

59 time.sleep(3) 

+

60 

+

61 print('cbpy.open result:', result) 

+

62 print('cbpy.open return_dict:', return_dict) 

+

63 if return_dict['connection'] != 'Master': 

+

64 raise Exception 

+

65 print('') 

+

66 

+

67 # return_dict = cbpy.open('default', self.parameters) # old cbpy 

+

68 

+

69 self._init = True 

+

70 

+

71 def select_channels(self, channels): 

+

72 '''Sets the channels on which to receive event/continuous data. 

+

73 

+

74 Parameters 

+

75 ---------- 

+

76 channels : array_like 

+

77 A sorted list of channels on which you want to receive data. 

+

78 ''' 

+

79 

+

80 if not self._init: 

+

81 raise ValueError("Please open the interface to Central/nPlay first.") 

+

82 

+

83 buffer_parameter = {'absolute': True} # want absolute timestamps 

+

84 

+

85 # ability to select desired channels not yet implemented in cbpy  

+

86 # range_parameter = dict() 

+

87 # range_parameter['begin_channel'] = channels[0] 

+

88 # range_parameter['end_channel'] = channels[-1] 

+

89 

+

90 print('calling cbpy.trial_config in cerelink.select_channels()') 

+

91 result, reset = cbpy.trial_config(buffer_parameter=buffer_parameter) 

+

92 print('cbpy.trial_config result:', result) 

+

93 print('cbpy.trial_config reset:', reset) 

+

94 print('') 

+

95 

+

96 def start_data(self): 

+

97 '''Start the buffering of data.''' 

+

98 

+

99 if not self._init: 

+

100 raise ValueError("Please open the interface to Central/nPlay first.") 

+

101 

+

102 self.streaming = True 

+

103 

+

104 def stop_data(self): 

+

105 '''Stop the buffering of data.''' 

+

106 

+

107 if not self._init: 

+

108 raise ValueError("Please open the interface to Central/nPlay first.") 

+

109 

+

110 print('calling cbpy.trial_config in cerelink.stop()') 

+

111 result, reset = cbpy.trial_config(reset=False) 

+

112 print('cbpy.trial_config result:', result) 

+

113 print('cbpy.trial_config reset:', reset) 

+

114 print('') 

+

115 

+

116 self.streaming = False 

+

117 

+

118 def disconnect(self): 

+

119 '''Close the interface to the NSP (or nPlay).''' 

+

120 

+

121 if not self._init: 

+

122 raise ValueError("Please open the interface to Central/nPlay first.") 

+

123 

+

124 print('calling cbpy.close in cerelink.disconnect()') 

+

125 result = cbpy.close() 

+

126 print('result:', result) 

+

127 print('') 

+

128 

+

129 self._init = False 

+

130 

+

131 def __del__(self): 

+

132 self.disconnect() 

+

133 

+

134 def get_event_data(self): 

+

135 '''A generator that yields spike event data.''' 

+

136 

+

137 sleep_time = 0 

+

138 

+

139 while self.streaming: 

+

140 

+

141 result, trial = cbpy.trial_event(reset=True) # TODO -- check if result = 0? 

+

142 arrival_ts = time.time() 

+

143 

+

144 for list_ in trial: 

+

145 chan = list_[0] 

+

146 for unit, unit_ts in enumerate(list_[1]['timestamps']): 

+

147 for ts in unit_ts: 

+

148 # blackrock unit numbers are 0-based where zero is unsorted unit 

+

149 if unit == 0: 

+

150 # Unsorted units are unit 10 (j) 

+

151 un = 10 

+

152 else: 

+

153 un = unit 

+

154 yield SpikeEventData(chan=chan-self.channel_offset, unit=un, ts=ts, arrival_ts=arrival_ts) 

+

155 

+

156 time.sleep(sleep_time) 

+

157 

+

158 

+

159 def get_continuous_data(self): 

+

160 '''A generator that yields continuous data.''' 

+

161 

+

162 sleep_time = 0 

+

163 

+

164 while self.streaming: 

+

165 result, trial = cbpy.trial_continuous(reset=True) 

+

166 arrival_ts = time.time() 

+

167 

+

168 for list_ in trial: 

+

169 

+

170 if self.debug: 

+

171 chan = list_[0] 

+

172 samples = list_[1] 

+

173 if chan == 8: 

+

174 self.nsamp_recv += len(samples) 

+

175 if self.nsamp_recv > self.nsamp_last_print + 2000: 

+

176 print("cerelink.py: # received =", self.nsamp_recv) 

+

177 self.nsamp_last_print = self.nsamp_recv 

+

178 

+

179 yield ContinuousData(chan=list_[0], 

+

180 samples=list_[1], 

+

181 arrival_ts=arrival_ts) 

+

182 

+

183 time.sleep(sleep_time) 

+

184 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi___init___py.html new file mode 100644 index 00000000..a27bd650 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi___init___py.html @@ -0,0 +1,69 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1from .bmi import BMI, BMISystem, Decoder, BMILoop, GaussianState, GaussianStateHMM, RectangularBounder 

+

2# import kfdecoder 

+

3# import train 

+

4# import sskfdecoder 

+

5# import clda 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_accumulator_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_accumulator_py.html new file mode 100644 index 00000000..355b04f9 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_accumulator_py.html @@ -0,0 +1,183 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\accumulator.py: 31% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Feature accumulators: The task and the decoder may want to run at 

+

3different rates. These modules provide rate-matching 

+

4''' 

+

5import numpy as np 

+

6 

+

7class FeatureAccumulator(object): 

+

8 '''Used only for type-checking''' 

+

9 pass 

+

10 

+

11class RectWindowSpikeRateEstimator(FeatureAccumulator): 

+

12 ''' 

+

13 Estimate spike firing rates using a rectangular window 

+

14 ''' 

+

15 def __init__(self, count_max, feature_shape, feature_dtype): 

+

16 ''' 

+

17 Constructor for RectWindowSpikeRateEstimator 

+

18 

+

19 Parameters 

+

20 ---------- 

+

21 count_max : int  

+

22 Number of bins to accumulate in the window. This is somewhat specific 

+

23 to rectangular binning 

+

24 feature_shape : np.array of shape (n_features, n_timepoints) 

+

25 Shape of the extracted features passed to the Decoder on each call 

+

26 feature_dtype : np.dtype 

+

27 Data type of feature vector. Can be "np.float64" for a vector of numbers 

+

28 or something more complicated. 

+

29 

+

30 Returns 

+

31 ------- 

+

32 RectWindowSpikeRateEstimator instance 

+

33 ''' 

+

34 self.count_max = count_max 

+

35 self.feature_shape = feature_shape 

+

36 self.feature_dtype = feature_dtype 

+

37 self.reset() 

+

38 

+

39 def reset(self): 

+

40 ''' 

+

41 Reset the current estimate of the spike rates. Used at the end of the window to clear the estimator for the new window 

+

42 ''' 

+

43 self.est = np.zeros(self.feature_shape, dtype=self.feature_dtype) 

+

44 self.count = 0 

+

45 

+

46 def __call__(self, features): 

+

47 ''' 

+

48 Accumulate the current 'features' with the previous estimate 

+

49 

+

50 Parameters 

+

51 ---------- 

+

52 features: np.ndarray of shape self.features_shape  

+

53 self.feature_shape is declared at object creation time 

+

54 

+

55 Returns 

+

56 ------- 

+

57 est: np.ndarray of shape self.features_shape 

+

58 Returns current estimate of features. This estimate may or may not be  

+

59 valid depending on when the estimate is checked 

+

60 

+

61 ''' 

+

62 self.count += 1 

+

63 self.est += features 

+

64 est = self.est 

+

65 decode = False 

+

66 if self.count == self.count_max: 

+

67 est = self.est.copy() 

+

68 self.reset() 

+

69 decode = True 

+

70 return est, decode 

+

71 

+

72class NullAccumulator(FeatureAccumulator): 

+

73 ''' 

+

74 A null accumulator to use in cases when no accumulation is desired. 

+

75 ''' 

+

76 def __init__(self, count_max): 

+

77 ''' 

+

78 Constructor for NullAccumulator 

+

79 

+

80 Parameters 

+

81 ---------- 

+

82 count_max: int  

+

83 Number of bins to accumulate in the window. This is somewhat specific 

+

84 to rectangular binning 

+

85 

+

86 Returns 

+

87 ------- 

+

88 NullAccumulator instance 

+

89 ''' 

+

90 self.count_max = count_max 

+

91 self.reset() 

+

92 

+

93 def reset(self): 

+

94 ''' 

+

95 Reset the counter 

+

96 ''' 

+

97 self.count = 0 

+

98 

+

99 def __call__(self, features): 

+

100 ''' 

+

101 Accumulate the current 'features' with the previous estimate 

+

102 

+

103 Parameters 

+

104 ---------- 

+

105 features: np.ndarray of shape self.features_shape  

+

106 self.feature_shape is declared at object creation time 

+

107 

+

108 Returns 

+

109 ------- 

+

110 est: np.ndarray of shape self.features_shape 

+

111 Returns current estimate of features. This estimate may or may not be  

+

112 valid depending on when the estimate is checked 

+

113 ''' 

+

114 self.count += 1 

+

115 decode = False 

+

116 if self.count == self.count_max: 

+

117 self.reset() 

+

118 decode = True 

+

119 return features, decode 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_assist_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_assist_py.html new file mode 100644 index 00000000..f2dc7f74 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_assist_py.html @@ -0,0 +1,201 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\assist.py: 36% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Various types of "assist", i.e. different methods for shared control 

+

3between neural control and machine control. Only applies in cases where 

+

4some knowledge of the task goals is available.  

+

5''' 

+

6 

+

7import numpy as np 

+

8from riglib.stereo_opengl import ik 

+

9from riglib.bmi import feedback_controllers 

+

10import pickle 

+

11 

+

12from utils.angle_utils import * 

+

13from utils.constants import * 

+

14 

+

15class Assister(object): 

+

16 ''' 

+

17 Parent class for various methods of assistive BMI. Children of this class  

+

18 can compute an "optimal" input to the system, which is mixed in with the input 

+

19 derived from the subject's neural input. The parent exists primarily for  

+

20 interface standardization and type-checking. 

+

21 ''' 

+

22 def calc_assisted_BMI_state(self, current_state, target_state, assist_level, mode=None, **kwargs): 

+

23 ''' 

+

24 Main assist calculation function 

+

25 

+

26 Parameters 

+

27 ---------- 

+

28 current_state: np.ndarray of shape (n_states, 1) 

+

29 Vector representing the current state of the prosthesis  

+

30 target_state: np.ndarray of shape (n_states, 1) 

+

31 Vector representing the target state of the prosthesis, i.e. the optimal state for the prosthesis to be in 

+

32 assist_level: float 

+

33 Number indicating the level of the assist. This can in general have arbitrary units but most assisters 

+

34 will have this be a number in the range (0, 1) where 0 is no assist and 1 is full assist 

+

35 mode: hashable type, optional, default=None 

+

36 Indicator of which mode of the assistive controller to use. When applied, this 'mode' is used as a dictionary key and must be hashable 

+

37 kwargs: additional keyword arguments 

+

38 These are ignored 

+

39 

+

40 Returns 

+

41 ------- 

+

42 ''' 

+

43 pass 

+

44 

+

45 def __call__(self, *args, **kwargs): 

+

46 ''' 

+

47 Wrapper for self.calc_assisted_BMI_state 

+

48 ''' 

+

49 return self.calc_assisted_BMI_state(*args, **kwargs) 

+

50 

+

51class FeedbackControllerAssist(Assister): 

+

52 ''' 

+

53 Assister where the machine control is an LQR controller, possibly with different 'modes' depending on the state of the task 

+

54 ''' 

+

55 def __init__(self, fb_ctrl, style='additive'): 

+

56 ''' 

+

57 Parameters 

+

58 ---------- 

+

59 fb_ctrl : feedback_controllers.FeedbackController instance 

+

60 TODO 

+

61 

+

62 Returns 

+

63 ------- 

+

64 FeedbackControllerAssist instance 

+

65 ''' 

+

66 self.fb_ctrl = fb_ctrl 

+

67 self.style = style 

+

68 assert self.style in ['additive', 'mixing', 'additive_cov'] 

+

69 

+

70 def calc_assisted_BMI_state(self, current_state, target_state, assist_level, mode=None, **kwargs): 

+

71 ''' 

+

72 See docs for Assister.calc_assisted_BMI_state 

+

73 ''' 

+

74 if self.style == 'additive': 

+

75 Bu = assist_level * self.fb_ctrl(current_state, target_state, mode=mode) 

+

76 return dict(Bu=Bu, assist_level=0) 

+

77 elif self.style == 'mixing': 

+

78 x_assist = self.fb_ctrl.calc_next_state(current_state, target_state, mode=mode) 

+

79 return dict(x_assist=x_assist, assist_level=assist_level) 

+

80 elif self.style == 'additive_cov': 

+

81 F = self.get_F(assist_level) 

+

82 return dict(F=F, x_target=target_state) 

+

83 

+

84class FeedbackControllerAssist_StateSpecAssistLevels(FeedbackControllerAssist): 

+

85 ''' 

+

86 Assister where machine controller is LQR controller, but different assist_levels for  

+

87 different control variables (e.g. X,Y,PSI in ArmAssist vs. Rehand) 

+

88 ''' 

+

89 def __init__(self, fb_ctrl, style='additive', **kwargs): 

+

90 super(FeedbackControllerAssist_StateSpecAssistLevels, self).__init__(fb_ctrl, style) 

+

91 

+

92 # Currently this assister assumes that plant is IsMore Plant:  

+

93 self.assist_level_state_ix = dict() 

+

94 self.assist_level_state_ix[0] = np.array([0, 1, 2, 7, 8, 9]) # ARM ASSIST 

+

95 self.assist_level_state_ix[1] = np.array([3, 4, 5, 6, 10, 11, 12, 13]) # REHAND 

+

96 

+

97 

+

98 def calc_assisted_BMI_state(self, current_state, target_state, assist_level, mode=None, **kwargs): 

+

99 if self.style == 'additive': 

+

100 Bu = self.fb_ctrl(current_state, target_state, mode=mode) 

+

101 for ia, al in enumerate(assist_level): 

+

102 Bu[self.assist_level_state_ix[ia]] = al*Bu[self.assist_level_state_ix[ia]] 

+

103 return dict(Bu=Bu, assist_level=0) 

+

104 

+

105 elif self.style == 'mixing': 

+

106 x_assist = self.fb_ctrl.calc_next_state(current_state, target_state, mode=mode) 

+

107 return dict(x_assist=x_assist, assist_level=assist_level, assist_level_ix=self.assist_level_state_ix) 

+

108 

+

109 

+

110class SSMLFCAssister(FeedbackControllerAssist): 

+

111 ''' 

+

112 An LFC assister where the state-space matrices (A, B) are specified from the Decoder's 'ssm' attribute 

+

113 ''' 

+

114 def __init__(self, ssm, Q, R, **kwargs): 

+

115 ''' 

+

116 Constructor for SSMLFCAssister 

+

117 

+

118 Parameters 

+

119 ---------- 

+

120 ssm: riglib.bmi.state_space_models.StateSpace instance 

+

121 The state-space model's A and B matrices represent the system to be controlled 

+

122 args: positional arguments 

+

123 These are ignored (none are necessary) 

+

124 kwargs: keyword arguments 

+

125 The constructor must be supplied with the 'kin_chain' kwarg, which must have the attribute 'link_lengths' 

+

126 This is specific to 'KinematicChain' plants. 

+

127 

+

128 Returns 

+

129 ------- 

+

130 SSMLFCAssister instance 

+

131 

+

132 ''' 

+

133 if ssm is None: 

+

134 raise ValueError("SSMLFCAssister requires a state space model!") 

+

135 

+

136 A, B, W = ssm.get_ssm_matrices() 

+

137 self.lqr_controller = feedback_controllers.LQRController(A, B, Q, R) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_bmi_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_bmi_py.html new file mode 100644 index 00000000..dce505d4 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_bmi_py.html @@ -0,0 +1,1496 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\bmi.py: 20% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2High-level classes for BMI used to tie all th BMI subcomponent systems together 

+

3''' 

+

4import numpy as np 

+

5import traceback 

+

6import re 

+

7import multiprocessing as mp 

+

8import queue 

+

9 

+

10import time 

+

11import re 

+

12import os 

+

13import tables 

+

14import datetime 

+

15import copy 

+

16 

+

17 

+

18class GaussianState(object): 

+

19 ''' 

+

20 Class representing a multivariate Gaussian. Gaussians are  

+

21 commonly used to represent the state 

+

22 of the BMI in decoders, including the KF and PPF decoders 

+

23 ''' 

+

24 def __init__(self, mean, cov): 

+

25 ''' 

+

26 Parameters 

+

27 mean: np.array of shape (N, 1) or (N,) 

+

28 N-dimensional vector representing the mean of the multivariate Gaussian distribution 

+

29 cov: np.array of shape (N, N) 

+

30 N-dimensional covariance matrix 

+

31 ''' 

+

32 if isinstance(mean, np.matrix): 

+

33 assert mean.shape[1] == 1 # column vector 

+

34 self.mean = mean 

+

35 elif isinstance(mean, (float, int)): 

+

36 mean = float(mean) 

+

37 if isinstance(cov, float): 

+

38 self.mean = mean 

+

39 else: 

+

40 self.mean = mean * np.mat(np.ones([cov.shape[0], 1])) 

+

41 

+

42 elif isinstance(mean, np.ndarray): 

+

43 if np.ndim(mean) == 1: 

+

44 mean = mean.reshape(-1,1) 

+

45 self.mean = np.mat(mean) 

+

46 else: 

+

47 raise Exception(str(type(mean))) 

+

48 

+

49 # Covariance 

+

50 assert cov.shape[0] == cov.shape[1] # Square matrix 

+

51 if isinstance(cov, np.ndarray): 

+

52 cov = np.mat(cov) 

+

53 self.cov = cov 

+

54 

+

55 def __rmul__(self, other): 

+

56 ''' 

+

57 Gaussian RV multiplication:  

+

58 If X ~ N(mu, sigma) and A is a matrix, then A*X ~ N(A*mu, A*sigma*A.T) 

+

59 ''' 

+

60 if isinstance(other, np.matrix): 

+

61 mu = other*self.mean 

+

62 cov = other*self.cov*other.T 

+

63 elif isinstance(other, int) or isinstance(other, np.float64) or isinstance(other, float): 

+

64 mu = other*self.mean 

+

65 cov = other**2 * self.cov 

+

66 elif isinstance(other, np.ndarray): 

+

67 # This never actually happens for an array because of how numpy implements array multiplication.. 

+

68 # (but if it did the following would be logical) 

+

69 other = np.mat(other) 

+

70 mu = other*self.mean 

+

71 cov = other*self.cov*other.T 

+

72 else: 

+

73 raise ValueError("Unrecognized type: ", type(other)) 

+

74 return GaussianState(mu, cov) 

+

75 

+

76 def __mul__(self, other): 

+

77 ''' 

+

78 Gaussian RV multiplication:  

+

79 If X ~ N(mu, sigma) and A is a matrix, then A*X ~ N(A*mu, A*sigma*A.T) 

+

80 ''' 

+

81 mean = other*self.mean 

+

82 if isinstance(other, int) or isinstance(other, np.float64) or isinstance(other, float): 

+

83 cov = other**2 * self.cov 

+

84 else: 

+

85 raise ValueError("Unrecognized type: ", type(other)) 

+

86 return GaussianState(mean, cov) 

+

87 

+

88 def __add__(self, other): 

+

89 ''' 

+

90 Gaussian RV addition: If X ~ N(mu1, sigma1) and Y ~ N(mu2, sigma2), then 

+

91 X + Y ~ N(mu1 + mu2, sigma1 + sigma2). If Y is a scalar, then X + Y ~ N(mu1 Y, sigma1) 

+

92 ''' 

+

93 if isinstance(other, (int, float)) and (other == 0): 

+

94 return GaussianState(self.mean, self.cov) 

+

95 elif isinstance(other, (int, float)): 

+

96 return GaussianState(self.mean + other, self.cov) 

+

97 elif isinstance(other, GaussianState): 

+

98 return GaussianState(self.mean+other.mean, self.cov+other.cov) 

+

99 elif isinstance(other, np.matrix) and other.shape == self.mean.shape: 

+

100 return GaussianState(self.mean + other, self.cov) 

+

101 else: 

+

102 # print other 

+

103 raise ValueError("Gaussian state: cannot add type :%s" % type(other)) 

+

104 

+

105 def probability(self, x, calc_log_pr=False): 

+

106 """ Evaluate multivariate Gaussian probability density of input vector, given the  

+

107 mean and covariance of this object """ 

+

108 assert x.shape == self.mean.shape 

+

109 k = self.mean.shape[0] 

+

110 

+

111 log_det_sign, log_det_cov = np.linalg.slogdet(self.cov) 

+

112 if log_det_sign != 1.0: 

+

113 raise ValueError("Covariance matrix is not positive definite!") 

+

114 

+

115 log_pr = -0.5*(k*np.log(2*np.pi) + log_det_cov + \ 

+

116 (x - self.mean).T * self.cov.I * (x - self.mean)) 

+

117 

+

118 if calc_log_pr: 

+

119 return log_pr 

+

120 else: 

+

121 return np.exp(log_pr) 

+

122 

+

123 def volume(self, boundary): 

+

124 """ Calculate volume of ellipsoid x^T * cov^-1 * x """ 

+

125 from scipy.special import gamma 

+

126 

+

127 cov_det = np.linalg.det(self.cov) 

+

128 n = self.mean.shape[0] 

+

129 if cov_det == 0: 

+

130 cov_eigenvals, _ = np.linalg.eig(self.cov) 

+

131 n = np.sum(cov_eigenvals > 1e-5) 

+

132 cov_det = np.product(cov_eigenvals[cov_eigenvals > 1e-5]) 

+

133 

+

134 n_sphere_vol = np.pi**(float(n)/2)/gamma(float(n)/2 + 1) 

+

135 return n_sphere_vol * boundary**(float(n)/2) * np.sqrt(cov_det) 

+

136 

+

137 def distance(self, x, sqrt=False): 

+

138 """ Calculate Mahalanobis distance """ 

+

139 if not hasattr(self, "cov_inv"): 

+

140 self.cov_inv = np.linalg.inv(self.cov) 

+

141 

+

142 dist_sq = (x - self.mean).T * self.cov_inv * (x - self.mean) 

+

143 dist_sq = dist_sq[0,0] 

+

144 if sqrt: 

+

145 return np.sqrt(dist_sq) 

+

146 else: 

+

147 return dist_sq 

+

148 

+

149 

+

150 

+

151class GaussianStateHMM(object): 

+

152 ''' 

+

153 General hidden Markov model decoder where the state is represented as a Gaussian random vector 

+

154 ''' 

+

155 model_attrs = [] 

+

156 

+

157 # List out the attributes to save at pickle time. Might not want this to be every attribute of the decoder (e.g., no point in saving the state of the BMI at pickle-time) 

+

158 attrs_to_pickle = [] 

+

159 def __init__(self, A, W): 

+

160 ''' 

+

161 Constructor for GaussianStateHMM 

+

162 

+

163 x_{t+1} = A*x_t + w_t; w_t ~ N(0, W) 

+

164 

+

165 Parameters 

+

166 ---------- 

+

167 A: np.mat of shape (N, N) 

+

168 State transition matrix 

+

169 W: np.mat of shape (N, N) 

+

170 Noise covariance 

+

171 ''' 

+

172 self.A = A 

+

173 self.W = W 

+

174 

+

175 def get_mean(self): 

+

176 ''' 

+

177 Return just the mean of the Gaussian representing the state estimate as a 1D array 

+

178 ''' 

+

179 return np.array(self.state.mean).ravel() 

+

180 

+

181 def _init_state(self, init_state=None, init_cov=None): 

+

182 """ 

+

183 Initialize the state of the filter with a mean and covariance (uncertainty) 

+

184 

+

185 Parameters 

+

186 ---------- 

+

187 init_state : np.matrix, optional 

+

188 Initial estimate of the unknown state. If unspecified, a vector of all 0's  

+

189 will be used (except for the offset state, if one exists). 

+

190 init_cov : np.matrix, optional 

+

191 Uncertainty about the initial state. If unspecified, it is assumed that there 

+

192 is no uncertainty (a matrix of all 0's). 

+

193  

+

194 Returns 

+

195 ------- 

+

196 None 

+

197 """ 

+

198 ## Initialize the BMI state, assuming  

+

199 nS = self.n_states() # number of state variables 

+

200 if init_state is None: 

+

201 init_state = np.mat( np.zeros([nS, 1]) ) 

+

202 if self.include_offset: init_state[-1,0] = 1 

+

203 if init_cov is None: 

+

204 init_cov = np.mat( np.zeros([nS, nS]) ) 

+

205 self.state = GaussianState(init_state, init_cov) 

+

206 self.init_noise_models() 

+

207 

+

208 def n_states(self): 

+

209 return self.A.shape[0] 

+

210 

+

211 def init_noise_models(self): 

+

212 ''' 

+

213 Initialize the process and observation noise models. The state noise should be  

+

214 Gaussian (as implied by the name of this class). The observation noise may be  

+

215 non-Gaussian depending on the observation model. 

+

216 ''' 

+

217 self.state_noise = GaussianState(0.0, self.W) 

+

218 self.obs_noise = GaussianState(0.0, self.Q) 

+

219 

+

220 def _ssm_pred(self, state, u=None, Bu=None, target_state=None, F=None): 

+

221 ''' 

+

222 Prior prediction of the hidden states using for linear directed random walk model 

+

223 x_{t+1} = Ax_t + c_t + w_t 

+

224 x_t = previous state 

+

225 c_t = control input (the "directed" part of the model) 

+

226 w_t = process noise (the "random walk" part of the model) 

+

227 

+

228 Parameters 

+

229 ---------- 

+

230 state : GaussianState instance 

+

231 State estimate and estimator covariance of current state 

+

232 u : np.mat, optional, default=None 

+

233 An assistive control input. Requires the filter to have an input matrix attribute, B 

+

234 Bu : np.mat of shape (N, 1) 

+

235 Assistive control input which is precomputed to already account for the control input matrix 

+

236 target_state : np.mat of shape (N, 1) 

+

237 Optimal value for x_t (defined by external factors, i.e. the task being performed) 

+

238 F : np.mat of shape (B.shape[1], N) 

+

239 Feedback control gains. Used to compute u_t = BF(x^* - x_t) 

+

240  

+

241 

+

242 Returns 

+

243 ------- 

+

244 GaussianState instance 

+

245 Represents the mean and estimator covariance of the new state estimate 

+

246 ''' 

+

247 A = self.A 

+

248 

+

249 if not (Bu is None): 

+

250 c_t = Bu 

+

251 elif not (u is None): 

+

252 c_t = self.B * u 

+

253 elif not (target_state is None): 

+

254 B = self.B 

+

255 if F is None: 

+

256 F = self.F 

+

257 # if not np.all(target_state[:-1,:] == 0): 

+

258 # import pdb; pdb.set_trace() 

+

259 A = A - B*F 

+

260 c_t = B*F*target_state 

+

261 else: 

+

262 c_t = 0 

+

263 

+

264 return A * state + c_t + self.state_noise 

+

265 

+

266 def __eq__(self, other): 

+

267 ''' 

+

268 Determine equality of two GaussianStateHMM instances 

+

269 ''' 

+

270 # import train 

+

271 return GaussianStateHMM.obj_eq(self, other, self.model_attrs) 

+

272 

+

273 def __sub__(self, other): 

+

274 ''' 

+

275 Subtract the model attributes of two GaussianStateHMM instances. Used to determine approximate equality, i.e., equality modulo floating point error 

+

276 ''' 

+

277 # import train 

+

278 return GaussianStateHMM.obj_diff(self, other, self.model_attrs) 

+

279 

+

280 @staticmethod 

+

281 def obj_eq(self, other, attrs=[]): 

+

282 ''' 

+

283 Determine if two objects have mattching array attributes 

+

284 

+

285 Parameters 

+

286 ---------- 

+

287 other : object 

+

288 If objects are not the same type, False is returned 

+

289 attrs : list, optional 

+

290 List of attributes to compare for equality. Only attributes that are common to both objects are used. 

+

291 The attributes should be np.array or similar as np.array_equal is used to determine equality 

+

292 

+

293 Returns 

+

294 ------- 

+

295 bool  

+

296 True value returned indicates equality between objects for the specified attributes 

+

297 ''' 

+

298 if isinstance(other, type(self)): 

+

299 attrs_eq = [y for y in [x for x in attrs if x in self.__dict__] if y in other.__dict__] 

+

300 equal = [np.array_equal(getattr(self, attr), getattr(other, attr)) for attr in attrs_eq] 

+

301 return np.all(equal) 

+

302 else: 

+

303 return False 

+

304 

+

305 @staticmethod 

+

306 def obj_diff(self, other, attrs=[]): 

+

307 ''' 

+

308 Calculate the difference of the two objects w.r.t the specified attributes 

+

309 

+

310 Parameters 

+

311 ---------- 

+

312 other : object 

+

313 If objects are not the same type, False is returned 

+

314 attrs : list, optional 

+

315 List of attributes to compare for equality. Only attributes that are common to both objects are used. 

+

316 The attributes should be np.array or similar as np.array_equal is used to determine equality 

+

317 

+

318 Returns 

+

319 ------- 

+

320 np.array 

+

321 The difference between each of the specified 'attrs' 

+

322 ''' 

+

323 if isinstance(other, type(self)): 

+

324 attrs_eq = [y for y in [x for x in attrs if x in self.__dict__] if y in other.__dict__] 

+

325 diff = [getattr(self, attr) - getattr(other, attr) for attr in attrs_eq] 

+

326 return np.array(diff) 

+

327 else: 

+

328 return False 

+

329 

+

330 def __call__(self, obs, **kwargs): 

+

331 """ 

+

332 When the object is called directly, it's a wrapper for the  

+

333 1-step forward inference function. 

+

334 """ 

+

335 self.state = self._forward_infer(self.state, obs, **kwargs) 

+

336 return self.state.mean 

+

337 

+

338 def _pickle_init(self): 

+

339 pass 

+

340 

+

341 def __setstate__(self, state): 

+

342 """ 

+

343 Unpickle decoders by loading all the saved parameters and then running _pickle_init 

+

344 

+

345 Parameters 

+

346 ---------- 

+

347 state : dict 

+

348 Provided by the unpickling system 

+

349 

+

350 Returns 

+

351 ------- 

+

352 None 

+

353 """ 

+

354 self.__dict__ = state 

+

355 self._pickle_init() 

+

356 

+

357 def __getstate__(self): 

+

358 data_to_pickle = dict() 

+

359 for attr in self.attrs_to_pickle: 

+

360 try: 

+

361 data_to_pickle[attr] = getattr(self, attr) 

+

362 except: 

+

363 print(("GaussianStateHMM: could not pickle attribute %s" % attr)) 

+

364 return data_to_pickle 

+

365 

+

366 def predict(self, observations, *args, **kwargs): 

+

367 if isinstance(observations, list): 

+

368 N = len(observations) 

+

369 elif isinstance(observations, np.ndarray): 

+

370 time_dim = kwargs.pop("time_dim", 1) 

+

371 if time_dim not in [0, 1]: 

+

372 raise ValueError("Can't interpret observation matrix") 

+

373 N = observations.shape[time_dim] 

+

374 if time_dim == 1: 

+

375 observations = observations.T 

+

376 state_seq = [] 

+

377 data = [] 

+

378 for k in range(N): 

+

379 ret = self._forward_infer(self.state, observations[k], *args, **kwargs) 

+

380 if np.iterable(ret): 

+

381 state_k, data_k = ret 

+

382 else: 

+

383 state_k = ret 

+

384 data_k = None 

+

385 

+

386 self.state = state_k 

+

387 

+

388 state_seq.append(state_k) 

+

389 data.append(data_k) 

+

390 return state_seq, data 

+

391 

+

392 

+

393class MachineOnlyFilter(GaussianStateHMM): 

+

394 ''' 

+

395 A degenerate case of the GaussianStateHMM where the inputs are driven only by the  

+

396 input control term and all observations are ignored.  

+

397 In other words, W = 0 and x_{t+1} = Ax_t + Bu_t  

+

398 ''' 

+

399 def __init__(self, *args, **kwargs): 

+

400 super(MachineOnlyFilter, self).__init__(*args, **kwargs) 

+

401 self.include_offset = True 

+

402 

+

403 def init_noise_models(self): 

+

404 ''' 

+

405 see bmi.GaussianStateHMM.init_noise_models for documentation 

+

406 ''' 

+

407 self.state_noise = GaussianState(0.0, self.W) 

+

408 

+

409 def _forward_infer(self, st, obs_t, Bu=None, u=None, target_state=None, **kwargs): 

+

410 if Bu is not None: 

+

411 return self.A * st + Bu 

+

412 else: 

+

413 return self.A * st 

+

414 

+

415 

+

416class RectangularBounder(object): 

+

417 """ Hard limit on state values """ 

+

418 def __init__(self, bounding_box, states_to_bound): 

+

419 self.bounding_box = bounding_box 

+

420 self.states_to_bound = states_to_bound 

+

421 

+

422 def __call__(self, state_mean, state_names): 

+

423 """ 

+

424 Apply bounds on state vector, if bounding box is specified 

+

425 """ 

+

426 state_mean = state_mean.copy() 

+

427 min_bounds, max_bounds = self.bounding_box 

+

428 

+

429 repl_with_min = np.array(state_mean[:,0]).ravel() < min_bounds 

+

430 repl_with_max = np.array(state_mean[:,0]).ravel() > max_bounds 

+

431 state_mean[repl_with_min, :] = min_bounds[repl_with_min].reshape(-1, 1) 

+

432 state_mean[repl_with_max, :] = min_bounds[repl_with_max].reshape(-1, 1) 

+

433 return state_mean 

+

434 

+

435 

+

436class Decoder(object): 

+

437 ''' 

+

438 All BMI decoders should inherit from this class 

+

439 ''' 

+

440 def __init__(self, filt, units, ssm, binlen=0.1, n_subbins=1, tslice=[-1,-1], call_rate=60.0, **kwargs): 

+

441 """  

+

442 Parameters 

+

443 ---------- 

+

444 filt : PointProcessFilter or KalmanFilter instance 

+

445 Generic inference algorithm that does the actual observation decoding 

+

446 units : array-like 

+

447 N x 2 array of units, where each row is (chan, unit) 

+

448 ssm : state_space_models.StateSpace instance  

+

449 The state-space model describes the states tracked by the decoder, whether or not 

+

450 they are stochastic/related to the observations, bounds on the state, etc. 

+

451 binlen : float, optional, default = 0.1 

+

452 Bin-length specified in seconds. Gets rounded to a multiple of 1./60 

+

453 to match the update rate of the task 

+

454 n_subbins : int, optional, default = 3 

+

455 Neural observations are always acquired at the 60Hz screen update rate. 

+

456 This parameter explains how many bins to sub-divide the observations  

+

457 into. Default of 3 is intended to correspond to ~180Hz / 5.5ms bins 

+

458 tslice : array_like, optional, default=[-1, -1] 

+

459 start and end times for the neural data used to train, e.g. from the .plx file 

+

460 call_rate: float, optional, default = 60 Hz 

+

461 Rate in Hz at which the task will run the __call__ function. 

+

462 """ 

+

463 

+

464 self.filt = filt 

+

465 if not filt is None: 

+

466 self.filt._init_state() 

+

467 self.ssm = ssm 

+

468 

+

469 self.units = np.array(units, dtype=np.int32) 

+

470 self.binlen = binlen 

+

471 self.bounding_box = ssm.bounding_box 

+

472 self.states = ssm.state_names 

+

473 

+

474 # The tslice parameter below properly belongs in the database and 

+

475 # not in the decoder object because the Decoder object has no record of  

+

476 # which plx file it was trained from. This is a leftover from when it 

+

477 # was assumed that every decoder would be trained entirely from a plx 

+

478 # file (i.e. and not CLDA) 

+

479 self.tslice = tslice 

+

480 self.states_to_bound = ssm.states_to_bound 

+

481 

+

482 self.drives_neurons = ssm.drives_obs #drives_neurons 

+

483 self.n_subbins = n_subbins 

+

484 

+

485 self.bmicount = 0 

+

486 self.bminum = int(self.binlen/(1/call_rate)) 

+

487 self.spike_counts = np.zeros([len(units), 1]) 

+

488 

+

489 self.set_call_rate(call_rate) 

+

490 

+

491 self._pickle_init() 

+

492 

+

493 def _pickle_init(self): 

+

494 ''' 

+

495 Functionality common to unpickling a Decoder from file and instantiating a new Decoder. 

+

496 A call to this function is the last line in __init__ as well as __setstate__. 

+

497 ''' 

+

498 from . import train 

+

499 

+

500 # If the decoder doesn't have an 'ssm' attribute, then it's an old 

+

501 # decoder in which case the ssm is the 2D endpoint SSM 

+

502 if not hasattr(self, 'ssm'): 

+

503 from . import state_space_models 

+

504 self.ssm = state_space_models.StateSpaceEndptVel2D() 

+

505 # self.ssm = train.endpt_2D_state_space 

+

506 

+

507 # Assign a default call rate of 60 Hz and initialize the bmicount/bminum attributes 

+

508 if hasattr(self, 'call_rate'): 

+

509 self.set_call_rate(self.call_rate) 

+

510 else: 

+

511 self.set_call_rate(60.0) 

+

512 

+

513 def plot_pds(self, C, ax=None, plot_states=['hand_vx', 'hand_vz'], invert=False, **kwargs): 

+

514 ''' 

+

515 Plot 2D "preferred directions" of features in the Decoder 

+

516 

+

517 Parameters 

+

518 ---------- 

+

519 C: np.array of shape (n_features, n_states) 

+

520 ax: matplotlib.pyplot axis, default=None 

+

521 axis to plot on. If None specified, a new one is created.  

+

522 plot_states: list of strings, default=['hand_vx', 'hand_vz'] 

+

523 List of decoder states to plot. Only two can be specified currently 

+

524 invert: bool, default=False 

+

525 If true, flip the signs of the arrows plotted 

+

526 kwargs: dict 

+

527 Keyword arguments for the low-level matplotlib function 

+

528 ''' 

+

529 import matplotlib.pyplot as plt 

+

530 if ax == None: 

+

531 plt.figure() 

+

532 ax = plt.subplot(111) 

+

533 ax.hold(True) 

+

534 

+

535 if C.shape[1] > 2: 

+

536 state_inds = [self.states.index(x) for x in plot_states] 

+

537 x, z = state_inds 

+

538 else: 

+

539 x, z = 0, 1 

+

540 n_neurons = C.shape[0] 

+

541 linestyles = ['-.', '-', '--', ':'] 

+

542 if invert: 

+

543 C = C*-1 

+

544 for k in range(n_neurons): 

+

545 unit_str = '%d%s' % (self.units[k,0], chr(96 + self.units[k,1])) 

+

546 ax.plot([0, C[k, x]], [0, C[k, z]], label=unit_str, linestyle=linestyles[k/7 % len(linestyles)], **kwargs) 

+

547 ax.legend(bbox_to_anchor=(1.1, 1.05), prop=dict(size=8)) 

+

548 try: 

+

549 ax.set_xlabel(plot_states[0]) 

+

550 ax.set_ylabel(plot_states[1]) 

+

551 except: 

+

552 pass 

+

553 ax.set_title(self) 

+

554 

+

555 def plot_C(self, **kwargs): 

+

556 ''' 

+

557 Plot the C matrix (see plot_pds docstring), which is used 

+

558 by the KFDecoder and the PPFDecoder 

+

559 ''' 

+

560 self.plot_pds(self.filt.C, linewidth=2, **kwargs) 

+

561 

+

562 def update_params(self, new_params, **kwargs): 

+

563 ''' 

+

564 Method for updating the parameters of the decoder 

+

565 

+

566 Parameters 

+

567 ---------- 

+

568 new_params: dict  

+

569 Keys are the parameters to be replaced, values are the new value of  

+

570 the parameter to replace. In particular, the keys can be dot-separated, 

+

571 e.g. to set the attribute 'self.kf.C', the key would be 'kf.C' 

+

572 ''' 

+

573 for key, val in list(new_params.items()): 

+

574 attr_list = key.split('.') 

+

575 final_attr = attr_list[-1] 

+

576 attr_list = attr_list[:-1] 

+

577 attr = self 

+

578 while len(attr_list) > 0: 

+

579 attr = getattr(self, attr_list[0]) 

+

580 attr_list = attr_list[1:] 

+

581 

+

582 setattr(attr, final_attr, val) 

+

583 

+

584 def bound_state(self): 

+

585 """ 

+

586 Apply bounds on state vector, if bounding box is specified 

+

587 """ 

+

588 if not self.bounding_box is None: 

+

589 min_bounds, max_bounds = self.bounding_box 

+

590 state = self[self.states_to_bound] 

+

591 repl_with_min = state < min_bounds 

+

592 state[repl_with_min] = min_bounds[repl_with_min] 

+

593 

+

594 repl_with_max = state > max_bounds 

+

595 state[repl_with_max] = max_bounds[repl_with_max] 

+

596 self[self.states_to_bound] = state 

+

597 

+

598 def __getitem__(self, idx): 

+

599 """ 

+

600 Get element(s) of the BMI state, indexed by name or number 

+

601 

+

602 Warning: The variable 'q' is a reserved keyword, referring to all of 

+

603 the position states. This strange letter choice was made to be consistent 

+

604 with the robotics literature, where 'q' refers to the vector of  

+

605 generalized joint coordinates. 

+

606 

+

607 Parameters 

+

608 ---------- 

+

609 idx: int or string 

+

610 Name of the state, index of the state, or list of indices/names  

+

611 of the Decoder state(s) to return 

+

612 """ 

+

613 if isinstance(idx, int) or isinstance(idx, np.int64) or isinstance(idx, np.int32): 

+

614 return self.filt.state.mean[idx, 0] 

+

615 elif idx == 'q': 

+

616 pos_states, = np.nonzero(self.ssm.state_order == 0) 

+

617 return np.array([self.__getitem__(k) for k in pos_states]) 

+

618 elif idx == 'qdot': 

+

619 vel_states, = np.nonzero(self.ssm.state_order == 1) 

+

620 return np.array([self.__getitem__(k) for k in vel_states]) 

+

621 elif isinstance(idx, str) or isinstance(idx, str): 

+

622 idx = self.states.index(idx) 

+

623 return self.filt.state.mean[idx, 0] 

+

624 elif np.iterable(idx): 

+

625 return np.array([self.__getitem__(k) for k in idx]) 

+

626 else: 

+

627 try: 

+

628 return self.filt.state.mean[idx, 0] 

+

629 except: 

+

630 raise ValueError("Decoder: Improper index type: %s" % type(idx)) 

+

631 

+

632 def __setitem__(self, idx, value): 

+

633 """ 

+

634 Set element(s) of the BMI state, indexed by name or number 

+

635 

+

636 Parameters 

+

637 ---------- 

+

638 idx: int or string 

+

639 Name of the state, index of the state, or list of indices/names  

+

640 of the Decoder state(s) to return 

+

641 """ 

+

642 if isinstance(idx, int) or isinstance(idx, np.int64) or isinstance(idx, np.int32): 

+

643 self.filt.state.mean[idx, 0] = value 

+

644 elif idx == 'q': 

+

645 pos_states, = np.nonzero(self.ssm.state_order == 0) 

+

646 self.filt.state.mean[pos_states, 0] = value 

+

647 elif idx == 'qdot': 

+

648 vel_states, = np.nonzero(self.ssm.state_order == 1) 

+

649 self.filt.state.mean[vel_states, 0] = value 

+

650 elif isinstance(idx, str) or isinstance(idx, str): 

+

651 idx = self.states.index(idx) 

+

652 self.filt.state.mean[idx, 0] = value 

+

653 elif np.iterable(idx): 

+

654 [self.__setitem__(k, val) for k, val in zip(idx, value)] 

+

655 else: 

+

656 try: 

+

657 self.filt.state.mean[idx, 0] = value 

+

658 except: 

+

659 raise ValueError("Decoder: Improper index type: %" % type(idx)) 

+

660 

+

661 def __setstate__(self, state): 

+

662 """ 

+

663 Set decoder state after un-pickling 

+

664 """ 

+

665 if 'db_entry' in state: 

+

666 del state['db_entry'] 

+

667 self.__dict__.update(state) 

+

668 self.filt._pickle_init() 

+

669 self.filt._init_state() 

+

670 

+

671 if not hasattr(self, 'n_subbins'): 

+

672 self.n_subbins = 1 

+

673 

+

674 if not hasattr(self, 'interpolate_using_ssm'): 

+

675 self.interpolate_using_ssm = False 

+

676 

+

677 if not hasattr(self, 'bmicount'): 

+

678 self.bmicount = 0 

+

679 

+

680 if not hasattr(self, 'n_features'): 

+

681 self.n_features = len(self.units) 

+

682 

+

683 self._pickle_init() 

+

684 

+

685 def set_call_rate(self, call_rate): 

+

686 ''' 

+

687 Function for the higher-level task to set the frequency of function calls to __call__ 

+

688 

+

689 Parameters 

+

690 ---------- 

+

691 call_rate : float  

+

692 1./call_rate should be an integer multiple or divisor of the Decoder's 'binlen' 

+

693 

+

694 Returns 

+

695 ------- 

+

696 None 

+

697 ''' 

+

698 self.call_rate = call_rate 

+

699 self.bmicount = 0 

+

700 self.bminum = int(self.binlen/(1./self.call_rate)) 

+

701 self.n_subbins = int(np.ceil(1./self.binlen /self.call_rate)) 

+

702 

+

703 def get_state(self, shape=-1): 

+

704 ''' 

+

705 Get the state of the decoder (mean of the Gaussian RV representing the 

+

706 state of the BMI) 

+

707 ''' 

+

708 return np.asarray(self.filt.state.mean).reshape(shape) 

+

709 

+

710 def predict(self, neural_obs, assist_level=0.0, weighted_avg_lfc=False, **kwargs): 

+

711 """ 

+

712 Decode the spikes 

+

713 

+

714 Parameters 

+

715 ---------- 

+

716 neural_obs: np.array of shape (N,) or (N, 1) 

+

717 One time-point worth of neural features to decode 

+

718 assist_level: float 

+

719 Weight given to the assist term. This variable name may be a slight misnomer, a more appropriate term might be 'reweight_factor' 

+

720 Bu: np.mat of shape (N, 1) 

+

721 Assist vector to be added on to the Decoder state. Must be of the same dimension  

+

722 as the state vector. 

+

723 kwargs: dict 

+

724 Mostly for kwargs function call compatibility 

+

725 """ 

+

726 if np.any(neural_obs > 1000): 

+

727 print('observations have counts >> 1000 ') 

+

728 

+

729 if np.any(assist_level) > 0 and 'x_assist' not in kwargs: 

+

730 raise ValueError("Assist cannot be used if the forcing term is not specified!") 

+

731 

+

732 # re-normalize the variance of the spike observations, if nec 

+

733 if hasattr(self, 'zscore') and self.zscore: 

+

734 #neural_obs = (np.asarray(neural_obs).ravel() - self.mFR_curr) * self.sdFR_ratio 

+

735 neural_obs = (np.asarray(neural_obs).ravel() - self.mFR) * (1./self.sdFR) 

+

736 # set the spike count of any unit that now has zero-mean with its original mean 

+

737 # This functionally removes it from the decoder.  

+

738 neural_obs[self.zeromeanunits] = self.mFR[self.zeromeanunits] 

+

739 

+

740 # re-format as a column matrix 

+

741 neural_obs = np.mat(neural_obs.reshape(-1,1)) 

+

742 

+

743 x = self.filt.state.mean 

+

744 

+

745 # Run the filter 

+

746 self.filt(neural_obs, **kwargs) 

+

747 

+

748 if np.any(assist_level) > 0: 

+

749 x_assist = kwargs.pop('x_assist') 

+

750 

+

751 if 'ortho_damp_assist' in kwargs and kwargs['ortho_damp_assist']: 

+

752 x_assist[self.drives_neurons,:] /= np.linalg.norm(x_assist[self.drives_neurons,:]) 

+

753 targ_comp = float(self.filt.state.mean[self.drives_neurons,:].T*x_assist[self.drives_neurons,:])*x_assist[self.drives_neurons,:] 

+

754 orth_comp = self.filt.state.mean[self.drives_neurons,:] - targ_comp 

+

755 

+

756 if type(assist_level) is np.ndarray: 

+

757 tmp = np.mat(np.zeros((len(self.filt.state.mean)))).T 

+

758 tmp[:7, :] = self.filt.state.mean[:7, 0] 

+

759 assist_level_ix = kwargs['assist_level_ix'] 

+

760 for ia, al in enumerate(assist_level): 

+

761 ix = np.nonzero(assist_level_ix[ia] <= 6)[0] 

+

762 tmp[assist_level_ix[ia][ix]+7, :] = targ_comp[assist_level_ix[ia][ix]] + (1 - al)*orth_comp[assist_level_ix[ia][ix]] 

+

763 self.filt.state.mean = tmp 

+

764 

+

765 else: 

+

766 # High assist damps orthogonal component a lot 

+

767 self.filt.state.mean[self.drives_neurons,:] = targ_comp + (1 - assist_level)*orth_comp 

+

768 

+

769 elif type(assist_level) is np.ndarray: 

+

770 tmp = np.zeros((len(self.filt.state.mean))) 

+

771 assist_level_ix = kwargs['assist_level_ix'] 

+

772 for ia, al in enumerate(assist_level): 

+

773 tmp[assist_level_ix[ia]] = (1-al)*self.filt.state.mean[assist_level_ix[ia]] + al*x_assist[assist_level_ix[ia]] 

+

774 self.filt.state.mean = np.mat(tmp).T 

+

775 

+

776 else: 

+

777 self.filt.state.mean = (1-assist_level)*self.filt.state.mean + assist_level * x_assist 

+

778 

+

779 # Bound cursor, if any hard bounds for states are applied 

+

780 if hasattr(self, 'bounder'): 

+

781 self.filt.state.mean = self.bounder(self.filt.state.mean, self.states) 

+

782 

+

783 state = self.filt.get_mean() 

+

784 return state 

+

785 

+

786 def decode(self, neural_obs, **kwargs): 

+

787 ''' 

+

788 Decode multiple observations sequentially. 

+

789 

+

790 Parameters 

+

791 ---------- 

+

792 neural_obs: np.array of shape (# features, # observations) 

+

793 Independent neural observations are columns of the data 

+

794 matrix and are decoded sequentially 

+

795 kwargs: dict 

+

796 Container for special keyword-arguments for the specific decoding 

+

797 algorithm's 'predict'.  

+

798 ''' 

+

799 output = [] 

+

800 n_obs = neural_obs.shape[1] 

+

801 for k in range(n_obs): 

+

802 self.predict(neural_obs[:,k], **kwargs) 

+

803 output.append(self.filt.get_mean()) 

+

804 return np.vstack(output) 

+

805 

+

806 def __str__(self): 

+

807 if hasattr(self, 'db_entry'): 

+

808 return self.db_entry.name 

+

809 else: 

+

810 return super(Decoder, self).__str__() 

+

811 

+

812 @property 

+

813 def n_states(self): 

+

814 ''' 

+

815 Return the number of states represented in the Decoder 

+

816 ''' 

+

817 return len(self.states) 

+

818 

+

819 @property 

+

820 def n_units(self): 

+

821 ''' 

+

822 Return the number of units used in the decoder. Not sure what this  

+

823 does for LFP decoders, i.e. decoders which extract multiple features from 

+

824 a single channel. 

+

825 ''' 

+

826 return len(self.units) 

+

827 

+

828 def __call__(self, obs_t, **kwargs): 

+

829 ''' 

+

830 Wrapper for the 'predict' method 

+

831 

+

832 Parameters 

+

833 ---------- 

+

834 obs_t: np.array of shape (# features, # subbins) 

+

835 Neural observation vector. If the decoding_rate of the Decoder is 

+

836 greater than the control rate of the plant (e.g. 60 Hz ) 

+

837 kwargs: dictionary 

+

838 Algorithm-specific arguments to be given to the Decoder.predict method 

+

839 ''' 

+

840 

+

841 self.predict(obs_t, **kwargs) 

+

842 return self.filt.get_mean().reshape(-1,1) 

+

843 

+

844 def save(self, filename=''): 

+

845 ''' 

+

846 Pickle the Decoder object to a file 

+

847 

+

848 Parameters 

+

849 ---------- 

+

850 filename: string, optional 

+

851 Filename to pickle the decoder to. If unspecified, a temporary file will be created. 

+

852 

+

853 Returns 

+

854 ------- 

+

855 filename: string 

+

856 filename of pickled Decoder object  

+

857 ''' 

+

858 if filename is not '': 

+

859 f = open(filename, 'w') 

+

860 pickle.dump(self, f) 

+

861 f.close() 

+

862 return filename 

+

863 else: 

+

864 import tempfile, pickle 

+

865 tf2 = tempfile.NamedTemporaryFile(delete=False) 

+

866 pickle.dump(self, tf2) 

+

867 tf2.flush() 

+

868 return tf2.name 

+

869 

+

870 def save_attrs(self, hdf_filename, table_name='task'): 

+

871 ''' 

+

872 Save the attributes of the Decoder to the attributes of the specified HDF table 

+

873 

+

874 Parameters 

+

875 ---------- 

+

876 hdf_filename: string 

+

877 HDF filename to write data to 

+

878 table_name: string, default='task' 

+

879 Specify the table within the HDF file to set attributes in.  

+

880 ''' 

+

881 h5file = tables.openFile(hdf_filename, mode='a') 

+

882 table = getattr(h5file.root, table_name) 

+

883 for attr in self.filt.model_attrs: 

+

884 table.attrs[attr] = np.array(getattr(self.filt, attr)) 

+

885 h5file.close() 

+

886 

+

887 @property 

+

888 def state_shape_rt(self): 

+

889 ''' 

+

890 Create attribute to access the shape of the accumulating spike counts feature. 

+

891 ''' 

+

892 return (self.n_states, self.n_subbins) 

+

893 

+

894 

+

895class BMISystem(object): 

+

896 ''' 

+

897 This class encapsulates all of the BMI decoding computations, including assist and CLDA 

+

898 ''' 

+

899 def __init__(self, decoder, learner, updater, feature_accumulator): 

+

900 ''' 

+

901 Instantiate the BMISystem 

+

902  

+

903 Parameters 

+

904 ---------- 

+

905 decoder : bmi.Decoder instance 

+

906 The decoder maps spike counts into the "state" of the prosthesis 

+

907 learner : clda.Learner instance 

+

908 The learner estimates the "intended" prosthesis state from task goals. 

+

909 updater : clda.Updater instance 

+

910 The updater remaps the decoder parameters to better match sets of  

+

911 observed spike counts and intended kinematics (from the learner) 

+

912 feature_accumulator : accumulator.FeatureAccumulator instance 

+

913 Combines features across time if necesary to perform rate matching  

+

914 between the task rate and the decoder rate. 

+

915 

+

916 Returns 

+

917 ------- 

+

918 BMISystem instance 

+

919 ''' 

+

920 self.decoder = decoder 

+

921 self.learner = learner 

+

922 self.updater = updater 

+

923 self.feature_accumulator = feature_accumulator 

+

924 self.param_hist = [] 

+

925 

+

926 self.has_updater = not (self.updater is None) 

+

927 if self.has_updater: 

+

928 self.updater.init(self.decoder) 

+

929 

+

930 def __call__(self, neural_obs, target_state, task_state, learn_flag=False, **kwargs): 

+

931 ''' 

+

932 Main function for all BMI functions, including running the decoder, adapting the decoder  

+

933 and incorporating assistive control inputs 

+

934 

+

935 Parameters 

+

936 ---------- 

+

937 neural_obs : np.ndarray,  

+

938 The shape of neural_obs should be [n_units, n_obs]. If multiple observations are given, then 

+

939 the decoder will run multiple times before returning.  

+

940 target_state : np.ndarray 

+

941 The assumed state that the subject is trying to drive the BMI toward, e.g. based on the  

+

942 objective of the task  

+

943 task_state : string 

+

944 State of the task. Used by CLDA so that assist is only applied during certain states, 

+

945 e.g. in some tasks, the target will be ambiguous during penalty states so CLDA should  

+

946 ignore data during those epochs. 

+

947 learn_flag : bool, optional, default=True 

+

948 Boolean specifying whether the decoder should update based on intention estimates 

+

949 **kwargs : dict 

+

950 Instance-specific arguments, e.g. RML/SmoothBatch require a 'half_life' parameter  

+

951 that is not required of other CLDA methods.  

+

952 

+

953 Returns 

+

954 ------- 

+

955 decoded_states : np.ndarray 

+

956 Columns of the array are vectors representing the decoder output as each of the  

+

957 observations are decoded. 

+

958 update_flag : boolean 

+

959 Boolean to indicate whether the parameters of the Decoder have changed based on the 

+

960 current function call  

+

961 ''' 

+

962 n_units, n_obs = neural_obs.shape 

+

963 # If the target is specified as a 1D position, tile to match  

+

964 # the number of dimensions as the neural features 

+

965 if np.ndim(target_state) == 1 or (target_state.shape[1] == 1 and n_obs > 1): 

+

966 target_state = np.tile(target_state, [1, n_obs]) 

+

967 

+

968 

+

969 decoded_states = np.zeros([self.decoder.n_states, n_obs]) 

+

970 update_flag = False 

+

971 

+

972 for k in range(n_obs): 

+

973 neural_obs_k = neural_obs[:,k].reshape(-1,1) 

+

974 target_state_k = target_state[:,k] 

+

975 

+

976 # NOTE: the conditional below is *only* for compatibility with older Carmena 

+

977 # lab data collected using a different MATLAB-based system. In all python cases,  

+

978 # the task_state should never contain NaN values.  

+

979 if np.any(np.isnan(target_state_k)): task_state = 'no_target' 

+

980 

+

981 ################################# 

+

982 ## Decode the current observation 

+

983 ################################# 

+

984 decodable_obs, decode = self.feature_accumulator(neural_obs_k) 

+

985 if decode: # if a new decodable observation is available from the feature accumulator 

+

986 prev_state = self.decoder.get_state() 

+

987 

+

988 self.decoder(decodable_obs, **kwargs) 

+

989 

+

990 # Determine whether the current state or previous state should be given to the learner 

+

991 if self.learner.input_state_index == 0: 

+

992 learner_state = self.decoder.get_state() 

+

993 elif self.learner.input_state_index == -1: 

+

994 learner_state = prev_state 

+

995 else: 

+

996 print(("Not implemented yet: %d" % self.learner.input_state_index)) 

+

997 learner_state = prev_state 

+

998 

+

999 if learn_flag: 

+

1000 self.learner(decodable_obs.copy(), learner_state, target_state_k, self.decoder.get_state(), task_state, state_order=self.decoder.ssm.state_order) 

+

1001 

+

1002 decoded_states[:,k] = self.decoder.get_state() 

+

1003 

+

1004 ############################ 

+

1005 ## Update decoder parameters 

+

1006 ############################ 

+

1007 if self.learner.is_ready(): 

+

1008 batch_data = self.learner.get_batch() 

+

1009 batch_data['decoder'] = self.decoder 

+

1010 kwargs.update(batch_data) 

+

1011 self.updater(**kwargs) 

+

1012 self.learner.disable() 

+

1013 

+

1014 new_params = None # by default, no new parameters are available 

+

1015 if self.has_updater: 

+

1016 new_params = copy.deepcopy(self.updater.get_result()) 

+

1017 

+

1018 # Update the decoder if new parameters are available 

+

1019 if not (new_params is None): 

+

1020 self.decoder.update_params(new_params, **self.updater.update_kwargs) 

+

1021 new_params['intended_kin'] = batch_data['intended_kin'] 

+

1022 new_params['spike_counts_batch'] = batch_data['spike_counts'] 

+

1023 

+

1024 self.learner.enable() 

+

1025 update_flag = True 

+

1026 

+

1027 # Save new parameters to parameter history 

+

1028 self.param_hist.append(new_params) 

+

1029 return decoded_states, update_flag 

+

1030 

+

1031 

+

1032class BMILoop(object): 

+

1033 ''' 

+

1034 Container class/interface definition for BMI tasks. Intended to be used with multiple inheritance structure paired with riglib.experiment classes 

+

1035 ''' 

+

1036 static_states = [] # states in which the decoder is not run 

+

1037 decoder_sequence = '' 

+

1038 

+

1039 def init(self): 

+

1040 ''' 

+

1041 Secondary init function. Finishes initializing the task after all the  

+

1042 constructors have run and all the requried attributes have been declared 

+

1043 for the task to operate.  

+

1044 ''' 

+

1045 # Initialize the decoder 

+

1046 self.load_decoder() 

+

1047 self.init_decoder_state() 

+

1048 if hasattr(self.decoder, 'adapting_state_inds'): 

+

1049 print('Decoder has adapting state inds') 

+

1050 

+

1051 if hasattr(self.decoder, 'adapting_neural_inds'): 

+

1052 print('Decoder has adapting neural inds') 

+

1053 

+

1054 # Declare data attributes to be stored in the sinks every iteration of the FSM 

+

1055 self.add_dtype('loop_time', 'f8', (1,)) 

+

1056 self.add_dtype('decoder_state', 'f8', (self.decoder.n_states, 1)) 

+

1057 self.add_dtype('internal_decoder_state', 'f8', self.decoder.state_shape_rt) 

+

1058 self.add_dtype('target_state', 'f8', self.decoder.state_shape_rt) 

+

1059 self.add_dtype('update_bmi', 'f8', (1,)) 

+

1060 

+

1061 # Construct the sub-pieces of the BMI system 

+

1062 self.create_assister() 

+

1063 self.create_feature_extractor() 

+

1064 self.create_feature_accumulator() 

+

1065 self.create_goal_calculator() 

+

1066 self.create_learner() 

+

1067 self.create_updater() 

+

1068 self.create_bmi_system() 

+

1069 

+

1070 super(BMILoop, self).init() 

+

1071 

+

1072 def create_bmi_system(self): 

+

1073 self.bmi_system = BMISystem(self.decoder, self.learner, self.updater, self.feature_accumulator) 

+

1074 

+

1075 def load_decoder(self): 

+

1076 ''' 

+

1077 Shell function. In tasks launched from the GUI with the BMI feature  

+

1078 enabled, the decoder attribute is automatically added to the task. This 

+

1079 is for simulation purposes only (or if you want to make a version that 

+

1080 launches from the command line) 

+

1081 ''' 

+

1082 pass 

+

1083 

+

1084 def init_decoder_state(self): 

+

1085 ''' 

+

1086 Initialize the state of the decoder to match the initial state of the plant 

+

1087 ''' 

+

1088 self.decoder.filt._init_state() 

+

1089 try: 

+

1090 self.decoder['q'] = self.plant.get_intrinsic_coordinates() 

+

1091 except: 

+

1092 print((self.plant.get_intrinsic_coordinates())) 

+

1093 print((self.decoder['q'])) 

+

1094 raise Exception("Error initializing decoder state") 

+

1095 self.init_decoder_mean = self.decoder.filt.state.mean 

+

1096 

+

1097 self.decoder.set_call_rate(1./self.update_rate) 

+

1098 

+

1099 def create_assister(self): 

+

1100 ''' 

+

1101 The 'assister' is a callable object which, for the specific plant being controlled, 

+

1102 will drive the plant toward the specified target state of the task.  

+

1103 ''' 

+

1104 self.assister = None 

+

1105 

+

1106 def create_feature_accumulator(self): 

+

1107 ''' 

+

1108 Instantiate the feature accumulator used to implement rate matching between the Decoder and the task, 

+

1109 e.g. using a 10 Hz KFDecoder in a 60 Hz task 

+

1110 ''' 

+

1111 from . import accumulator 

+

1112 feature_shape = [self.decoder.n_features, 1] 

+

1113 feature_dtype = np.float64 

+

1114 acc_len = int(self.decoder.binlen / self.update_rate) 

+

1115 acc_len = max(1, acc_len) 

+

1116 if self.extractor.feature_type in ['lfp_power', 'emg_amplitude']: 

+

1117 self.feature_accumulator = accumulator.NullAccumulator(acc_len) 

+

1118 else: 

+

1119 self.feature_accumulator = accumulator.RectWindowSpikeRateEstimator(acc_len, feature_shape, feature_dtype) 

+

1120 

+

1121 def create_goal_calculator(self): 

+

1122 ''' 

+

1123 The 'goal_calculator' is a callable object which will define the optimal state for the Decoder  

+

1124 to be in for this particular task. This object is necessary for CLDA (to estimate the "error" of the decoder 

+

1125 in order to adapt it) and for any assistive control (for the 'machine' controller to determine where to  

+

1126 drive the plant  

+

1127 ''' 

+

1128 self.goal_calculator = None 

+

1129 

+

1130 def create_feature_extractor(self): 

+

1131 ''' 

+

1132 Create the feature extractor object. The feature extractor takes raw neural data from the streaming processor 

+

1133 (e.g., spike timestamps) and outputs a decodable observation vector (e.g., counts of spikes in last 100ms from each unit) 

+

1134 ''' 

+

1135 from . import extractor 

+

1136 if hasattr(self.decoder, 'extractor_cls') and hasattr(self.decoder, 'extractor_kwargs'): 

+

1137 self.extractor = self.decoder.extractor_cls(self.neurondata, **self.decoder.extractor_kwargs) 

+

1138 else: 

+

1139 # if using an older decoder that doesn't have extractor_cls and  

+

1140 # extractor_kwargs as attributes, then create a BinnedSpikeCountsExtractor by default 

+

1141 self.extractor = extractor.BinnedSpikeCountsExtractor(self.neurondata, 

+

1142 n_subbins=self.decoder.n_subbins, units=self.decoder.units) 

+

1143 

+

1144 self._add_feature_extractor_dtype() 

+

1145 

+

1146 def _add_feature_extractor_dtype(self): 

+

1147 ''' 

+

1148 Helper function to add the datatype of the extractor output to be saved in the HDF file. Uses a separate function  

+

1149 so that simulations can overwrite. 

+

1150 ''' 

+

1151 if isinstance(self.extractor.feature_dtype, tuple): # Feature extractor only returns 1 type 

+

1152 self.add_dtype(*self.extractor.feature_dtype) 

+

1153 else: 

+

1154 for x in self.extractor.feature_dtype: # Feature extractor returns multiple named fields 

+

1155 self.add_dtype(*x) 

+

1156 

+

1157 def create_learner(self): 

+

1158 ''' 

+

1159 The "learner" uses knowledge of the task goals to determine the "intended"  

+

1160 action of the BMI subject and pairs this intention estimation with actual observations. 

+

1161 ''' 

+

1162 from . import clda 

+

1163 self.learn_flag = False 

+

1164 self.learner = clda.DumbLearner() 

+

1165 

+

1166 def create_updater(self): 

+

1167 ''' 

+

1168 The "updater" uses the output batches of data from the learner and an update rule to  

+

1169 alter the decoder parameters to better match the intention estimates. 

+

1170 ''' 

+

1171 self.updater = None 

+

1172 

+

1173 def call_decoder(self, neural_obs, target_state, **kwargs): 

+

1174 ''' 

+

1175 Run the decoder computations 

+

1176 

+

1177 Parameters 

+

1178 ---------- 

+

1179 neural_obs : object, typically np.array of shape (n_features, n_subbins) 

+

1180 n_features is the number of neural features the decoder is expecting to decode from. 

+

1181 n_subbins is the number of simultaneous observations which will be decoded (typically 1) 

+

1182 target_state : np.array of shape (n_states, 1) 

+

1183 The current optimal state to be in to accomplish the task. In this function call, this getsget_target_BMI_state 

+

1184 used when adapting the decoder using CLDA 

+

1185 kwargs : optional keyword arguments 

+

1186 Optional arguments to CLDA, assist, etc. 

+

1187 ''' 

+

1188 # Get the decoder output 

+

1189 decoder_output, update_flag = self.bmi_system(neural_obs, target_state, self.state, learn_flag=self.learn_flag, **kwargs) 

+

1190 

+

1191 self.task_data['update_bmi'] = int(update_flag) 

+

1192 

+

1193 return decoder_output 

+

1194 

+

1195 def get_features(self): 

+

1196 ''' 

+

1197 Run the feature extractor to get any new features to be decoded. Called by move_plant 

+

1198 ''' 

+

1199 start_time = self.get_time() 

+

1200 return self.extractor(start_time) 

+

1201 

+

1202 def move_plant(self, **kwargs): 

+

1203 ''' 

+

1204 The main functions to retrieve raw observations from the neural data source and convert them to movement of the plant 

+

1205 

+

1206 Parameters 

+

1207 ---------- 

+

1208 **kwargs : optional keyword arguments 

+

1209 optional arguments for the decoder, assist, CLDA, etc. fed to the BMISystem 

+

1210 

+

1211 Returns 

+

1212 ------- 

+

1213 decoder_state : np.mat 

+

1214 (N, 1) vector representing the state decoded by the BMI 

+

1215 ''' 

+

1216 

+

1217 # Run the feature extractor 

+

1218 feature_data = self.get_features() 

+

1219 

+

1220 # Save the "neural features" (e.g., spike counts vector) to HDF file 

+

1221 for key, val in list(feature_data.items()): 

+

1222 self.task_data[key] = val 

+

1223 

+

1224 # Determine the target_state and save to file 

+

1225 current_assist_level = self.get_current_assist_level() 

+

1226 if np.any(current_assist_level > 0) or self.learn_flag: 

+

1227 target_state = self.get_target_BMI_state(self.decoder.states) 

+

1228 else: 

+

1229 target_state = np.ones([self.decoder.n_states, self.decoder.n_subbins]) * np.nan 

+

1230 

+

1231 

+

1232 # Determine the assistive control inputs to the Decoder 

+

1233 if np.any(current_assist_level) > 0: 

+

1234 current_state = self.get_current_state() 

+

1235 

+

1236 if target_state.shape[1] > 1: 

+

1237 assist_kwargs = self.assister(current_state, target_state[:,0].reshape(-1,1), current_assist_level, mode=self.state) 

+

1238 else: 

+

1239 assist_kwargs = self.assister(current_state, target_state, current_assist_level, mode=self.state) 

+

1240 

+

1241 kwargs.update(assist_kwargs) 

+

1242 

+

1243 # Run the decoder 

+

1244 if self.state not in self.static_states: 

+

1245 neural_features = feature_data[self.extractor.feature_type] 

+

1246 

+

1247 tmp = self.call_decoder(neural_features, target_state, **kwargs) 

+

1248 self.task_data['internal_decoder_state'] = tmp 

+

1249 

+

1250 # Drive the plant to the decoded state, if permitted by the constraints of the plant 

+

1251 # If not possible, plant.drive should also take care of setting the decoder's  

+

1252 # state as close as possible to physical reality 

+

1253 self.plant.drive(self.decoder) 

+

1254 try: 

+

1255 self.dec_cnt += 1 

+

1256 except: 

+

1257 self.dec_cnt = 0 

+

1258 

+

1259 self.task_data['decoder_state'] = decoder_state = self.decoder.get_state(shape=(-1,1)) 

+

1260 return decoder_state 

+

1261 

+

1262 def get_current_assist_level(self): 

+

1263 return self.current_assist_level 

+

1264 

+

1265 def get_current_state(self): 

+

1266 ''' 

+

1267 In most cases, the current state of the plant needed for calculating assistive control inputs will be stored in the decoder 

+

1268 ''' 

+

1269 return self.decoder.filt.state.mean 

+

1270 

+

1271 def get_target_BMI_state(self, *args): 

+

1272 ''' 

+

1273 Run the goal calculator to determine what the target state of the task is. 

+

1274 Since this is not a real task, this function must be  

+

1275 overridden in child classes if any of the assist/CLDA functionality is to be used. 

+

1276 ''' 

+

1277 raise NotImplementedError 

+

1278 

+

1279 def _cycle(self): 

+

1280 self.move_plant() 

+

1281 

+

1282 # save loop time to HDF file 

+

1283 self.task_data['loop_time'] = self.iter_time() 

+

1284 super(BMILoop, self)._cycle() 

+

1285 

+

1286 def enable_clda(self): 

+

1287 print("CLDA enabled") 

+

1288 self.learn_flag = True 

+

1289 

+

1290 def disable_clda(self): 

+

1291 print(("CLDA disabled after %d successful trials" % self.calc_state_occurrences('reward'))) 

+

1292 self.learn_flag = False 

+

1293 

+

1294 def cleanup_hdf(self): 

+

1295 ''' 

+

1296 Re-open the HDF file and save any extra task data kept in RAM 

+

1297 ''' 

+

1298 super(BMILoop, self).cleanup_hdf() 

+

1299 log_file = open(os.path.join(os.getenv("HOME"), 'code/bmi3d/log/clda_log'), 'w') 

+

1300 log_file.write(str(self.state) + '\n') 

+

1301 try: 

+

1302 from . import clda 

+

1303 if len(self.bmi_system.param_hist) > 0 and not self.updater is None: 

+

1304 log_file.write('n_updates: %g\n' % len(self.bmi_system.param_hist)) 

+

1305 ignore_none = self.learner.batch_size > 1 

+

1306 log_file.write('Ignoring "None" values: %s\n' % str(ignore_none)) 

+

1307 

+

1308 self.write_clda_data_to_hdf_table( 

+

1309 self.h5file.name, self.bmi_system.param_hist, 

+

1310 ignore_none=ignore_none) 

+

1311 except: 

+

1312 import traceback 

+

1313 traceback.print_exc(file=log_file) 

+

1314 log_file.close() 

+

1315 

+

1316 @staticmethod 

+

1317 def write_clda_data_to_hdf_table(hdf_fname, data, ignore_none=False): 

+

1318 ''' 

+

1319 Save CLDA data generated during the experiment to the specified HDF file 

+

1320 

+

1321 Parameters 

+

1322 ---------- 

+

1323 hdf_fname : string 

+

1324 filename of HDF file 

+

1325 data : list 

+

1326 list of dictionaries with the same keys and same dtypes for values 

+

1327 

+

1328 Returns 

+

1329 ------- 

+

1330 None 

+

1331 ''' 

+

1332 log_file = open(os.path.expandvars('$HOME/code/bmi3d/log/clda_hdf_log'), 'w') 

+

1333 

+

1334 compfilt = tables.Filters(complevel=5, complib="zlib", shuffle=True) 

+

1335 if len(data) > 0: 

+

1336 # Find the first parameter update dictionary 

+

1337 k = 0 

+

1338 first_update = data[k] 

+

1339 while first_update is None: 

+

1340 k += 1 

+

1341 first_update = data[k] 

+

1342 

+

1343 table_col_names = list(first_update.keys()) 

+

1344 print(table_col_names) 

+

1345 dtype = [] 

+

1346 shapes = [] 

+

1347 for col_name in table_col_names: 

+

1348 if isinstance(first_update[col_name], float): 

+

1349 shape = (1,) 

+

1350 else: 

+

1351 shape = first_update[col_name].shape 

+

1352 dtype.append((col_name.replace('.', '_'), 'f8', shape)) 

+

1353 shapes.append(shape) 

+

1354 

+

1355 log_file.write(str(dtype)) 

+

1356 # Create the HDF table with the datatype above 

+

1357 dtype = np.dtype(dtype) 

+

1358 

+

1359 h5file = tables.openFile(hdf_fname, mode='a') 

+

1360 arr = h5file.createTable("/", 'clda', dtype, filters=compfilt) 

+

1361 

+

1362 null_update = np.zeros((1,), dtype=dtype) 

+

1363 for col_name in table_col_names: 

+

1364 null_update[col_name.replace('.', '_')] *= np.nan 

+

1365 

+

1366 for k, param_update in enumerate(data): 

+

1367 log_file.write('%d, %s\n' % (k, str(ignore_none))) 

+

1368 if param_update == None: 

+

1369 if ignore_none: 

+

1370 continue 

+

1371 else: 

+

1372 data_row = null_update 

+

1373 else: 

+

1374 data_row = np.zeros((1,), dtype=dtype) 

+

1375 for col_name in table_col_names: 

+

1376 data_row[col_name.replace('.', '_')] = np.asarray(param_update[col_name]) 

+

1377 

+

1378 arr.append(data_row) 

+

1379 h5file.close() 

+

1380 

+

1381 def cleanup(self, database, saveid, **kwargs): 

+

1382 super(BMILoop, self).cleanup(database, saveid, **kwargs) 

+

1383 

+

1384 # Resave decoder with drift-parameter saved as prev_task_drift_corr: 

+

1385 if hasattr(self.decoder.filt, 'drift_corr'): 

+

1386 print(('saving decoder: ', self.decoder.filt.drift_corr, self.decoder.filt.prev_drift_corr)) 

+

1387 decoder_name = self.decoder.name + '_d'+str(saveid) 

+

1388 decoder_tempfilename = self.decoder.save() 

+

1389 

+

1390 # Link the pickled decoder file to the associated task entry in the database 

+

1391 dbname = kwargs['dbname'] if 'dbname' in kwargs else 'default' 

+

1392 if dbname == 'default': 

+

1393 database.save_bmi(decoder_name, saveid, decoder_tempfilename) 

+

1394 else: 

+

1395 database.save_bmi(decoder_name, saveid, decoder_tempfilename, dbname=dbname) 

+

1396 

+

1397 # Open a log file in case of error b/c errors not visible to console 

+

1398 # at this point 

+

1399 from config import config 

+

1400 f = open(os.path.join(config.log_path, 'clda_cleanup_log'), 'w') 

+

1401 f.write('Opening log file\n') 

+

1402 

+

1403 f.write('# of paramter updates: %d\n' % len(self.bmi_system.param_hist)) 

+

1404 

+

1405 # save out the parameter history and new decoder unless task was stopped 

+

1406 # before 1st update 

+

1407 try: 

+

1408 if len(self.bmi_system.param_hist) > 0 and not self.updater is None: 

+

1409 # create name for new decoder  

+

1410 now = datetime.datetime.now() 

+

1411 decoder_name = self.decoder_sequence + now.strftime('%m%d%H%M') 

+

1412 

+

1413 # Pickle the decoder 

+

1414 decoder_tempfilename = self.decoder.save() 

+

1415 

+

1416 # Link the pickled decoder file to the associated task entry in the database 

+

1417 dbname = kwargs['dbname'] if 'dbname' in kwargs else 'default' 

+

1418 if dbname == 'default': 

+

1419 database.save_bmi(decoder_name, saveid, decoder_tempfilename) 

+

1420 else: 

+

1421 database.save_bmi(decoder_name, saveid, decoder_tempfilename, dbname=dbname) 

+

1422 except: 

+

1423 traceback.print_exc(file=f) 

+

1424 f.close() 

+

1425 

+

1426 

+

1427class BMI(object): 

+

1428 ''' 

+

1429 Legacy class, used only for unpickling super old Decoder objects. Ignore completely. 

+

1430 ''' 

+

1431 pass 

+

1432 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_clda_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_clda_py.html new file mode 100644 index 00000000..6225fbab --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_clda_py.html @@ -0,0 +1,1182 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\clda.py: 16% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Closed-loop decoder adaptation (CLDA) classes. There are two types of classes, 

+

3"Learners" and "Updaters". Learners implement various methods to estimate the 

+

4"intended" BMI movements of the user. Updaters implement various method for  

+

5updating  

+

6''' 

+

7import multiprocessing as mp 

+

8import numpy as np 

+

9from . import kfdecoder, ppfdecoder, train, bmi, feedback_controllers 

+

10import time 

+

11import cmath 

+

12 

+

13import tables 

+

14import re 

+

15from . import assist 

+

16import os 

+

17import scipy 

+

18import copy 

+

19 

+

20from utils.angle_utils import * 

+

21 

+

22inv = np.linalg.inv 

+

23 

+

24try: 

+

25 from numpy.linalg import lapack_lite 

+

26 lapack_routine = lapack_lite.dgesv 

+

27except: 

+

28 pass 

+

29 

+

30def fast_inv(A): 

+

31 ''' 

+

32 This method represents a way to speed up matrix inverse computations when  

+

33 several independent matrix inverses of all the same shape must be taken all 

+

34 at once. This is used by the PPFContinuousBayesianUpdater. Without this method, 

+

35 the updates could not be performed in real time with ~30 cells (compute complexity  

+

36 is linear in the number of units, so it is possible that fewer units would not 

+

37 have had this issue). 

+

38 

+

39 Code stolen from:  

+

40 # http://stackoverflow.com/questions/11972102/is-there-a-way-to-efficiently-invert-an-array-of-matrices-with-numpy  

+

41 ''' 

+

42 b = np.identity(A.shape[2], dtype=A.dtype) 

+

43 

+

44 n_eq = A.shape[1] 

+

45 n_rhs = A.shape[2] 

+

46 pivots = np.zeros(n_eq, np.intc) 

+

47 identity = np.eye(n_eq) 

+

48 def lapack_inverse(a): 

+

49 b = np.copy(identity) 

+

50 pivots = np.zeros(n_eq, np.intc) 

+

51 results = lapack_lite.dgesv(n_eq, n_rhs, a, n_eq, pivots, b, n_eq, 0) 

+

52 if results['info'] > 0: 

+

53 raise np.LinAlgError('Singular matrix') 

+

54 return b 

+

55 

+

56 return np.array([lapack_inverse(a) for a in A]) 

+

57 

+

58def slow_inv(A): 

+

59 return np.array([np.linalg.inv(a) for a in A]) 

+

60 

+

61############################################################################## 

+

62## Learners 

+

63############################################################################## 

+

64class Learner(object): 

+

65 ''' 

+

66 Classes for estimating the 'intention' of the BMI operator, inferring the intention from task goals. 

+

67 ''' 

+

68 def __init__(self, batch_size, *args, **kwargs): 

+

69 ''' 

+

70 Instantiate a Learner for estimating intention during CLDA 

+

71 

+

72 Parameters 

+

73 ---------- 

+

74 batch_size: int 

+

75 number of samples used to estimate each new decoder parameter setting 

+

76 done_states: list of strings, optional 

+

77 states of the task which end a batch, regardless of the length of the batch. default = [] 

+

78 reset_states: list of strings, optional 

+

79 states of the task which, if encountered, reset the batch regardless of its length. default = [] 

+

80 

+

81 ''' 

+

82 self.done_states = kwargs.pop('done_states', []) 

+

83 self.reset_states = kwargs.pop('reset_states', []) 

+

84 print("Reset states for learner: ") 

+

85 print(self.reset_states) 

+

86 print("Done states for learner: ") 

+

87 print(self.done_states) 

+

88 self.batch_size = batch_size 

+

89 self.passed_done_state = False 

+

90 self.enabled = True 

+

91 self.input_state_index = -1 

+

92 self.reset() 

+

93 

+

94 def disable(self): 

+

95 '''Set a flag to disable forming intention estimates from new incoming data''' 

+

96 self.enabled = False 

+

97 

+

98 def enable(self): 

+

99 '''Set a flag to enable forming intention estimates from new incoming data''' 

+

100 self.enabled = True 

+

101 

+

102 def reset(self): 

+

103 '''Reset the lists of saved intention estimates and corresponding neural data''' 

+

104 self.kindata = [] 

+

105 self.neuraldata = [] 

+

106 self.obs_value = [] 

+

107 

+

108 def __call__(self, spike_counts, decoder_state, target_state, decoder_output, task_state, state_order=None, **kwargs): 

+

109 """ 

+

110 Calculate the intended kinematics and pair with the neural data 

+

111 

+

112 Parameters 

+

113 ---------- 

+

114 spike_counts : np.mat of shape (K, 1) 

+

115 Neural observations used to decode 'decoder_state' 

+

116 decoder_state : np.mat of shape (N, 1) 

+

117 State estimate output from the decoder 

+

118 target_state : np.mat of shape (N, 1) 

+

119 For the current time, this is the optimal state for the Decoder as specified by the task 

+

120 decoder_output : np.mat of shape (N, 1) 

+

121 ... this seems like the same as decoder_state 

+

122 task_state : string 

+

123 Name of the task state; some learners (e.g., the cursorGoal learner) have different intention estimates depending on the phase of the task/trial 

+

124 state_order : np.ndarray of shape (N,), optional 

+

125 Order of each state in the decoder; see riglib.bmi.state_space_models.State 

+

126 **kwargs: dict 

+

127 Optional keyword arguments for the 'value' calculator 

+

128 

+

129 Returns 

+

130 ------- 

+

131 None 

+

132 """ 

+

133 if task_state in self.reset_states: 

+

134 print("resetting CLDA batch") 

+

135 self.reset() 

+

136 

+

137 int_kin = self.calc_int_kin(decoder_state, target_state, decoder_output, task_state, state_order=state_order) 

+

138 obs_value = self.calc_value(decoder_state, target_state, decoder_output, task_state, state_order=state_order, **kwargs) 

+

139 

+

140 if self.passed_done_state and self.enabled: 

+

141 if task_state in ['hold', 'target']: 

+

142 self.passed_done_state = False 

+

143 

+

144 if self.enabled and not self.passed_done_state and int_kin is not None: 

+

145 self.kindata.append(int_kin) 

+

146 self.neuraldata.append(spike_counts) 

+

147 self.obs_value.append(obs_value) 

+

148 

+

149 if task_state in self.done_states: 

+

150 self.passed_done_state = True 

+

151 

+

152 def calc_value(self, *args, **kwargs): 

+

153 ''' 

+

154 Calculate a "value", i.e. a usefulness, for a particular observation.  

+

155 Can override in child classe for RL-style updates, but a priori all observations are equally informative 

+

156 ''' 

+

157 return 1. 

+

158 

+

159 def is_ready(self): 

+

160 ''' 

+

161 Returns True if the collected estimates of the subject's intention are ready for processing into new decoder parameters 

+

162 ''' 

+

163 _is_ready = len(self.kindata) >= self.batch_size or ((len(self.kindata) > 0) and self.passed_done_state) 

+

164 return _is_ready 

+

165 

+

166 def get_batch(self): 

+

167 ''' 

+

168 Returns all the data from the last 'batch' of obserations of intended kinematics and neural decoder inputs 

+

169 ''' 

+

170 kindata = np.hstack(self.kindata) 

+

171 neuraldata = np.hstack(self.neuraldata) 

+

172 self.reset() 

+

173 return dict(intended_kin=kindata, spike_counts=neuraldata) 

+

174 # return kindata, neuraldata 

+

175 

+

176class DumbLearner(Learner): 

+

177 ''' 

+

178 A learner that never learns anything. Used to make non-adaptive BMI tasks interface the same as CLDA tasks. 

+

179 ''' 

+

180 def __init__(self, *args, **kwargs): 

+

181 ''' 

+

182 Constructor for DumbLearner 

+

183 

+

184 Parameters 

+

185 ---------- 

+

186 args, kwargs: positional and keyword arguments 

+

187 Ignored, none are needed 

+

188 

+

189 Returns 

+

190 ------- 

+

191 DumbLearner instance 

+

192 ''' 

+

193 self.enabled = False 

+

194 self.input_state_index = 0 

+

195 

+

196 def __call__(self, *args, **kwargs): 

+

197 """ 

+

198 Do nothing; hence the name of the class 

+

199 

+

200 Parameters 

+

201 ---------- 

+

202 args, kwargs: positional and keyword arguments 

+

203 Ignored, none are needed 

+

204 

+

205 Returns 

+

206 ------- 

+

207 None 

+

208 """ 

+

209 pass 

+

210 

+

211 def is_ready(self): 

+

212 '''DumbLearner is never ready to tell you what it learnerd''' 

+

213 return False 

+

214 

+

215 def get_batch(self): 

+

216 '''DumbLearner never has any 'batch' data to retrieve''' 

+

217 raise NotImplementedError 

+

218 

+

219class FeedbackControllerLearner(Learner): 

+

220 ''' 

+

221 An intention estimator where the subject is assumed to operate like a state feedback controller 

+

222 ''' 

+

223 def __init__(self, batch_size, fb_ctrl, *args, **kwargs): 

+

224 self.fb_ctrl = fb_ctrl 

+

225 self.style = kwargs.pop('style', 'mixing') 

+

226 super(FeedbackControllerLearner, self).__init__(batch_size, *args, **kwargs) 

+

227 

+

228 def calc_int_kin(self, current_state, target_state, decoder_output, task_state, state_order=None): 

+

229 """ 

+

230 Used by __call__ to figure out the next state vector to pair to the neural activity in the batch. 

+

231 

+

232 Parameters 

+

233 ---------- 

+

234 [same OFCLearner.calc_int_kin] 

+

235 current_state : np.mat of shape (N, 1) 

+

236 State estimate output from the decoder. 

+

237 target_state : np.mat of shape (N, 1) 

+

238 For the current time, this is the optimal state for the Decoder as specified by the task 

+

239 decoder_output : np.mat of shape (N, 1) 

+

240 State estimate output from the decoder, after the current observations (may be one step removed from 'current_state') 

+

241 task_state : string 

+

242 Name of the task state; some learners (e.g., the cursorGoal learner) have different intention estimates depending on the phase of the task/trial 

+

243 state_order : np.ndarray of shape (N,), optional 

+

244 Order of each state in the decoder; see riglib.bmi.state_space_models.State 

+

245 

+

246 Returns 

+

247 ------- 

+

248 np.mat of shape (N, 1) 

+

249 Optimal next state to pair to neural activity 

+

250 """ 

+

251 try: 

+

252 if self.style == 'additive': 

+

253 output = self.fb_ctrl(current_state, target_state, mode=task_state) 

+

254 elif self.style == 'mixing': 

+

255 output = self.fb_ctrl.calc_next_state(current_state, target_state, mode=task_state) 

+

256 return output 

+

257 except: 

+

258 # Key errors happen when the feedback controller doesn't have a policy for the current task state 

+

259 import traceback 

+

260 traceback.print_exc() 

+

261 return None 

+

262 

+

263class OFCLearner(Learner): 

+

264 ''' 

+

265 An intention estimator where the subject is assumed to operate like a muiti-modal LQR controller 

+

266 ''' 

+

267 def __init__(self, batch_size, A, B, F_dict, *args, **kwargs): 

+

268 ''' 

+

269 Constructor for OFCLearner 

+

270 

+

271 Parameters 

+

272 ---------- 

+

273 batch_size : int 

+

274 size of batch of samples to pass to the Updater to estimate new decoder parameters 

+

275 A : np.mat  

+

276 State transition matrix of the modeled discrete-time system 

+

277 B : np.mat  

+

278 Control input matrix of the modeled discrete-time system 

+

279 F_dict : dict 

+

280 Keys match names of task states, values are feedback matrices (size n_inputs x n_states) 

+

281 *args : additional comma-separated args 

+

282 Passed to super constructor 

+

283 **kwargs : additional keyword args 

+

284 Passed to super constructor 

+

285 

+

286 Returns 

+

287 ------- 

+

288 OFCLearner instance 

+

289 ''' 

+

290 super(OFCLearner, self).__init__(batch_size, *args, **kwargs) 

+

291 self.B = B 

+

292 self.F_dict = F_dict 

+

293 self.A = A 

+

294 

+

295 def calc_int_kin(self, current_state, target_state, decoder_output, task_state, state_order=None): 

+

296 ''' 

+

297 Calculate intended kinematics as  

+

298 x_t^{int} = A*x_t + B*F(x^* - x_t) 

+

299 

+

300 Parameters 

+

301 ---------- 

+

302 [same FeedbackControllerLearner.calc_int_kin] 

+

303 current_state : np.mat of shape (N, 1) 

+

304 State estimate output from the decoder. 

+

305 target_state : np.mat of shape (N, 1) 

+

306 For the current time, this is the optimal state for the Decoder as specified by the task 

+

307 decoder_output : np.mat of shape (N, 1) 

+

308 State estimate output from the decoder, after the current observations (may be one step removed from 'current_state') 

+

309 task_state : string 

+

310 Name of the task state; some learners (e.g., the cursorGoal learner) have different intention estimates depending on the phase of the task/trial 

+

311 state_order : np.ndarray of shape (N,), optional 

+

312 Order of each state in the decoder; see riglib.bmi.state_space_models.State 

+

313 

+

314 Returns 

+

315 ------- 

+

316 np.mat of shape (N, 1) 

+

317 Estimate of intended next state for BMI 

+

318 ''' 

+

319 try: 

+

320 current_state = np.mat(current_state).reshape(-1,1) 

+

321 target_state = np.mat(target_state).reshape(-1,1) 

+

322 F = self.F_dict[task_state] 

+

323 A = self.A 

+

324 B = self.B 

+

325 

+

326 return A*current_state + B*F*(target_state - current_state) 

+

327 except KeyError: 

+

328 return None 

+

329 

+

330class RegexKeyDict(dict): 

+

331 ''' 

+

332 Dictionary where key matching applies regular expressions in addition to exact matches 

+

333 ''' 

+

334 def __getitem__(self, key): 

+

335 ''' 

+

336 Lookup key in dictionary by finding exactly one dict key which, by regex, matches the input argument 'key' 

+

337 ''' 

+

338 keys = list(self.keys()) 

+

339 matching_keys = [x for x in keys if re.match(x, key)] 

+

340 if len(matching_keys) == 0: 

+

341 raise KeyError("No matching keys were found!") 

+

342 elif len(matching_keys) > 1: 

+

343 raise ValueError("Multiple keys match!") 

+

344 else: 

+

345 return super(RegexKeyDict, self).__getitem__(matching_keys[0]) 

+

346 

+

347 def __contains__(self, key): 

+

348 ''' 

+

349 Determine if a key is in the dictionary using regular expression matching 

+

350 ''' 

+

351 keys = list(self.keys()) 

+

352 matching_keys = [x for x in keys if re.match(x, key)] 

+

353 if len(matching_keys) == 0: 

+

354 return False 

+

355 elif len(matching_keys) > 1: 

+

356 raise ValueError("Multiple keys match!") 

+

357 else: 

+

358 return True 

+

359 

+

360############################################################################## 

+

361## Updaters 

+

362############################################################################## 

+

363from riglib.mp_calc import MPCompute 

+

364class Updater(object): 

+

365 ''' 

+

366 Wrapper for MPCompute computations running in another process 

+

367 ''' 

+

368 def __init__(self, fn, multiproc=False, verbose=False): 

+

369 self.verbose = verbose 

+

370 self.multiproc = multiproc 

+

371 if self.multiproc: 

+

372 # create the queues 

+

373 self.work_queue = mp.Queue() 

+

374 self.result_queue = mp.Queue() 

+

375 

+

376 # Instantiate the process 

+

377 self.calculator = MPCompute(self.work_queue, self.result_queue, fn) 

+

378 

+

379 # spawn the process 

+

380 self.calculator.start() 

+

381 else: 

+

382 self.fn = fn 

+

383 

+

384 self._result = None 

+

385 self.waiting = False 

+

386 

+

387 def init(self, decoder): 

+

388 pass 

+

389 

+

390 def __call__(self, *args, **kwargs): 

+

391 input_data = (args, kwargs) 

+

392 if self.multiproc: 

+

393 if self.verbose: print("queuing job") 

+

394 self.work_queue.put(input_data) 

+

395 self.prev_input = input_data 

+

396 self.waiting = True 

+

397 else: 

+

398 self._result = self.fn(*args, **kwargs) 

+

399 

+

400 def get_result(self): 

+

401 if self.multiproc: 

+

402 try: 

+

403 output_data = self.result_queue.get_nowait() 

+

404 self.prev_result = output_data 

+

405 self.waiting = False 

+

406 return output_data 

+

407 except Queue.Empty: 

+

408 return None 

+

409 except: 

+

410 import traceback 

+

411 traceback.print_exc() 

+

412 else: 

+

413 if self._result is not None: 

+

414 res = self._result 

+

415 else: 

+

416 res = None 

+

417 self._result = None 

+

418 return res 

+

419 

+

420 def __del__(self): 

+

421 ''' 

+

422 Stop the child process if one was spawned 

+

423 ''' 

+

424 if self.multiproc: 

+

425 self.calculator.stop() 

+

426 

+

427class PPFContinuousBayesianUpdater(Updater): 

+

428 ''' 

+

429 Adapt the parameters of a PPFDecoder using an HMM to implement a gradient-descent type parameter update. 

+

430 

+

431 (currently only works for PPFs which do not also include the self-history or correlational elements) 

+

432 

+

433 See Shanechi and Carmena, "Optimal feedback-controlled point process decoder for  

+

434 adaptation and assisted training in brain-machine interfaces", IEEE EMBC, 2014 

+

435 for mathematical details 

+

436 ''' 

+

437 update_kwargs = dict() 

+

438 def __init__(self, decoder, units='cm', param_noise_scale=1., param_noise_variances=None): 

+

439 ''' 

+

440 Constructor for PPFContinuousBayesianUpdater 

+

441 

+

442 Parameters 

+

443 ---------- 

+

444 decoder : bmi.ppfdecoder.PPFDecoder instance 

+

445 Should have a 'filt' attribute which is a PointProcessFilter instance 

+

446 units : string 

+

447 Docstring 

+

448 param_noise_scale : float 

+

449 Multiplicative factor to increase the parameter "process noise". Higher values result in faster but less stable parameter convergence. 

+

450 ''' 

+

451 super(PPFContinuousBayesianUpdater, self).__init__(self.calc, multiproc=False) 

+

452 

+

453 self.n_units = decoder.filt.C.shape[0] 

+

454 if param_noise_variances == None: 

+

455 if units == 'm': 

+

456 vel_gain = 1e-4 

+

457 elif units == 'cm': 

+

458 vel_gain = 1e-8 

+

459 

+

460 print("Updater param noise scale %g" % param_noise_scale) 

+

461 vel_gain *= param_noise_scale 

+

462 param_noise_variances = np.array([vel_gain*0.13, vel_gain*0.13, 1e-4*0.06/50]) 

+

463 self.W = np.tile(np.diag(param_noise_variances), [self.n_units, 1, 1]) 

+

464 

+

465 

+

466 self.P_params_est = self.W.copy() 

+

467 

+

468 self.neuron_driving_state_inds = np.nonzero(decoder.drives_neurons)[0] 

+

469 self.neuron_driving_states = list(np.take(decoder.states, np.nonzero(decoder.drives_neurons)[0])) 

+

470 self.n_states = len(decoder.states) 

+

471 self.full_size = len(decoder.states) 

+

472 

+

473 self.dt = decoder.filt.dt 

+

474 self.beta_est = np.array(decoder.filt.C) 

+

475 

+

476 def calc(self, intended_kin=None, spike_counts=None, decoder=None, **kwargs): 

+

477 ''' Docstring ''' 

+

478 

+

479 if (intended_kin is None) or (spike_counts is None) or (decoder is None): 

+

480 raise ValueError("must specify intended_kin, spike_counts and decoder objects for the updater to work!") 

+

481 

+

482 if 0: 

+

483 print(np.array(intended_kin).ravel()) 

+

484 

+

485 int_kin_full = intended_kin 

+

486 spike_obs_full = spike_counts 

+

487 n_samples = int_kin_full.shape[1] 

+

488 

+

489 # Squash any observed spike counts which are greater than 1 

+

490 spike_obs_full[spike_obs_full > 1] = 1 

+

491 for k in range(n_samples): 

+

492 spike_obs = spike_obs_full[:,k] 

+

493 int_kin = int_kin_full[:,k] 

+

494 

+

495 beta_est = self.beta_est[:,self.neuron_driving_state_inds] 

+

496 int_kin = np.asarray(int_kin).ravel()[self.neuron_driving_state_inds] 

+

497 Loglambda_predict = np.dot(int_kin, beta_est.T) 

+

498 rates = np.exp(Loglambda_predict) 

+

499 if np.any(rates > 1): 

+

500 print('rates > 1!') 

+

501 rates[rates > 1] = 1 

+

502 unpred_spikes = np.asarray(spike_obs).ravel() - rates 

+

503 

+

504 C_xpose_C = np.outer(int_kin, int_kin) 

+

505 

+

506 self.P_params_est += self.W 

+

507 try: 

+

508 P_params_est_inv = fast_inv(self.P_params_est) 

+

509 except: 

+

510 P_params_est_inv = slow_inv(self.P_params_est) 

+

511 L = np.dstack([rates[c] * C_xpose_C for c in range(self.n_units)]).transpose([2,0,1]) 

+

512 

+

513 try: 

+

514 self.P_params_est = fast_inv(P_params_est_inv + L) 

+

515 except: 

+

516 self.P_params_est = slow_inv(P_params_est_inv + L) 

+

517 

+

518 beta_est += (unpred_spikes * np.dot(int_kin, self.P_params_est).T).T 

+

519 

+

520 # store beta_est 

+

521 self.beta_est[:, self.neuron_driving_state_inds] = beta_est 

+

522 

+

523 return {'filt.C': np.mat(self.beta_est.copy())} 

+

524 

+

525class KFRML(Updater): 

+

526 ''' 

+

527 Calculate updates for KF parameters using the recursive maximum likelihood (RML) method 

+

528 See (Dangi et al, Neural Computation, 2014) for mathematical details. 

+

529 ''' 

+

530 update_kwargs = dict(steady_state=False) 

+

531 def __init__(self, batch_time, half_life, adapt_C_xpose_Q_inv_C=True, regularizer=None): 

+

532 ''' 

+

533 Constructor for KFRML 

+

534 

+

535 Parameters 

+

536 ---------- 

+

537 batch_time : float 

+

538 Size of data batch to use for each update. Specify in seconds. 

+

539 half_life : float  

+

540 Amount of time (in seconds) before parameters are half-overwritten by new data. 

+

541 adapt_C_xpose_Q_inv_C : bool 

+

542 Flag specifying whether to update the decoder property C^T Q^{-1} C, which  

+

543 defines the feedback dynamics of the final closed-loop system if A and W are known 

+

544 regularizer: float 

+

545 Defines lambda regularizer to use in calculation of C matrix : C = (X*X.T + lambda*eye).I * (X*Y) 

+

546 

+

547 Returns 

+

548 ------- 

+

549 KFRML instance 

+

550 ''' 

+

551 super(KFRML, self).__init__(self.calc, multiproc=False) 

+

552 self.batch_time = batch_time 

+

553 self.half_life = half_life 

+

554 self.rho = np.exp(np.log(0.5) / (self.half_life/batch_time)) 

+

555 self.adapt_C_xpose_Q_inv_C = adapt_C_xpose_Q_inv_C 

+

556 self.regularizer = regularizer 

+

557 self._new_params = None 

+

558 

+

559 @staticmethod 

+

560 def compute_suff_stats(hidden_state, obs, include_offset=True): 

+

561 ''' 

+

562 Calculate initial estimates of the parameter sufficient statistics used in the RML update rules 

+

563 

+

564 Parameters 

+

565 ---------- 

+

566 hidden_state : np.ndarray of shape (n_states, n_samples) 

+

567 Examples of the hidden state x_t taken from training seed data.  

+

568 obs : np.ndarray of shape (n_features, n_samples) 

+

569 Multiple neural observations paired with each of the hidden state examples 

+

570 include_offset : bool, optional 

+

571 If true, a state of all 1's is added to the hidden_state to represent mean offsets. True by default 

+

572 

+

573 Returns 

+

574 ------- 

+

575 R : np.ndarray of shape (n_states, n_states) 

+

576 Proportional to covariance of the hidden state samples  

+

577 S : np.ndarray of shape (n_features, n_states) 

+

578 Proportional to cross-covariance between  

+

579 T : np.ndarray of shape (n_features, n_features) 

+

580 Proportional to covariance of the neural observations 

+

581 ESS : float 

+

582 Effective number of samples. In the initialization, this is just the  

+

583 dimension of the array passed in, but the parameter can become non-integer  

+

584 during the update procedure as old parameters are "forgotten". 

+

585 ''' 

+

586 assert hidden_state.shape[1] == obs.shape[1] 

+

587 

+

588 if isinstance(hidden_state, np.ma.core.MaskedArray): 

+

589 mask = ~hidden_state.mask[0,:] # NOTE THE INVERTER  

+

590 inds = np.nonzero([ mask[k]*mask[k+1] for k in range(len(mask)-1)])[0] 

+

591 

+

592 X = np.mat(hidden_state[:,mask]) 

+

593 n_pts = len(np.nonzero(mask)[0]) 

+

594 

+

595 Y = np.mat(obs[:,mask]) 

+

596 if include_offset: 

+

597 X = np.vstack([ X, np.ones([1,n_pts]) ]) 

+

598 else: 

+

599 num_hidden_state, n_pts = hidden_state.shape 

+

600 X = np.mat(hidden_state) 

+

601 if include_offset: 

+

602 X = np.vstack([ X, np.ones([1,n_pts]) ]) 

+

603 Y = np.mat(obs) 

+

604 X = np.mat(X, dtype=np.float64) 

+

605 

+

606 R = (X * X.T) 

+

607 S = (Y * X.T) 

+

608 T = (Y * Y.T) 

+

609 ESS = n_pts 

+

610 

+

611 return (R, S, T, ESS) 

+

612 

+

613 def init(self, decoder): 

+

614 ''' 

+

615 Retrieve sufficient statistics from the seed decoder. 

+

616 

+

617 Parameters 

+

618 ---------- 

+

619 decoder : bmi.Decoder instance 

+

620 The seed decoder before any adaptation runs. 

+

621 

+

622 Returns 

+

623 ------- 

+

624 None 

+

625 ''' 

+

626 self.R = decoder.filt.R 

+

627 self.S = decoder.filt.S 

+

628 self.T = decoder.filt.T 

+

629 self.ESS = decoder.filt.ESS 

+

630 

+

631 

+

632 #Neural indices that will be adapted / stable are defined here: 

+

633 self.feature_inds = np.arange(decoder.n_features) 

+

634 

+

635 # Units that you want to stay stable 

+

636 self.stable_inds = [] 

+

637 

+

638 # By default, tuning parameters for all features will adapt 

+

639 if hasattr(decoder, 'adapting_neur_inds'): 

+

640 self.set_stable_inds(None, adapting_inds=decoder.adapting_neur_inds) 

+

641 else: 

+

642 self.adapting_inds = self.feature_inds.copy() 

+

643 self.stable_inds_independent = False 

+

644 

+

645 self.adapting_inds_mesh = np.ix_(self.adapting_inds, self.adapting_inds) 

+

646 

+

647 #Are stable units independent from other units ? If yes Q[stable_unit, other_units] = 0 

+

648 #State space indices that will be adapted:  

+

649 self.state_inds = np.arange(len(decoder.states)) 

+

650 if hasattr(decoder, 'adapting_state_inds'): 

+

651 if type(decoder.adapting_state_inds) is not list: 

+

652 ad = [i for i, j in enumerate(decoder.states) if j in decoder.adapting_state_inds.state_names] 

+

653 else: 

+

654 ad = decoder.adapting_state_inds 

+

655 self.set_stable_states(None, adapting_state_inds=ad) 

+

656 else: 

+

657 self.state_adapting_inds = np.arange(decoder.n_states) 

+

658 

+

659 self.neur_by_state_adapting_inds_mesh = np.ix_(self.adapting_inds, self.state_adapting_inds) 

+

660 

+

661 

+

662 if hasattr(decoder, 'adapt_mFR_stats'): 

+

663 print('setitng adapting mFR. updater', decoder.adapt_mFR_stats) 

+

664 self.adapt_mFR_stats = decoder.adapt_mFR_stats 

+

665 else: 

+

666 self.adapt_mFR_stats = False 

+

667 

+

668 def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=None, values=None, **kwargs): 

+

669 ''' 

+

670 Parameters 

+

671 ---------- 

+

672 intended_kin : np.ndarray of shape (n_states, batch_size) 

+

673 Batch of estimates of intended kinematics, from the learner 

+

674 spike_counts : np.ndarray of shape (n_features, batch_size) 

+

675 Batch of observations of decoder features, from the learner 

+

676 decoder : bmi.Decoder instance 

+

677 Reference to the Decoder instance 

+

678 half_life : float, optional 

+

679 Half-life to use to calculate the parameter change step size. If not specified, the half-life specified when the Updater was constructed is used. 

+

680 values : np.ndarray, optional 

+

681 Relative value of each sample of the batch. If not specified, each sample is assumed to have equal value. 

+

682 kwargs : dict 

+

683 Optional keyword arguments, ignored 

+

684 

+

685 Returns 

+

686 ------- 

+

687 new_params : dict 

+

688 New parameters to feed back to the Decoder in use by the task. 

+

689 ''' 

+

690 if intended_kin is None or spike_counts is None or decoder is None: 

+

691 raise ValueError("must specify intended_kin, spike_counts and decoder objects for the updater to work!") 

+

692 

+

693 # Calculate the step size based on the half life and the number of samples to train from 

+

694 batch_size = intended_kin.shape[1] 

+

695 batch_time = batch_size * decoder.binlen 

+

696 

+

697 if half_life is not None: 

+

698 rho = np.exp(np.log(0.5)/(half_life/batch_time)) 

+

699 else: 

+

700 rho = self.rho 

+

701 

+

702 #update driver of neurons 

+

703 try: 

+

704 drives_neurons = decoder.drives_neurons.copy() 

+

705 mFR_old = decoder.mFR.copy() 

+

706 sdFR_old = decoder.sdFR.copy() 

+

707 except: 

+

708 drives_neurons = decoder.drives_neurons 

+

709 mFR_old = decoder.mFR 

+

710 sdFR_old = decoder.sdFR 

+

711 

+

712 x = np.mat(intended_kin) 

+

713 y = np.mat(spike_counts) 

+

714 #limit x to the indices that can adapt: 

+

715 #x = x[self.state_adapting_inds, :] 

+

716 

+

717 # limit y to the features which are permitted to adapt 

+

718 #y = y[self.adapting_inds, :] 

+

719 

+

720 if values is not None: 

+

721 n_samples = np.sum(values) 

+

722 B = np.mat(np.diag(values)) 

+

723 else: 

+

724 n_samples = spike_counts.shape[1] 

+

725 B = np.mat(np.eye(n_samples)) 

+

726 

+

727 if self.adapt_C_xpose_Q_inv_C: 

+

728 #self.R[self.state_adapting_inds_mesh] = rho*self.R[self.state_adapting_inds_mesh] + (x*B*x.T) 

+

729 self.R = rho*self.R + (x*B*x.T) 

+

730 

+

731 if np.any(np.isnan(self.R)): 

+

732 print('np.nan in self.R in riglib/bmi/clda.py!') 

+

733 

+

734 #self.S[self.neur_by_state_adapting_inds_mesh] = rho*self.S[self.neur_by_state_adapting_inds_mesh] + (y*B*x.T) 

+

735 #self.T[self.adapting_inds_mesh] = rho*self.T[self.adapting_inds_mesh] + np.dot(y, B*y.T) 

+

736 

+

737 self.S[:, decoder.drives_neurons] = rho*self.S[:, decoder.drives_neurons] + (y*B*x[decoder.drives_neurons, :].T) 

+

738 self.T = rho*self.T + np.dot(y, B*y.T) 

+

739 self.ESS = rho*self.ESS + n_samples 

+

740 

+

741 R_inv = np.mat(np.zeros(self.R.shape)) 

+

742 

+

743 try: 

+

744 if self.regularizer is None: 

+

745 R_inv[np.ix_(drives_neurons, drives_neurons)] = np.linalg.pinv(self.R[np.ix_(drives_neurons, drives_neurons)]) 

+

746 else: 

+

747 dn = np.sum(drives_neurons) 

+

748 R_inv[np.ix_(drives_neurons, drives_neurons)] = np.linalg.pinv(self.R[np.ix_(drives_neurons, drives_neurons)]+self.regularizer*np.eye(dn)) 

+

749 except: 

+

750 print(self.R) 

+

751 print('Error with pinv in riglib/bmi/clda.py') 

+

752 

+

753 C_new = self.S * R_inv 

+

754 C = copy.deepcopy(decoder.filt.C) 

+

755 C[np.ix_(self.adapting_inds, self.state_adapting_inds)] = C_new[np.ix_(self.adapting_inds, self.state_adapting_inds)] 

+

756 

+

757 Q = (1./self.ESS) * (self.T - self.S*C.T) 

+

758 if hasattr(self, 'stable_inds_mesh'): 

+

759 if len(self.stable_inds) > 0: 

+

760 print('stable inds mesh: ', self.stable_inds, self.stable_inds_mesh) 

+

761 Q_old = decoder.filt.Q[self.stable_inds_mesh].copy() 

+

762 Q[self.stable_inds_mesh] = Q_old 

+

763 

+

764 if self.stable_inds_independent: 

+

765 Q[np.ix_(self.stable_inds, self.adapting_inds)] = 0 

+

766 Q[np.ix_(self.adapting_inds, self.stable_inds)] = 0 

+

767 

+

768 #mFR and sdFR are not exempt from the 'adapting_inds' 

+

769 try: 

+

770 mFR = mFR_old.copy() 

+

771 sdFR = sdFR_old.copy() 

+

772 except: 

+

773 mFR = 0. 

+

774 sdFR = 1. 

+

775 

+

776 if self.adapt_mFR_stats: 

+

777 mFR[self.adapting_inds] = (1-rho)*np.mean(spike_counts[self.adapting_inds,:].T, axis=0) + rho*mFR_old[self.adapting_inds] 

+

778 sdFR[self.adapting_inds] = (1-rho)*np.std(spike_counts[self.adapting_inds,:].T, axis=0) + rho*sdFR_old[self.adapting_inds] 

+

779 

+

780 C_xpose_Q_inv = C.T * np.linalg.pinv(Q) 

+

781 new_params = {'filt.C':C, 'filt.Q':Q, 'filt.C_xpose_Q_inv':C_xpose_Q_inv, 

+

782 'mFR':mFR, 'sdFR':sdFR, 'kf.ESS':self.ESS, 'filt.S':self.S, 'filt.T':self.T} 

+

783 

+

784 if self.adapt_C_xpose_Q_inv_C: 

+

785 C_xpose_Q_inv_C = C_xpose_Q_inv * C 

+

786 new_params['filt.C_xpose_Q_inv_C'] = C_xpose_Q_inv_C 

+

787 new_params['filt.C_xpose_Q_inv'] = C_xpose_Q_inv 

+

788 new_params['filt.R'] = self.R 

+

789 else: 

+

790 new_params['filt.C_xpose_Q_inv_C'] = decoder.filt.C_xpose_Q_inv_C 

+

791 new_params['filt.R'] = decoder.filt.R 

+

792 

+

793 self._new_params = new_params 

+

794 return new_params 

+

795 

+

796 def set_stable_inds(self, stable_inds, adapting_inds=None, stable_inds_independent=False): 

+

797 ''' 

+

798 Set certain neural tuning parmeters to remain static, e.g., if you  

+

799 want to add a new unit to a decoder but keep the existing parameters for the old units.  

+

800 ''' 

+

801 if adapting_inds is None: # Stable inds provided 

+

802 self.stable_inds = stable_inds 

+

803 self.adapting_inds = np.array([x for x in self.feature_inds if x not in self.stable_inds]).astype(int) 

+

804 elif stable_inds is None: # Adapting inds provided: 

+

805 self.adapting_inds = np.array(adapting_inds).astype(int) 

+

806 self.stable_inds = np.array([x for x in self.feature_inds if x not in self.adapting_inds]) 

+

807 

+

808 self.adapting_inds_mesh = np.ix_(self.adapting_inds, self.adapting_inds) 

+

809 self.stable_inds_mesh = np.ix_(self.stable_inds, self.stable_inds) 

+

810 self.stable_inds_independent = stable_inds_independent 

+

811 

+

812 def set_stable_states(self, stable_state_inds, adapting_state_inds=None, stable_state_inds_independent=False): 

+

813 ''' 

+

814 Maybe you want to keep specific states states (e.g. in iBMI, keep ArmAssist stable but adapt ReHand) 

+

815 ''' 

+

816 if adapting_state_inds is None: 

+

817 self.state_adapting_inds = np.array([x for x in self.state_inds if x not in stable_state_inds]) 

+

818 elif stable_state_inds is None: 

+

819 self.state_adapting_inds = np.array(adapting_state_inds).astype(int) 

+

820 self.state_adapting_inds_mesh = np.ix_(self.state_adapting_inds, self.state_adapting_inds) 

+

821 self.stable_state_inds_independent = stable_state_inds_independent 

+

822 

+

823class KFRML_IVC(KFRML): 

+

824 ''' 

+

825 RML version where diagonality constraints are imposed on the steady state KF matrices 

+

826 ''' 

+

827 default_gain = None 

+

828 def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=None, values=None, **kwargs): 

+

829 ''' 

+

830 See KFRML.calc for input argument documentation 

+

831 ''' 

+

832 new_params = super(KFRML_IVC, self).calc(intended_kin=intended_kin, spike_counts=spike_counts, decoder=decoder, half_life=half_life, values=values, **kwargs) 

+

833 C, Q, = new_params['filt.C'], new_params['filt.Q'] 

+

834 

+

835 D = (C.T * np.linalg.pinv(Q) * C) 

+

836 if self.default_gain == None: 

+

837 # assume velocity states are last half of states:  

+

838 v0 = int(.5*(D.shape[0] - 1)) 

+

839 

+

840 # get non-zero indices (e.g. cursor state only uses indices 3 and 5, not 4) 

+

841 vix = np.nonzero(np.diag(D[v0:-1, v0:-1]))[0] + v0 

+

842 

+

843 # take mean:  

+

844 d = np.mean(np.diag(D)[vix]) 

+

845 

+

846 # set diagonal to mean, off-diagonal to zeros:  

+

847 D[v0:-1, v0:-1] = np.diag(np.zeros((v0,))+d) 

+

848 

+

849 #Old: cursor only:  

+

850 #d = np.mean([D[3,3], D[5,5]]) 

+

851 #D[3:6, 3:6] = np.diag([d, d, d]) 

+

852 else: 

+

853 # calculate the gain from the riccati equation solution 

+

854 A_diag = np.diag(np.asarray(decoder.filt.A[3:6, 3:6])) 

+

855 W_diag = np.diag(np.asarray(decoder.filt.W[3:6, 3:6])) 

+

856 D_diag = [] 

+

857 for a, w, n in zip(A_diag, W_diag, [self.default_gain]*3): 

+

858 d = self.scalar_riccati_eq_soln(a, w, n) 

+

859 D_diag.append(d) 

+

860 

+

861 D[3:6, 3:6] = np.mat(np.diag(D_diag)) 

+

862 

+

863 new_params['filt.C_xpose_Q_inv_C'] = D 

+

864 new_params['filt.C_xpose_Q_inv'] = C.T * np.linalg.pinv(Q) 

+

865 return new_params 

+

866 

+

867 @classmethod 

+

868 def scalar_riccati_eq_soln(cls, a, w, n): 

+

869 ''' 

+

870 For the scalar case, determine what you want the prediction covariance of the KF,  

+

871 which follows the riccati recursion for constant model parameters, 

+

872 based on what gain you want to set for the steady-state KF 

+

873 

+

874 Parameters 

+

875 ---------- 

+

876 a : float 

+

877 Diagonal value of the A matrix for the velocity terms 

+

878 w : float 

+

879 Diagonal value of the W matrix for the velocity terms 

+

880 n : float 

+

881 Steady-state kalman filter gain for the velocity terms 

+

882 

+

883 Returns  

+

884 ------- 

+

885 float 

+

886 ''' 

+

887 return (1-a*n)/w * (a-n)/n 

+

888 

+

889class KFRML_baseline(KFRML): 

+

890 ''' 

+

891 RML version where only the baseline firing rates are adapted 

+

892 ''' 

+

893 def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=None, values=None, **kwargs): 

+

894 ''' 

+

895 See KFRML.calc for input argument documentation 

+

896 ''' 

+

897 print("calculating new baseline parameters") 

+

898 if half_life is not None: 

+

899 rho = np.exp(np.log(0.5)/(half_life/self.batch_time)) 

+

900 else: 

+

901 rho = self.rho 

+

902 

+

903 drives_neurons = decoder.drives_neurons 

+

904 mFR_old = decoder.mFR.copy() 

+

905 sdFR_old = decoder.sdFR.copy() 

+

906 

+

907 mFR = mFR_old.copy() 

+

908 sdFR= sdFR_old.copy() 

+

909 

+

910 mFR[self.adapting_inds] = (1-rho)*np.mean(spike_counts[self.adapting_inds,:].T, axis=0) + rho*mFR_old[self.adapting_inds] 

+

911 sdFR[self.adapting_inds] = (1-rho)*np.std(spike_counts[self.adapting_inds,:].T, axis=0) + rho*sdFR_old[self.adapting_inds] 

+

912 

+

913 new_params = {'mFR':mFR, 'sdFR':sdFR} 

+

914 

+

915 return new_params 

+

916 

+

917 

+

918################################### 

+

919##### Updaters in development ##### 

+

920################################### 

+

921class PPFRML(Updater): 

+

922 '''RML method applied to more generic GLM''' 

+

923 update_kwargs = dict() 

+

924 def __init__(self, *args, **kwargs): 

+

925 super(PPFRML, self).__init__(self.calc, multiproc=False) 

+

926 

+

927 def init(self, decoder): 

+

928 self.dt = decoder.filt.dt 

+

929 self.C_est = decoder.filt.C 

+

930 self.H = decoder.H 

+

931 self.M = decoder.M 

+

932 self.S = decoder.S 

+

933 

+

934 self.neuron_driving_state_inds = np.nonzero(decoder.drives_neurons)[0] 

+

935 self.neuron_driving_states = list(np.take(decoder.states, np.nonzero(decoder.drives_neurons)[0])) 

+

936 self.n_states = len(decoder.states) 

+

937 self.full_size = len(decoder.states) 

+

938 

+

939 

+

940 def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=120., **kwargs): 

+

941 ''' 

+

942 # time iterative RLS 

+

943 ''' 

+

944 if (intended_kin is None) or (spike_counts is None) or (decoder is None): 

+

945 raise ValueError("must specify intended_kin, spike_counts and decoder objects for the updater to work!") 

+

946 

+

947 batch_size = 1. 

+

948 batch_time = batch_size * decoder.binlen 

+

949 rho = np.exp(np.log(0.5)/(half_life/batch_time)) 

+

950 

+

951 n_cells = self.C_est.shape[0] 

+

952 n_obs = intended_kin.shape[1] 

+

953 intended_kin = np.mat(intended_kin) 

+

954 # print "updating" 

+

955 # print intended_kin 

+

956 # print spike_counts.T 

+

957 spike_counts[spike_counts > 1] = 1 

+

958 for k in range(n_cells): 

+

959 for m in range(n_obs): 

+

960 H = np.mat(self.H[k]) 

+

961 S = np.mat(self.S[k].reshape(-1,1)) 

+

962 M = np.mat(self.M[k].reshape(-1,1)) 

+

963 c = self.C_est[k, self.neuron_driving_state_inds].T 

+

964 

+

965 # print H 

+

966 c_new = c - H.I * (S - M) 

+

967 c = rho*c + (1-rho)*c_new 

+

968 self.C_est[k, self.neuron_driving_state_inds] = c.T 

+

969 

+

970 x_m = intended_kin[self.neuron_driving_state_inds, m] #X[k].T 

+

971 mu_m = np.exp(c.T * x_m)[0,0] 

+

972 y_m = spike_counts[k, m] 

+

973 

+

974 

+

975 self.H[k] = rho*H + (1-rho)*(-mu_m * x_m * x_m.T) 

+

976 self.M[k] = np.array(rho*M + (1-rho)*(mu_m * x_m)).ravel() 

+

977 self.S[k] = np.array(rho*S + (1-rho)*(y_m*x_m)).ravel() 

+

978 

+

979 return {'filt.C': self.C_est} 

+

980 

+

981 

+

982############################### 

+

983##### Deprecated updaters ##### 

+

984############################### 

+

985class KFSmoothbatch(Updater): 

+

986 ''' 

+

987 Deprecation Warning: This update method has not been used for quite long. See KFRML for an enhanced but similar method 

+

988 

+

989 Calculate KF Parameter updates using the SmoothBatch method. See [Orsborn et al, 2012] for mathematical details 

+

990 ''' 

+

991 update_kwargs = dict(steady_state=True) 

+

992 def __init__(self, batch_time, half_life): 

+

993 ''' 

+

994 Constructor for KFSmoothbatch 

+

995 

+

996 Parameters 

+

997 ---------- 

+

998 batch_time : float 

+

999 Time over which to collect sample data 

+

1000 half_life : float 

+

1001 Time over which parameters are half-overwritten 

+

1002 

+

1003 Return 

+

1004 ------ 

+

1005 KFSmoothbatch instance 

+

1006 ''' 

+

1007 super(KFSmoothbatch, self).__init__(self.calc, multiproc=False) 

+

1008 self.half_life = half_life 

+

1009 self.batch_time = batch_time 

+

1010 self.rho = np.exp(np.log(0.5) / (self.half_life/batch_time)) 

+

1011 

+

1012 def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=None, **kwargs): 

+

1013 """ 

+

1014 Smoothbatch calculations 

+

1015 

+

1016 Run least-squares on (intended_kinematics, spike_counts) to  

+

1017 determine the C_hat and Q_hat of new batch. Then combine with  

+

1018 old parameters using step-size rho 

+

1019 """ 

+

1020 print("calculating new SB parameters") 

+

1021 C_old = decoder.kf.C 

+

1022 Q_old = decoder.kf.Q 

+

1023 drives_neurons = decoder.drives_neurons 

+

1024 mFR_old = decoder.mFR 

+

1025 sdFR_old = decoder.sdFR 

+

1026 

+

1027 C_hat, Q_hat = kfdecoder.KalmanFilter.MLE_obs_model( 

+

1028 intended_kin, spike_counts, include_offset=False, drives_obs=drives_neurons) 

+

1029 

+

1030 if not (half_life is None): 

+

1031 rho = np.exp(np.log(0.5)/(half_life/self.batch_time)) 

+

1032 else: 

+

1033 rho = self.rho 

+

1034 

+

1035 C = (1-rho)*C_hat + rho*C_old 

+

1036 Q = (1-rho)*Q_hat + rho*Q_old 

+

1037 

+

1038 mFR = (1-rho)*np.mean(spike_counts.T, axis=0) + rho*mFR_old 

+

1039 sdFR = (1-rho)*np.std(spike_counts.T, axis=0) + rho*sdFR_old 

+

1040 

+

1041 D = C.T * np.linalg.pinv(Q) * C 

+

1042 new_params = {'kf.C':C, 'kf.Q':Q, 

+

1043 'kf.C_xpose_Q_inv_C':D, 'kf.C_xpose_Q_inv':C.T * np.linalg.pinv(Q), 

+

1044 'mFR':mFR, 'sdFR':sdFR, 'rho':rho } 

+

1045 return new_params 

+

1046 

+

1047 

+

1048class KFOrthogonalPlantSmoothbatch(KFSmoothbatch): 

+

1049 '''This module is deprecated. See KFRML_IVC''' 

+

1050 def __init__(self, *args, **kwargs): 

+

1051 self.default_gain = kwargs.pop('default_gain', None) 

+

1052 suoer(KFOrthogonalPlantSmoothbatch, self).__init__(*args, **kwargs) 

+

1053 

+

1054 def calc(self, *args, **kwargs): 

+

1055 new_params = super(KFOrthogonalPlantSmoothbatch, self).calc(*args, **kwargs) 

+

1056 C, Q, = new_params['kf.C'], new_params['kf.Q'] 

+

1057 

+

1058 D = (C.T * np.linalg.pinv(Q) * C) 

+

1059 if self.default_gain == None: 

+

1060 d = np.mean([D[3,3], D[5,5]]) 

+

1061 D[3:6, 3:6] = np.diag([d, d, d]) 

+

1062 else: 

+

1063 # calculate the gain from the riccati equation solution 

+

1064 A_diag = np.diag(np.asarray(decoder.filt.A[3:6, 3:6])) 

+

1065 W_diag = np.diag(np.asarray(decoder.filt.W[3:6, 3:6])) 

+

1066 D_diag = [] 

+

1067 for a, w, n in zip(A_diag, W_diag, self.default_gain): 

+

1068 d = self.scalar_riccati_eq_soln(a, w, n) 

+

1069 D_diag.append(d) 

+

1070 

+

1071 D[3:6, 3:6] = np.mat(np.diag(D_diag)) 

+

1072 

+

1073 new_params['kf.C_xpose_Q_inv_C'] = D 

+

1074 new_params['kf.C_xpose_Q_inv'] = C.T * np.linalg.pinv(Q) 

+

1075 return new_params 

+

1076 

+

1077 

+

1078class PPFSmoothbatch(Updater): 

+

1079 ''' 

+

1080 Deprecated: This updater as of 2015-Sept-19 was never used in an experiment.  

+

1081 ''' 

+

1082 def __init__(self, batch_time, half_life): 

+

1083 super(PPFSmoothbatch, self).__init__(self.calc, multiproc=True) 

+

1084 self.half_life = half_life 

+

1085 self.rho = np.exp(np.log(0.5) / (self.half_life/batch_time)) 

+

1086 

+

1087 def calc(self, intended_kin=None, spike_counts=None, decoder=None, half_life=None, **kwargs): 

+

1088 """ 

+

1089 Smoothbatch calculations 

+

1090 

+

1091 Run least-squares on (intended_kinematics, spike_counts) to  

+

1092 determine the C_hat and Q_hat of new batch. Then combine with  

+

1093 old parameters using step-size rho 

+

1094 """ 

+

1095 if half_life is not None: 

+

1096 rho = np.exp(np.log(0.5)/(half_life/self.batch_time)) 

+

1097 else: 

+

1098 rho = self.rho 

+

1099 

+

1100 C_old = decoder.filt.C 

+

1101 drives_neurons = decoder.drives_neurons 

+

1102 states = decoder.states 

+

1103 decoding_states = np.take(states, np.nonzero(drives_neurons)).ravel().tolist() #['hand_vx', 'hand_vz', 'offset']  

+

1104 

+

1105 C_hat, pvalues = ppfdecoder.PointProcessFilter.MLE_obs_model( 

+

1106 intended_kin, spike_counts, include_offset=False, drives_obs=drives_neurons) 

+

1107 C_hat = train.inflate(C_hat, decoding_states, states, axis=1) 

+

1108 pvalues = train.inflate(pvalues, decoding_states, states, axis=1) 

+

1109 pvalues[pvalues[:,:-1] == 0] = np.inf 

+

1110 

+

1111 mesh = np.nonzero(pvalues < 0.1) 

+

1112 C = np.array(C_old.copy()) 

+

1113 C[mesh] = (1-rho)*C_hat[mesh] + rho*np.array(C_old)[mesh] 

+

1114 C = np.mat(C) 

+

1115 

+

1116 new_params = {'filt.C':C} 

+

1117 return new_params 

+

1118 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_extractor_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_extractor_py.html new file mode 100644 index 00000000..3649cc1e --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_extractor_py.html @@ -0,0 +1,1424 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\extractor.py: 13% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Classes for extracting "decodable features" from various types of neural signal sources.  

+

3Examples include spike rate estimation, LFP power, and EMG amplitude. 

+

4''' 

+

5import numpy as np 

+

6import time 

+

7from scipy.signal import butter, lfilter 

+

8import math 

+

9import os 

+

10import nitime.algorithms as tsa 

+

11 

+

12class FeatureExtractor(object): 

+

13 ''' 

+

14 Parent of all feature extractors, used only for interfacing/type-checking. 

+

15 Feature extractors are objects tha gets the data that it needs (e.g., spike timestamps, LFP voltages, etc.)  

+

16 from the neural data source object and extracts features from it 

+

17 ''' 

+

18 @classmethod 

+

19 def extract_from_file(cls, *args, **kwargs): 

+

20 raise NotImplementedError 

+

21 

+

22 

+

23class DummyExtractor(FeatureExtractor): 

+

24 ''' 

+

25 An extractor which does nothing. Used for tasks which are only pretending to be BMI tasks, e.g., visual feedback tasks 

+

26 ''' 

+

27 feature_type = 'obs' 

+

28 feature_dtype = [('obs', 'f8', (1,))] 

+

29 

+

30 def __call__(self, *args, **kwargs): 

+

31 return dict(obs=np.array([[np.nan]])) 

+

32 

+

33class BinnedSpikeCountsExtractor(FeatureExtractor): 

+

34 ''' 

+

35 Extracts spike counts from spike timestamps separated into rectangular window.  

+

36 This extractor is (currently) the main type of feature extractor in intracortical BMIs 

+

37 ''' 

+

38 feature_type = 'spike_counts' 

+

39 

+

40 def __init__(self, source, n_subbins=1, units=[]): 

+

41 ''' 

+

42 Constructor for BinnedSpikeCountsExtractor 

+

43 

+

44 Parameters 

+

45 ---------- 

+

46 source: DataSource instance 

+

47 Source must implement a '.get()' function which returns the appropriate data  

+

48 (appropriateness will change depending on the source) 

+

49 n_subbins: int, optional, default=1 

+

50 Number of bins into which to divide the observed spike counts  

+

51 units: np.ndarray of shape (N, 2), optional, default=[] 

+

52 Units which need spike binning. Each row of the array corresponds to (channel, unit). By default no units will be binned. 

+

53 

+

54 Returns 

+

55 ------- 

+

56 BinnedSpikeCountsExtractor instance 

+

57 ''' 

+

58 self.feature_dtype = [('spike_counts', 'u4', (len(units), n_subbins)), ('bin_edges', 'f8', 2)] 

+

59 

+

60 self.source = source 

+

61 self.n_subbins = n_subbins 

+

62 self.units = units 

+

63 

+

64 extractor_kwargs = dict() 

+

65 extractor_kwargs['n_subbins'] = self.n_subbins 

+

66 extractor_kwargs['units'] = self.units 

+

67 self.extractor_kwargs = extractor_kwargs 

+

68 self.last_get_spike_counts_time = 0 

+

69 

+

70 def set_n_subbins(self, n_subbins): 

+

71 ''' 

+

72 Alter the # of subbins without changing the extractor kwargs of a decoder 

+

73 

+

74 Parameters 

+

75 ---------- 

+

76 n_subbins : int  

+

77 Number of bins into which to divide the observed spike counts  

+

78 

+

79 Returns 

+

80 ------- 

+

81 None 

+

82 ''' 

+

83 self.n_subbins = n_subbins 

+

84 self.extractor_kwargs['n_subbins'] = n_subbins 

+

85 self.feature_dtype = [('spike_counts', 'u4', (len(self.units), n_subbins)), ('bin_edges', 'f8', 2)] 

+

86 

+

87 def get_spike_ts(self, *args, **kwargs): 

+

88 ''' 

+

89 Get the spike timestamps from the neural data source. This function has no type checking,  

+

90 i.e., it is assumed that the Extractor object was created with the proper source 

+

91 

+

92 Parameters 

+

93 ---------- 

+

94 None are needed (args and kwargs are ignored) 

+

95 

+

96 Returns 

+

97 ------- 

+

98 numpy record array 

+

99 Spike timestamps in the format of riglib.plexon.Spikes.dtype 

+

100 ''' 

+

101 return self.source.get() 

+

102 

+

103 def get_bin_edges(self, ts): 

+

104 ''' 

+

105 Determine the first and last spike timestamps to allow HDF files  

+

106 created by the BMI to be semi-synchronized with the neural data file 

+

107 

+

108 Parameters 

+

109 ---------- 

+

110 ts : numpy record array 

+

111 Must have field 'ts' of spike timestamps in seconds 

+

112 

+

113 Returns 

+

114 ------- 

+

115 np.ndarray of shape (2,) 

+

116 The smallest and largest timestamps corresponding to the current feature;  

+

117 useful for rough synchronization of the BMI event loop with the neural recording system. 

+

118 ''' 

+

119 if len(ts) == 0: 

+

120 bin_edges = np.array([np.nan, np.nan]) 

+

121 else: 

+

122 min_ind = np.argmin(ts['ts']) 

+

123 max_ind = np.argmax(ts['ts']) 

+

124 bin_edges = np.array([ts[min_ind]['ts'], ts[max_ind]['ts']]) 

+

125 

+

126 @classmethod 

+

127 def bin_spikes(cls, ts, units, max_units_per_channel=13): 

+

128 ''' 

+

129 Count up the number of BMI spikes in a list of spike timestamps. 

+

130 

+

131 Parameters 

+

132 ---------- 

+

133 ts : numpy record array 

+

134 Must have field 'ts' of spike timestamps in seconds 

+

135 units : np.ndarray of shape (N, 2) 

+

136 Each row corresponds to the channel index (typically the electrode number) and  

+

137 the unit index (an index to differentiate the possibly many units on the same electrode). These are  

+

138 the units used in the BMI. 

+

139 max_units_per_channel : int, optional, default=13 

+

140 This int is used to map from a (channel, unit) index to a single 'unit_ind'  

+

141 for faster binning of spike timestamps. Just set to a large number. 

+

142 

+

143 Returns 

+

144 ------- 

+

145 np.ndarray of shape (N, 1) 

+

146 Column vector of counts of spike events for each of the N units. 

+

147 ''' 

+

148 unit_inds = units[:,0]*max_units_per_channel + units[:,1] 

+

149 edges = np.sort(np.hstack([unit_inds - 0.5, unit_inds + 0.5])) 

+

150 

+

151 spiking_unit_inds = ts['chan']*max_units_per_channel + ts['unit'] 

+

152 counts, _ = np.histogram(spiking_unit_inds, edges) 

+

153 return counts[::2] 

+

154 

+

155 def __call__(self, start_time, *args, **kwargs): 

+

156 ''' 

+

157 Main function to retreive new spike data and bin the counts 

+

158 

+

159 Parameters 

+

160 ---------- 

+

161 start_time : float  

+

162 Absolute time from the task event loop. This is used only to subdivide  

+

163 the spike timestamps into multiple bins, if desired (if the 'n_subbins' attribute is > 1) 

+

164 *args, **kwargs : optional positional/keyword arguments 

+

165 These are passed to the source, or ignored (not needed for this extractor). 

+

166 

+

167 Returns 

+

168 ------- 

+

169 dict 

+

170 Extracted features to be saved in the task.  

+

171 ''' 

+

172 ts = self.get_spike_ts(*args, **kwargs) 

+

173 if len(ts) == 0: 

+

174 counts = np.zeros([len(self.units), self.n_subbins]) 

+

175 elif self.n_subbins > 1: 

+

176 subbin_edges = np.linspace(self.last_get_spike_counts_time, start_time, self.n_subbins+1) 

+

177 

+

178 # Decrease the first subbin index to include any spikes that were 

+

179 # delayed in getting to the task layer due to threading issues 

+

180 # An acceptable delay is 1 sec or less. Realistically, most delays should be 

+

181 # on the millisecond order 

+

182 subbin_edges[0] -= 1 

+

183 subbin_inds = np.digitize(ts['arrival_ts'], subbin_edges) 

+

184 counts = np.vstack([self.bin_spikes(ts[subbin_inds == k], self.units) for k in range(1, self.n_subbins+1)]).T 

+

185 else: 

+

186 counts = self.bin_spikes(ts, self.units).reshape(-1, 1) 

+

187 

+

188 counts = np.array(counts, dtype=np.uint32) 

+

189 bin_edges = self.get_bin_edges(ts) 

+

190 self.last_get_spike_counts_time = start_time 

+

191 

+

192 return dict(spike_counts=counts, bin_edges=bin_edges) 

+

193 

+

194 @classmethod 

+

195 def extract_from_file(cls, files, neurows, binlen, units, extractor_kwargs, strobe_rate=60.0): 

+

196 ''' 

+

197 Compute binned spike count features 

+

198 

+

199 Parameters 

+

200 ---------- 

+

201 files : dict 

+

202 Data files used to train the decoder. Should contain exactly one type of neural data file (e.g., Plexon, Blackrock, TDT) 

+

203 neurows: np.ndarray of shape (T,) 

+

204 Timestamps in the plexon time reference corresponding to bin boundaries 

+

205 binlen: float 

+

206 Length of time over which to sum spikes from the specified cells 

+

207 units: np.ndarray of shape (N, 2) 

+

208 List of units that the decoder will be trained on. The first column specifies the electrode number and the second specifies the unit on the electrode 

+

209 extractor_kwargs: dict  

+

210 Any additional parameters to be passed to the feature extractor. This function is agnostic to the actual extractor utilized 

+

211 strobe_rate: 60.0 

+

212 The rate at which the task sends the sync pulse to the plx file 

+

213 

+

214 Returns 

+

215 ------- 

+

216 spike_counts : np.ndarray of shape (N, T) 

+

217 Spike counts binned over the length of the datafile. 

+

218 units : np.ndarray of shape (N, 2) 

+

219 Each row corresponds to the channel index (typically the electrode number) and  

+

220 the unit index (an index to differentiate the possibly many units on the same electrode). These are  

+

221 the units used in the BMI. 

+

222 extractor_kwargs : dict 

+

223 Parameters used to instantiate the feature extractor, to be stored  

+

224 along with the trained decoder so that the exact same feature extractor can be re-created at runtime. 

+

225 ''' 

+

226 if 'plexon' in files: 

+

227 from plexon import plexfile 

+

228 plx = plexfile.openFile(str(files['plexon'])) 

+

229 # interpolate between the rows to 180 Hz 

+

230 if binlen < 1./strobe_rate: 

+

231 interp_rows = [] 

+

232 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

233 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

234 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

235 interp_rows = np.array(interp_rows) 

+

236 else: 

+

237 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

238 interp_rows = neurows[::step] 

+

239 print(('step: ', step)) 

+

240 from plexon import psth 

+

241 spike_bin_fn = psth.SpikeBin(units, binlen) 

+

242 spike_counts = np.array(list(plx.spikes.bin(interp_rows, spike_bin_fn))) 

+

243 

+

244 

+

245 # discard units that never fired at all 

+

246 discard_zero_units = extractor_kwargs.pop('discard_zero_units', True) 

+

247 if discard_zero_units: 

+

248 unit_inds, = np.nonzero(np.sum(spike_counts, axis=0)) 

+

249 units = units[unit_inds,:] 

+

250 spike_counts = spike_counts[:, unit_inds] 

+

251 extractor_kwargs['units'] = units 

+

252 

+

253 return spike_counts, units, extractor_kwargs 

+

254 

+

255 elif 'blackrock' in files: 

+

256 nev_fname = [name for name in files['blackrock'] if '.nev' in name][0] # only one of them 

+

257 nev_hdf_fname = [name for name in files['blackrock'] if '.nev' in name and name[-4:]=='.hdf'] 

+

258 nsx_fnames = [name for name in files['blackrock'] if '.ns' in name] 

+

259 # interpolate between the rows to 180 Hz 

+

260 if binlen < 1./strobe_rate: 

+

261 interp_rows = [] 

+

262 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

263 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

264 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

265 interp_rows = np.array(interp_rows) 

+

266 else: 

+

267 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

268 interp_rows = neurows[::step] 

+

269 

+

270 

+

271 

+

272 if len(nev_hdf_fname) == 0: 

+

273 nev_hdf_fname = nev_fname + '.hdf' 

+

274 

+

275 if not os.path.isfile(nev_hdf_fname): 

+

276 # convert .nev file to hdf file using Blackrock's n2h5 utility 

+

277 subprocess.call(['n2h5', nev_fname, nev_hdf_fname]) 

+

278 else: 

+

279 nev_hdf_fname = nev_hdf_fname[0] 

+

280 

+

281 try: 

+

282 nev_hdf = h5py.File(nev_hdf_fname, 'r') 

+

283 open_method = 1 

+

284 except: 

+

285 import tables 

+

286 nev_hdf = tables.openFile(nev_hdf_fname) 

+

287 open_method = 2 

+

288 #print 'open method 2' 

+

289 

+

290 n_bins = len(interp_rows) 

+

291 n_units = units.shape[0] 

+

292 spike_counts = np.zeros((n_bins, n_units)) 

+

293 

+

294 for i in range(n_units): 

+

295 chan = units[i, 0] 

+

296 

+

297 # 1-based numbering (comes from web interface) 

+

298 unit = units[i, 1] 

+

299 

+

300 chan_str = str(chan).zfill(5) 

+

301 path = 'channel/channel%s/spike_set' % chan_str 

+

302 

+

303 if open_method == 1: 

+

304 ts = nev_hdf.get(path).value['TimeStamp'] 

+

305 # the units corresponding to each timestamp in ts 

+

306 # 0-based numbering (comes from .nev file), so add 1 

+

307 units_ts = nev_hdf.get(path).value['Unit'] 

+

308 

+

309 elif open_method == 2: 

+

310 try: 

+

311 grp = nev_hdf.getNode('/'+path) 

+

312 ts = grp[:]['TimeStamp'] 

+

313 units_ts = grp[:]['Unit'] 

+

314 except: 

+

315 print(('no spikes recorded on channel: ', chan_str, ': adding zeros')) 

+

316 ts = [] 

+

317 units_ts = [] 

+

318 

+

319 

+

320 # get the ts for this unit, in units of secs 

+

321 fs = 30000. 

+

322 ts = [t/fs for idx, (t, u_t) in enumerate(zip(ts, units_ts)) if u_t == unit] 

+

323 

+

324 # insert value interp_rows[0]-step to beginning of interp_rows array 

+

325 interp_rows_ = np.insert(interp_rows, 0, interp_rows[0]-step) 

+

326 

+

327 # use ts to fill in the spike_counts that corresponds to unit i 

+

328 spike_counts[:, i] = np.histogram(ts, interp_rows_)[0] 

+

329 

+

330 

+

331 # discard units that never fired at all 

+

332 if 'keep_zero_units' in extractor_kwargs: 

+

333 print('keeping zero firing units') 

+

334 else: 

+

335 unit_inds, = np.nonzero(np.sum(spike_counts, axis=0)) 

+

336 units = units[unit_inds,:] 

+

337 spike_counts = spike_counts[:, unit_inds] 

+

338 

+

339 extractor_kwargs['units'] = units 

+

340 

+

341 return spike_counts, units, extractor_kwargs 

+

342 

+

343 elif 'tdt' in files: 

+

344 raise NotImplementedError 

+

345 

+

346# bands should be a list of tuples representing ranges 

+

347# e.g., bands = [(0, 10), (10, 20), (130, 140)] for 0-10, 10-20, and 130-140 Hz 

+

348start = 0 

+

349end = 150 

+

350step = 10 

+

351default_bands = [] 

+

352for freq in range(start, end, step): 

+

353 default_bands.append((freq, freq+step)) 

+

354 

+

355class LFPMTMPowerExtractor(object): 

+

356 ''' 

+

357 Computes log power of the LFP in different frequency bands (for each  

+

358 channel) in freq-domain using the multi-taper method. 

+

359 ''' 

+

360 

+

361 feature_type = 'lfp_power' 

+

362 

+

363 def __init__(self, source, channels=[], bands=default_bands, win_len=0.2, NW=3, fs=1000, **kwargs): 

+

364 ''' 

+

365 Constructor for LFPMTMPowerExtractor, which extracts LFP power using the multi-taper method 

+

366 

+

367 Parameters 

+

368 ---------- 

+

369 source : riglib.source.Source object 

+

370 Object which yields new data when its 'get' method is called 

+

371 channels : list  

+

372 LFP electrode indices to use for feature extraction 

+

373 bands : list of tuples 

+

374 Each tuple defines a frequency band of interest as (start frequency, end frequency)  

+

375 

+

376 Returns 

+

377 ------- 

+

378 LFPMTMPowerExtractor instance 

+

379 ''' 

+

380 #self.feature_dtype = ('lfp_power', 'f8', (len(channels)*len(bands), 1)) 

+

381 

+

382 self.source = source 

+

383 self.channels = channels 

+

384 self.bands = bands 

+

385 self.win_len = win_len 

+

386 self.NW = NW 

+

387 if source is not None: 

+

388 self.fs = source.source.update_freq 

+

389 else: 

+

390 self.fs = fs 

+

391 

+

392 extractor_kwargs = dict() 

+

393 extractor_kwargs['channels'] = self.channels 

+

394 extractor_kwargs['bands'] = self.bands 

+

395 extractor_kwargs['win_len'] = self.win_len 

+

396 extractor_kwargs['NW'] = self.NW 

+

397 extractor_kwargs['fs'] = self.fs 

+

398 

+

399 

+

400 extractor_kwargs['no_log'] = 'no_log' in kwargs and kwargs['no_log']==True #remove log calculation 

+

401 extractor_kwargs['no_mean'] = 'no_mean' in kwargs and kwargs['no_mean']==True #r 

+

402 self.extractor_kwargs = extractor_kwargs 

+

403 

+

404 self.n_pts = int(self.win_len * self.fs) 

+

405 self.nfft = 2**int(np.ceil(np.log2(self.n_pts))) # nextpow2(self.n_pts) 

+

406 fft_freqs = np.arange(0., fs, float(fs)/self.nfft)[:self.nfft/2 + 1] 

+

407 self.fft_inds = dict() 

+

408 for band_idx, band in enumerate(bands): 

+

409 self.fft_inds[band_idx] = [freq_idx for freq_idx, freq in enumerate(fft_freqs) if band[0] <= freq < band[1]] 

+

410 

+

411 extractor_kwargs['fft_inds'] = self.fft_inds 

+

412 extractor_kwargs['fft_freqs'] = fft_freqs 

+

413 

+

414 self.epsilon = 1e-9 

+

415 

+

416 if extractor_kwargs['no_mean']: #Used in lfp 1D control task 

+

417 self.feature_dtype = ('lfp_power', 'f8', (len(channels)*len(fft_freqs), 1)) 

+

418 else: 

+

419 self.feature_dtype = ('lfp_power', 'f8', (len(channels)*len(bands), 1)) 

+

420 

+

421 def get_cont_samples(self, *args, **kwargs): 

+

422 ''' 

+

423 Retreives the last n_pts number of samples for each LPF channel from the neural data 'source' 

+

424 

+

425 Parameters 

+

426 ---------- 

+

427 *args, **kwargs : optional arguments 

+

428 Ignored for this extractor (not necessary) 

+

429 

+

430 Returns 

+

431 ------- 

+

432 np.ndarray of shape ??? 

+

433 ''' 

+

434 return self.source.get(self.n_pts, self.channels) 

+

435 

+

436 def extract_features(self, cont_samples): 

+

437 ''' 

+

438 Extract spectral features from a block of time series samples 

+

439 

+

440 Parameters 

+

441 ---------- 

+

442 cont_samples : np.ndarray of shape (n_channels, n_samples) 

+

443 Raw voltage time series (one per channel) from which to extract spectral features  

+

444 

+

445 Returns 

+

446 ------- 

+

447 lfp_power : np.ndarray of shape (n_channels * n_features, 1) 

+

448 Multi-band power estimates for each channel, for each band specified when the feature extractor was instantiated. 

+

449 ''' 

+

450 psd_est = tsa.multi_taper_psd(cont_samples, Fs=self.fs, NW=self.NW, jackknife=False, low_bias=True, NFFT=self.nfft)[1] 

+

451 

+

452 if ('no_mean' in self.extractor_kwargs) and (self.extractor_kwargs['no_mean'] is True): 

+

453 return psd_est.reshape(psd_est.shape[0]*psd_est.shape[1], 1) 

+

454 

+

455 else: 

+

456 # compute average power of each band of interest 

+

457 n_chan = len(self.channels) 

+

458 lfp_power = np.zeros((n_chan * len(self.bands), 1)) 

+

459 for idx, band in enumerate(self.bands): 

+

460 if self.extractor_kwargs['no_log']: 

+

461 lfp_power[idx*n_chan : (idx+1)*n_chan, 0] = np.mean(psd_est[:, self.fft_inds[idx]], axis=1) 

+

462 else: 

+

463 lfp_power[idx*n_chan : (idx+1)*n_chan, 0] = np.mean(np.log10(psd_est[:, self.fft_inds[idx]] + self.epsilon), axis=1) 

+

464 

+

465 return lfp_power 

+

466 

+

467 def __call__(self, start_time, *args, **kwargs): 

+

468 ''' 

+

469 Parameters 

+

470 ---------- 

+

471 start_time : float  

+

472 Absolute time from the task event loop. This is unused by LFP extractors in their current implementation 

+

473 and only passed in to ensure that function signatures are the same across extractors. 

+

474 *args, **kwargs : optional positional/keyword arguments 

+

475 These are passed to the source, or ignored (not needed for this extractor). 

+

476 

+

477 Returns 

+

478 ------- 

+

479 dict 

+

480 Extracted features to be saved in the task.  

+

481 ''' 

+

482 cont_samples = self.get_cont_samples(*args, **kwargs) # dims of channels x time 

+

483 lfp_power = self.extract_features(cont_samples) 

+

484 

+

485 return dict(lfp_power=lfp_power) 

+

486 

+

487 @classmethod 

+

488 def extract_from_file(cls, files, neurows, binlen, units, extractor_kwargs, strobe_rate=60.0): 

+

489 ''' 

+

490 Compute binned spike count features 

+

491 

+

492 Parameters 

+

493 ---------- 

+

494 files : dict 

+

495 Data files used to train the decoder. Should contain exactly one type of neural data file (e.g., Plexon, Blackrock, TDT) 

+

496 neurows: np.ndarray of shape (T,) 

+

497 Timestamps in the plexon time reference corresponding to bin boundaries 

+

498 binlen: float 

+

499 Length of time over which to sum spikes from the specified cells 

+

500 units: np.ndarray of shape (N, 2) 

+

501 List of units that the decoder will be trained on. The first column specifies the electrode number and the second specifies the unit on the electrode 

+

502 extractor_kwargs: dict  

+

503 Any additional parameters to be passed to the feature extractor. This function is agnostic to the actual extractor utilized 

+

504 strobe_rate: 60.0 

+

505 The rate at which the task sends the sync pulse to the plx file 

+

506 

+

507 Returns 

+

508 ------- 

+

509 spike_counts : np.ndarray of shape (N, T) 

+

510 Spike counts binned over the length of the datafile. 

+

511 units :  

+

512 Not used by this type of extractor, just passed back from the input argument to make the outputs consistent with spike count extractors 

+

513 extractor_kwargs : dict 

+

514 Parameters used to instantiate the feature extractor, to be stored  

+

515 along with the trained decoder so that the exact same feature extractor can be re-created at runtime. 

+

516 ''' 

+

517 if 'plexon' in files: 

+

518 from plexon import plexfile 

+

519 plx = plexfile.openFile(str(files['plexon'])) 

+

520 

+

521 # interpolate between the rows to 180 Hz 

+

522 if binlen < 1./strobe_rate: 

+

523 interp_rows = [] 

+

524 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

525 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

526 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

527 interp_rows = np.array(interp_rows) 

+

528 else: 

+

529 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

530 interp_rows = neurows[::step] 

+

531 

+

532 

+

533 # create extractor object 

+

534 f_extractor = LFPMTMPowerExtractor(None, **extractor_kwargs) 

+

535 extractor_kwargs = f_extractor.extractor_kwargs 

+

536 

+

537 win_len = f_extractor.win_len 

+

538 bands = f_extractor.bands 

+

539 channels = f_extractor.channels 

+

540 fs = f_extractor.fs 

+

541 print(('bands:', bands)) 

+

542 

+

543 n_itrs = len(interp_rows) 

+

544 n_chan = len(channels) 

+

545 lfp_power = np.zeros((n_itrs, n_chan * len(bands))) 

+

546 

+

547 # for i, t in enumerate(interp_rows): 

+

548 # cont_samples = plx.lfp[t-win_len:t].data[:, channels-1] 

+

549 # lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

550 lfp = plx.lfp[:].data[:, channels-1] 

+

551 n_pts = int(win_len * fs) 

+

552 for i, t in enumerate(interp_rows): 

+

553 try: 

+

554 sample_num = int(t * fs) 

+

555 cont_samples = lfp[sample_num-n_pts:sample_num, :] 

+

556 lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

557 except: 

+

558 print("Error with LFP decoder training") 

+

559 print((i, t)) 

+

560 pass 

+

561 

+

562 

+

563 # TODO -- discard any channel(s) for which the log power in any frequency  

+

564 # bands was ever equal to -inf (i.e., power was equal to 0) 

+

565 # or, perhaps just add a small epsilon inside the log to avoid this 

+

566 # then, remember to do this: extractor_kwargs['channels'] = channels 

+

567 # and reset the units variable 

+

568 

+

569 return lfp_power, units, extractor_kwargs 

+

570 

+

571 elif 'blackrock' in files: 

+

572 raise NotImplementedError 

+

573 

+

574 

+

575######################################################### 

+

576##### Reconstruction extractors, used in test cases ##### 

+

577######################################################### 

+

578class ReplaySpikeCountsExtractor(BinnedSpikeCountsExtractor): 

+

579 ''' 

+

580 A "feature extractor" that replays spike counts from an HDF file 

+

581 ''' 

+

582 feature_type = 'spike_counts' 

+

583 def __init__(self, hdf_table, source='spike_counts', cycle_rate=60.0, units=[]): 

+

584 ''' 

+

585 Parameters 

+

586 ---------- 

+

587 hdf_table : HDF table 

+

588 Data table to replay. Usually the 'task' table.  

+

589 source : string, optional, default=spike_counts 

+

590 Column of the HDF table to replay 

+

591 cycle_rate : float, optional, default=60.0 

+

592 Rate at which the task FSM "cycles", i.e., the rate at which the task will ask for new observations 

+

593 units : iterable, optional, default=[] 

+

594 Names (channel, unit) of the units. If none specified, some fake names are created 

+

595  

+

596 Returns 

+

597 ------- 

+

598 ReplaySpikeCountsExtractor instance 

+

599 ''' 

+

600 self.idx = 0 

+

601 self.hdf_table = hdf_table 

+

602 self.source = source 

+

603 self.units = units 

+

604 self.n_subbins = hdf_table[0][source].shape[1] 

+

605 self.last_get_spike_counts_time = 0 

+

606 self.cycle_rate = cycle_rate 

+

607 

+

608 n_units = hdf_table[0]['spike_counts'].shape[0] 

+

609 self.feature_dtype = [('spike_counts', 'u4', (n_units, self.n_subbins)), 

+

610 ('bin_edges', 'f8', 2)] 

+

611 

+

612 def get_spike_ts(self): 

+

613 ''' 

+

614 Make up fake timestamps to go with the spike counts extracted from the HDF file 

+

615 ''' 

+

616 from . import sim_neurons 

+

617 

+

618 # Get counts from HDF file 

+

619 counts = self.hdf_table[self.idx][self.source] 

+

620 n_subbins = counts.shape[1] 

+

621 

+

622 # Convert counts to timestamps between (self.idx*1./cycle_rate, (self.idx+1)*1./cycle_rate) 

+

623 # NOTE: this code is mostly copied from riglib.bmi.sim_neurons.CLDASimPointProcessEnsemble 

+

624 ts_data = [] 

+

625 cycle_rate = self.cycle_rate 

+

626 for k in range(n_subbins): 

+

627 fake_time = (self.idx - 1) * 1./cycle_rate + (k + 0.5)*1./cycle_rate*1./n_subbins 

+

628 nonzero_units, = np.nonzero(counts[:,k]) 

+

629 for unit_ind in nonzero_units: 

+

630 n_spikes = counts[unit_ind, k] 

+

631 for m in range(n_spikes): 

+

632 ts = (fake_time, self.units[unit_ind, 0], self.units[unit_ind, 1], fake_time) 

+

633 ts_data.append(ts) 

+

634 

+

635 ts_dtype_new = sim_neurons.ts_dtype_new 

+

636 return np.array(ts_data, dtype=ts_dtype_new) 

+

637 

+

638 def get_bin_edges(self, ts): 

+

639 ''' 

+

640 Get the first and last timestamp of spikes in the current "bin" as saved in the HDF file 

+

641 ''' 

+

642 return self.hdf_table[self.idx]['bin_edges'] 

+

643 

+

644 def __call__(self, *args, **kwargs): 

+

645 ''' 

+

646 See BinnedSpikeCountsExtractor.__call__ for documentation 

+

647 ''' 

+

648 output = super(ReplaySpikeCountsExtractor, self).__call__(*args, **kwargs) 

+

649 if not np.array_equal(output['spike_counts'], self.hdf_table[self.idx][self.source]): 

+

650 print(("spike binning error: ", self.idx)) 

+

651 self.idx += 1 

+

652 return output 

+

653 

+

654class ReplayLFPPowerExtractor(BinnedSpikeCountsExtractor): 

+

655 ''' 

+

656 A "feature extractor" that replays LFP power estimates from an HDF file 

+

657 ''' 

+

658 feature_type = 'lfp_power' 

+

659 def __init__(self, hdf_table, source='lfp_power'): 

+

660 '''  

+

661 Constructor for ReplayLFPPowerExtractor 

+

662 

+

663 Parameters 

+

664 ---------- 

+

665 hdf_table : HDF table 

+

666 Data table to replay. Usually the 'task' table.  

+

667 source : string, optional, default=spike_counts 

+

668 Column of the HDF table to replay 

+

669  

+

670 Returns 

+

671 ------- 

+

672 ReplayLFPPowerExtractor instance 

+

673 ''' 

+

674 self.idx = 0 

+

675 self.hdf_table = hdf_table 

+

676 self.source = source 

+

677 self.n_subbins = hdf_table[0][source].shape[1] 

+

678 self.last_get_spike_counts_time = 0 

+

679 

+

680 n_units = hdf_table[0][source].shape[0] 

+

681 self.feature_dtype = [('lfp_power', 'f8', (n_units, self.n_subbins)), ] 

+

682 

+

683 def __call__(self, *args, **kwargs): 

+

684 '''  

+

685 See BinnedSpikeCountsExtractor.__call__ for documentation 

+

686 ''' 

+

687 output = self.hdf_table[self.idx][self.source] 

+

688 self.idx += 1 

+

689 return dict(lfp_power=output) 

+

690 

+

691################################# 

+

692##### Simulation extractors ##### 

+

693################################# 

+

694class SimBinnedSpikeCountsExtractor(BinnedSpikeCountsExtractor): 

+

695 ''' 

+

696 Spike count features are generated by a population of synthetic neurons 

+

697 ''' 

+

698 feature_type = 'spike_counts' 

+

699 

+

700 def __init__(self, input_device, encoder, n_subbins, units, task=None): 

+

701 ''' 

+

702 Constructor for SimBinnedSpikeCountsExtractor 

+

703 

+

704 Parameters 

+

705 ---------- 

+

706 input_device: object with a "calc_next_state" method 

+

707 Generate the "intended" next state, e.g., by feedback control policy 

+

708 encoder: callable with 1 argument 

+

709 Maps the "control" input into the spike timestamps of a set of neurons 

+

710 n_subbins: 

+

711 Number of subbins to divide the spike counts into, e.g. 3 are necessary for the PPF 

+

712 units: np.ndarray of shape (N, 2) 

+

713 Each row of the array corresponds to (channel, unit) 

+

714 

+

715 Returns 

+

716 ------- 

+

717 SimBinnedSpikeCountsExtractor instance 

+

718 ''' 

+

719 self.input_device = input_device 

+

720 self.encoder = encoder 

+

721 self.n_subbins = n_subbins 

+

722 self.units = units 

+

723 self.last_get_spike_counts_time = 0 

+

724 self.feature_dtype = [('spike_counts', 'f8', (len(units), n_subbins)), ('bin_edges', 'f8', 2), 

+

725 ('ctrl_input', 'f8', self.encoder.C.shape[1])] 

+

726 self.task = task 

+

727 self.sim_ctrl = np.zeros((self.encoder.C.shape[1])) 

+

728 

+

729 def get_spike_ts(self): 

+

730 ''' 

+

731 see BinnedSpikeCountsExtractor.get_spike_ts for docs 

+

732 ''' 

+

733 current_state = self.task.get_current_state() 

+

734 target_state = self.task.get_target_BMI_state() 

+

735 ctrl = self.input_device.calc_next_state(current_state, target_state) 

+

736 #print current_state.T, target_state.T, ctrl.T 

+

737 self.sim_ctrl = ctrl 

+

738 ts_data = self.encoder(ctrl) 

+

739 return ts_data 

+

740 

+

741class SimDirectObsExtractor(SimBinnedSpikeCountsExtractor): 

+

742 ''' 

+

743 This extractor just passes back the observation vector generated by the encoder 

+

744 ''' 

+

745 def __call__(self, start_time, *args, **kwargs): 

+

746 y_t = self.get_spike_ts(*args, **kwargs) 

+

747 return dict(spike_counts=np.reshape(y_t, (len(self.units),self.n_subbins))) 

+

748 

+

749 

+

750############################################# 

+

751##### Feature extractors in development ##### 

+

752############################################# 

+

753class LFPButterBPFPowerExtractor(object): 

+

754 ''' 

+

755 Computes log power of the LFP in different frequency bands (for each  

+

756 channel) in time-domain using Butterworth band-pass filters. 

+

757 ''' 

+

758 

+

759 feature_type = 'lfp_power' 

+

760 

+

761 def __init__(self, source, channels=[], bands=default_bands, win_len=0.2, filt_order=5, fs=1000): 

+

762 self.feature_dtype = ('lfp_power', 'u4', (len(channels)*len(bands), 1)) 

+

763 

+

764 self.source = source 

+

765 self.channels = channels 

+

766 self.bands = bands 

+

767 self.win_len = win_len # secs 

+

768 self.filt_order = filt_order 

+

769 if source is not None: 

+

770 self.fs = source.source.update_freq 

+

771 else: 

+

772 self.fs = fs 

+

773 

+

774 extractor_kwargs = dict() 

+

775 extractor_kwargs['channels'] = self.channels 

+

776 extractor_kwargs['bands'] = self.bands 

+

777 extractor_kwargs['win_len'] = self.win_len 

+

778 extractor_kwargs['filt_order'] = self.filt_order 

+

779 extractor_kwargs['fs'] = self.fs 

+

780 self.extractor_kwargs = extractor_kwargs 

+

781 

+

782 self.n_pts = int(self.win_len * self.fs) 

+

783 self.filt_coeffs = dict() 

+

784 for band in bands: 

+

785 nyq = 0.5 * self.fs 

+

786 low = band[0] / nyq 

+

787 high = band[1] / nyq 

+

788 self.filt_coeffs[band] = butter(self.filt_order, [low, high], btype='band') # returns (b, a) 

+

789 

+

790 self.epsilon = 1e-9 

+

791 

+

792 self.last_get_lfp_power_time = 0 # TODO -- is this variable necessary for LFP? 

+

793 

+

794 def get_cont_samples(self, *args, **kwargs): 

+

795 return self.source.get(self.n_pts, self.channels) 

+

796 

+

797 def extract_features(self, cont_samples): 

+

798 n_chan = len(self.channels) 

+

799 

+

800 lfp_power = np.zeros((n_chan * len(self.bands), 1)) 

+

801 for i, band in enumerate(self.bands): 

+

802 b, a = self.filt_coeffs[band] 

+

803 y = lfilter(b, a, cont_samples) 

+

804 lfp_power[i*n_chan:(i+1)*n_chan] = np.log((1. / self.n_pts) * np.sum(y**2, axis=1) + self.epsilon).reshape(-1, 1) 

+

805 

+

806 return lfp_power 

+

807 

+

808 def __call__(self, start_time, *args, **kwargs): 

+

809 cont_samples = self.get_cont_samples(*args, **kwargs) # dims of channels x time 

+

810 lfp_power = self.extract_features(cont_samples) 

+

811 

+

812 self.last_get_lfp_power_time = start_time 

+

813 

+

814 return dict(lfp_power=lfp_power) 

+

815 

+

816 @classmethod 

+

817 def extract_from_file(cls, files, neurows, binlen, units, extractor_kwargs, strobe_rate=60.0): 

+

818 '''Compute lfp power features from a blackrock data file.''' 

+

819 

+

820 nsx_fnames = [name for name in files['blackrock'] if '.ns' in name] 

+

821 

+

822 # interpolate between the rows to 180 Hz 

+

823 if binlen < 1./strobe_rate: 

+

824 interp_rows = [] 

+

825 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

826 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

827 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

828 interp_rows = np.array(interp_rows) 

+

829 else: 

+

830 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

831 interp_rows = neurows[::step] 

+

832 

+

833 # TODO -- for now, use .ns3 or .ns2 file (2 kS/s) 

+

834 nsx_fname = None 

+

835 for fname in nsx_fnames: 

+

836 if '.ns3' in fname: 

+

837 nsx_fname = fname 

+

838 fs_ = 2000 

+

839 if nsx_fname is None: 

+

840 for fname in nsx_fnames: 

+

841 if '.ns2' in fname: 

+

842 nsx_fname = fname 

+

843 fs_ = 1000 

+

844 

+

845 if nsx_fname is None: 

+

846 raise Exception('Need an nsx file --> .ns2 or .ns3 is acceptable. Higher nsx files yield memory errors') 

+

847 extractor_kwargs['fs'] = fs_ 

+

848 

+

849 # default order of 5 seems to cause problems when fs > 1000 

+

850 extractor_kwargs['filt_order'] = 3 

+

851 

+

852 if nsx_fname[-4:] == '.hdf': 

+

853 nsx_hdf_fname = nsx_fname 

+

854 else: 

+

855 nsx_hdf_fname = nsx_fname + '.hdf' 

+

856 if not os.path.isfile(nsx_hdf_fname): 

+

857 # convert .nsx file to hdf file using Blackrock's n2h5 utility 

+

858 from db.tracker import models 

+

859 models.parse_blackrock_file(None, [nsx_fname], ) 

+

860 

+

861 import h5py 

+

862 nsx_hdf = h5py.File(nsx_hdf_fname, 'r') 

+

863 

+

864 # create extractor object 

+

865 f_extractor = LFPButterBPFPowerExtractor(None, **extractor_kwargs) 

+

866 extractor_kwargs = f_extractor.extractor_kwargs 

+

867 

+

868 win_len = f_extractor.win_len 

+

869 bands = f_extractor.bands 

+

870 channels = f_extractor.channels 

+

871 fs = f_extractor.fs 

+

872 

+

873 n_itrs = len(interp_rows) 

+

874 n_chan = len(channels) 

+

875 lfp_power = np.zeros((n_itrs, n_chan * len(bands))) 

+

876 n_pts = int(win_len * fs) 

+

877 

+

878 # for i, t in enumerate(interp_rows): 

+

879 # sample_num = int(t * fs) 

+

880 # # cont_samples = np.zeros((n_chan, n_pts)) 

+

881 

+

882 # # for j, chan in enumerate(channels): 

+

883 # # chan_str = str(chan).zfill(5) 

+

884 # # path = 'channel/channel%s/continuous_set' % chan_str 

+

885 # # cont_samples[j, :] = nsx_hdf.get(path).value[sample_num-n_pts:sample_num] 

+

886 # cont_samples = abs(np.random.randn(n_chan, n_pts)) 

+

887 

+

888 # feats = f_extractor.extract_features(cont_samples).T 

+

889 # print feats 

+

890 # lfp_power[i, :] = f_extractor.extract_features(cont_samples).T 

+

891 

+

892 print(('*' * 40)) 

+

893 print('WARNING: replacing LFP values from .ns3 file with random values!!') 

+

894 print(('*' * 40)) 

+

895 

+

896 lfp_power = abs(np.random.randn(n_itrs, n_chan * len(bands))) 

+

897 

+

898 # TODO -- discard any channel(s) for which the log power in any frequency  

+

899 # bands was ever equal to -inf (i.e., power was equal to 0) 

+

900 # or, perhaps just add a small epsilon inside the log to avoid this 

+

901 # then, remember to do this: extractor_kwargs['channels'] = channels 

+

902 # and reset the units variable 

+

903 

+

904 return lfp_power, units, extractor_kwargs 

+

905 

+

906 

+

907class AIMTMPowerExtractor(LFPMTMPowerExtractor): 

+

908 ''' Multitaper extractor for Plexon analog input channels''' 

+

909 

+

910 feature_type = 'ai_power' 

+

911 

+

912 def __init__(self, source, channels=[], bands=default_bands, win_len=0.2, NW=3, fs=1000, **kwargs): 

+

913 #self.feature_dtype = ('lfp_power', 'f8', (len(channels)*len(bands), 1)) 

+

914 

+

915 self.source = source 

+

916 self.channels = channels 

+

917 self.bands = bands 

+

918 self.win_len = win_len 

+

919 self.NW = NW 

+

920 if source is not None: 

+

921 self.fs = source.source.update_freq 

+

922 else: 

+

923 self.fs = fs 

+

924 

+

925 extractor_kwargs = dict() 

+

926 extractor_kwargs['channels'] = self.channels 

+

927 extractor_kwargs['bands'] = self.bands 

+

928 extractor_kwargs['win_len'] = self.win_len 

+

929 extractor_kwargs['NW'] = self.NW 

+

930 extractor_kwargs['fs'] = self.fs 

+

931 

+

932 

+

933 extractor_kwargs['no_log'] = 'no_log' in kwargs and kwargs['no_log']==True #remove log calculation 

+

934 extractor_kwargs['no_mean'] = 'no_mean' in kwargs and kwargs['no_mean']==True #r 

+

935 self.extractor_kwargs = extractor_kwargs 

+

936 

+

937 self.n_pts = int(self.win_len * self.fs) 

+

938 self.nfft = 2**int(np.ceil(np.log2(self.n_pts))) # nextpow2(self.n_pts) 

+

939 fft_freqs = np.arange(0., fs, float(fs)/self.nfft)[:self.nfft/2 + 1] 

+

940 self.fft_inds = dict() 

+

941 for band_idx, band in enumerate(bands): 

+

942 self.fft_inds[band_idx] = [freq_idx for freq_idx, freq in enumerate(fft_freqs) if band[0] <= freq < band[1]] 

+

943 

+

944 extractor_kwargs['fft_inds'] = self.fft_inds 

+

945 extractor_kwargs['fft_freqs'] = fft_freqs 

+

946 

+

947 self.epsilon = 1e-9 

+

948 

+

949 if extractor_kwargs['no_mean']: #Used in lfp 1D control task 

+

950 self.feature_dtype = ('ai_power', 'f8', (len(channels)*len(fft_freqs), 1)) 

+

951 else: #Else:  

+

952 self.feature_dtype = ('ai_power', 'f8', (len(channels)*len(bands), 1)) 

+

953 

+

954 def __call__(self, start_time, *args, **kwargs): 

+

955 cont_samples = self.get_cont_samples(*args, **kwargs) # dims of channels x time 

+

956 #cont_samples = np.random.randn(len(self.channels), self.n_pts) # change back! 

+

957 lfp_power = self.extract_features(cont_samples) 

+

958 

+

959 return dict(ai_power=lfp_power) 

+

960 

+

961 

+

962 

+

963 @classmethod 

+

964 def extract_from_file(cls, files, neurows, binlen, units, extractor_kwargs, strobe_rate=60.0): 

+

965 ''' 

+

966 Compute binned spike count features 

+

967 

+

968 Parameters 

+

969 ---------- 

+

970 plx: neural data file instance 

+

971 neurows: np.ndarray of shape (T,) 

+

972 Timestamps in the plexon time reference corresponding to bin boundaries 

+

973 binlen: float 

+

974 Length of time over which to sum spikes from the specified cells 

+

975 units: np.ndarray of shape (N, 2) 

+

976 List of units that the decoder will be trained on. The first column specifies the electrode number and the second specifies the unit on the electrode 

+

977 extractor_kwargs: dict  

+

978 Any additional parameters to be passed to the feature extractor. This function is agnostic to the actual extractor utilized 

+

979 strobe_rate: 60.0 

+

980 The rate at which the task sends the sync pulse to the plx file 

+

981 

+

982 Returns 

+

983 ------- 

+

984 ''' 

+

985 if 'plexon' in files: 

+

986 from plexon import plexfile 

+

987 plx = plexfile.openFile(str(files['plexon'])) 

+

988 

+

989 # interpolate between the rows to 180 Hz 

+

990 if binlen < 1./strobe_rate: 

+

991 interp_rows = [] 

+

992 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

993 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

994 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

995 interp_rows = np.array(interp_rows) 

+

996 else: 

+

997 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

998 interp_rows = neurows[::step] 

+

999 

+

1000 

+

1001 # create extractor object 

+

1002 f_extractor = AIMTMPowerExtractor(None, **extractor_kwargs) 

+

1003 extractor_kwargs = f_extractor.extractor_kwargs 

+

1004 

+

1005 win_len = f_extractor.win_len 

+

1006 bands = f_extractor.bands 

+

1007 channels = f_extractor.channels 

+

1008 fs = f_extractor.fs 

+

1009 print(('bands:', bands)) 

+

1010 

+

1011 n_itrs = len(interp_rows) 

+

1012 n_chan = len(channels) 

+

1013 lfp_power = np.zeros((n_itrs, n_chan * len(bands))) 

+

1014 

+

1015 # for i, t in enumerate(interp_rows): 

+

1016 # cont_samples = plx.lfp[t-win_len:t].data[:, channels-1] 

+

1017 # lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

1018 lfp = plx.lfp[:].data[:, channels-1] 

+

1019 n_pts = int(win_len * fs) 

+

1020 for i, t in enumerate(interp_rows): 

+

1021 try: 

+

1022 sample_num = int(t * fs) 

+

1023 cont_samples = lfp[sample_num-n_pts:sample_num, :] 

+

1024 lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

1025 except: 

+

1026 print("Error with LFP decoder training") 

+

1027 print((i, t)) 

+

1028 pass 

+

1029 

+

1030 

+

1031 # TODO -- discard any channel(s) for which the log power in any frequency  

+

1032 # bands was ever equal to -inf (i.e., power was equal to 0) 

+

1033 # or, perhaps just add a small epsilon inside the log to avoid this 

+

1034 # then, remember to do this: extractor_kwargs['channels'] = channels 

+

1035 # and reset the units variable 

+

1036 

+

1037 return lfp_power, units, extractor_kwargs 

+

1038 

+

1039 elif 'blackrock' in files: 

+

1040 raise NotImplementedError 

+

1041 

+

1042 

+

1043class AIAmplitudeExtractor(object): 

+

1044 ''' 

+

1045 Computes the analog input channel amplitude. Out of date... 

+

1046 ''' 

+

1047 

+

1048 feature_type = 'ai_amplitude' 

+

1049 

+

1050 def __init__(self, source, channels=[], win_len=0.1, fs=1000): 

+

1051 self.feature_dtype = ('emg_amplitude', 'u4', (len(channels), 1)) 

+

1052 

+

1053 self.source = source 

+

1054 self.channels = channels 

+

1055 self.win_len = win_len 

+

1056 if source is not None: 

+

1057 self.fs = source.source.update_freq 

+

1058 else: 

+

1059 self.fs = fs 

+

1060 

+

1061 extractor_kwargs = dict() 

+

1062 extractor_kwargs['channels'] = self.channels 

+

1063 extractor_kwargs['fs'] = self.fs 

+

1064 extractor_kwargs['win_len'] = self.win_len 

+

1065 self.extractor_kwargs = extractor_kwargs 

+

1066 

+

1067 self.n_pts = int(self.win_len * self.fs) 

+

1068 

+

1069 def get_cont_samples(self, *args, **kwargs): 

+

1070 return self.source.get(self.n_pts, self.channels) 

+

1071 

+

1072 def extract_features(self, cont_samples): 

+

1073 n_chan = len(self.channels) 

+

1074 emg_amplitude = np.mean(cont_samples,axis=1) 

+

1075 emg_amplitude = emg_amplitude[:,None] 

+

1076 return emg_amplitude 

+

1077 

+

1078 def __call__(self, start_time, *args, **kwargs): 

+

1079 cont_samples = self.get_cont_samples(*args, **kwargs) # dims of channels x time 

+

1080 emg = self.extract_features(cont_samples) 

+

1081 return emg, None 

+

1082 

+

1083class WaveformClusterCountExtractor(FeatureExtractor): 

+

1084 feature_type = 'cluster_counts' 

+

1085 def __init__(self, source, gmm_model_params, n_subbins=1, units=[]): 

+

1086 self.feature_dtype = [('cluster_counts', 'f8', (len(units), n_subbins)), ('bin_edges', 'f8', 2)] 

+

1087 

+

1088 self.source = source 

+

1089 self.gmm_model_params = gmm_model_params 

+

1090 self.n_subbins = n_subbins 

+

1091 self.units = units 

+

1092 self.n_units = len(units) 

+

1093 

+

1094 extractor_kwargs = dict() 

+

1095 extractor_kwargs['n_subbins'] = self.n_subbins 

+

1096 extractor_kwargs['units'] = self.units 

+

1097 extractor_kwargs['gmm_model_params'] = gmm_model_params 

+

1098 self.extractor_kwargs = extractor_kwargs 

+

1099 

+

1100 self.last_get_spike_counts_time = 0 

+

1101 

+

1102 def get_spike_data(self): 

+

1103 ''' 

+

1104 Get the spike timestamps from the neural data source. This function has no type checking,  

+

1105 i.e., it is assumed that the Extractor object was created with the proper source 

+

1106 ''' 

+

1107 return self.source.get() 

+

1108 

+

1109 def get_bin_edges(self, ts): 

+

1110 ''' 

+

1111 Determine the first and last spike timestamps to allow HDF files  

+

1112 created by the BMI to be semi-synchronized with the neural data file 

+

1113 ''' 

+

1114 if len(ts) == 0: 

+

1115 bin_edges = np.array([np.nan, np.nan]) 

+

1116 else: 

+

1117 min_ind = np.argmin(ts['ts']) 

+

1118 max_ind = np.argmax(ts['ts']) 

+

1119 bin_edges = np.array([ts[min_ind]['ts'], ts[max_ind]['ts']]) 

+

1120 

+

1121 def __call__(self, start_time, *args, **kwargs): 

+

1122 

+

1123 spike_data = self.get_spike_data() 

+

1124 if len(spike_data) == 0: 

+

1125 counts = np.zeros([len(self.units), self.n_subbins]) 

+

1126 elif self.n_subbins > 1: 

+

1127 subbin_edges = np.linspace(self.last_get_spike_counts_time, start_time, self.n_subbins+1) 

+

1128 

+

1129 # Decrease the first subbin index to include any spikes that were 

+

1130 # delayed in getting to the task layer due to threading issues 

+

1131 # An acceptable delay is 1 sec or less. Realistically, most delays should be 

+

1132 # on the millisecond order 

+

1133 # subbin_edges[0] -= 1 

+

1134 # subbin_inds = np.digitize(spike_data['arrival_ts'], subbin_edges) 

+

1135 # counts = np.vstack([bin_spikes(ts[subbin_inds == k], self.units) for k in range(1, self.n_subbins+1)]).T 

+

1136 raise NotImplementedError 

+

1137 else: 

+

1138 # TODO pull the waveforms 

+

1139 waveforms = [] 

+

1140 

+

1141 # TODO determine p(class) for each waveform against the model params 

+

1142 counts = np.zeros(self.n_units) 

+

1143 wf_class_probs = [] 

+

1144 for wf in waveforms: 

+

1145 raise NotImplementedError 

+

1146 

+

1147 # counts = bin_spikes(ts, self.units).reshape(-1, 1) 

+

1148 

+

1149 counts = np.array(counts, dtype=np.uint32) 

+

1150 bin_edges = self.get_bin_edges(ts) 

+

1151 self.last_get_spike_counts_time = start_time 

+

1152 

+

1153 return dict(spike_counts=counts, bin_edges=bin_edges) 

+

1154 

+

1155 @classmethod 

+

1156 def extract_from_file(cls, files, neurows, binlen, units, extractor_kwargs, strobe_rate=60.0): 

+

1157 from sklearn.mixture import GMM 

+

1158 if 'plexon' in files: 

+

1159 from plexon import plexfile 

+

1160 plx = plexfile.openFile(str(files['plexon'])) 

+

1161 

+

1162 channels = units[:,0] 

+

1163 channels = np.unique(channels) 

+

1164 np.sort(channels) 

+

1165 

+

1166 spike_chans = plx.spikes[:].data['chan'] 

+

1167 spike_times = plx.spikes[:].data['ts'] 

+

1168 waveforms = plx.spikes[:].waveforms 

+

1169 

+

1170 # construct the feature matrix (n_timepoints, n_units) 

+

1171 # interpolate between the rows to 180 Hz 

+

1172 if binlen < 1./strobe_rate: 

+

1173 interp_rows = [] 

+

1174 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

1175 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

1176 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

1177 interp_rows = np.array(interp_rows) 

+

1178 else: 

+

1179 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

1180 interp_rows = neurows[::step] 

+

1181 

+

1182 # digitize the spike timestamps into interp_rows 

+

1183 spike_bin_ind = np.digitize(spike_times, interp_rows) 

+

1184 spike_counts = np.zeros(len(interp_rows), n_units) 

+

1185 

+

1186 for ch in channels: 

+

1187 ch_waveforms = waveforms[spike_chans == ch] 

+

1188 

+

1189 # cluster the waveforms using a GMM 

+

1190 # TODO pick the number of components in an unsupervised way! 

+

1191 n_components = len(np.nonzero(units[:,0] == ch)[0]) 

+

1192 gmm = GMM(n_components=n_components) 

+

1193 gmm.fit(ch_waveforms) 

+

1194 

+

1195 # store the cluster probabilities back in the same order that the waveforms were extracted 

+

1196 wf_probs = gmm.predict_proba(ch_waveforms) 

+

1197 

+

1198 ch_spike_bin_inds = spike_bin_ind[spike_chans == ch] 

+

1199 ch_inds, = np.nonzero(units[:,0] == ch) 

+

1200 

+

1201 # TODO don't assume the units are sorted! 

+

1202 for bin_ind, wf_prob in zip(ch_spike_bin_inds, wf_probs): 

+

1203 spike_counts[bin_ind, ch_inds] += wf_prob 

+

1204 

+

1205 

+

1206 # discard units that never fired at all 

+

1207 unit_inds, = np.nonzero(np.sum(spike_counts, axis=0)) 

+

1208 units = units[unit_inds,:] 

+

1209 spike_counts = spike_counts[:, unit_inds] 

+

1210 extractor_kwargs['units'] = units 

+

1211 

+

1212 return spike_counts, units, extractor_kwargs 

+

1213 else: 

+

1214 raise NotImplementedError('Not implemented for blackrock/TDT data yet!') 

+

1215 

+

1216 

+

1217 

+

1218def get_butter_bpf_lfp_power(plx, neurows, binlen, units, extractor_kwargs, strobe_rate=60.0): 

+

1219 ''' 

+

1220 Compute lfp power features -- corresponds to LFPButterBPFPowerExtractor. 

+

1221 

+

1222 

+

1223 ''' 

+

1224 

+

1225 # interpolate between the rows to 180 Hz 

+

1226 if binlen < 1./strobe_rate: 

+

1227 interp_rows = [] 

+

1228 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

1229 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

1230 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

1231 interp_rows = np.array(interp_rows) 

+

1232 else: 

+

1233 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

1234 interp_rows = neurows[::step] 

+

1235 

+

1236 

+

1237 # create extractor object 

+

1238 f_extractor = extractor.LFPButterBPFPowerExtractor(None, **extractor_kwargs) 

+

1239 extractor_kwargs = f_extractor.extractor_kwargs 

+

1240 

+

1241 win_len = f_extractor.win_len 

+

1242 bands = f_extractor.bands 

+

1243 channels = f_extractor.channels 

+

1244 fs = f_extractor.fs 

+

1245 

+

1246 n_itrs = len(interp_rows) 

+

1247 n_chan = len(channels) 

+

1248 lfp_power = np.zeros((n_itrs, n_chan * len(bands))) 

+

1249 # for i, t in enumerate(interp_rows): 

+

1250 # cont_samples = plx.lfp[t-win_len:t].data[:, channels-1] 

+

1251 # lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

1252 lfp = plx.lfp[:].data[:, channels-1] 

+

1253 n_pts = int(win_len * fs) 

+

1254 for i, t in enumerate(interp_rows): 

+

1255 sample_num = int(t * fs) 

+

1256 cont_samples = lfp[sample_num-n_pts:sample_num, :] 

+

1257 lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

1258 

+

1259 # TODO -- discard any channel(s) for which the log power in any frequency  

+

1260 # bands was ever equal to -inf (i.e., power was equal to 0) 

+

1261 # or, perhaps just add a small epsilon inside the log to avoid this 

+

1262 # then, remember to do this: extractor_kwargs['channels'] = channels 

+

1263 # and reset the units variable 

+

1264 

+

1265 return lfp_power, units, extractor_kwargs 

+

1266 

+

1267 

+

1268def get_mtm_lfp_power(plx, neurows, binlen, units, extractor_kwargs, strobe_rate=60.0): 

+

1269 ''' 

+

1270 Compute lfp power features -- corresponds to LFPMTMPowerExtractor. 

+

1271 

+

1272 

+

1273 ''' 

+

1274 

+

1275 # interpolate between the rows to 180 Hz 

+

1276 if binlen < 1./strobe_rate: 

+

1277 interp_rows = [] 

+

1278 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

1279 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

1280 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

1281 interp_rows = np.array(interp_rows) 

+

1282 else: 

+

1283 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

1284 interp_rows = neurows[::step] 

+

1285 

+

1286 

+

1287 # create extractor object 

+

1288 f_extractor = extractor.LFPMTMPowerExtractor(None, **extractor_kwargs) 

+

1289 extractor_kwargs = f_extractor.extractor_kwargs 

+

1290 

+

1291 win_len = f_extractor.win_len 

+

1292 bands = f_extractor.bands 

+

1293 channels = f_extractor.channels 

+

1294 fs = f_extractor.fs 

+

1295 print(('bands:', bands)) 

+

1296 

+

1297 n_itrs = len(interp_rows) 

+

1298 n_chan = len(channels) 

+

1299 lfp_power = np.zeros((n_itrs, n_chan * len(bands))) 

+

1300 

+

1301 # for i, t in enumerate(interp_rows): 

+

1302 # cont_samples = plx.lfp[t-win_len:t].data[:, channels-1] 

+

1303 # lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

1304 lfp = plx.lfp[:].data[:, channels-1] 

+

1305 n_pts = int(win_len * fs) 

+

1306 for i, t in enumerate(interp_rows): 

+

1307 sample_num = int(t * fs) 

+

1308 cont_samples = lfp[sample_num-n_pts:sample_num, :] 

+

1309 lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

1310 

+

1311 

+

1312 # TODO -- discard any channel(s) for which the log power in any frequency  

+

1313 # bands was ever equal to -inf (i.e., power was equal to 0) 

+

1314 # or, perhaps just add a small epsilon inside the log to avoid this 

+

1315 # then, remember to do this: extractor_kwargs['channels'] = channels 

+

1316 # and reset the units variable 

+

1317 

+

1318 return lfp_power, units, extractor_kwargs 

+

1319 

+

1320def get_emg_amplitude(plx, neurows, binlen, units, extractor_kwargs, strobe_rate=60.0): 

+

1321 ''' 

+

1322 Compute EMG features. 

+

1323 

+

1324 ''' 

+

1325 

+

1326 # interpolate between the rows to 180 Hz 

+

1327 if binlen < 1./strobe_rate: 

+

1328 interp_rows = [] 

+

1329 neurows = np.hstack([neurows[0] - 1./strobe_rate, neurows]) 

+

1330 for r1, r2 in zip(neurows[:-1], neurows[1:]): 

+

1331 interp_rows += list(np.linspace(r1, r2, 4)[1:]) 

+

1332 interp_rows = np.array(interp_rows) 

+

1333 else: 

+

1334 step = int(binlen/(1./strobe_rate)) # Downsample kinematic data according to decoder bin length (assumes non-overlapping bins) 

+

1335 interp_rows = neurows[::step] 

+

1336 

+

1337 

+

1338 # create extractor object 

+

1339 f_extractor = extractor.EMGAmplitudeExtractor(None, **extractor_kwargs) 

+

1340 extractor_kwargs = f_extractor.extractor_kwargs 

+

1341 

+

1342 win_len = f_extractor.win_len 

+

1343 channels = f_extractor.channels 

+

1344 fs = f_extractor.fs 

+

1345 

+

1346 n_itrs = len(interp_rows) 

+

1347 n_chan = len(channels) 

+

1348 emg = np.zeros((n_itrs, n_chan)) 

+

1349 

+

1350 # for i, t in enumerate(interp_rows): 

+

1351 # cont_samples = plx.lfp[t-win_len:t].data[:, channels-1] 

+

1352 # lfp_power[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

1353 emgraw = plx.analog[:].data[:, channels-1] 

+

1354 n_pts = int(win_len * fs) 

+

1355 for i, t in enumerate(interp_rows): 

+

1356 sample_num = int(t * fs) 

+

1357 cont_samples = emgraw[sample_num-n_pts:sample_num, :] 

+

1358 emg[i, :] = f_extractor.extract_features(cont_samples.T).T 

+

1359 

+

1360 return emg, units, extractor_kwargs 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_feedback_controllers_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_feedback_controllers_py.html new file mode 100644 index 00000000..f379b120 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_feedback_controllers_py.html @@ -0,0 +1,462 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\feedback_controllers.py: 27% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Feedback controllers. In the BMI context, these have three potential 

+

3applications: (1) assisitve/shared control between neural and pure machine sources, 

+

4(2) estimating the "intended" BMI movements of the subject (see clda.py), and 

+

5(3) driving populations of simulated neurons with artificially constructed tuning 

+

6functions.  

+

7''' 

+

8import numpy as np 

+

9 

+

10 

+

11class FeedbackController(object): 

+

12 ''' 

+

13 Abstract class for feedback controllers, only used for type-checking & interface standardization 

+

14 ''' 

+

15 def __init__(self, *args, **kwargs): 

+

16 pass 

+

17 

+

18 def calc_next_state(self, current_state, target_state, mode=None): 

+

19 raise NotImplementedError 

+

20 

+

21 def __call__(self, current_state, target_state, mode=None): 

+

22 raise NotImplementedError 

+

23 

+

24 def get(self, current_state, target_state, mode=None): 

+

25 raise NotImplementedError 

+

26 

+

27 

+

28class LinearFeedbackController(FeedbackController): 

+

29 ''' 

+

30 Generic linear state-feedback controller. Can be time-varying in general and not be related to a specific cost function 

+

31 ''' 

+

32 def __init__(self, A, B, F): 

+

33 ''' 

+

34 Constructor for LinearFeedbackController 

+

35  

+

36 FC for a linear system 

+

37 x_{t+1} = Ax_t + Bu_t 

+

38 where the control input u_t is calculated using linear feedback 

+

39 u_t = F(x_t - x^*) 

+

40 

+

41 Parameters 

+

42 ---------- 

+

43 B : np.mat 

+

44 Input matrix for the system 

+

45 F : np.mat 

+

46 static feedback gain matrix 

+

47 

+

48 Returns 

+

49 ------- 

+

50 LinearFeedbackController instance 

+

51 ''' 

+

52 self.A = A 

+

53 self.B = B 

+

54 self.F = F 

+

55 

+

56 def calc_next_state(self, current_state, target_state, mode=None): 

+

57 ''' 

+

58 Returns x_{t+1} = Ax_t + BF(x* - x_t) 

+

59 

+

60 Parameters 

+

61 ---------- 

+

62 current_state : np.matrix 

+

63 Current state of the system, x_t  

+

64 target_state : np.matrix 

+

65 State that you're trying to steer the system toward, x* 

+

66 mode : object, default=None 

+

67 Select the operational mode of the feedback controller. Ignored in this class 

+

68 

+

69 Returns 

+

70 ------- 

+

71 np.matrix 

+

72 ''' 

+

73 # explicitly cast current_state and target_state to column vectors 

+

74 current_state = np.mat(current_state).reshape(-1,1) 

+

75 target_state = np.mat(target_state).reshape(-1,1) 

+

76 ns = self.A * current_state + self.B * self.F * (target_state - current_state) 

+

77 

+

78 return ns 

+

79 

+

80 def __call__(self, current_state, target_state, mode=None): 

+

81 ''' 

+

82 Calculate the control input u_t = BF(x* - x_t) 

+

83 

+

84 Parameters 

+

85 ---------- 

+

86 (see self.calc_next_state for input argument descriptions) 

+

87 

+

88 Returns 

+

89 ------- 

+

90 np.mat of shape (N, 1) 

+

91 B*u where u_t = F(x^* - x_t) 

+

92 ''' 

+

93 

+

94 

+

95 current_state = np.mat(current_state).reshape(-1,1) 

+

96 target_state = np.mat(target_state).reshape(-1,1) 

+

97 Bu = self.B * self.F * (target_state - current_state) 

+

98 return Bu 

+

99 

+

100 

+

101class LQRController(LinearFeedbackController): 

+

102 '''Linear feedback controller where control gains are set by optimizing a quadratic cost function''' 

+

103 def __init__(self, A, B, Q, R, **kwargs): 

+

104 ''' 

+

105 Constructor for LQRController 

+

106 

+

107 The system should evolve as 

+

108 $$x_{t+1} = Ax_t + Bu_t + w_t; w_t ~ N(0, W)$$ 

+

109 

+

110 with infinite horizon cost  

+

111 $$\sum{t=0}^{+\infty} (x_t - x_target)^T * Q * (x_t - x_target) + u_t^T * R * u_t$$ 

+

112 

+

113 Parameters 

+

114 ---------- 

+

115 A: np.ndarray of shape (n_states, n_states) 

+

116 Model of the state transition matrix of the system to be controlled.  

+

117 B: np.ndarray of shape (n_states, n_controls) 

+

118 Control input matrix of the system to be controlled.  

+

119 Q: np.ndarray of shape (n_states, n_states) 

+

120 Quadratic cost on state 

+

121 R: np.ndarray of shape (n_controls, n_controls) 

+

122 Quadratic cost on control inputs 

+

123 

+

124 Returns 

+

125 ------- 

+

126 LQRController instance 

+

127 ''' 

+

128 self.A = np.mat(A) 

+

129 self.B = np.mat(B) 

+

130 self.Q = np.mat(Q) 

+

131 self.R = np.mat(R) 

+

132 F = self.dlqr(A, B, Q, R, **kwargs) 

+

133 super(LQRController, self).__init__(A, B, F, **kwargs) 

+

134 

+

135 @staticmethod 

+

136 def dlqr(A, B, Q, R, Q_f=None, T=np.inf, max_iter=1000, eps=1e-10, dtype=np.mat): 

+

137 ''' 

+

138 Find the solution to the discrete-time LQR problem 

+

139 

+

140 The system should evolve as 

+

141 $$x_{t+1} = Ax_t + Bu_t + w_t; w_t ~ N(0, W)$$ 

+

142 

+

143 with cost function 

+

144 $$\sum{t=0}^{T} (x_t - x_target)^T * Q * (x_t - x_target) + u_t^T * R * u_t$$ 

+

145 

+

146 The cost function can be either finite or infinite horizion, where finite horizion is assumed if  

+

147 a final const is specified 

+

148 

+

149 Parameters 

+

150 ---------- 

+

151 A: np.ndarray of shape (n_states, n_states) 

+

152 Model of the state transition matrix of the system to be controlled.  

+

153 B: np.ndarray of shape (n_states, n_controls) 

+

154 Control input matrix of the system to be controlled.  

+

155 Q: np.ndarray of shape (n_states, n_states) 

+

156 Quadratic cost on state 

+

157 R: np.ndarray of shape (n_controls, n_controls) 

+

158 Quadratic cost on control inputs 

+

159 Q_f: np.ndarray of shape (n_states, n_states), optional, default=None 

+

160 Final quadratic cost on state at the end of the horizon. Only applies to finite-horizion variants 

+

161 T: int, optional, default = np.inf 

+

162 Control horizon duration. Infinite by default. Must be less than infinity (and Q_f must be specified) 

+

163 to get the finite horizon feedback controllers 

+

164 eps: float, optional, default=1e-10 

+

165 Threshold of change in feedback matrices to define when the Riccatti recursion has converged 

+

166 dtype: callable, optional, default=np.mat 

+

167 Callable function to reformat the feedback matrices  

+

168 

+

169 Returns 

+

170 ------- 

+

171 K: list or matrix 

+

172 Returns a sequence of feedback gains if finite horizon or a single controller if infinite horizon. 

+

173 

+

174 ''' 

+

175 if Q_f == None: 

+

176 Q_f = Q 

+

177 

+

178 if T < np.inf: # Finite horizon 

+

179 K = [None]*T 

+

180 P = Q_f 

+

181 for t in range(0,T-1)[::-1]: 

+

182 K[t] = (R + B.T*P*B).I * B.T*P*A 

+

183 P = Q + A.T*P*A -A.T*P*B*K[t] 

+

184 return dtype(K) 

+

185 else: # Infinite horizon 

+

186 P = Q_f 

+

187 K = np.inf 

+

188 for t in range(max_iter): 

+

189 K_old = K 

+

190 K = (R + B.T*P*B).I * B.T*P*A 

+

191 P = Q + A.T*P*A -A.T*P*B*K 

+

192 if np.linalg.norm(K - K_old) < eps: 

+

193 break 

+

194 return dtype(K) 

+

195 

+

196 

+

197class MultiModalLFC(LinearFeedbackController): 

+

198 ''' 

+

199 A linear feedback controller with different feedback gains in different "modes" 

+

200 ''' 

+

201 def __init__(self, A=None, B=None, F_dict=dict()): 

+

202 ''' 

+

203 Constructor for MultiModalLFC 

+

204 

+

205 Parameters 

+

206 ---------- 

+

207 B : np.mat 

+

208 Input matrix for the system 

+

209 F : dict 

+

210 keys should be control 'modes', values should be np.mat static feedback gain matrices 

+

211 

+

212 Returns 

+

213 ------- 

+

214 MultiModalLFC instance 

+

215 ''' 

+

216 self.A = A 

+

217 self.B = B 

+

218 self.F_dict = F_dict 

+

219 self.F = None 

+

220 

+

221 def calc_next_state(self, current_state, target_state, mode=None): 

+

222 ''' 

+

223 See LinearFeedbackController.calc_next_state for docs 

+

224 ''' 

+

225 self.F = self.F_dict[mode] 

+

226 super(MultiModalLFC, self).calc_next_state(current_state, target_state, mode=mode) 

+

227 

+

228 def __call__(self, current_state, target_state, mode=None): 

+

229 ''' 

+

230 See LinearFeedbackController.__call__ for docs 

+

231 ''' 

+

232 self.F = self.F_dict[mode] 

+

233 super(MultiModalLFC, self).__call__(current_state, target_state, mode=mode) 

+

234 

+

235 

+

236class PIDController(FeedbackController): 

+

237 ''' 

+

238 Linear feedback controller where gains are set directly instead of through a cost function 

+

239 ''' 

+

240 def __init__(self, K_prop, K_deriv, K_int, state_order): 

+

241 ''' 

+

242 Constructor for PIDController 

+

243 

+

244 Parameters 

+

245 ---------- 

+

246 K_prop : float  

+

247 Gain on proportional error 

+

248 K_deriv : float  

+

249 Gain on derivative error  

+

250 K_int : float 

+

251 Gain on integrated error 

+

252 state_order : np.ndarray of shape (N, 1) 

+

253 Specify whether each element of the state vector is a proportional, derivative, or integral state 

+

254 

+

255 Returns 

+

256 ------- 

+

257 PIDController instance 

+

258 ''' 

+

259 self.K_prop = K_prop 

+

260 self.K_deriv = K_deriv 

+

261 self.K_int = K_int 

+

262 

+

263 self.int_terms, = np.nonzero(state_order == -1) 

+

264 self.prop_terms, = np.nonzero(state_order == 0) 

+

265 self.deriv_terms, = np.nonzero(state_order == 1) 

+

266 

+

267 def __call__(self, current_state, target_state): 

+

268 ''' 

+

269 Determine the PID controller output to be added onto the current state 

+

270 

+

271 Parameters 

+

272 ---------- 

+

273 current_state : np.matrix 

+

274 Current state of the system, x_t  

+

275 target_state : np.matrix 

+

276 State that you're trying to steer the system toward, x* 

+

277 

+

278 Returns 

+

279 ------- 

+

280 cmd : np.ndarray of shape (K, 1) 

+

281 K is the number of states in the proportional term.  

+

282 ''' 

+

283 state_diff = target_state - current_state 

+

284 cmd = 0 

+

285 if len(self.prop_terms) > 0: 

+

286 cmd += self.K_prop * state_diff[self.prop_terms] 

+

287 

+

288 if len(self.deriv_terms) > 0: 

+

289 cmd += self.K_deriv * state_diff[self.deriv_terms] 

+

290 

+

291 if len(self.int_terms) > 0: 

+

292 cmd += self.K_int * state_diff[self.int_terms] 

+

293 

+

294 return cmd 

+

295 

+

296 def calc_next_state(self, current_state, target_state, **kwargs): 

+

297 ''' 

+

298 see self.__call__ 

+

299 ''' 

+

300 return self.__call__(current_state, target_state) 

+

301 

+

302 

+

303###################################################################### 

+

304##### Deprecated task/plant-specific controllers for simulations ##### 

+

305###################################################################### 

+

306class CenterOutCursorGoal(object): 

+

307 ''' 

+

308 Cursor controller which moves the cursor toward the target at a constant speed 

+

309 ''' 

+

310 def __init__(self, angular_noise_var=0, gain=0.15): 

+

311 ''' 

+

312 Constructor for CenterOutCursorGoal 

+

313 

+

314 Parameters 

+

315 ---------- 

+

316 angular_noise_var: float, optional, default=0 

+

317 Angular noise is added onto the control direction as a clipped Gaussian distribution with this variance 

+

318 gain: float, optional, default=0.15 

+

319 Speed at which to move the cursor, in m/s 

+

320 

+

321 Returns 

+

322 ------- 

+

323 CenterOutCursorGoal instance 

+

324 ''' 

+

325 self.angular_noise_var = angular_noise_var 

+

326 self.gain = gain 

+

327 

+

328 def get(self, cur_target, cur_pos, keys_pressed=None): 

+

329 # Make sure y-dimension is 0 

+

330 assert cur_pos[1] == 0 

+

331 assert cur_target[1] == 0 

+

332 

+

333 dir_to_targ = cur_target - cur_pos 

+

334 

+

335 if self.angular_noise_var > 0: 

+

336 angular_noise_rad = np.random.normal(0, self.angular_noise_var) 

+

337 while abs(angular_noise_rad) > np.pi: 

+

338 angular_noise_rad = np.random.normal(0, self.angular_noise_var) 

+

339 else: 

+

340 angular_noise_rad = 0 

+

341 angle = np.arctan2(dir_to_targ[2], dir_to_targ[0]) 

+

342 sum_angle = angle + angular_noise_rad 

+

343 return self.gain*np.array([np.cos(sum_angle), np.sin(sum_angle)]) 

+

344 

+

345 

+

346class CenterOutCursorGoalJointSpace2D(CenterOutCursorGoal): 

+

347 '''2-link arm controller which moves the endpoint toward a target position at a constant speed''' 

+

348 def __init__(self, link_lengths, shoulder_anchor, *args, **kwargs): 

+

349 ''' 

+

350 Constructor for CenterOutCursorGoalJointSpace2D 

+

351 

+

352 Parameters 

+

353 ---------- 

+

354 link_lengths:  

+

355 shoulder_anchor:  

+

356 args, kwargs: positional and keyword arguments for parent constructor (CenterOutCursorGoal) 

+

357 

+

358 

+

359 Returns 

+

360 ------- 

+

361 ''' 

+

362 self.link_lengths = link_lengths 

+

363 self.shoulder_anchor = shoulder_anchor 

+

364 super(CenterOutCursorGoalJointSpace2D, self).__init__(*args, **kwargs) 

+

365 

+

366 

+

367 def get(self, cur_target, cur_pos, keys_pressed=None): 

+

368 ''' 

+

369 cur_target and cur_pos should be specified in workspace coordinates 

+

370 ''' 

+

371 vx, vz = super(CenterOutCursorGoalJointSpace2D, self).get(cur_target, cur_pos, keys_pressed) 

+

372 vy = 0 

+

373 

+

374 px, py, pz = cur_pos 

+

375 

+

376 pos = np.array([px, py, pz]) - self.shoulder_anchor 

+

377 vel = np.array([vx, vy, vz]) 

+

378 

+

379 # Convert to joint velocities 

+

380 from riglib.stereo_opengl import ik 

+

381 joint_pos, joint_vel = ik.inv_kin_2D(pos, self.link_lengths[0], self.link_lengths[1], vel) 

+

382 return joint_vel[0]['sh_vabd'], joint_vel[0]['el_vflex'] 

+

383 

+

384class PosFeedbackController(FeedbackController): 

+

385 ''' 

+

386 Dumb controller that just spits back the target 

+

387 ''' 

+

388 def __init__(self, *args, **kwargs): 

+

389 pass 

+

390 

+

391 def calc_next_state(self, current_state, target_state, mode=None): 

+

392 return target_state 

+

393 

+

394 def __call__(self, current_state, target_state, mode=None): 

+

395 return target_state 

+

396 

+

397 def get(self, current_state, target_state, mode=None): 

+

398 return target_state 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_goal_calculators_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_goal_calculators_py.html new file mode 100644 index 00000000..f23fb4b9 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_goal_calculators_py.html @@ -0,0 +1,422 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\goal_calculators.py: 18% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1#!/usr/bin/python 

+

2''' 

+

3Classes to determine the "goal" during a BMI task. Knowledge of the 

+

4task goal is required for many versions of assistive control (see assist.py) and 

+

5many versions of CLDA (see clda.py).  

+

6''' 

+

7import numpy as np 

+

8from . import train 

+

9from riglib import mp_calc 

+

10# from riglib.stereo_opengl import ik # appears to be unused 

+

11import re 

+

12import pickle 

+

13from riglib.bmi import state_space_models 

+

14 

+

15# TODO -- does ssm really need to be passed as an argument into __init__? 

+

16# maybe just make it an optional kwarg for the classes that really need it 

+

17 

+

18class GoalCalculator(object): 

+

19 def reset(self): 

+

20 pass 

+

21 

+

22class Obs_Goal_Calc(GoalCalculator): 

+

23 def __init__(self, ssm=None, **kwargs): 

+

24 self.ssm = ssm 

+

25 import os 

+

26 self.pre_obs = True 

+

27 self.mid_speed = kwargs.pop('mid_targ_speed', 10) 

+

28 self.mid_targ_rad = kwargs.pop('mid_targ_rad', 6) 

+

29 self.targ_cnt = 0 

+

30 self.pre_obs_targ_state = None 

+

31 self.post_obs_targ_state = None 

+

32 

+

33 def clear(self): 

+

34 self.pre_obs_targ_state = None 

+

35 self.post_obs_targ_state = None 

+

36 print('CLEAR') 

+

37 

+

38 def __call__(self, target_pos, **kwargs): 

+

39 #Use q_start th 

+

40 pos = kwargs.pop('q_start') 

+

41 

+

42 #if past obstacle midline:  

+

43 

+

44 if 'center_pos' in kwargs: 

+

45 obstacle_center = kwargs['center_pos'] + (target_pos - kwargs['center_pos'])*.5 

+

46 center = kwargs['center_pos'] 

+

47 else: 

+

48 obstacle_center = target_pos/2. 

+

49 center = np.zeros((3, )) 

+

50 

+

51 target_pos = target_pos.round(1) 

+

52 try: 

+

53 slope = -1*1./((target_pos[2] - center[2])/(target_pos[0]-center[0])) 

+

54 except: 

+

55 slope = np.inf 

+

56 

+

57 #if ((np.abs(slope) != np.inf) and (np.abs(slope) != np.nan) and np.abs(slope)!=0): 

+

58 pre_obs = self.fcn_det(slope, obstacle_center, pos, center, target_pos) 

+

59 #print 'pre_obs: ', pre_obs, slope 

+

60 # else: 

+

61 # #print 'division by zero!' 

+

62 # if target_pos[0] ==0: 

+

63 # #Division by zero 

+

64 # pre_obs = False 

+

65 # if np.abs(pos[2]) < (np.abs(obstacle_center[2])-.2): pre_obs = True 

+

66 # elif target_pos[2] == 0: 

+

67 # pre_obs = False 

+

68 # if np.abs(pos[0]) < (np.abs(obstacle_center[0]) -.2): pre_obs = True 

+

69 # else:  

+

70 # Exception('Not vertical or horiz. line causing divide by zero --> error') 

+

71 

+

72 if pre_obs: 

+

73 if 1: 

+

74 #if self.pre_obs_targ_state is None: 

+

75 obs_ang = np.angle(obstacle_center[0]-center[0] + 1j*(obstacle_center[2]-center[2])) 

+

76 obs_r = np.abs((obstacle_center[0]-center[0]) + 1j*(obstacle_center[2]-center[2])) 

+

77 

+

78 if self.ccw_fcn(pos, obstacle_center): 

+

79 targ_vect_ang = np.pi/2 

+

80 else: 

+

81 targ_vect_ang = -1*np.pi/2 

+

82 

+

83 target_state_pos = obstacle_center + self.mid_targ_rad*(np.array([np.cos(targ_vect_ang+obs_ang), 0, np.sin(targ_vect_ang+obs_ang)])) 

+

84 target_vel = self.mid_speed*np.array([np.cos(obs_ang), 0, np.sin(obs_ang)]) 

+

85 target_state = np.hstack((target_state_pos, target_vel, 1)).reshape(-1, 1) 

+

86 self.pre_obs_targ_state = target_state 

+

87 else: 

+

88 target_state = self.pre_obs_targ_state 

+

89 

+

90 else: 

+

91 if 1: 

+

92 #if self.post_obs_targ_state is None: 

+

93 target_vel = np.zeros_like(target_pos) 

+

94 offset_val = 1 

+

95 target_state = np.hstack([target_pos, target_vel, 1]).reshape(-1, 1) 

+

96 

+

97 if self.pre_obs_targ_state is not None: 

+

98 self.post_obs_targ_state = target_state 

+

99 else: 

+

100 target_state = self.post_obs_targ_state 

+

101 

+

102 error = 0 

+

103 

+

104 # if self.pre_obs != pre_obs: 

+

105 # self.pre_obs = pre_obs 

+

106 # print self.pre_obs, target_state 

+

107 

+

108 return (target_state, error), True 

+

109 

+

110 

+

111 def fcn_det(self, slope, pt_on_line, test_pt, center, target): 

+

112 d_center = np.sqrt(np.sum((test_pt - center)**2)) 

+

113 d_target = np.sqrt(np.sum((test_pt - target)**2)) 

+

114 if d_center < d_target: 

+

115 return True 

+

116 else: 

+

117 return False 

+

118 

+

119 

+

120 # abs_pt = np.abs(pt_on_line) 

+

121 # abs_test = np.abs(test_pt) 

+

122 # slope = -1*np.abs(slope) 

+

123 

+

124 # b = abs_pt[2] - slope*abs_pt[0] 

+

125 

+

126 # if abs_test[2] +0.3 < (b + slope*abs_test[0]): 

+

127 # return True 

+

128 # else: 

+

129 # return False 

+

130 

+

131 # zz = False 

+

132 # if 0 < b: zz = True 

+

133 

+

134 # if test_pt[2] < 0: 

+

135 # eps = -.2 

+

136 # else: 

+

137 # eps = .2 

+

138 

+

139 # test = False 

+

140 # if (test_pt[2]+eps) < ((slope*test_pt[0]) + b): test = True 

+

141 

+

142 # if zz!=test: 

+

143 # return True 

+

144 # else: 

+

145 # return False 

+

146 

+

147 def ccw_fcn(self, pos_test, pos_ref): 

+

148 theta1 = np.angle(pos_test[0] + 1j * pos_test[2]) 

+

149 theta2 = np.angle(pos_ref[0]+ 1j*pos_ref[2]) 

+

150 

+

151 if pos_ref[0] < 0 and pos_test[0] < 0: 

+

152 if pos_ref[2] < 0 and pos_test[2] >0: 

+

153 theta2 += 2*np.pi 

+

154 elif pos_ref[2] > 0 and pos_test[2] < 0: 

+

155 theta1 += 2*np.pi 

+

156 return theta1 > theta2 

+

157 

+

158class ZeroVelocityGoal(GoalCalculator): 

+

159 ''' 

+

160 Assumes that the target state of the BMI is to move to the task-specified position with zero velocity 

+

161 ''' 

+

162 def __init__(self, ssm=None): 

+

163 ''' 

+

164 Constructor for ZeroVelocityGoal 

+

165  

+

166 Parameters 

+

167 ---------- 

+

168 ssm : state_space_models.StateSpace instance 

+

169 The state-space model of the Decoder that is being assisted/adapted. Not needed for this particular method 

+

170  

+

171 Returns 

+

172 ------- 

+

173 ZeroVelocityGoal instance 

+

174 ''' 

+

175 try: 

+

176 self.ssm = ssm() 

+

177 except: 

+

178 self.ssm = ssm 

+

179 

+

180 def __call__(self, target_pos, **kwargs): 

+

181 ''' 

+

182 Calculate the goal state [p, 0, 1] where p is the n-dim position and 0 is the n-dim velocity 

+

183  

+

184 Parameters 

+

185 ---------- 

+

186 target_pos : np.ndarray 

+

187 Optimal position, in generalized coordinates (i.e., need not be cartesian coordinates) 

+

188 kwargs : optional kwargs 

+

189 These are ignored, just present for function call compatibility 

+

190  

+

191 Returns 

+

192 ------- 

+

193 np.ndarray 

+

194 (N, 1) indicating the target state 

+

195 ''' 

+

196 

+

197 # Add zero velocity if needed:  

+

198 n_pos_vel_states = int(self.ssm.n_states) - 1 

+

199 if len(target_pos) < n_pos_vel_states : 

+

200 target_vel = np.zeros_like(target_pos) 

+

201 offset_val = 1 

+

202 target_state = np.hstack([target_pos, target_vel, 1]).reshape(-1, 1) 

+

203 elif len(target_pos) == n_pos_vel_states: 

+

204 target_state = np.hstack([target_pos, 1]).reshape(-1, 1) 

+

205 else: 

+

206 target_state = np.hstack(target_pos).reshape(-1, 1) # don't add offset 

+

207 error = 0 

+

208 return (target_state, error), True 

+

209 

+

210class ZeroVelocityGoal_ismore(GoalCalculator): 

+

211 def __init__(self, ssm=None, pause_states=[]): 

+

212 try: 

+

213 self.ssm = ssm() 

+

214 except: 

+

215 self.ssm = ssm 

+

216 

+

217 self.pause_states = pause_states 

+

218 

+

219 def __call__(self, target_pos, state, **kwargs): 

+

220 if state in self.pause_states: 

+

221 target_state = kwargs['current_state'] 

+

222 error = 0 

+

223 else: 

+

224 n_pos_vel_states = int(self.ssm.n_states) - 1 

+

225 if len(target_pos) < n_pos_vel_states : 

+

226 target_vel = np.zeros_like(target_pos) 

+

227 offset_val = 1 

+

228 target_state = np.hstack([target_pos, target_vel, 1]).reshape(-1, 1) 

+

229 else: 

+

230 target_state = np.hstack([target_pos, 1]).reshape(-1, 1) 

+

231 error = 0 

+

232 return (target_state, error), True 

+

233 

+

234class ZeroVelocityAccelGoal(ZeroVelocityGoal): 

+

235 ''' 

+

236 Similar to ZeroVelocityGoal, but used for a second order system where you also want the goal acceleration to be zero. 

+

237 ''' 

+

238 def __call__(self, target_pos, **kwargs): 

+

239 ''' 

+

240 See ZeroVelocityGoal.__call__ for argument documentation 

+

241 ''' 

+

242 target_vel = np.zeros_like(target_pos) 

+

243 target_acc = np.zeros_like(target_pos) 

+

244 offset_val = 1 

+

245 error = 0 

+

246 target_state = np.hstack([target_pos, target_vel, target_acc, 1]) 

+

247 return (target_state, error), True 

+

248 

+

249 

+

250class PlanarMultiLinkJointGoal(GoalCalculator, mp_calc.FuncProxy): 

+

251 ''' 

+

252 Looks up goal configuration for a redundant system based on the endpoint goal and tries to find the closest solution. 

+

253 

+

254 DEPRECATED: The method implemented has not been used for a long time, and is not the best method for achieving finding the "closest" config space solution as desired. 

+

255 ''' 

+

256 def __init__(self, ssm, shoulder_anchor, kin_chain, multiproc=False, init_resp=None): 

+

257 def fn(target_pos, **kwargs): 

+

258 joint_pos = kin_chain.inverse_kinematics(target_pos, **kwargs) 

+

259 endpt_error = np.linalg.norm(kin_chain.endpoint_pos(joint_pos) - target_pos) 

+

260 

+

261 target_state = np.hstack([joint_pos, np.zeros_like(joint_pos), 1]) 

+

262 

+

263 return target_state, endpt_error 

+

264 super(PlanarMultiLinkJointGoal, self).__init__(fn, multiproc=multiproc, waiting_resp='prev', init_resp=init_resp) 

+

265 

+

266class PlanarMultiLinkJointGoalCached(GoalCalculator, mp_calc.FuncProxy): 

+

267 ''' 

+

268 Determine the goal state of a redundant system by look-up-table, i.e. redundancy is collapsed  

+

269 by arbitrary mapping between redudnant target space and configuration space 

+

270 

+

271 TODO: since multiprocessing is not required for this class, it needs to do a better job of hiding the multiprocessing. 

+

272 ''' 

+

273 def __init__(self, ssm, shoulder_anchor, kin_chain, multiproc=False, init_resp=None, **kwargs): 

+

274 ''' 

+

275 Constructor for PlanarMultiLinkJointGoalCached 

+

276 

+

277 Parameters 

+

278 ---------- 

+

279 ssm : state_space_models.StateSpace instance 

+

280 shoulder_anchor : np.ndarray of shape (3,) 

+

281 Position of the manipulator anchor 

+

282 kin_chain : robot_arms.KinematicChain instance 

+

283 Object representing the kinematic chain linkages (D-H parameters) 

+

284 multiproc : bool, optional, default=False 

+

285 Should leave this false for this 'cached' method 

+

286 init_resp : None 

+

287 Ignore this if multiproc=False, as directed above. 

+

288 kwargs : optional keyword arguments 

+

289 Can pass in 'goal_cache_block' to specify from which task entry to grab the cache file. WARNING: this is currently commented out 

+

290 

+

291 Returns 

+

292 ------- 

+

293 PlanarMultiLinkJointGoalCached instance 

+

294 

+

295 ''' 

+

296 self.ssm = ssm 

+

297 self.shoulder_anchor = shoulder_anchor 

+

298 self.kin_chain = kin_chain 

+

299 if 0: #'goal_cache_block' in kwargs: 

+

300 goal_cache_block = kwargs.pop('goal_cache_block') 

+

301 self.cached_data = pickle.load(open('/storage/assist_params/tentacle_cache_%d.pkl' % goal_cache_block)) 

+

302 else: 

+

303 self.cached_data = pickle.load(open('/storage/assist_params/tentacle_cache3.pkl')) 

+

304 

+

305 def fn(target_pos, **kwargs): 

+

306 ''' Docstring ''' 

+

307 joint_pos = None 

+

308 for pos in self.cached_data: 

+

309 if np.linalg.norm(target_pos - np.array(pos)) < 0.001: 

+

310 possible_joint_pos = self.cached_data[pos] 

+

311 ind = np.random.randint(0, len(possible_joint_pos)) 

+

312 joint_pos = possible_joint_pos[ind] 

+

313 break 

+

314 

+

315 if joint_pos is None: 

+

316 raise ValueError("Unknown target position!: %s" % str(target_pos)) 

+

317 

+

318 target_state = np.hstack([joint_pos, np.zeros_like(joint_pos), 1]) 

+

319 

+

320 int_endpt_location = target_pos - shoulder_anchor 

+

321 endpt_error = np.linalg.norm(kin_chain.endpoint_pos(-joint_pos) - int_endpt_location) 

+

322 print(endpt_error) 

+

323 

+

324 return (target_state, endpt_error) 

+

325 

+

326 super(PlanarMultiLinkJointGoalCached, self).__init__(fn, multiproc=multiproc, waiting_resp='prev', init_resp=init_resp) 

+

327 

+

328 def __call__(self, target_pos, **kwargs): 

+

329 ''' 

+

330 Calculate the goal state [p, 0, 1] where p is the n-dim position and 0 is the n-dim velocity 

+

331 p is the configuration space position, which must be looked up based on the target_pos  

+

332 (not a one-to-one mapping in general) 

+

333  

+

334 Parameters 

+

335 ---------- 

+

336 target_pos : np.ndarray 

+

337 Optimal position, in generalized coordinates (i.e., need not be cartesian coordinates) 

+

338 kwargs : optional kwargs 

+

339 These are ignored, just present for function call compatibility 

+

340  

+

341 Returns 

+

342 ------- 

+

343 np.ndarray 

+

344 (N, 1) indicating the target state 

+

345 ''' 

+

346 joint_pos = None 

+

347 for pos in self.cached_data: 

+

348 if np.linalg.norm(target_pos - np.array(pos)) < 0.001: 

+

349 joint_pos = self.cached_data[pos] 

+

350 

+

351 if joint_pos == None: 

+

352 raise ValueError("Unknown target position!: %s" % str(target_pos)) 

+

353 

+

354 target_state = np.hstack([joint_pos, np.zeros_like(joint_pos), 1]) 

+

355 

+

356 endpt_error = 0 

+

357 

+

358 return (target_state, endpt_error), True 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_kfdecoder_fcns_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_kfdecoder_fcns_py.html new file mode 100644 index 00000000..fd22e66e --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_kfdecoder_fcns_py.html @@ -0,0 +1,476 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\kfdecoder_fcns.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1 

+

2from db.tracker import models 

+

3from db import dbfunctions as dbfn 

+

4from db.tracker.models import Decoder 

+

5from db import trainbmi 

+

6import numpy as np 

+

7import scipy 

+

8from .kfdecoder import KalmanFilter, KFDecoder 

+

9from . import train 

+

10import pickle 

+

11import re 

+

12import tables 

+

13 

+

14########## MAIN DECODER MANIPULATION METHODS ################# 

+

15 

+

16def add_rm_units(task_entry_id, units, add_or_rm, flag_added_for_adaptation, name_suffix='', decoder_entry_id=None, **kwargs): 

+

17 ''' 

+

18 Summary: Method to add or remove units from KF decoder.  

+

19 Takes in task_entry_id number or decoder_entry_id to get decoder 

+

20 Removes or adds units to decoder 

+

21 If adds, sets decoder weights to random entries 

+

22 

+

23 

+

24 Input param: task_entry_id: Decoder = dbfn.TaskEntry(task_entry_id).get_decoders_trained_in_block() 

+

25 Input param: units: list of units to add or remove 

+

26 Input param: add_or_rm: 'add' or 'rm'  

+

27 Input param: flag_added_for_adaptation: whether or not to flag newly added units as adapting inds 

+

28 Input param: name_suffix: new decoder suffix. If empty, will append 'add_or_rm_units_len(units)' 

+

29 Input param: decoder_entry_id: used if more than 1 decoder training on block 

+

30 

+

31 ''' 

+

32 if 'decoder_path' in kwargs: 

+

33 kfdec = pickle.load(open(kwargs['decoder_path'])) 

+

34 else: 

+

35 kfdec = get_decoder_corr(task_entry_id, decoder_entry_id) 

+

36 

+

37 if add_or_rm is 'add': 

+

38 kfdec_new , n_new_units = add_units(kfdec, units) 

+

39 

+

40 # Only Adapt new units:  

+

41 if flag_added_for_adaptation: 

+

42 kfdec_new.adapting_neural_inds = np.arange(len(kfdec_new.units)-len(units), len(kfdec_new.units)) 

+

43 

+

44 save_new_dec(task_entry_id, kfdec_new, name_suffix+'_add_'+str(n_new_units)+'_units') 

+

45 

+

46 elif add_or_rm is 'rm': 

+

47 orig_units = kfdec.units 

+

48 inds_to_keep = proc_units(kfdec, units, 'remove') 

+

49 if len(orig_units) == len(inds_to_keep): 

+

50 print(' units cannot be removed since theyre not in original decoder', orig_units) 

+

51 else: 

+

52 dec_new = return_proc_units_decoder(kfdec, inds_to_keep) 

+

53 save_new_dec(task_entry_id, dec_new, name_suffix+'_rm_'+str(len(units))+'_units') 

+

54 

+

55def flag_adapting_inds_for_CLDA(task_entry_id, state_names_to_adapt=None, units_to_adapt = None, decoder_entry_id=None): 

+

56 

+

57 decoder = get_decoder_corr(task_entry_id, decoder_entry_id) 

+

58 state_adapting_inds = [] 

+

59 for s in state_names: 

+

60 ix = np.nonzero(decoder.states == s)[0] 

+

61 assert len(ix) == 0 

+

62 state_adapting_inds.append(int(ix)) 

+

63 decoder.adapting_state_inds = np.array(adapting_inds) 

+

64 

+

65 neural_adapting_inds = [] 

+

66 for u in units_to_adapt: 

+

67 uix = np.nonzero(np.logical_and(decoder.units[:, 0]== u[0], decoder.units[:, 1]== u[1]))[0] 

+

68 neural_adapting_inds.append(int(uix)) 

+

69 decoder.adapting_neural_inds = np.array(neural_adapting_inds) 

+

70 save_new_dec(task_entry_id, decoder, '_adapt_only_'+str(len(units_to_adapt))+'_units_'+str(len(state_names_to_adapt))+'_states') 

+

71 

+

72def zscore_units(task_entry_id, calc_zscore_from_te, pos_key = 'cursor', decoder_entry_id=None, 

+

73 training_method=train.train_KFDecoder, retrain_flag = False, **kwargs): 

+

74 ''' 

+

75 Summary: Method to be able to 'convert' a trained decoder (that uses zscoring) to one that uses z-scored from another session 

+

76 (e.g. you train a decoder from VFB, but you want to zscore unit according to a passive / still session earlier). You  

+

77 would use the task_entry_id that was used to train the decoder OR entry that used the decoder. Then 'calc_zscore_from_te' 

+

78 is the task entry ID used to compute the z-scored units. You can either retrain the decoder iwth the new z-scored units, or not 

+

79 

+

80 Input param: task_entry_id: 

+

81 Input param: decoder_entry_id: 

+

82 Input param: calc_zscore_from_te: 

+

83 Output param:  

+

84 ''' 

+

85 if 'decoder_path' in kwargs: 

+

86 decoder = pickle.load(open(kwargs['decoder_path'])) 

+

87 else: 

+

88 decoder = get_decoder_corr(task_entry_id, decoder_entry_id) 

+

89 

+

90 assert (hasattr(decoder, 'zscore') and decoder.zscore is True)," Cannot update mFR /sdFR of decoder that was not trained as zscored decoder. Retrain!" 

+

91 

+

92 # Init mFR / sdFR 

+

93 if 'hdf_path' in kwargs: 

+

94 hdf = tables.openFile(kwargs['hdf_path']) 

+

95 else: 

+

96 hdf = dbfn.TaskEntry(calc_zscore_from_te).hdf 

+

97 

+

98 # Get HDF update rate from hdf file.  

+

99 try: 

+

100 hdf_update_rate = np.round(np.mean(hdf.root.task[:]['loop_time'])*1000.)/1000. 

+

101 except: 

+

102 from config import config 

+

103 if config.recording_system == 'blackrock': 

+

104 hdf_update_rate = .05; 

+

105 elif config.recording_system == 'plexon': 

+

106 hdf_update_rate = 1/60. 

+

107 

+

108 spk_counts = hdf.root.task[:]['spike_counts'][:, :, 0] 

+

109 

+

110 # Make sure not repeated entries: 

+

111 sum_spk_counts = np.sum(spk_counts, axis=1) 

+

112 ix = np.nonzero(sum_spk_counts)[0][0] 

+

113 sample = 1+ sum_spk_counts[ix:ix+6] - sum_spk_counts[ix] 

+

114 assert np.sum(sample) != 6 

+

115 

+

116 decoder_update_rate = decoder.binlen 

+

117 bin_spks, _ = bin_(None, spk_counts.T, hdf_update_rate, decoder_update_rate, only_neural=True) 

+

118 mFR = np.squeeze(np.mean(bin_spks, axis=1)) 

+

119 sdFR = np.std(bin_spks, axis=1) 

+

120 kwargs2 = dict(mFR=mFR, sdFR=sdFR) 

+

121 

+

122 #Retrain decoder w/ new zscoring: 

+

123 if 'te_id' in kwargs: 

+

124 training_id = kwargs['te_id'] 

+

125 else: 

+

126 training_id = decoder.te_id 

+

127 

+

128 

+

129 if retrain_flag: 

+

130 raise NotImplementedError("Need to test retraining with real data") 

+

131 saved_files = models.DataFile.objects.filter(entry_id=training_id) 

+

132 files = {} 

+

133 for fl in saved_files: 

+

134 files[fl.system.name] = fl.get_path() 

+

135 import bmilist 

+

136 decoder = training_method(files, decoder.extractor_cls, decoder.extractor_kwargs, bmilist.kin_extractors[''], decoder.ssm, 

+

137 decoder.units, update_rate=decoder.binlen, tslice=decoder.tslice, pos_key=pos_key, zscore=True, **kwargs2) 

+

138 suffx = '_zscore_set_from_'+str(calc_zscore_from_te)+'_retrained' 

+

139 else: 

+

140 decoder.mFR = 0. 

+

141 decoder.sdFR = 1. 

+

142 decoder.init_zscore(mFR, sdFR) 

+

143 decoder.mFR = mFR 

+

144 decoder.sdFR = sdFR 

+

145 

+

146 suffx = '_zscore_set_from_'+str(calc_zscore_from_te) 

+

147 

+

148 if task_entry_id is not None: 

+

149 save_new_dec(task_entry_id, decoder, suffx) 

+

150 else: 

+

151 return decoder, suffx 

+

152 

+

153def adj_state_noise(task_entry_id, decoder_entry_id, new_w): 

+

154 decoder = get_decoder_corr(task_entry_id, decoder_entry_id, return_used_te=True) 

+

155 W = np.diag(decoder.filt.W) 

+

156 wix = np.nonzero(W)[0] 

+

157 W[wix] = new_w 

+

158 W_new = np.diag(W) 

+

159 decoder.filt.W = W_new 

+

160 save_new_dec(task_entry_id, decoder, '_W_new_'+str(new_w)) 

+

161 

+

162########################## HELPER DECODER MANIPULATION METHODS ################################# 

+

163 

+

164def get_decoder_corr(task_entry_id, decoder_entry_id, get_dec_used=True): 

+

165 ''' 

+

166 Summary: get KF decoder either from entry that has trained the decoder (if this, need decoder_entry_id if > 1 decoder),  

+

167 or decoder that was used during task_entry_id 

+

168 Input param: task_entry_id: dbname task entry ID 

+

169 Input param: decoder_entry_id: decoder entry id: (models.Decoder.objects.get(entry=entry)) 

+

170 Output param: KF Decoder 

+

171 ''' 

+

172 ld = True 

+

173 if get_dec_used is False: 

+

174 decoder_entries = dbfn.TaskEntry(task_entry_id).get_decoders_trained_in_block() 

+

175 if len(decoder_entries) > 0: 

+

176 print('Loading decoder TRAINED from task %d'%task_entry_id) 

+

177 if type(decoder_entries) is models.Decoder: 

+

178 decoder = decoder_entries 

+

179 ld = False 

+

180 else: # list of decoders. Search for the right one.  

+

181 try: 

+

182 dec_ids = [de.pk for de in decoder_entries] 

+

183 _ix = np.nonzero(dec_ids==decoder_entry_id)[0] 

+

184 decoder = decoder_entries[_ix] 

+

185 ld = False 

+

186 except: 

+

187 if decoder_entry_id is None: 

+

188 print('Too many decoder entries trained from this TE, specify decoder_entry_id') 

+

189 else: 

+

190 print('Too many decoder entries trained from this TE, no match to decoder_entry_id %d'%decoder_entry_id) 

+

191 if ld is False: 

+

192 kfdec = decoder.load() 

+

193 else: 

+

194 try: 

+

195 kfdec = dbfn.TaskEntry(task_entry_id).decoder 

+

196 print('Loading decoder USED in task %s'%dbfn.TaskEntry(task_entry_id).task) 

+

197 except: 

+

198 raise Exception('Cannot load decoder from TE%d'%task_entry_id) 

+

199 return kfdec 

+

200 

+

201def add_units(kfdec, units): 

+

202 ''' 

+

203 Add units to KFDecoder, e.g. to account for appearance of new cells  

+

204 on a particular day, will need to do CLDA to fit new deocder weight 

+

205  

+

206 Parameters:  

+

207 units: string or np.ndarray of shape (N, 2) of units to ADD to current decoder 

+

208 ''' 

+

209 units_curr = kfdec.units 

+

210 new_units = proc_units(kfdec, units, 'to_int') 

+

211 

+

212 keep_ix = [] 

+

213 for r, r_un in enumerate(new_units): 

+

214 if len(np.nonzero(np.all(r_un==units_curr, axis=1))[0]) > 0: 

+

215 print('not adding unit ', r_un, ' -- already in decoder') 

+

216 else: 

+

217 keep_ix.append(r) 

+

218 

+

219 new_units = np.array(new_units)[keep_ix, :] 

+

220 units = np.vstack((units_curr, new_units)) 

+

221 n_states = kfdec.filt.C.shape[1] 

+

222 n_features = len(units) 

+

223 

+

224 C = np.vstack(( kfdec.filt.C, 1e-3*np.random.randn(len(new_units), kfdec.ssm.n_states))) 

+

225 Q = np.eye( len(units), len(units) ) 

+

226 Q[np.ix_(np.arange(len(units_curr)), np.arange(len(units_curr)))] = kfdec.filt.Q 

+

227 Q_inv = np.linalg.inv(Q) 

+

228 

+

229 if isinstance(kfdec.mFR, np.ndarray): 

+

230 mFR = np.hstack(( kfdec.mFR, np.zeros((len(new_units))) )) 

+

231 sdFR = np.hstack(( kfdec.sdFR, np.ones((len(new_units))) )) 

+

232 else: 

+

233 mFR = kfdec.mFR 

+

234 sdFR = kfdec.sdFR 

+

235 

+

236 filt = KalmanFilter(A=kfdec.filt.A, W=kfdec.filt.W, C=C, Q=Q, is_stochastic=kfdec.filt.is_stochastic) 

+

237 C_xpose_Q_inv = C.T * Q_inv 

+

238 C_xpose_Q_inv_C = C.T * Q_inv * C 

+

239 filt.C_xpose_Q_inv = C_xpose_Q_inv 

+

240 filt.C_xpose_Q_inv_C = C_xpose_Q_inv_C 

+

241 

+

242 filt.R = kfdec.filt.R 

+

243 ix = np.random.permutation(n_features)[:len(new_units)] 

+

244 filt.S = np.vstack(( kfdec.filt.S, kfdec.filt.S[ix, :])) 

+

245 filt.T = Q + filt.S * filt.S.T 

+

246 filt.ESS = kfdec.filt.ESS 

+

247 

+

248 decoder = KFDecoder(filt, units, kfdec.ssm, mFR=mFR, sdFR=sdFR, binlen=kfdec.binlen, tslice=kfdec.tslice) 

+

249 decoder.n_features = units.shape[0] 

+

250 decoder.units = units 

+

251 decoder.extractor_cls = kfdec.extractor_cls 

+

252 decoder.extractor_kwargs = kfdec.extractor_kwargs 

+

253 try: 

+

254 CE = kfdec.corresp_encoder 

+

255 CE.C = np.vstack((CE.C, 3*np.random.randn(len(new_units), CE.C.shape[1]))) 

+

256 Q = .1*np.eye(len(units)) 

+

257 Q[:len(units_curr), :len(units_curr)] = CE.Q 

+

258 CE.Q = Q 

+

259 CE.n_features = len(units) 

+

260 decoder.corresp_encoder = CE 

+

261 print('adjusted corresp_encoder too!') 

+

262 except: 

+

263 pass 

+

264 decoder.extractor_kwargs['units'] = units 

+

265 return decoder, len(keep_ix) 

+

266 

+

267def proc_units(kfdec, units, mode): 

+

268 ''' 

+

269 Parse list of units indices to keep from string or np.ndarray of shape (N, 2) 

+

270 Inputs:  

+

271 units --  

+

272 mode -- can be 'keep' or 'remove' or 'to_int'. Tells function what to do with the units 

+

273 ''' 

+

274 

+

275 if isinstance(units[0], str): 

+

276 # convert to array 

+

277 if isinstance(units, str): 

+

278 units = units.split(', ') 

+

279 

+

280 units_lut = dict(a=1, b=2, c=3, d=4, e=5, f=6, g=7, h=8, i=9, j=10, k=11) 

+

281 units_int = [] 

+

282 for u in units: 

+

283 ch = int(re.match('(\d+)([a-k])', u).group(1)) 

+

284 unit_ind = re.match('(\d+)([a-k])', u).group(2) 

+

285 # import pdb; pdb.set_trace() 

+

286 units_int.append((ch, units_lut[unit_ind])) 

+

287 

+

288 units = units_int 

+

289 

+

290 if mode == 'to_int': 

+

291 return units 

+

292 

+

293 inds_to_keep = [] 

+

294 new_units = list(map(tuple, units)) 

+

295 for k, old_unit in enumerate(kfdec.units): 

+

296 if mode == 'keep': 

+

297 if tuple(old_unit) in new_units: 

+

298 inds_to_keep.append(k) 

+

299 elif mode == 'remove': 

+

300 if tuple(old_unit) not in new_units: 

+

301 inds_to_keep.append(k) 

+

302 return inds_to_keep 

+

303 

+

304def return_proc_units_decoder(kfdec, inds_to_keep): 

+

305 A = kfdec.filt.A 

+

306 W = kfdec.filt.W 

+

307 C = kfdec.filt.C 

+

308 Q = kfdec.filt.Q 

+

309 print('Indices to keep: ', inds_to_keep) 

+

310 C = C[inds_to_keep, :] 

+

311 Q = Q[np.ix_(inds_to_keep, inds_to_keep)] 

+

312 Q_inv = np.linalg.inv(Q) 

+

313 

+

314 if isinstance(kfdec.mFR, np.matrix): 

+

315 mFR = np.squeeze(np.array(kfdec.mFR))[inds_to_keep] 

+

316 sdFR = np.squeeze(np.array(kfdec.sdFR))[inds_to_keep] 

+

317 

+

318 elif isinstance(kfdec.mFR, np.ndarray): 

+

319 mFR = kfdec.mFR[inds_to_keep] 

+

320 sdFR = kfdec.mFR[inds_to_keep] 

+

321 else: 

+

322 mFR = kfdec.mFR 

+

323 sdFR = kfdec.sdFR 

+

324 

+

325 filt = KalmanFilter(A=A, W=W, C=C, Q=Q, is_stochastic=kfdec.filt.is_stochastic) 

+

326 C_xpose_Q_inv = C.T * Q_inv 

+

327 C_xpose_Q_inv_C = C.T * Q_inv * C 

+

328 filt.C_xpose_Q_inv = C_xpose_Q_inv 

+

329 filt.C_xpose_Q_inv_C = C_xpose_Q_inv_C 

+

330 

+

331 units = kfdec.units[inds_to_keep] 

+

332 

+

333 filt.R = kfdec.filt.R 

+

334 filt.S = kfdec.filt.S[inds_to_keep, :] 

+

335 filt.T = kfdec.filt.T[np.ix_(inds_to_keep, inds_to_keep)] 

+

336 filt.ESS = kfdec.filt.ESS 

+

337 

+

338 decoder = KFDecoder(filt, units, kfdec.ssm, mFR=mFR, sdFR=sdFR, binlen=kfdec.binlen, tslice=kfdec.tslice) 

+

339 

+

340 decoder.n_features = units.shape[0] 

+

341 decoder.units = units 

+

342 decoder.extractor_cls = kfdec.extractor_cls 

+

343 decoder.extractor_kwargs = kfdec.extractor_kwargs 

+

344 decoder.extractor_kwargs['units'] = units 

+

345 try: 

+

346 CE = kfdec.corresp_encoder 

+

347 CE.C = CE.C[inds_to_keep, :] 

+

348 CE.Q = CE.Q[np.ix_(inds_to_keep, inds_to_keep)] 

+

349 CE.n_features = len(units) 

+

350 decoder.corresp_encoder = CE 

+

351 print('adjusted corresp_encoder too!') 

+

352 except: 

+

353 pass 

+

354 

+

355 return decoder 

+

356 

+

357def save_new_dec(task_entry_id, dec_obj, suffix): 

+

358 ''' 

+

359 Summary: Method to save decoder to DB -- saves to TE that original decoder came from 

+

360 Input param: task_entry_id: original task to save decoder to 

+

361 Input param: dec_obj: KF decoder new 

+

362 Input param: suffix: 

+

363 Output param:  

+

364 ''' 

+

365 

+

366 te = dbfn.TaskEntry(task_entry_id) 

+

367 try: 

+

368 te_id = te.te_id 

+

369 except: 

+

370 dec_nm = te.name 

+

371 te_ix = re.search('te[0-9]',dec_nm) 

+

372 ix = te_ix.start() + 2 

+

373 sub_dec_nm = dec_nm[ix:] 

+

374 

+

375 te_ix_end = sub_dec_nm.find('_') 

+

376 if te_ix_end == -1: 

+

377 te_ix_end = len(sub_dec_nm) 

+

378 te_id = int(sub_dec_nm[:te_ix_end]) 

+

379 

+

380 old_dec_obj = te.decoder_record 

+

381 if old_dec_obj is None: 

+

382 old_dec_obj = faux_decoder_obj(task_entry_id) 

+

383 trainbmi.save_new_decoder_from_existing(dec_obj, old_dec_obj, suffix=suffix) 

+

384 

+

385def bin_(kin, neural_features, update_rate, desired_update_rate, only_neural=False): 

+

386 

+

387 n = desired_update_rate/float(update_rate) 

+

388 if not only_neural: 

+

389 assert kin.shape[1] == neural_features.shape[1] 

+

390 

+

391 ix_end = int(np.floor(neural_features.shape[1] / n)*n) 

+

392 

+

393 if (n - round(n)) < 1e-5: 

+

394 n = int(n) 

+

395 if not only_neural: 

+

396 kin_ = kin[:, :ix_end].reshape(kin[:, :ix_end].shape[0], kin[:, :ix_end].shape[1]/n, n) 

+

397 bin_kf = np.mean(kin_, axis=2) 

+

398 

+

399 nf_ = neural_features[:, :ix_end].reshape(neural_features[:, :ix_end].shape[0], neural_features[:, :ix_end].shape[1]/n, n) 

+

400 bin_nf = np.sum(nf_, axis=2) 

+

401 

+

402 if only_neural: 

+

403 return bin_nf, desired_update_rate 

+

404 else: 

+

405 return bin_nf, bin_kf, desired_update_rate 

+

406 else: 

+

407 raise Exception('Desired rate %f not multiple of original rate %f', desired_update_rate, update_rate) 

+

408 

+

409class faux_decoder_obj(object): 

+

410 def __init__(self, task_entry_id, *args,**kwargs): 

+

411 self.name = '' 

+

412 self.entry_id = task_entry_id 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_kfdecoder_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_kfdecoder_py.html new file mode 100644 index 00000000..71931d4b --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_kfdecoder_py.html @@ -0,0 +1,979 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\kfdecoder.py: 27% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Classes for BMI decoding using the Kalman filter.  

+

3''' 

+

4 

+

5import numpy as np 

+

6from scipy.io import loadmat 

+

7 

+

8from . import bmi 

+

9import pickle 

+

10import re 

+

11 

+

12class KalmanFilter(bmi.GaussianStateHMM): 

+

13 """ 

+

14 Low-level KF, agnostic to application 

+

15 

+

16 Model:  

+

17 x_{t+1} = Ax_t + w_t; w_t ~ N(0, W) 

+

18 y_t = Cx_t + q_t; q_t ~ N(0, Q) 

+

19 """ 

+

20 model_attrs = ['A', 'W', 'C', 'Q', 'C_xpose_Q_inv', 'C_xpose_Q_inv_C'] 

+

21 attrs_to_pickle = ['A', 'W', 'C', 'Q', 'C_xpose_Q_inv', 'C_xpose_Q_inv_C', 'R', 'S', 'T', 'ESS'] 

+

22 

+

23 def __init__(self, A=None, W=None, C=None, Q=None, is_stochastic=None): 

+

24 ''' 

+

25 Constructor for KalmanFilter  

+

26 

+

27 Parameters 

+

28 ---------- 

+

29 A : np.mat, optional 

+

30 Model of state transition matrix 

+

31 W : np.mat, optional 

+

32 Model of process noise covariance 

+

33 C : np.mat, optional 

+

34 Model of conditional distribution between observations and hidden state 

+

35 Q : np.mat, optional 

+

36 Model of observation noise covariance 

+

37 is_stochastic : np.array, optional 

+

38 Array of booleans specifying for each state whether it is stochastic.  

+

39 If 'None' specified, all states are assumed to be stochastic 

+

40 

+

41 Returns 

+

42 ------- 

+

43 KalmanFilter instance 

+

44 ''' 

+

45 if A is None and W is None and C is None and Q is None: 

+

46 ## This condition should only be true in the unpickling phase 

+

47 pass 

+

48 else: 

+

49 self.A = np.mat(A) 

+

50 self.W = np.mat(W) 

+

51 self.C = np.mat(C) 

+

52 self.Q = np.mat(Q) 

+

53 

+

54 if is_stochastic is None: 

+

55 n_states = self.A.shape[0] 

+

56 self.is_stochastic = np.ones(n_states, dtype=bool) 

+

57 else: 

+

58 self.is_stochastic = is_stochastic 

+

59 

+

60 self.state_noise = bmi.GaussianState(0.0, self.W) 

+

61 self.obs_noise = bmi.GaussianState(0.0, self.Q) 

+

62 self._pickle_init() 

+

63 

+

64 def _pickle_init(self): 

+

65 """Code common to unpickling and initialization 

+

66 """ 

+

67 nS = self.A.shape[0] 

+

68 offset_row = np.zeros(nS) 

+

69 offset_row[-1] = 1 

+

70 self.include_offset = np.array_equal(np.array(self.A)[-1, :], offset_row) 

+

71 

+

72 self.alt = nS < self.C.shape[0] # No. of states less than no. of observations 

+

73 attrs = list(self.__dict__.keys()) 

+

74 if not 'C_xpose_Q_inv_C' in attrs: 

+

75 C, Q = self.C, self.Q 

+

76 self.C_xpose_Q_inv = C.T * np.linalg.pinv(Q) 

+

77 self.C_xpose_Q_inv_C = C.T * np.linalg.pinv(Q) * C 

+

78 

+

79 try: 

+

80 self.is_stochastic 

+

81 except: 

+

82 n_states = self.A.shape[0] 

+

83 self.is_stochastic = np.ones(n_states, dtype=bool) 

+

84 

+

85 def _obs_prob(self, state): 

+

86 ''' 

+

87 Predict the observations based on the model parameters: 

+

88 y_est = C*x_t + Q 

+

89 

+

90 Parameters 

+

91 ---------- 

+

92 state : bmi.GaussianState instance 

+

93 The model-predicted state 

+

94 

+

95 Returns 

+

96 ------- 

+

97 bmi.GaussianState instance 

+

98 the model-predicted observations 

+

99 ''' 

+

100 return self.C * state + self.obs_noise 

+

101 

+

102 def _forward_infer(self, st, obs_t, Bu=None, u=None, x_target=None, F=None, obs_is_control_independent=True, **kwargs): 

+

103 ''' 

+

104 Estimate p(x_t | ..., y_{t-1}, y_t) 

+

105 

+

106 Parameters 

+

107 ---------- 

+

108 st : GaussianState 

+

109 Current estimate (mean and cov) of hidden state 

+

110 obs_t : np.mat of shape (N, 1) 

+

111 ARG_DESCR 

+

112 Bu : DATA_TYPE, optional, default=None 

+

113 ARG_DESCR 

+

114 u : DATA_TYPE, optional, default=None 

+

115 ARG_DESCR 

+

116 x_target : DATA_TYPE, optional, default=None 

+

117 ARG_DESCR 

+

118 obs_is_control_independent : bool, optional, default=True 

+

119 ARG_DESCR 

+

120 kwargs : optional kwargs 

+

121 ARG_DESCR 

+

122 

+

123 Returns 

+

124 ------- 

+

125 GaussianState 

+

126 New state estimate incorporating the most recent observation 

+

127 

+

128 ''' 

+

129 using_control_input = (Bu is not None) or (u is not None) or (x_target is not None) 

+

130 pred_state = self._ssm_pred(st, target_state=x_target, Bu=Bu, u=u, F=F) 

+

131 

+

132 C, Q = self.C, self.Q 

+

133 P = pred_state.cov 

+

134 

+

135 K = self._calc_kalman_gain(P) 

+

136 I = np.mat(np.eye(self.C.shape[1])) 

+

137 D = self.C_xpose_Q_inv_C 

+

138 KC = P*(I - D*P*(I + D*P).I)*D 

+

139 F = (I - KC)*self.A 

+

140 

+

141 post_state = pred_state 

+

142 

+

143 if obs_is_control_independent and using_control_input: 

+

144 post_state.mean += -KC*self.A*st.mean + K*obs_t 

+

145 else: 

+

146 post_state.mean += -KC*pred_state.mean + K*obs_t 

+

147 

+

148 post_state.cov = (I - KC) * P 

+

149 

+

150 return post_state 

+

151 

+

152 def set_state_cov(self, n_steps): 

+

153 C, Q = self.C, self.Q 

+

154 A, W = self.A, self.W 

+

155 P = self.state.cov 

+

156 for k in range(n_steps): 

+

157 

+

158 P = A*P*A.T + W 

+

159 

+

160 K = self._calc_kalman_gain(P) 

+

161 I = np.mat(np.eye(self.C.shape[1])) 

+

162 D = self.C_xpose_Q_inv_C 

+

163 KC = P*(I - D*P*(I + D*P).I)*D 

+

164 P = (I - KC) * P 

+

165 

+

166 return P 

+

167 

+

168 def _calc_kalman_gain(self, P): 

+

169 ''' 

+

170 Calculate Kalman gain using the 'alternate' definition 

+

171 

+

172 Parameters 

+

173 ---------- 

+

174 P : np.matrix 

+

175 Prediciton covariance matrix, i.e., cov(x_{t+1} | y_1, \cdots, y_t) 

+

176 

+

177 Returns 

+

178 ------- 

+

179 K : np.matrix 

+

180 Kalman gain matrix for the input next state prediciton covariance.  

+

181 ''' 

+

182 nX = P.shape[0] 

+

183 I = np.mat(np.eye(nX)) 

+

184 D = self.C_xpose_Q_inv_C 

+

185 L = self.C_xpose_Q_inv 

+

186 K = P * (I - D*P*(I + D*P).I) * L 

+

187 return K 

+

188 

+

189 def get_sskf(self, tol=1e-15, return_P=False, dtype=np.array, max_iter=4000, 

+

190 verbose=False, return_Khist=False, alt=True): 

+

191 """Calculate the steady-state KF matrices 

+

192 

+

193 value of P returned is the posterior error cov, i.e. P_{t|t} 

+

194 

+

195 Parameters 

+

196 ---------- 

+

197 

+

198 Returns 

+

199 -------  

+

200 """ 

+

201 A, W, C, Q = np.mat(self.A), np.mat(self.W), np.mat(self.C), np.mat(self.Q) 

+

202 

+

203 nS = A.shape[0] 

+

204 P = np.mat(np.zeros([nS, nS])) 

+

205 I = np.mat(np.eye(nS)) 

+

206 

+

207 D = self.C_xpose_Q_inv_C 

+

208 

+

209 last_K = np.mat(np.ones(C.T.shape))*np.inf 

+

210 K = np.mat(np.ones(C.T.shape))*0 

+

211 

+

212 K_hist = [] 

+

213 

+

214 iter_idx = 0 

+

215 last_P = None 

+

216 while np.linalg.norm(K-last_K) > tol and iter_idx < max_iter: 

+

217 P = A*P*A.T + W 

+

218 last_K = K 

+

219 K = self._calc_kalman_gain(P) 

+

220 K_hist.append(K) 

+

221 KC = P*(I - D*P*(I + D*P).I)*D 

+

222 last_P = P 

+

223 P -= KC*P; 

+

224 iter_idx += 1 

+

225 if verbose: 

+

226 print(("Converged in %d iterations--error: %g" % (iter_idx, np.linalg.norm(K-last_K)))) 

+

227 

+

228 n_state_vars, n_state_vars = A.shape 

+

229 F = (np.mat(np.eye(n_state_vars, n_state_vars)) - KC) * A 

+

230 

+

231 if return_P and return_Khist: 

+

232 return dtype(F), dtype(K), dtype(last_P), K_hist 

+

233 elif return_P: 

+

234 return dtype(F), dtype(K), dtype(last_P) 

+

235 elif return_Khist: 

+

236 return dtype(F), dtype(K), K_hist 

+

237 else: 

+

238 return dtype(F), dtype(K) 

+

239 

+

240 

+

241 def get_kalman_gain_seq(self, N=1000, tol=1e-10, verbose=False): 

+

242 ''' 

+

243 Calculate K_t for times {0, 1, ..., N} 

+

244 

+

245 Parameters 

+

246 ---------- 

+

247 N : int, optional 

+

248 Number of steps to calculate Kalman gain for, default = 1000 

+

249 tol : float, optional 

+

250 Tolerance on K matrix convergence, default = 1e-10 

+

251 verbose : bool, optional 

+

252 Print intermediate/debugging information if true, default=False 

+

253 

+

254 Returns 

+

255 ------- 

+

256 list 

+

257 [K_0, K_1, ..., K_{N-1}] 

+

258 ''' 

+

259 A, W, H, Q = np.mat(self.kf.A), np.mat(self.kf.W), np.mat(self.kf.H), np.mat(self.kf.Q) 

+

260 P = np.mat( np.zeros(A.shape) ) 

+

261 K = [None]*N 

+

262 

+

263 ss_idx = None # index at which K is steady-state (within tol) 

+

264 for n in range(N): 

+

265 if not ss_idx == None and n > ss_idx: 

+

266 K[n] = K[ss_idx] 

+

267 else: 

+

268 P = A*P*A.T + W 

+

269 K[n] = (P*H.T)*linalg.pinv(H*P*H.T + Q); 

+

270 P -= K[n]*H*P; 

+

271 if n > 0 and np.linalg.norm(K[n] - K[n-1]) < tol: 

+

272 ss_idx = n 

+

273 if verbose: 

+

274 print(("breaking after %d iterations" % n)) 

+

275 

+

276 return K, ss_idx 

+

277 

+

278 def get_kf_system_mats(self, T): 

+

279 """ 

+

280 KF system matrices 

+

281 

+

282 x_{t+1} = F_t*x_t + K_t*y_t  

+

283 

+

284 Parameters 

+

285 ---------- 

+

286 T : int  

+

287 Number of system iterations to calculate (F_t, K_t) 

+

288 

+

289 Returns 

+

290 ------- 

+

291 tuple of lists 

+

292 Each element of the tuple is (F_t, K_t) for a given 't' 

+

293 

+

294 """ 

+

295 F = [None]*T 

+

296 K, ss_idx = self.get_kalman_gain_seq(N=T, verbose=False) 

+

297 nX = self.kf.A.shape[0] 

+

298 I = np.mat(np.eye(nX)) 

+

299 

+

300 for t in range(T): 

+

301 if t > ss_idx: F[t] = F[ss_idx] 

+

302 else: F[t] = (I - K[t]*self.kf.H)*self.kf.A 

+

303 

+

304 return F, K 

+

305 

+

306 @classmethod 

+

307 def MLE_obs_model(self, hidden_state, obs, include_offset=True, drives_obs=None, 

+

308 regularizer=None): 

+

309 """ 

+

310 Unconstrained ML estimator of {C, Q} given observations and 

+

311 the corresponding hidden states 

+

312 

+

313 Parameters 

+

314 ---------- 

+

315 include_offset : bool, optional, default=True 

+

316 A row of all 1's is added as the last row of hidden_state if one is not already present 

+

317 

+

318 Returns 

+

319 -------  

+

320 """ 

+

321 assert hidden_state.shape[1] == obs.shape[1], "different numbers of time samples: %s vs %s" % (str(hidden_state.shape), str(obs.shape)) 

+

322 

+

323 if isinstance(hidden_state, np.ma.core.MaskedArray): 

+

324 mask = ~hidden_state.mask[0,:] # NOTE THE INVERTER  

+

325 inds = np.nonzero([ mask[k]*mask[k+1] for k in range(len(mask)-1)])[0] 

+

326 

+

327 X = np.mat(hidden_state[:,mask]) 

+

328 T = len(np.nonzero(mask)[0]) 

+

329 

+

330 Y = np.mat(obs[:,mask]) 

+

331 if include_offset: 

+

332 if not np.all(X[-1,:] == 1): 

+

333 X = np.vstack([ X, np.ones([1,T]) ]) 

+

334 else: 

+

335 num_hidden_state, T = hidden_state.shape 

+

336 X = np.mat(hidden_state) 

+

337 if include_offset: 

+

338 if not np.all(X[-1,:] == 1): 

+

339 X = np.vstack([ X, np.ones([1,T]) ]) 

+

340 Y = np.mat(obs) 

+

341 

+

342 n_states = X.shape[0] 

+

343 if not drives_obs is None: 

+

344 X = X[drives_obs, :] 

+

345 

+

346 # ML estimate of C and Q 

+

347 if regularizer is None: 

+

348 C = np.mat(np.linalg.lstsq(X.T, Y.T)[0].T) 

+

349 else: 

+

350 x = X.T 

+

351 y = Y.T 

+

352 XtX_lamb = x.T.dot(x) + regularizer * np.eye(x.shape[1]) 

+

353 XtY = x.T.dot(y) 

+

354 C = np.linalg.solve(XtX_lamb, XtY).T 

+

355 Q = np.cov(Y - C*X, bias=1) 

+

356 

+

357 if np.ndim(Q) == 0: 

+

358 # if "obs" only has 1 feature, Q might get collapsed to a scalar 

+

359 Q = np.mat(Q.reshape(1,1)) 

+

360 

+

361 if not drives_obs is None: 

+

362 n_obs = C.shape[0] 

+

363 C_tmp = np.zeros([n_obs, n_states]) 

+

364 C_tmp[:,drives_obs] = C 

+

365 C = C_tmp 

+

366 return (C, Q) 

+

367 

+

368 @classmethod 

+

369 def MLE_state_space_model(self, hidden_state, include_offset=True): 

+

370 ''' 

+

371 Train state space model for KF from fully observed hidden state 

+

372 

+

373 Parameters 

+

374 ---------- 

+

375 hidden_state : np.ndarray of shape (N, T) 

+

376 N = dimensionality of state vector, T = number of observations 

+

377 include_offset : boolean, optional, default=False 

+

378 if True, append a "1" to each state vector to add an offset term into the  

+

379 regression 

+

380 

+

381 Returns 

+

382 -------  

+

383 A : np.ndarray of shape (N, N) 

+

384 W : np.ndarray of shape (N, N) 

+

385 ''' 

+

386 X = hidden_state 

+

387 T = hidden_state.shape[1] 

+

388 if include_offset: 

+

389 X = np.vstack([ X, np.ones([1,T]) ]) 

+

390 X1 = X[:,:-1] 

+

391 X2 = X[:,1:] 

+

392 A = np.linalg.lstsq(X1.T, X2.T)[0].T 

+

393 W = np.cov(X2 - np.dot(A, X1), bias=1) 

+

394 return A, W 

+

395 

+

396 def set_steady_state_pred_cov(self): 

+

397 ''' 

+

398 Calculate the steady-state prediction covariance and set the current state prediction covariance to the steady-state value 

+

399 ''' 

+

400 

+

401 A, W, C, Q = np.mat(self.A), np.mat(self.W), np.mat(self.C), np.mat(self.Q) 

+

402 D = self.C_xpose_Q_inv_C 

+

403 nS = A.shape[0] 

+

404 P = np.mat(np.zeros([nS, nS])) 

+

405 I = np.mat(np.eye(nS)) 

+

406 

+

407 last_K = np.mat(np.ones(C.T.shape))*np.inf 

+

408 K = np.mat(np.ones(C.T.shape))*0 

+

409 

+

410 iter_idx = 0 

+

411 for iter_idx in range(40): 

+

412 P = A*P*A.T + W 

+

413 last_K = K 

+

414 KC = P*(I - D*P*(I + D*P).I)*D 

+

415 P -= KC*P; 

+

416 

+

417 # TODO fix 

+

418 P[0:3, 0:3] = 0 

+

419 F, K = self.get_sskf() 

+

420 F = (I - KC)*A 

+

421 self._init_state(init_state=self.state.mean, init_cov=P) 

+

422 

+

423 def get_K_null(self): 

+

424 ''' 

+

425 $$y_{null} = K_{null} * y_t$$ gives the "null" component of the spike inputs, i.e. $$K_t*y_{null} = 0_{N\times 1}$$ 

+

426 Parameters 

+

427 ---------- 

+

428 

+

429 Returns 

+

430 -------  

+

431 ''' 

+

432 F, K = self.get_sskf() 

+

433 K = np.mat(K) 

+

434 n_neurons = K.shape[1] 

+

435 K_null = np.eye(n_neurons) - np.linalg.pinv(K) * K 

+

436 return K_null 

+

437 

+

438class KalmanFilterDriftCorrection(KalmanFilter): 

+

439 attrs_to_pickle = ['A', 'W', 'C', 'Q', 'C_xpose_Q_inv', 

+

440 'C_xpose_Q_inv_C', 'R', 'S', 'T', 'ESS', 'drift_corr','prev_drift_corr'] 

+

441 noise_threshold = 96.*3.5 

+

442 

+

443 def _init_state(self): 

+

444 if hasattr(self, 'prev_drift_corr'): 

+

445 self.drift_corr = self.prev_drift_corr.copy() 

+

446 print(('prev drift corr', np.mean(self.prev_drift_corr))) 

+

447 else: 

+

448 self.drift_corr = np.mat(np.zeros(( self.A.shape[0], 1))) 

+

449 self.prev_drift_corr = np.mat(np.zeros(( self.A.shape[0], 1))) 

+

450 

+

451 if hasattr(self, 'noise_rej'): 

+

452 if self.noise_rej: 

+

453 print(('noise rej thresh: ', self.noise_rej_cutoff)) 

+

454 else: 

+

455 self.noise_rej = False 

+

456 self.noise_cnt = 0 

+

457 

+

458 super(KalmanFilterDriftCorrection, self)._init_state() 

+

459 

+

460 def _forward_infer(self, st, obs_t, Bu=None, u=None, x_target=None, F=None, obs_is_control_independent=True, **kwargs): 

+

461 

+

462 if self.noise_rej: 

+

463 if np.sum(obs_t) > self.noise_rej_cutoff: 

+

464 #print np.sum(obs_t), 'rejecting noise!' 

+

465 self.noise_cnt += 1 

+

466 obs_t = np.mat(self.noise_rej_mFR).T 

+

467 

+

468 

+

469 state = super(KalmanFilterDriftCorrection, self)._forward_infer(st, obs_t, Bu=None, u=None, x_target=None, F=None, 

+

470 obs_is_control_independent=True, **kwargs) 

+

471 

+

472 ### Apply Drift Correction ### 

+

473 decoded_vel = state.mean.copy() 

+

474 state.mean[self.vel_ix] = decoded_vel[self.vel_ix] - self.drift_corr[self.vel_ix] 

+

475 

+

476 ### Update Drift Correcton ### 

+

477 self.drift_corr[self.vel_ix] = self.drift_corr[self.vel_ix]*self.drift_rho + decoded_vel[self.vel_ix]*float(1. - self.drift_rho) 

+

478 self.prev_drift_corr = self.drift_corr.copy() 

+

479 

+

480 return state 

+

481 

+

482class PCAKalmanFilter(KalmanFilter): 

+

483 ''' 

+

484 A modified KalmanFilter where the Kalman gain is confined to produce outputs in a lower-dimensional linear subspace, i.e. some principal component space 

+

485 ''' 

+

486 def _forward_infer(self, st, obs_t, Bu=None, u=None, target_state=None, obs_is_control_independent=True, **kwargs): 

+

487 ''' 

+

488 See KalmanFilter._forward_infer for docs 

+

489 ''' 

+

490 using_control_input = (Bu is not None) or (u is not None) or (target_state is not None) 

+

491 pred_state = self._ssm_pred(st, target_state=target_state, Bu=Bu, u=u) 

+

492 

+

493 C, Q = self.C, self.Q 

+

494 P = pred_state.cov 

+

495 

+

496 try: 

+

497 M = self.M 

+

498 pca_offset = self.pca_offset 

+

499 except: 

+

500 print("couldn't extract PCA parameters!") 

+

501 M = 1 

+

502 pca_offset = 0 

+

503 

+

504 K = self._calc_kalman_gain(P) 

+

505 I = np.mat(np.eye(self.C.shape[1])) 

+

506 D = self.C_xpose_Q_inv_C 

+

507 

+

508 KC = K*C 

+

509 F = (I - KC)*self.A 

+

510 

+

511 post_state = pred_state 

+

512 if obs_is_control_independent and using_control_input: 

+

513 post_state.mean += -KC*self.A*st.mean + M*K*obs_t + pca_offset 

+

514 else: 

+

515 post_state.mean += -KC*pred_state.mean + M*K*obs_t + pca_offset 

+

516 

+

517 post_state.cov = (I - KC) * P 

+

518 

+

519 return post_state 

+

520 

+

521 def __getstate__(self): 

+

522 ''' 

+

523 See KalmanFilter.__getstate__ for docs 

+

524 ''' 

+

525 data = super(PCAKalmanFilter, self).__getstate__() 

+

526 data['M'] = self.M 

+

527 data['pca_offset'] = self.pca_offset 

+

528 return data 

+

529 

+

530 def __setstate__(self, state): 

+

531 ''' 

+

532 See KalmanFilter.__setstate__ for docs 

+

533 ''' 

+

534 super(PCAKalmanFilter, self).__setstate__(state) 

+

535 self.M = state['M'] 

+

536 self.pca_offset = state['pca_offset'] 

+

537 

+

538class FAKalmanFilter(KalmanFilter): 

+

539 

+

540 def _forward_infer(self, st, obs_t, Bu=None, u=None, target_state=None, obs_is_control_independent=True, **kwargs): 

+

541 input_dict = {} 

+

542 if hasattr(self, 'FA_kwargs'): 

+

543 

+

544 input_type = self.FA_input + '_input' 

+

545 

+

546 input_dict['all_input'] = obs_t.copy() 

+

547 

+

548 dmn = obs_t - self.FA_kwargs['fa_mu'] 

+

549 shar = (self.FA_kwargs['fa_sharL'] * dmn) 

+

550 priv = (dmn - shar) 

+

551 main_shar = (self.FA_kwargs['fa_main_shared'] * dmn) 

+

552 main_priv = (dmn - main_shar) 

+

553 

+

554 FA = self.FA_kwargs['FA_model'] 

+

555 

+

556 inp = obs_t.copy() 

+

557 if inp.shape[1] == 1: 

+

558 inp = inp.T # want 1 x neurons 

+

559 z = FA.transform(dmn.T) 

+

560 z = z.T #Transform to fact x 1 

+

561 z = z[:self.FA_kwargs['fa_main_shar_n_dim'], :] #only use number in main space 

+

562 

+

563 input_dict['private_input'] = priv + self.FA_kwargs['fa_mu'] 

+

564 input_dict['shared_input'] = shar + self.FA_kwargs['fa_mu'] 

+

565 

+

566 input_dict['private_scaled_input'] = np.multiply(priv, self.FA_kwargs['fa_priv_var_sc']) + self.FA_kwargs['fa_mu'] 

+

567 input_dict['shared_scaled_input'] = np.multiply(shar, self.FA_kwargs['fa_shar_var_sc']) + self.FA_kwargs['fa_mu'] 

+

568 

+

569 input_dict['all_scaled_by_shar_input'] = np.multiply(dmn, self.FA_kwargs['fa_shar_var_sc']) + self.FA_kwargs['fa_mu'] 

+

570 

+

571 input_dict['sc_shared+unsc_priv_input'] = input_dict['shared_scaled_input'] + input_dict['private_input'] - self.FA_kwargs['fa_mu'] 

+

572 input_dict['sc_shared+sc_priv_input'] = input_dict['shared_scaled_input'] + input_dict['private_scaled_input']- self.FA_kwargs['fa_mu'] 

+

573 

+

574 input_dict['main_shared_input'] = main_shar + self.FA_kwargs['fa_mu'] 

+

575 input_dict['main_sc_shared_input'] = np.multiply(main_shar, self.FA_kwargs['fa_main_shared_sc']) + self.FA_kwargs['fa_mu'] 

+

576 

+

577 input_dict['main_sc_shar+unsc_priv_input'] = input_dict['main_sc_shared_input'] + input_dict['private_input'] - self.FA_kwargs['fa_mu'] 

+

578 input_dict['main_sc_shar+sc_priv_input'] = input_dict['main_sc_shared_input'] + input_dict['private_scaled_input'] - self.FA_kwargs['fa_mu'] 

+

579 input_dict['main_sc_private_input'] = np.multiply(main_priv, self.FA_kwargs['fa_main_private_sc']) + self.FA_kwargs['fa_mu'] 

+

580 

+

581 #z = self.FA_kwargs['u_svd'].T*self.FA_kwargs['uut_psi_inv']*dmn 

+

582 input_dict['split_input'] = np.vstack((z, main_priv)) 

+

583 #print input_dict['split_input'].shape 

+

584 

+

585 own_pc_trans = np.mat(self.FA_kwargs['own_pc_trans'])*np.mat(dmn) 

+

586 input_dict['pca_input'] = own_pc_trans + self.FA_kwargs['fa_mu'] 

+

587 

+

588 if input_type in list(input_dict.keys()): 

+

589 #print input_type 

+

590 obs_t_mod = input_dict[input_type] 

+

591 else: 

+

592 print(input_type) 

+

593 raise Exception("Error in FA_KF input_type, none of the expected inputs") 

+

594 else: 

+

595 obs_t_mod = obs_t.copy() 

+

596 

+

597 input_dict['task_input'] = obs_t_mod.copy() 

+

598 

+

599 

+

600 post_state = super(FAKalmanFilter, self)._forward_infer(st, obs_t_mod, Bu=Bu, u=u, target_state=target_state, 

+

601 obs_is_control_independent=obs_is_control_independent, **kwargs) 

+

602 

+

603 self.FA_input_dict = input_dict 

+

604 

+

605 return post_state 

+

606 

+

607class KFDecoder(bmi.BMI, bmi.Decoder): 

+

608 ''' 

+

609 Wrapper for KalmanFilter specifically for the application of BMI decoding. 

+

610 ''' 

+

611 def __init__(self, *args, **kwargs): 

+

612 ''' 

+

613 Constructor for KFDecoder  

+

614  

+

615 Parameters 

+

616 ---------- 

+

617 *args, **kwargs : see riglib.bmi.bmi.Decoder for arguments 

+

618  

+

619 Returns 

+

620 ------- 

+

621 KFDecoder instance 

+

622 ''' 

+

623 

+

624 super(KFDecoder, self).__init__(*args, **kwargs) 

+

625 mFR = kwargs.pop('mFR', 0.) 

+

626 sdFR = kwargs.pop('sdFR', 1.) 

+

627 self.mFR = mFR 

+

628 self.sdFR = sdFR 

+

629 self.zeromeanunits = None 

+

630 self.zscore = False 

+

631 self.kf = self.filt 

+

632 

+

633 def _pickle_init(self): 

+

634 super(KFDecoder, self)._pickle_init() 

+

635 if not hasattr(self.filt, 'B'): 

+

636 self.filt.B = np.mat(np.vstack([np.zeros([3,3]), np.eye(3)*1000*self.binlen, np.zeros(3)])) 

+

637 

+

638 if not hasattr(self.filt, 'F'): 

+

639 self.filt.F = np.mat(np.zeros([3,7])) 

+

640 

+

641 def init_zscore(self, mFR_curr, sdFR_curr): 

+

642 ''' 

+

643 Initialize parameters for zcoring observations, if that feature is enabled in the decoder object 

+

644  

+

645 Parameters 

+

646 ---------- 

+

647 mFR_curr : np.array of shape (N,) 

+

648 Current mean estimates (as opposed to potentially old estimates already stored in the decoder) 

+

649 sdFR_curr : np.array of shape (N,) 

+

650 Current standard deviation estimates (as opposed to potentially old estimates already stored in the decoder) 

+

651  

+

652 Returns 

+

653 ------- 

+

654 None 

+

655 ''' 

+

656 

+

657 # if interfacing with Kinarm system, may mean and sd will be shape (n, 1) 

+

658 self.zeromeanunits, = np.nonzero(mFR_curr == 0) #find any units with a mean FR of zero for this session 

+

659 sdFR_curr[self.zeromeanunits] = np.nan # set mean and SD of quiet units to nan to avoid divide by 0 error 

+

660 mFR_curr[self.zeromeanunits] = np.nan 

+

661 #self.sdFR_ratio = self.sdFR/sdFR_curr 

+

662 #self.mFR_diff = mFR_curr-self.mFR 

+

663 #self.mFR_curr = mFR_curr 

+

664 self.mFR = mFR_curr 

+

665 self.sdFR = sdFR_curr 

+

666 self.zscore = True 

+

667 

+

668 def update_params(self, new_params, steady_state=True): 

+

669 ''' 

+

670 Update the decoder parameters if new parameters are available (e.g., by CLDA). See Decoder.update_params 

+

671 ''' 

+

672 super(KFDecoder, self).update_params(new_params) 

+

673 

+

674 # set the KF to the new steady state 

+

675 if steady_state: 

+

676 self.kf.set_steady_state_pred_cov() 

+

677 

+

678 def __setstate__(self, state): 

+

679 """ 

+

680 Set decoder state after un-pickling. See Decoder.__setstate__, which runs the _pickle_init function at some point during the un-pickling process 

+

681  

+

682 Parameters 

+

683 ---------- 

+

684 state : dict 

+

685 Variables to set as attributes of the unpickled object. 

+

686  

+

687 Returns 

+

688 ------- 

+

689 None 

+

690 """ 

+

691 if 'kf' in state and 'filt' not in state: 

+

692 state['filt'] = state['kf'] 

+

693 

+

694 super(KFDecoder, self).__setstate__(state) 

+

695 

+

696 def plot_K(self, **kwargs): 

+

697 ''' 

+

698 Plot the Kalman gain weights 

+

699  

+

700 Parameters 

+

701 ---------- 

+

702 **kwargs : optional kwargs 

+

703 These are passed to the plot function (e.g., which rows to plot) 

+

704  

+

705 Returns 

+

706 ------- 

+

707 None 

+

708 ''' 

+

709 

+

710 F, K = self.kf.get_sskf() 

+

711 self.plot_pds(K.T, **kwargs) 

+

712 

+

713 def shuffle(self, shuffle_baselines=False): 

+

714 ''' 

+

715 Shuffle the neural model 

+

716  

+

717 Parameters 

+

718 ---------- 

+

719 shuffle_baselines : bool, optional, default = False 

+

720 If true, shuffle the estimates of the baseline firing rates in addition to the state-dependent neural tuning parameters. 

+

721  

+

722 Returns 

+

723 ------- 

+

724 None (shuffling is done on the current decoder object)  

+

725 

+

726 ''' 

+

727 # generate random permutation 

+

728 import random 

+

729 inds = list(range(self.filt.C.shape[0])) 

+

730 random.shuffle(inds) 

+

731 

+

732 # shuffle rows of C, and rows+cols of Q 

+

733 C_orig = self.filt.C.copy() 

+

734 self.filt.C = self.filt.C[inds, :] 

+

735 if not shuffle_baselines: 

+

736 self.filt.C[:,-1] = C_orig[:,-1] 

+

737 self.filt.Q = self.filt.Q[inds, :] 

+

738 self.filt.Q = self.filt.Q[:, inds] 

+

739 

+

740 self.filt.C_xpose_Q_inv = self.filt.C.T * np.linalg.pinv(self.filt.Q.I) 

+

741 

+

742 # RML sufficient statistics (S and T, but not R and ESS) 

+

743 # shuffle rows of S, and rows+cols of T 

+

744 try: 

+

745 self.filt.S = self.filt.S[inds, :] 

+

746 self.filt.T = self.filt.T[inds, :] 

+

747 self.filt.T = self.filt.T[:, inds] 

+

748 except AttributeError: 

+

749 # if this decoder never had the RML sufficient statistics 

+

750 # (R, S, T, and ESS) as attributes of self.filt 

+

751 pass 

+

752 

+

753 def change_binlen(self, new_binlen, screen_update_rate=60.0): 

+

754 ''' 

+

755 Function to change the binlen of the KFDecoder analytically.  

+

756 

+

757 Parameters 

+

758 ---------- 

+

759 new_binlen : float 

+

760 New bin length of the decoder, in seconds 

+

761 screen_update_rate: float, optional, default = 60Hz 

+

762 Rate at which the __call__ function will be called 

+

763 ''' 

+

764 bin_gain = new_binlen / self.binlen 

+

765 self.binlen = new_binlen 

+

766 

+

767 # Alter bminum, bmicount, # of subbins 

+

768 screen_update_period = 1./screen_update_rate 

+

769 if self.binlen < screen_update_period: 

+

770 self.n_subbins = int(screen_update_period / self.binlen) 

+

771 self.bmicount = 0 

+

772 if hasattr(self, 'bminum'): 

+

773 del self.bminum 

+

774 else: 

+

775 self.n_subbins = 1 

+

776 self.bminum = int(self.binlen / screen_update_period) 

+

777 self.bmicount = 0 

+

778 

+

779 # change C matrix 

+

780 self.filt.C *= bin_gain 

+

781 self.filt.Q *= bin_gain**2 

+

782 self.filt.C_xpose_Q_inv *= 1./bin_gain 

+

783 

+

784 # change state space Model 

+

785 # TODO generalize this beyond endpoint 

+

786 from . import state_space_models 

+

787 A, W = self.ssm.get_ssm_matrices(update_rate=new_binlen) 

+

788 self.filt.A = A 

+

789 self.filt.W = W 

+

790 

+

791 def conv_to_steady_state(self): 

+

792 ''' 

+

793 Create an SSKFDecoder object based on KalmanFilter parameters in this KFDecoder object 

+

794 ''' 

+

795 from . import sskfdecoder 

+

796 self.filt = sskfdecoder.SteadyStateKalmanFilter(A=self.filt.A, W=self.filt.W, C=self.filt.C, Q=self.filt.Q) 

+

797 

+

798 def subselect_units(self, units): 

+

799 ''' 

+

800 Prune units from the KFDecoder, e.g., due to loss of recordings for a particular cell 

+

801 

+

802 Parameters 

+

803 units : string or np.ndarray of shape (N,2) 

+

804 The units which should be KEPT in the decoder 

+

805 

+

806 Returns  

+

807 ------- 

+

808 KFDecoder  

+

809 New KFDecoder object using only a subset of the cells of the original KFDecoder 

+

810 ''' 

+

811 # Parse units into list of indices to keep 

+

812 inds_to_keep = self._proc_units(units, 'keep') 

+

813 dec_new = self._return_proc_units_decoder(inds_to_keep) 

+

814 return dec_new 

+

815 #self._save_new_dec(dec_new, '_subset') 

+

816 

+

817def project_Q(C_v, Q_hat): 

+

818 """  

+

819 Deprecated! See clda.KFRML_IVC 

+

820 """ 

+

821 print("projecting!") 

+

822 from scipy.optimize import fmin_bfgs, fmin_ncg 

+

823 

+

824 C_v = np.mat(C_v) 

+

825 Q_hat = np.mat(Q_hat) 

+

826 Q_hat_inv = Q_hat.I 

+

827 

+

828 c_1 = C_v[:,0] 

+

829 c_2 = C_v[:,1] 

+

830 A_1 = c_1*c_1.T - c_2*c_2.T 

+

831 A_2 = c_2*c_1.T 

+

832 A_3 = c_1*c_2.T 

+

833 A = [A_1, A_2, A_3] 

+

834 if 1: 

+

835 U = np.hstack([c_1 - c_2, c_2, c_1]) 

+

836 V = np.vstack([(c_1 + c_2).T, c_1.T, c_2.T]) 

+

837 C_inv_fn = lambda nu: np.mat(np.diag([1./nu[0], 1./(nu[0] + nu[1]), 1./(nu[2] - nu[0]) ])) 

+

838 C_fn = lambda nu: np.mat(np.diag([nu[0], (nu[0] + nu[1]), (nu[2] - nu[0]) ])) 

+

839 nu_0 = np.zeros(3) 

+

840 c_scalars = np.ones(3) 

+

841 else: 

+

842 u_1, s_1, v_1 = np.linalg.svd(A_1) 

+

843 c_scalars = np.hstack([s_1[0:2], 1, 1]) 

+

844 U = np.hstack([u_1[:,0:2], c_2, c_1]) 

+

845 V = np.vstack([v_1[0:2, :], c_1.T, c_2.T]) 

+

846 C_fn = lambda nu: np.mat(np.diag(nu * c_scalars)) 

+

847 nu_0 = np.zeros(4) 

+

848 

+

849 def cost_fn_gen(nu, return_type='cost'): 

+

850 C = C_fn(nu) 

+

851 S_star_inv = Q_hat + U*C_fn(nu)*V 

+

852 #if return_type == 'cost': 

+

853 # print C_v.T * S_star_inv * C_v 

+

854 

+

855 if np.any(np.diag(C) == 0): 

+

856 S_star = S_star_inv.I 

+

857 else: 

+

858 C_inv = C.I 

+

859 S_star = Q_hat_inv - Q_hat_inv * U * (C_inv + V*Q_hat_inv*U).I*V * Q_hat_inv; 

+

860 

+

861 # log-determinant using LU decomposition, required if Q is large, i.e. lots of simultaneous observations 

+

862 cost = -np.log(np.linalg.det(S_star_inv)) 

+

863 #cost = -np.prod(np.linalg.slogdet(S_star_inv)) 

+

864 

+

865 # TODO gradient dimension needs to be the same as nu 

+

866 #grad = -np.array([np.trace(S_star*U[:,0] * c_scalars[0] * V[0,:]) for k in range(len(nu))]) 

+

867 #grad = -1e-4*np.array([np.trace(S_star*A[0]), np.trace(S_star*A[1]), np.trace(S_star*A[2])]) 

+

868 #print c_2.T*S_star*c_2 

+

869 grad = -1e-4*np.array(np.hstack([c_1.T*S_star*c_1 - c_2.T*S_star*c_2, c_1.T*S_star*c_2, c_2.T*S_star*c_1])).ravel() 

+

870 S = S_star 

+

871 hess = np.mat([[np.trace(S*A_1*S*A_1), np.trace(S*A_2*S*A_1), np.trace(S*A_3*S*A_1)], 

+

872 [np.trace(S*A_1*S*A_2), np.trace(S*A_2*S*A_2), np.trace(S*A_3*S*A_2)], 

+

873 [np.trace(S*A_1*S*A_3), np.trace(S*A_2*S*A_3), np.trace(S*A_3*S*A_3)]]) 

+

874 

+

875 #grad = hess*np.mat(grad.reshape(-1,1)) 

+

876 #log = logging.getLogger() 

+

877 #print "nu = %s, cost = %g, grad=%s" % (nu, cost, grad) 

+

878 #log.warning("nu = %s, cost = %g, grad=%s" % (nu, cost, grad)) 

+

879 

+

880 if return_type == 'cost': 

+

881 return cost 

+

882 elif return_type == 'grad': 

+

883 return grad 

+

884 elif return_type == 'hess': 

+

885 return hess 

+

886 elif return_type == 'opt_val': 

+

887 return S_star 

+

888 else: 

+

889 raise ValueError("Cost function doesn't know how to return this: %s" % return_type) 

+

890 

+

891 cost_fn = lambda nu: cost_fn_gen(nu, return_type = 'cost') 

+

892 grad = lambda nu: cost_fn_gen(nu, return_type = 'grad') 

+

893 hess = lambda nu: cost_fn_gen(nu, return_type = 'hess') 

+

894 arg_opt = lambda nu: cost_fn_gen(nu, return_type = 'opt_val') 

+

895 

+

896 # Call optimization routine 

+

897 #v_star = fmin_ncg(cost_fn, nu_0, fprime=grad, fhess=hess, maxiter=10000) 

+

898 #print v_star 

+

899 #v_star = fmin_bfgs(cost_fn, nu_0, maxiter=10000, gtol=1e-15) 

+

900 v_star = fmin_bfgs(cost_fn, nu_0, fprime=grad, maxiter=10000, gtol=1e-15) 

+

901 print(v_star) 

+

902 

+

903 Q_inv = arg_opt(v_star) 

+

904 Q = Q_inv.I 

+

905 Q = Q_hat + U * C_fn(v_star) * V 

+

906 

+

907 # TODO print out (log) a more useful measure of success 

+

908 #print C_v.T * Q_inv * C_v 

+

909 #print C_v.T * Q.I * C_v 

+

910 #print v_star 

+

911 return Q 

+

912 

+

913 

+

914 

+

915 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_lindecoder_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_lindecoder_py.html new file mode 100644 index 00000000..2eeca02a --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_lindecoder_py.html @@ -0,0 +1,129 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\lindecoder.py: 24% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Classes for BMI decoding using linear scaling.  

+

3''' 

+

4import numpy as np 

+

5 

+

6class State(object): 

+

7 '''For compatibility with other BMI decoding implementations, literally just holds the state''' 

+

8 

+

9 def __init__(self, mean, *args, **kwargs): 

+

10 self.mean = mean 

+

11 

+

12class LinearScaleFilter(object): 

+

13 

+

14 def __init__(self, n_counts, window, n_states, n_units): 

+

15 ''' 

+

16 Parameters: 

+

17 

+

18 n_counts How many observations to hold 

+

19 window How many observations to average 

+

20 n_states How many state space variables are there 

+

21 n_units Number of neural units 

+

22 ''' 

+

23 self.state = State(np.zeros([n_states,1])) 

+

24 self.obs = np.zeros((n_counts, n_units)) 

+

25 self.n_states = n_states 

+

26 self.window = window 

+

27 self.n_units = n_units 

+

28 self.count = 0 

+

29 

+

30 def _init_state(self): 

+

31 pass 

+

32 

+

33 def get_mean(self): 

+

34 return np.array(self.state.mean).ravel() 

+

35 

+

36 def __call__(self, obs, **kwargs): 

+

37 self.state = self._normalize(obs, **kwargs) 

+

38 

+

39 def _normalize(self, obs,**kwargs): 

+

40 ''' Function to compute normalized scaling of new observations''' 

+

41 

+

42 self.obs[:-1, :] = self.obs[1:, :] 

+

43 self.obs[-1, :] = np.squeeze(obs) 

+

44 if self.count < len(self.obs): 

+

45 self.count += 1 

+

46 

+

47 m_win = np.squeeze(np.mean(self.obs[-self.window:, :], axis=0)) 

+

48 m = np.median(self.obs[-self.count:, :], axis=0) 

+

49 # range = max(1, np.amax(self.obs[-self.count:, :]) - np.amin(self.obs[-self.count:, :])) 

+

50 range = np.std(self.obs[-self.count:, :], axis=0)*3 

+

51 range[range < 1] = 1 

+

52 x = (m_win - m) / range 

+

53 x = np.squeeze(np.asarray(x)) * 20 # hack for 14x14 cursor 

+

54 

+

55 # Arrange output 

+

56 if self.n_states == self.n_units: 

+

57 return State(x) 

+

58 elif self.n_states == 3 and self.n_units == 2: 

+

59 mean = np.zeros([self.n_states,1]) 

+

60 mean[0] = x[0] 

+

61 mean[2] = x[1] 

+

62 return State(mean) 

+

63 else: 

+

64 raise NotImplementedError() 

+

65 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_onedim_lfp_decoder_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_onedim_lfp_decoder_py.html new file mode 100644 index 00000000..269c38a3 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_onedim_lfp_decoder_py.html @@ -0,0 +1,306 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\onedim_lfp_decoder.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1 

+

2from riglib import bmi, plexon, source 

+

3from riglib.bmi import extractor 

+

4import numpy as np 

+

5from riglib.bmi import clda 

+

6from riglib.bmi import train 

+

7 

+

8kinarm_bands = [] 

+

9for i in np.arange(0,100,10): 

+

10 kinarm_bands.extend([[i, i+10]]) 

+

11kinarm_bands.extend([[25, 40],[40, 55], [65, 90], [2, 100]]) 

+

12 

+

13class StateHolder(object): 

+

14 def __init__(self, x_mat, A_mat, powercap_flag, zbound, *args, **kwargs): 

+

15 if powercap_flag: 

+

16 pos_mean = np.sum((np.multiply(x_mat, A_mat)*0)+zbound[0], axis=0) 

+

17 self.mean = np.squeeze(np.hstack((np.array([pos_mean]), np.zeros((1,4))))) 

+

18 else: 

+

19 #self.mean = np.dot(x_array, A_array) 

+

20 pos_mean = np.sum(np.multiply(x_mat, A_mat), axis=0) 

+

21 self.mean = np.squeeze(np.hstack((np.array([pos_mean]), np.zeros((1,4))))) 

+

22 

+

23class SmoothFilter(object): 

+

24 '''Moving Avergae Filter used in 1D or 2D LFP control: 

+

25 x_{t} = a0*x_{t} + a1*x_{t-1} + a2*x_{t-2} + ... 

+

26 

+

27 Parameters 

+

28 

+

29 ---------- 

+

30 A: np.array of shape (N, 3) 

+

31 Weights for previous states 

+

32 X: np. array of previous states (N, 3) 

+

33 ''' 

+

34 

+

35 def __init__(self, n_steps, **kwargs): 

+

36 self.n_steps = n_steps 

+

37 A = np.ones(( self.n_steps, ))/float(self.n_steps) 

+

38 self.A = np.tile(np.array([A]).T, [1,3]) 

+

39 

+

40 self.control_method = 'fraction' 

+

41 self.current_lfp_pos = 0 

+

42 self.current_powercap_flag = 0 

+

43 

+

44 def get_mean(self): 

+

45 return np.array(self.state.mean).ravel() 

+

46 

+

47 def init_from_task(self,**kwargs): 

+

48 if 'n_steps' in kwargs: 

+

49 self.n_steps = kwargs['n_steps'] 

+

50 A = np.ones(( self.n_steps, ))/float(self.n_steps) 

+

51 self.A = np.tile(np.array([A]).T, [1,3]) 

+

52 

+

53 if 'powercap' in kwargs: 

+

54 self.powercap = kwargs['powercap'] 

+

55 

+

56 if 'zboundaries' in kwargs: 

+

57 self.zboundaries = kwargs['zboundaries'] 

+

58 

+

59 if 'lfp_frac_lims' in kwargs: 

+

60 self.frac_lims = kwargs['lfp_frac_lims'] 

+

61 

+

62 if 'xlfp_frac_lims' in kwargs: 

+

63 self.xfrac_lims = kwargs['xlfp_frac_lims'] 

+

64 

+

65 def _init_state(self, init_state=None,**kwargs): 

+

66 if init_state is None: 

+

67 self.X = np.zeros(( self.n_steps, 3)) 

+

68 

+

69 #Implemented later:  

+

70 elif init_state is 'average': 

+

71 if self.control_method == 'fraction': 

+

72 mn = np.mean(np.array(kwargs['frac_lim'])) 

+

73 elif self.control_method == 'power': 

+

74 mn = np.mean(np.array(kwargs['pwr_mean'])) 

+

75 self.X = np.zeros(( self.n_steps, 3 )) + mn 

+

76 

+

77 self.state = StateHolder(self.X, self.A, 0, 0) 

+

78 

+

79 self.cnt = 0 

+

80 

+

81 def __call__(self, obs, **kwargs): 

+

82 self.state = self._mov_avg(obs, **kwargs) 

+

83 

+

84 def _mov_avg(self, obs,**kwargs): 

+

85 #self.zboundaries = kwargs['zboundaries'] 

+

86 self.fft_inds = kwargs['fft_inds'] 

+

87 obs = obs.reshape(len(kwargs['channels']), len(kwargs['fft_freqs'])) 

+

88 

+

89 self.current_lfp_pos, self.current_powercap_flag = self.get_lfp_cursor(obs) 

+

90 

+

91 self.X = np.vstack(( self.X[1:,:], self.current_lfp_pos )) 

+

92 return StateHolder(self.X, self.A, self.current_powercap_flag, self.zboundaries) 

+

93 

+

94 def get_lfp_cursor(self, psd_est): 

+

95 'Obs: channels x frequencies ' 

+

96 # Control band:  

+

97 c_idx = self.control_band_ind 

+

98 

+

99 #As done in kinarm script, sum together frequencies within a band, then take the mean across channels 

+

100 c_val = np.mean(np.sum(psd_est[:, self.fft_inds[c_idx]], axis=1)) 

+

101 

+

102 xc_idx = self.x_control_band_ind 

+

103 xc_val = np.mean(np.sum(psd_est[:, self.fft_inds[xc_idx]], axis=1)) 

+

104 

+

105 p_idx = self.totalpw_band_ind 

+

106 p_val = np.mean(np.sum(psd_est[:, self.fft_inds[p_idx]], axis=1)) 

+

107 

+

108 if self.control_method == 'fraction': 

+

109 lfp_control = c_val / float(p_val) 

+

110 xlfp_control = xc_val/float(p_val) 

+

111 

+

112 elif self.control_method == 'power': 

+

113 lfp_control = c_val 

+

114 

+

115 cursor_pos = self.lfp_to_cursor(lfp_control, xlfp_control) 

+

116 

+

117 #if p_val <= self.powercap: 

+

118 #write c_val, xc_val, p_val, to file: 

+

119 if self.cnt < 3000: 

+

120 self.files[0].write(str(c_val)+',') 

+

121 self.files[1].write(str(p_val)+',') 

+

122 self.files[2].write(str(xc_val)+',') 

+

123 self.cnt += 1 

+

124 elif self.cnt == 3000: 

+

125 for f in self.files: 

+

126 f.close() 

+

127 

+

128 #Hack: make x axis control powercap value: 

+

129 if xc_val <= self.powercap: 

+

130 powercap_flag = 0 

+

131 else: 

+

132 powercap_flag = 1 

+

133 

+

134 return cursor_pos, powercap_flag 

+

135 

+

136 def lfp_to_cursor(self, lfppos, xlfp_control): 

+

137 if self.control_method == 'fraction': 

+

138 dmn = lfppos - np.mean(self.frac_lims); 

+

139 cursor_pos = dmn * (self.zboundaries[1]-self.zboundaries[0]) / (self.frac_lims[1] - self.frac_lims[0]) 

+

140 

+

141 #Xcursor postion, keep within (0, -16) boundaries:  

+

142 xdmn = xlfp_control - np.mean(self.xfrac_lims) 

+

143 xcursor_pos = xdmn * (0--16) / (self.xfrac_lims[1] - self.xfrac_lims[0]) 

+

144 xcursor_pos = xcursor_pos - 8; #nonzero offset 

+

145 return np.array([xcursor_pos, 0, cursor_pos]) 

+

146 

+

147 

+

148 def _pickle_init(self): 

+

149 pass 

+

150 

+

151from .bmi import Decoder 

+

152class One_Dim_LFP_Decoder(Decoder): 

+

153 

+

154 def __init__(self, *args, **kwargs): 

+

155 

+

156 #Args: sf, units, ssm, extractor_cls, extractor_kwargs 

+

157 super(One_Dim_LFP_Decoder, self).__init__(args[0], args[1], args[2]) 

+

158 

+

159 self.extractor_cls = args[3] 

+

160 self.extractor_kwargs = args[4] 

+

161 

+

162 #For now, hardcoded:  

+

163 bands = kinarm_bands 

+

164 control_method='fraction' 

+

165 no_log=True 

+

166 no_mean=True 

+

167 

+

168 self.extractor_kwargs['bands'] = bands 

+

169 self.extractor_kwargs['no_log'] = no_log 

+

170 self.extractor_kwargs['no_mean'] = no_mean 

+

171 

+

172 def __getitem__(self, key): 

+

173 return getattr(self, key) 

+

174 

+

175 def __setitem__(self, key, value): 

+

176 setattr(self,key,value) 

+

177 

+

178 def predict(self, neural_obs, **kwargs): 

+

179 #kwargs['zboundaries'] = self.filt.zboundaries 

+

180 kwargs['fft_inds'] = self.extractor_kwargs['fft_inds'] 

+

181 kwargs['channels'] = self.extractor_kwargs['channels'] 

+

182 kwargs['fft_freqs'] = self.extractor_kwargs['fft_freqs'] 

+

183 

+

184 self.filt(neural_obs, **kwargs) 

+

185 

+

186 

+

187 def init_from_task(self,**kwargs): 

+

188 if 'lfp_control_band' in kwargs: 

+

189 self.filt.control_band_ind, self.extractor_kwargs['bands'], self.extractor_kwargs['fft_inds'] = \ 

+

190 self._get_band_ind(self.extractor_kwargs['fft_freqs'], kwargs['lfp_control_band'], self.extractor_kwargs['bands']) 

+

191 

+

192 

+

193 if 'lfp_totalpw_band' in kwargs: 

+

194 self.filt.totalpw_band_ind, self.extractor_kwargs['bands'], self.extractor_kwargs['fft_inds'] = \ 

+

195 self._get_band_ind(self.extractor_kwargs['fft_freqs'], kwargs['lfp_totalpw_band'], self.extractor_kwargs['bands']) 

+

196 

+

197 if 'xlfp_control_band' in kwargs: 

+

198 self.filt.x_control_band_ind, self.extractor_kwargs['bands'], self.extractor_kwargs['fft_inds'] = \ 

+

199 self._get_band_ind(self.extractor_kwargs['fft_freqs'], kwargs['xlfp_control_band'], self.extractor_kwargs['bands']) 

+

200 

+

201 c_txt = open('/home/helene/Downloads/pk/txt_write/control.txt','w') 

+

202 p_txt = open('/home/helene/Downloads/pk/txt_write/tot.txt','w') 

+

203 xc_txt = open('/home/helene/Downloads/pk/txt_write/x_cont.txt','w') 

+

204 self.filt.files = [c_txt, p_txt, xc_txt] 

+

205 

+

206 def _get_band_ind(self, freq_pts, band, band_set): 

+

207 band_ind = -1 

+

208 for b, bd in enumerate(band_set): 

+

209 if (bd[0]==band[0]) and (bd[1]==band[1]): 

+

210 band_ind = b 

+

211 if band_ind == -1: 

+

212 band_ind = b+1 

+

213 band_set.extend([band]) 

+

214 

+

215 fft_ind = dict() 

+

216 for band_idx, band in enumerate(band_set): 

+

217 fft_ind[band_idx] = [freq_idx for freq_idx, freq in enumerate(freq_pts) if band[0] <= freq < band[1]] 

+

218 

+

219 return band_ind, band_set, fft_ind 

+

220 

+

221 

+

222def _init_decoder_for_sim(n_steps = 10): 

+

223 kw = dict(control_method='fraction') 

+

224 sf = SmoothFilter(n_steps,**kw) 

+

225 ssm = train.endpt_2D_state_space 

+

226 units = [[23, 1],[24,1],[25,1]] 

+

227 decoder = One_Dim_LFP_Decoder(sf, units, ssm, binlen=0.1, n_subbins=1) 

+

228 learner = clda.DumbLearner() 

+

229 

+

230 return decoder 

+

231 

+

232def create_decoder(units, ssm, extractor_cls, extractor_kwargs, n_steps=10): 

+

233 kw = dict(control_method='fraction') 

+

234 sf = SmoothFilter(n_steps,**kw) 

+

235 decoder = One_Dim_LFP_Decoder(sf, units, ssm, extractor_cls, extractor_kwargs) 

+

236 

+

237 return decoder 

+

238 

+

239 

+

240 

+

241 

+

242 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_ppfdecoder_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_ppfdecoder_py.html new file mode 100644 index 00000000..09b5f744 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_ppfdecoder_py.html @@ -0,0 +1,574 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\ppfdecoder.py: 15% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Classes for BMI decoding using the Point-process filter.  

+

3''' 

+

4 

+

5import numpy as np 

+

6 

+

7from . import bmi 

+

8from .bmi import GaussianState 

+

9from scipy.io import loadmat 

+

10import time 

+

11import cmath 

+

12from . import feedback_controllers 

+

13import pickle 

+

14from . import train 

+

15 

+

16class PointProcessFilter(bmi.GaussianStateHMM): 

+

17 """ 

+

18 Low-level Point-process filter, agnostic to application 

+

19 

+

20 Model:  

+

21 x_{t+1} = Ax_t + Bu_t + w_t; w_t ~ N(0, W) 

+

22 log(y_t) = Cx_t 

+

23 

+

24 See Shanechi et al., "Feedback-Controlled Parallel Point Process Filter for  

+

25 Estimation of Goal-Directed Movements From Neural Signals", IEEE TNSRE, 2013 

+

26 for mathematical details. 

+

27 """ 

+

28 model_attrs = ['A', 'W', 'C'] 

+

29 

+

30 def __init__(self, A=None, W=None, C=None, dt=None, is_stochastic=None, B=0, F=0): 

+

31 ''' 

+

32 Constructor for PointProcessFilter 

+

33 

+

34 Parameters 

+

35 ---------- 

+

36 A : np.mat 

+

37 Model of state transition matrix 

+

38 W : np.mat 

+

39 Model of process noise covariance 

+

40 C : np.mat 

+

41 Model of conditional distribution between observations and hidden state 

+

42 log(obs) = C * hidden_state 

+

43 dt : float 

+

44 Discrete-time sampling rate of the filter. Used to map spike counts to spike rates 

+

45 B : np.ndarray, optional 

+

46 Control input matrix 

+

47 F : np.ndarray, optional 

+

48 State-space feedback gain matrix to drive state back to equilibrium state. 

+

49 is_stochastic : np.array, optional 

+

50 Array of booleans specifying for each state whether it is stochastic.  

+

51 If 'None' specified, all states are assumed to be stochastic 

+

52 

+

53 Returns 

+

54 ------- 

+

55 KalmanFilter instance 

+

56 ''' 

+

57 if A is None and W is None and C is None and dt is None: 

+

58 ## This condition should only be true in the unpickling phase 

+

59 pass 

+

60 else: 

+

61 self.A = np.mat(A) 

+

62 self.W = np.mat(W) 

+

63 self.C = np.mat(C) 

+

64 self.dt = dt 

+

65 self.spike_rate_dt = dt 

+

66 

+

67 self.B = B 

+

68 self.F = F 

+

69 

+

70 if is_stochastic == None: 

+

71 n_states = A.shape[0] 

+

72 self.is_stochastic = np.ones(n_states, dtype=bool) 

+

73 else: 

+

74 self.is_stochastic = np.array(is_stochastic) 

+

75 

+

76 self.state_noise = GaussianState(0.0, W) 

+

77 self._pickle_init() 

+

78 

+

79 def _pickle_init(self): 

+

80 """ 

+

81 Code common to unpickling and initialization 

+

82 """ 

+

83 nS = self.A.shape[0] 

+

84 offset_row = np.zeros(nS) 

+

85 offset_row[-1] = 1 

+

86 self.include_offset = np.array_equal(np.array(self.A)[-1, :], offset_row) 

+

87 

+

88 self.spike_rate_dt = self.dt 

+

89 

+

90 if not hasattr(self, 'B'): self.B = 0 

+

91 if not hasattr(self, 'F'): self.F = 0 

+

92 

+

93 def init_noise_models(self): 

+

94 ''' 

+

95 see bmi.GaussianStateHMM.init_noise_models for documentation 

+

96 ''' 

+

97 self.state_noise = GaussianState(0.0, self.W) 

+

98 self.id = np.zeros([1, self.C.shape[0]]) 

+

99 

+

100 def _check_valid(self, lambda_predict): 

+

101 ''' 

+

102 Docstring  

+

103  

+

104 Parameters 

+

105 ---------- 

+

106  

+

107 Returns 

+

108 ------- 

+

109 ''' 

+

110 if np.any((lambda_predict * self.spike_rate_dt) > 1): 

+

111 raise ValueError("Cell exploded!") 

+

112 

+

113 def _obs_prob(self, state): 

+

114 ''' 

+

115 Docstring  

+

116  

+

117 Parameters 

+

118 ---------- 

+

119  

+

120 Returns 

+

121 ------- 

+

122 ''' 

+

123 Loglambda_predict = self.C * state.mean 

+

124 lambda_predict = np.exp(Loglambda_predict)/self.spike_rate_dt 

+

125 

+

126 nan_inds = np.isnan(lambda_predict) 

+

127 lambda_predict[nan_inds] = 0 

+

128 

+

129 # check max rate is less than 1 b/c it's a probability 

+

130 rate_too_high_inds = ((lambda_predict * self.spike_rate_dt) > 1) 

+

131 lambda_predict[rate_too_high_inds] = 1./self.spike_rate_dt 

+

132 

+

133 # check min rate is > 0 

+

134 rate_too_low_inds = (lambda_predict < 0) 

+

135 lambda_predict[rate_too_low_inds] = 0 

+

136 

+

137 invalid_inds = nan_inds | rate_too_high_inds | rate_too_low_inds 

+

138 if np.any(invalid_inds): 

+

139 pass 

+

140 #print np.nonzero(invalid_inds.ravel()[0]) 

+

141 return lambda_predict 

+

142 

+

143 def _forward_infer(self, st, obs_t, Bu=None, u=None, x_target=None, F=None, obs_is_control_independent=False, **kwargs): 

+

144 ''' 

+

145 Docstring  

+

146  

+

147 Parameters 

+

148 ---------- 

+

149  

+

150 Returns 

+

151 ------- 

+

152 ''' 

+

153 if np.any(obs_t > 1): 

+

154 raise Exception 

+

155 using_control_input = (Bu is not None) or (u is not None) or (x_target is not None) 

+

156 if x_target is not None: 

+

157 x_target = np.mat(x_target[:,0].reshape(-1,1)) 

+

158 target_state = x_target 

+

159 

+

160 obs_t = np.mat(obs_t.reshape(-1,1)) 

+

161 C = self.C 

+

162 n_obs, n_states = C.shape 

+

163 

+

164 dt = self.spike_rate_dt 

+

165 inds, = np.nonzero(self.is_stochastic) 

+

166 mesh = np.ix_(inds, inds) 

+

167 A = self.A 

+

168 W = self.W 

+

169 C = C[:,inds] 

+

170 

+

171 

+

172 # print np.array(x_target).ravel() 

+

173 pred_state = self._ssm_pred(st, target_state=x_target, Bu=Bu, u=u, F=F) 

+

174 x_pred, P_pred = pred_state.mean, pred_state.cov 

+

175 P_pred = P_pred[mesh] 

+

176 

+

177 Loglambda_predict = self.C * x_pred 

+

178 exp = np.vectorize(lambda x: np.real(cmath.exp(x))) 

+

179 lambda_predict = exp(np.array(Loglambda_predict).ravel())/dt 

+

180 

+

181 Q_inv = np.mat(np.diag(lambda_predict*dt)) 

+

182 

+

183 if np.linalg.cond(P_pred) > 1e5: 

+

184 P_est = P_pred; 

+

185 else: 

+

186 P_est = (P_pred.I + C.T*np.mat(np.diag(lambda_predict*dt))*C).I 

+

187 

+

188 # inflate P_est 

+

189 P_est_full = np.mat(np.zeros([n_states, n_states])) 

+

190 P_est_full[mesh] = P_est 

+

191 P_est = P_est_full 

+

192 

+

193 unpred_spikes = obs_t - np.mat(lambda_predict*dt).reshape(-1,1) 

+

194 

+

195 x_est = np.mat(np.zeros([n_states, 1])) 

+

196 x_est = x_pred + P_est*self.C.T*unpred_spikes 

+

197 self.neural_push = P_est*self.C.T*obs_t 

+

198 self.P_est = P_est 

+

199 post_state = GaussianState(x_est, P_est) 

+

200 return post_state 

+

201 

+

202 def __getstate__(self): 

+

203 ''' 

+

204 Return model parameters to be pickled. Overrides the default __getstate__ so that things like the P matrix aren't pickled. 

+

205  

+

206 Parameters 

+

207 ---------- 

+

208 None 

+

209  

+

210 Returns 

+

211 ------- 

+

212 dict 

+

213 ''' 

+

214 return dict(A=self.A, W=self.W, C=self.C, dt=self.dt, B=self.B, 

+

215 is_stochastic=self.is_stochastic) 

+

216 

+

217 def tomlab(self, unit_scale=1.): 

+

218 ''' 

+

219 Convert to the MATLAB beta matrix convention from the one used here (different state order, transposed) 

+

220 ''' 

+

221 return np.array(np.hstack([self.C[:,-1], unit_scale*self.C[:,self.is_stochastic]])).T 

+

222 

+

223 @classmethod 

+

224 def frommlab(self, beta_mat): 

+

225 ''' 

+

226 Convert from the MATLAB beta matrix convention to the one used here (different state order, transposed) 

+

227 ''' 

+

228 return np.vstack([beta_mat[1:,:], beta_mat[0,:]]).T 

+

229 

+

230 @classmethod 

+

231 def MLE_obs_model(cls, hidden_state, obs, include_offset=True, drives_obs=None): 

+

232 """ 

+

233 Unconstrained ML estimator of {C, } given observations and 

+

234 the corresponding hidden states 

+

235 Docstring  

+

236  

+

237 Parameters 

+

238 ---------- 

+

239  

+

240 Returns 

+

241 -------  

+

242 """ 

+

243 assert hidden_state.shape[1] == obs.shape[1] 

+

244 

+

245 if isinstance(hidden_state, np.ma.core.MaskedArray): 

+

246 mask = ~hidden_state.mask[0,:] # NOTE THE INVERTER  

+

247 inds = np.nonzero([ mask[k]*mask[k+1] for k in range(len(mask)-1)])[0] 

+

248 

+

249 X = np.mat(hidden_state[:,mask]) 

+

250 T = len(np.nonzero(mask)[0]) 

+

251 

+

252 Y = np.mat(obs[:,mask]) 

+

253 if include_offset: 

+

254 X = np.vstack([ X, np.ones([1,T]) ]) 

+

255 else: 

+

256 num_hidden_state, T = hidden_state.shape 

+

257 X = np.mat(hidden_state) 

+

258 if include_offset: 

+

259 X = np.vstack([ X, np.ones([1,T]) ]) 

+

260 if not drives_obs == None: 

+

261 drives_obs = np.hstack([drives_obs, True]) 

+

262 

+

263 Y = np.mat(obs) 

+

264 

+

265 X = np.array(X) 

+

266 if not drives_obs == None: 

+

267 X = X[drives_obs, :] 

+

268 Y = np.array(Y) 

+

269 

+

270 # ML estimate of C and Q 

+

271 n_units = Y.shape[0] 

+

272 n_states = X.shape[0] 

+

273 C = np.zeros([n_units, n_states]) 

+

274 pvalues = np.zeros([n_units, n_states]) 

+

275 import statsmodels.api as sm 

+

276 glm_family = sm.families.Poisson() 

+

277 for k in range(n_units): 

+

278 model = sm.GLM(Y[k,:], X.T, family=glm_family) 

+

279 try: 

+

280 model_fit = model.fit() 

+

281 C[k,:] = model_fit.params 

+

282 pvalues[k,:] = model_fit.pvalues 

+

283 except: 

+

284 pvalues[k,:] = np.nan 

+

285 

+

286 return C, pvalues 

+

287 

+

288 

+

289 

+

290class OneStepMPCPointProcessFilter(PointProcessFilter): 

+

291 ''' 

+

292 Use MPC with a horizon of 1 to predict  

+

293 ''' 

+

294 attrs_to_pickle = ['A', 'W', 'C'] 

+

295 def _pickle_init(self): 

+

296 super(OneStepMPCPointProcessFilter, self)._pickle_init() 

+

297 

+

298 self.prev_obs = None 

+

299 if not hasattr(self, 'mpc_cost_step'): 

+

300 mpc_cost_half_life = 1200. 

+

301 batch_time = 0.1 

+

302 self.mpc_cost_step = np.exp(np.log(0.5) / (mpc_cost_half_life/batch_time)) 

+

303 self.ESS = 1000 

+

304 

+

305 def _ssm_pred(self, state, u=None, Bu=None, target_state=None, F=None): 

+

306 ''' Docstring 

+

307 Run the "predict" step of the Kalman filter/HMM inference algorithm: 

+

308 x_{t+1|t} = N(Ax_{t|t}, AP_{t|t}A.T + W) 

+

309 

+

310 Parameters 

+

311 ---------- 

+

312 state: GaussianState instance 

+

313 State estimate and estimator covariance of current state 

+

314 u: np.mat  

+

315  

+

316 

+

317 Returns 

+

318 ------- 

+

319 GaussianState instance 

+

320 Represents the mean and estimator covariance of the new state estimate 

+

321 ''' 

+

322 A = self.A 

+

323 

+

324 dt = self.dt 

+

325 

+

326 Loglambda_predict = self.C * state.mean 

+

327 exp = np.vectorize(lambda x: np.real(cmath.exp(x))) 

+

328 lambda_predict = exp(np.array(Loglambda_predict).ravel())/dt 

+

329 

+

330 Q_inv = np.mat(np.diag(lambda_predict*dt)) 

+

331 

+

332 if self.prev_obs is not None: 

+

333 y_ref = self.prev_obs 

+

334 G = self.C.T * Q_inv 

+

335 D = G * self.C 

+

336 D[:,-1] = 0 

+

337 D[-1,:] = 0 

+

338 

+

339 # Solve for R 

+

340 R = 200*D 

+

341 

+

342 alpha = A*state 

+

343 v = np.linalg.pinv(R + D)*(G*y_ref - D*alpha.mean) 

+

344 else: 

+

345 v = np.zeros_like(state.mean) 

+

346 

+

347 if Bu is not None: 

+

348 return A*state + Bu + self.state_noise + v 

+

349 elif u is not None: 

+

350 Bu = self.B * u 

+

351 return A*state + Bu + self.state_noise + v 

+

352 elif target_state is not None: 

+

353 B = self.B 

+

354 F = self.F 

+

355 return (A - B*F)*state + B*F*target_state + self.state_noise + v 

+

356 else: 

+

357 return A*state + self.state_noise + v 

+

358 

+

359 def _forward_infer(self, st, obs_t, **kwargs): 

+

360 res = super(OneStepMPCPointProcessFilter, self)._forward_infer(st, obs_t, **kwargs) 

+

361 

+

362 # if not (self.prev_obs is None): 

+

363 # # Update the sufficient statistics for the R matrix 

+

364 # l = self.mpc_cost_step 

+

365 # self.E00 = l*self.E00 + (1-l)*(self.prev_obs - self.C*self.A*st.mean)*(self.prev_obs - self.C*self.A*st.mean).T 

+

366 # self.E01 = l*self.E01 + (1-l)*(self.prev_obs - self.C*self.A*st.mean)*(obs_t - self.C*self.A*st.mean).T 

+

367 

+

368 self.prev_obs = obs_t 

+

369 return res 

+

370 

+

371class OneStepMPCPointProcessFilterCovFb(OneStepMPCPointProcessFilter): 

+

372 def _ssm_pred(self, state, **kwargs): 

+

373 ''' Docstring 

+

374 Run the "predict" step of the Kalman filter/HMM inference algorithm: 

+

375 x_{t+1|t} = N(Ax_{t|t}, AP_{t|t}A.T + W) 

+

376 

+

377 Parameters 

+

378 ---------- 

+

379 state: GaussianState instance 

+

380 State estimate and estimator covariance of current state 

+

381 u: np.mat  

+

382  

+

383 

+

384 Returns 

+

385 ------- 

+

386 GaussianState instance 

+

387 Represents the mean and estimator covariance of the new state estimate 

+

388 ''' 

+

389 A = self.A 

+

390 

+

391 dt = self.dt 

+

392 

+

393 Loglambda_predict = self.C * state.mean 

+

394 exp = np.vectorize(lambda x: np.real(cmath.exp(x))) 

+

395 lambda_predict = exp(np.array(Loglambda_predict).ravel())/dt 

+

396 

+

397 Q_inv = np.mat(np.diag(lambda_predict*dt)) 

+

398 

+

399 from .bmi import GaussianState 

+

400 if (self.prev_obs is not None) and (self.r_scale < np.inf): 

+

401 y_ref = self.prev_obs 

+

402 G = self.C.T * Q_inv 

+

403 D = G * self.C 

+

404 D[:,-1] = 0 

+

405 D[-1,:] = 0 

+

406 

+

407 # Solve for R 

+

408 R = self.r_scale*D 

+

409 

+

410 alpha = A*state 

+

411 v = np.linalg.pinv(R + D)*(G*y_ref - D*alpha.mean) 

+

412 I = np.mat(np.eye(D.shape[0])) 

+

413 C = self.C 

+

414 A = (I - G*C) * self.A 

+

415 mean = A*state.mean + G*y_ref 

+

416 cov = A*state.cov*A.T + self.W 

+

417 

+

418 return GaussianState(mean, cov) 

+

419 else: 

+

420 return A*state + self.state_noise 

+

421 # v = np.zeros_like(state.mean) 

+

422 

+

423 

+

424 

+

425class PPFDecoder(bmi.BMI, bmi.Decoder): 

+

426 def __call__(self, obs_t, **kwargs): 

+

427 ''' 

+

428 see bmi.Decoder.__call__ for docs 

+

429 ''' 

+

430 # The PPF model predicts that at most one spike can be observed in  

+

431 # each bin; if more are observed, squash the counts 

+

432 # (make a copy of the observation matrix prior to squashing) 

+

433 obs_t = obs_t.copy() 

+

434 obs_t[obs_t > 1] = 1 

+

435 return super(PPFDecoder, self).__call__(obs_t, **kwargs) 

+

436 

+

437 def shuffle(self): 

+

438 ''' 

+

439 Shuffle the neural model 

+

440  

+

441 Parameters 

+

442 ---------- 

+

443 None 

+

444  

+

445 Returns 

+

446 ------- 

+

447 None 

+

448 ''' 

+

449 import random 

+

450 inds = list(range(self.filt.C.shape[0])) 

+

451 random.shuffle(inds) 

+

452 

+

453 # shuffle rows of C 

+

454 self.filt.C = self.filt.C[inds, :] 

+

455 

+

456 def compute_suff_stats(self, hidden_state, obs, include_offset=True): 

+

457 ''' 

+

458 Calculate initial estimates of the parameter sufficient statistics used in the RML update rules 

+

459 

+

460 Parameters 

+

461 ---------- 

+

462 hidden_state : np.ndarray of shape (n_states, n_samples) 

+

463 Examples of the hidden state x_t taken from training seed data.  

+

464 obs : np.ndarray of shape (n_features, n_samples) 

+

465 Multiple neural observations paired with each of the hidden state examples 

+

466 include_offset : bool, optional 

+

467 If true, a state of all 1's is added to the hidden_state to represent mean offsets. True by default 

+

468 

+

469 Returns 

+

470 ------- 

+

471 R : np.ndarray of shape (n_states, n_states) 

+

472 Proportional to covariance of the hidden state samples  

+

473 S : np.ndarray of shape (n_features, n_states) 

+

474 Proportional to cross-covariance between  

+

475 T : np.ndarray of shape (n_features, n_features) 

+

476 Proportional to covariance of the neural observations 

+

477 ESS : float 

+

478 Effective number of samples. In the initialization, this is just the  

+

479 dimension of the array passed in, but the parameter can become non-integer  

+

480 during the update procedure as old parameters are "forgotten". 

+

481 ''' 

+

482 n_obs = obs.shape[0] 

+

483 nS = hidden_state.shape[0] 

+

484 

+

485 H = np.zeros([n_obs, nS, nS]) 

+

486 M = np.zeros([n_obs, nS]) 

+

487 S = np.zeros([n_obs, nS]) 

+

488 

+

489 C = self.filt.C[:,self.drives_neurons] 

+

490 

+

491 X = np.array(hidden_state) 

+

492 T = X.shape[1] 

+

493 if include_offset: 

+

494 if not np.all(X[-1,:] == 1): 

+

495 X = np.vstack([ X, np.ones([1,T]) ]) 

+

496 

+

497 for k in range(n_obs): 

+

498 Mu = np.exp(np.dot(C[k,:], X)).ravel() 

+

499 Y = obs[k,:] 

+

500 H[k] = np.dot((np.tile(np.array(-Mu), [nS, 1]) * X), X.T) 

+

501 M[k] = np.dot(Mu, X.T) 

+

502 S[k] = np.dot(Y, X.T) 

+

503 

+

504 self.H = H 

+

505 self.S = S 

+

506 self.M = M 

+

507 

+

508 @property 

+

509 def n_features(self): 

+

510 return self.filt.C.shape[0] 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_rat_bmi_decoder_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_rat_bmi_decoder_py.html new file mode 100644 index 00000000..2f07e1c3 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_rat_bmi_decoder_py.html @@ -0,0 +1,646 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\rat_bmi_decoder.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1 

+

2from riglib import bmi, plexon, source 

+

3from riglib.bmi import extractor 

+

4import numpy as np 

+

5from riglib.bmi import clda 

+

6from riglib.bmi import train 

+

7from riglib.bmi import state_space_models 

+

8import datetime 

+

9import matplotlib.pyplot as plt 

+

10 

+

11class State(object): 

+

12 '''For compatibility with other BMI decoding implementations, literally just holds the state''' 

+

13 

+

14 def __init__(self, mean, *args, **kwargs): 

+

15 self.mean = mean 

+

16 

+

17class RatFilter(object): 

+

18 '''Moving Avergae Filter used in 1D or 2D LFP control: 

+

19 x_{t} = a0*x_{t} + a1*x_{t-1} + a2*x_{t-2} + ... 

+

20 x_{t} = b_1:t*x_{} 

+

21 

+

22 

+

23 Parameters 

+

24 

+

25 ---------- 

+

26 A: np.array of shape (N, ) 

+

27 Weights for previous states 

+

28 X: np. array of previous states (N, ) 

+

29 ''' 

+

30 

+

31 def __init__(self, task_params): 

+

32 self.e1_inds = task_params['e1_inds'] 

+

33 self.e2_inds = task_params['e2_inds'] 

+

34 self.FR_to_freq_fn = task_params['FR_to_freq_fn'] 

+

35 self.t1 = task_params['t1'] 

+

36 self.t2 = task_params['t2'] 

+

37 self.mid = task_params['mid'] 

+

38 self.dec_params = task_params 

+

39 

+

40 #Cursor data (X) 

+

41 self.X = 0. 

+

42 

+

43 #Freq data(F)  

+

44 self.F = 0. 

+

45 self.baseline = False 

+

46 

+

47 def get_mean(self): 

+

48 return np.array(self.state.mean).ravel() 

+

49 

+

50 def init_from_task(self, n_units, **kwargs): 

+

51 #Define n_steps 

+

52 if 'nsteps' in kwargs: 

+

53 self.n_steps = kwargs['nsteps'] 

+

54 self.A = np.ones(( self.n_steps, ))/float(self.n_steps) 

+

55 

+

56 #Neural data (Y) 

+

57 self.Y = np.zeros(( self.n_steps, n_units)) 

+

58 self.n_units = n_units 

+

59 

+

60 else: 

+

61 raise Exception 

+

62 

+

63 

+

64 def _init_state(self, init_state=None,**kwargs): 

+

65 if init_state is None: 

+

66 init_state = 0 

+

67 

+

68 self.state = State(init_state) 

+

69 

+

70 def __call__(self, obs, **kwargs): 

+

71 self.state = self._mov_avg(obs, **kwargs) 

+

72 

+

73 def _mov_avg(self, obs,**kwargs): 

+

74 ''' Function to compute moving average with old mean and new observation''' 

+

75 

+

76 self.Y[:-1, :] = self.Y[1:, :] 

+

77 self.Y[-1, :] = np.squeeze(obs) 

+

78 

+

79 d_fr = np.sum(self.Y[:, self.e1_inds], axis=1) - np.sum(self.Y[:, self.e2_inds], axis=1) 

+

80 mean_FR = np.dot(d_fr, self.A) 

+

81 self.X = mean_FR 

+

82 self.F = self.FR_to_freq_fn(self.X) 

+

83 return State(self.X) 

+

84 

+

85 def FR_to_freq(self, mean_FR): 

+

86 return self.FR_to_freq_fn(mean_FR) 

+

87 

+

88 def _pickle_init(self): 

+

89 pass 

+

90 

+

91class IsmoreSleepFilter(RatFilter): 

+

92 def __init__(self, task_params): 

+

93 self.e1_inds = task_params['e1_inds'] 

+

94 self.e2_inds = task_params['e2_inds'] 

+

95 self.FR_to_alpha_fn = task_params['FR_to_alpha_fn'] 

+

96 self.dec_params = task_params 

+

97 self.mid = task_params['mid'] 

+

98 self.e1_max = task_params['e1_perc'] 

+

99 self.e2_max = task_params['e2_perc'] 

+

100 self.freq_lim = task_params['freq_lim'] 

+

101 

+

102 #Cursor data (X) 

+

103 self.FR = 0. 

+

104 

+

105 #Freq data(F)  

+

106 self.alpha = 0. 

+

107 self.model_attrs = [] 

+

108 self.baseline = False 

+

109 

+

110 def init_from_task(self, **kwargs): 

+

111 #Define n_steps 

+

112 self.n_steps = kwargs.pop('nsteps', 1) 

+

113 self.A = np.ones(( self.n_steps, ))/float(self.n_steps) 

+

114 

+

115 #Neural data (Y) 

+

116 self.Y = np.zeros(( self.n_steps, len(self.e1_inds)+len(self.e2_inds))) 

+

117 self.n_units = len(self.e1_inds)+len(self.e2_inds) 

+

118 

+

119 def _mov_avg(self, obs,**kwargs): 

+

120 ''' Function to compute moving average with old mean and new observation''' 

+

121 

+

122 self.Y[:-1, :] = self.Y[1:, :] 

+

123 self.Y[-1, :] = np.squeeze(obs) 

+

124 

+

125 if self.e1_max is not None: 

+

126 e1_tmp = np.min([self.e1_max, np.sum(self.Y[:, self.e1_inds], axis=1)]) 

+

127 else: 

+

128 e1_tmp = np.sum(self.Y[:, self.e1_inds], axis=1) 

+

129 

+

130 if self.e2_max is not None: 

+

131 e2_tmp = np.min([self.e2_max, np.sum(self.Y[:, self.e2_inds], axis=1)]) 

+

132 else: 

+

133 e2_tmp = np.sum(self.Y[:, self.e2_inds], axis=1) 

+

134 

+

135 d_fr = e1_tmp - e2_tmp 

+

136 mean_FR = np.dot(d_fr, self.A) 

+

137 self.FR = mean_FR 

+

138 self.alpha = self.FR_to_alpha_fn(self.FR) 

+

139 

+

140 # Max alpha is -1 or 1 :  

+

141 if self.alpha > self.freq_lim[1]: 

+

142 self.alpha = self.freq_lim[1] 

+

143 elif self.alpha < self.freq_lim[0]: 

+

144 self.alpha = self.freq_lim[0] 

+

145 

+

146 self.baseline = self.FR < self.mid 

+

147 return State(self.alpha) 

+

148 

+

149from riglib.bmi.bmi import Decoder 

+

150class RatDecoder(Decoder): 

+

151 

+

152 def __init__(self, *args, **kwargs): 

+

153 

+

154 #Args: filter, units, ssm, extractor_cls, extractor_kwargs 

+

155 super(RatDecoder, self).__init__(args[0], args[1], args[2]) 

+

156 self.extractor_cls = args[3] 

+

157 self.extractor_kwargs = args[4] 

+

158 self.n_features = len(self.filt.e1_inds) + len(self.filt.e2_inds) 

+

159 

+

160 def __getitem__(self, key): 

+

161 return getattr(self, key) 

+

162 

+

163 def __setitem__(self, key, value): 

+

164 setattr(self,key,value) 

+

165 

+

166 def predict(self, neural_obs, **kwargs): 

+

167 self.filt(neural_obs, **kwargs) 

+

168 

+

169 

+

170 def init_from_task(self,**kwargs): 

+

171 pass 

+

172 

+

173class IsmoreSleepDecoder(RatDecoder): 

+

174 def __init__(self, *args, **kwargs): 

+

175 self.binlen = 0.1 

+

176 super(IsmoreSleepDecoder, self).__init__(*args, **kwargs) 

+

177 

+

178########## Functions to make decoder ########### 

+

179 

+

180def create_decoder(ssm, task_params): 

+

181 filter_ = RatFilter(task_params) 

+

182 decoder = RatDecoder(filter_, task_params['units'], ssm, task_params['extractor_cls'], dict()) 

+

183 return decoder 

+

184 

+

185########### Called from trainbmi.py to make decoder from Baseline ##### 

+

186import re 

+

187cellname = re.compile(r'(\d{1,3})\s*(\w{1})') 

+

188 

+

189def calc_decoder_from_baseline_file(neural_features, neural_features_unbinned, units, nsteps, prob_t1, 

+

190 prob_t2, timeout, timeout_pause, freq_lim, e1_inds, e2_inds, sim_fcn='rat', **kwargs): 

+

191 

+

192 #Enter e1, e2 as string:  

+

193 if np.logical_or(e1_inds is None, e2_inds is None): 

+

194 e1_string = input('Enter e1 cells: ') 

+

195 e2_string = input('Enter e2 cells: ') 

+

196 

+

197 e1 = np.array([ (int(c), ord(u) - 96) for c, u in cellname.findall(e1_string)]) 

+

198 e2 = np.array([ (int(c), ord(u) - 96) for c, u in cellname.findall(e2_string)]) 

+

199 

+

200 e1_inds = np.array([i for i, u in enumerate(units) if np.logical_and(u[0] in e1[:,0], u[1] in e1[:,1])]) 

+

201 e2_inds = np.array([i for i, u in enumerate(units) if np.logical_and(u[0] in e2[:,0], u[1] in e2[:,1])]) 

+

202 

+

203 T = neural_features.shape[0] 

+

204 if 'saturate_perc' in kwargs: 

+

205 baseline_data = np.zeros((T - nsteps, 2)) 

+

206 else: 

+

207 baseline_data = np.zeros((T - nsteps)) 

+

208 for ib in range(T-nsteps): 

+

209 if 'saturate_perc' in kwargs: 

+

210 baseline_data[ib, 0] = np.mean(np.sum(neural_features[ib:ib+nsteps, 

+

211 e1_inds], axis=1)) 

+

212 baseline_data[ib, 1] = np.mean(np.sum(neural_features[ib:ib+nsteps, 

+

213 e2_inds], axis=1)) 

+

214 else: 

+

215 baseline_data[ib] = np.mean(np.sum(neural_features[ib:ib+nsteps, 

+

216 e1_inds], axis=1))-np.mean(np.sum(neural_features[ib:ib+nsteps, 

+

217 e2_inds], axis=1)) 

+

218 

+

219 if 'saturate_perc' in kwargs: 

+

220 

+

221 sat_perc = kwargs.pop('saturate_perc') 

+

222 # ignore the first second of data 

+

223 e1_perc = np.percentile(baseline_data[20:, 0], sat_perc) 

+

224 e2_perc = np.percentile(baseline_data[20:, 1], sat_perc) 

+

225 

+

226 baseline_data[:, 0][baseline_data[:, 0] > e1_perc] = e1_perc 

+

227 baseline_data[:, 1][baseline_data[:, 1] > e2_perc] = e2_perc 

+

228 baseline_data = baseline_data[:, 0] - baseline_data[:, 1] 

+

229 else: 

+

230 e1_perc = None 

+

231 e2_perc = None 

+

232 

+

233 x, pdf, pdf_individual = generate_gmm(baseline_data) 

+

234 

+

235 if sim_fcn == 'rat': 

+

236 t2, mid, t1, num_t1, num_t2, num_miss, FR_to_freq_fn = sim_data(x, pdf, pdf_individual, prob_t1, prob_t2, 

+

237 baseline_data, timeout, timeout_pause, freq_lim, sim_bmi_fcn='rat') 

+

238 

+

239 return e1_inds, e2_inds, FR_to_freq_fn, units, t1, t2, mid 

+

240 

+

241 elif sim_fcn == 'ismore': 

+

242 # Get fcn:  

+

243 t1 = prob_under_pdf(x, pdf, prob_t1) 

+

244 t2 = prob_under_pdf(x, pdf, prob_t2) 

+

245 idx_mid = np.argmax(pdf) 

+

246 mid = x[idx_mid] 

+

247 FR_to_alpha_fn = map_to_freq(t2, mid, t1, freq_lim[0], freq_lim[1]) 

+

248 

+

249 # Plot FR to alpha fcn 

+

250 x_axis = np.linspace(t2, t1, 100) 

+

251 y_axis = [] 

+

252 for xi in x_axis: 

+

253 yi = FR_to_alpha_fn(xi) 

+

254 yi = np.max([freq_lim[0], yi]) 

+

255 yi = np.min([freq_lim[1], yi]) 

+

256 y_axis.append(yi) 

+

257 

+

258 x_axis2 = np.arange(-10, 10) 

+

259 y_axis2 = [] 

+

260 for xi in x_axis2: 

+

261 y_axis2.append(FR_to_alpha_fn(xi)) 

+

262 

+

263 import matplotlib.pyplot as plt 

+

264 f, ax = plt.subplots() 

+

265 ax.plot(x_axis, y_axis) 

+

266 ax.plot(x_axis2, y_axis2, 'r.') 

+

267 ax.plot([-5, 5], [0, 0], 'r--') 

+

268 

+

269 kwargs2 = dict(replay_neural_features = neural_features, e1_inds=e1_inds, 

+

270 e2_inds = e2_inds, FR_to_alpha_fn=FR_to_alpha_fn, mid = mid, e1_perc=e1_perc, 

+

271 e2_perc=e2_perc, freq_lim=freq_lim) 

+

272 

+

273 filt = IsmoreSleepFilter(kwargs2) 

+

274 decoder = IsmoreSleepDecoder(filt, units, state_space_models.StateSpaceEndptPos1D(), 

+

275 extractor.BinnedSpikeCountsExtractor, {}) 

+

276 if kwargs['skip_sim']: 

+

277 nrewards = [] 

+

278 import pdb 

+

279 pdb.set_trace() 

+

280 

+

281 else: 

+

282 if type(kwargs['targets_matrix']) is str: 

+

283 import pickle 

+

284 kwargs['targets_matrix'] = pickle.load(open(kwargs['targets_matrix'])) 

+

285 pname = ismore_sim_bmi(neural_features_unbinned, decoder, targets_matrix=kwargs['targets_matrix'], 

+

286 session_length=kwargs['session_length']) 

+

287 

+

288 # Analyze data:  

+

289 import tables 

+

290 import matplotlib.pyplot as plt 

+

291 hdf = tables.openFile(pname[:-4]+'.hdf') 

+

292 

+

293 # Plot x/y trajectory:  

+

294 f, ax = plt.subplots() 

+

295 ax.plot(hdf.root.task[2:]['plant_pos'][:, 0], hdf.root.task[2:]['plant_pos'][:, 1]) 

+

296 

+

297 ix = np.nonzero(hdf.root.task[:]['target_index'] == 1)[0] 

+

298 ax.plot(hdf.root.task[ix[0]]['target_pos'][0], hdf.root.task[ix[0]]['target_pos'][1], 'r.', 

+

299 markersize=20) 

+

300 

+

301 ix = np.nonzero(hdf.root.task[:]['target_index'] == 0)[0] + 5 

+

302 ax.plot(hdf.root.task[ix[0]]['target_pos'][0], hdf.root.task[ix[0]]['target_pos'][1], 'g.', 

+

303 markersize=20) 

+

304 

+

305 nrewards = np.nonzero(hdf.root.task_msgs[:]['msg']=='reward')[0] 

+

306 

+

307 return decoder, len(nrewards) 

+

308 

+

309 

+

310###### From Rat BMI ####### 

+

311###### From Rat BMI ####### 

+

312###### From Rat BMI ####### 

+

313###### From Rat BMI ####### 

+

314 

+

315from sklearn import metrics 

+

316from sklearn.mixture import GMM 

+

317import numpy as np 

+

318import matplotlib.pyplot as plt 

+

319 

+

320def generate_gmm(data, ax=None): 

+

321 ##reshape the data 

+

322 X = data.reshape(data.shape[0], 1) 

+

323 ##fit models with 1-10 components 

+

324 N = np.arange(1,11) 

+

325 models = [None for i in range(len(N))] 

+

326 for i in range(len(N)): 

+

327 models[i] = GMM(N[i]).fit(X) 

+

328 ##compute AIC 

+

329 AIC = [m.aic(X) for m in models] 

+

330 ##figure out the best-fit mixture 

+

331 M_best = models[np.argmin(AIC)] 

+

332 x = np.linspace(data.min()-1, data.max()+1, data.size) 

+

333 ##compute the pdf 

+

334 logprob, responsibilities = M_best.score_samples(x.reshape(x.size, 1)) 

+

335 pdf = np.exp(logprob) 

+

336 pdf_individual = responsibilities * pdf[:, np.newaxis] 

+

337 #plot the stuff 

+

338 if ax is None: 

+

339 fig, ax = plt.subplots() 

+

340 ax.hist(X, 50, normed = True, histtype = 'stepfilled', alpha = 0.4) 

+

341 ax.plot(x, pdf, '-k') 

+

342 ax.plot(x, pdf_individual, '--k') 

+

343 ax.text(0.04, 0.96, "Best-fit Mixture", 

+

344 ha='left', va='top', transform=ax.transAxes) 

+

345 ax.set_xlabel('$x$') 

+

346 ax.set_ylabel('$p(x)$') 

+

347 return x, pdf, pdf_individual 

+

348 

+

349##this function takes in an array of x-values and an array 

+

350##of y-values that correspond to a probability density function 

+

351##and determines the x-value at which the area under the PDF is approximately 

+

352##equal to some value passed in the arguments. 

+

353def prob_under_pdf(x_pdf, y_pdf, prob): 

+

354 auc = 0 

+

355 i = 2 

+

356 while auc < prob: 

+

357 x_range = x_pdf[0:i] 

+

358 y_range = y_pdf[0:i] 

+

359 auc = metrics.auc(x_range, y_range) 

+

360 i+=1 

+

361 return x_pdf[i] 

+

362 

+

363##function to map ensemble values to frequency values 

+

364def map_to_freq(t2, mid, t1, min_freq, max_freq): 

+

365 fr_pts = np.array([t2, mid, t1]) 

+

366 freq_pts = np.array([min_freq, (float(max_freq) + float(min_freq))/2., max_freq]) 

+

367 z = np.polyfit(fr_pts, freq_pts, 2) 

+

368 p = np.poly1d(z) 

+

369 return p 

+

370 

+

371def sim_data(x, pdf, pdf_individual, prob_t1, prob_t2, data, 

+

372 timeout, timeout_pause, freq_lim, sim_bmi_fcn='rat'): 

+

373 t1 = prob_under_pdf(x, pdf, prob_t1) 

+

374 t2 = prob_under_pdf(x, pdf, prob_t2) 

+

375 idx_mid = np.argmax(pdf) 

+

376 mid = x[idx_mid] 

+

377 fig, ax1 = plt.subplots() 

+

378 ax1.hist(data+np.random.normal(0, 0.1*data.std(), data.size), 50, 

+

379 normed = True, histtype = 'stepfilled', alpha = 0.4) 

+

380 ax1.plot(x, pdf, '-k') 

+

381 ax1.plot(x, pdf_individual, '--k') 

+

382 ax1.text(0.04, 0.96, "Best-fit Mixture", 

+

383 ha='left', va='top', transform=ax1.transAxes) 

+

384 ax1.set_xlabel('Cursor Value (E1-E2)') 

+

385 ax1.set_ylabel('$p(x)$') 

+

386 ##find the points where t1 and t2 lie on the gaussian 

+

387 idx_t2 = np.where(x>t2)[0][0] 

+

388 x_t2 = t2 

+

389 y_t2 = pdf[idx_t2] 

+

390 idx_t1 = np.where(x>t1)[0][0] 

+

391 x_t1 = t1 

+

392 y_t1 = pdf[idx_t1] 

+

393 y_mid = pdf[idx_mid] 

+

394 ax1.plot(x_t1, y_t1, 'o', color = 'g') 

+

395 ax1.plot(x_t2, y_t2, 'o', color = 'g') 

+

396 ax1.plot(mid, y_mid, 'o', color = 'g') 

+

397 ax1.set_title("Firing rate histogram and gaussian fit") 

+

398 ax1.annotate('T1: ('+str(round(x_t1, 3))+')', xy=(x_t1, y_t1), xytext=(40,20), 

+

399 textcoords='offset points', ha='center', va='bottom', 

+

400 bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3), 

+

401 arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.5', 

+

402 color='red')) 

+

403 ax1.annotate('T2: ('+str(round(x_t2, 3))+')', xy=(x_t2, y_t2), xytext=(-40,20), 

+

404 textcoords='offset points', ha='center', va='bottom', 

+

405 bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3), 

+

406 arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.5', 

+

407 color='red')) 

+

408 ax1.annotate('Base: ('+str(round(mid, 3))+')', xy=(mid, y_mid), xytext=(-100,-20), 

+

409 textcoords='offset points', ha='center', va='bottom', 

+

410 bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3), 

+

411 arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.5', 

+

412 color='red')) 

+

413 ##get the control function 

+

414 p = map_to_freq(t2, mid, t1, freq_lim[0], freq_lim[1]) 

+

415 ##run a simulation 

+

416 if sim_bmi_fcn=='rat': 

+

417 num_t1, num_t2, num_miss = sim_bmi(data, t1, t2, mid, timeout, timeout_pause, p) 

+

418 elif sim_bmi_fcn == 'ismore': 

+

419 num_t1, num_t2, num_miss = ismore_sim_bmi(data, t1, t2, mid, timeout, timeout_pause, p) 

+

420 

+

421 print("Simulation results:\nNumber of T1: " + str(num_t1) + "\nNumber of T2: " + str(num_t2) + "\nNumber of Misses: " + str(num_miss)) 

+

422 print("Calculated T2 value is " + str(round(t2, 5))) 

+

423 print("Calculated mid value is " + str(round(mid, 5))) 

+

424 print("Calculated T1 value is " + str(round(t1, 5))) 

+

425 ##plot the control functio 

+

426 plot_cursor_func(t2, mid, t1, freq_lim[0], freq_lim[1]) 

+

427 #plt.show() 

+

428 return t2, mid, t1, num_t1, num_t2, num_miss, p 

+

429 

+

430def sim_bmi(baseline_data, t1, t2, midpoint, timeout, timeout_pause, p): 

+

431 data = baseline_data 

+

432 samp_int = 100. #ms 

+

433 

+

434 ##get the timeout duration in samples 

+

435 timeout_samps = int((timeout*1000.0)/samp_int) 

+

436 timeout_pause_samps = int((timeout_pause*1000.0)/samp_int) 

+

437 ##"global" variables 

+

438 num_t1 = 0 

+

439 num_t2 = 0 

+

440 num_miss = 0 

+

441 back_to_baseline = 1 

+

442 ##run through the data and simulate BMI 

+

443 i = 0 

+

444 clock = 0 

+

445 while i < (data.shape[0]-1): 

+

446 cursor = data[i] 

+

447 ##check for a target hit 

+

448 if cursor >= t1: 

+

449 num_t1+=1 

+

450 i += int(4000/samp_int) 

+

451 back_to_baseline = 0 

+

452 ##wait for a return to baseline 

+

453 while cursor >= midpoint and i < (data.shape[0]-1): 

+

454 #advance the sample 

+

455 i+=1 

+

456 ##get cursor value 

+

457 cursor = data[i] 

+

458 ##reset the clock 

+

459 clock = 0 

+

460 elif cursor <= t2: 

+

461 num_t2+=1 

+

462 i += int(4000/samp_int) 

+

463 back_to_baseline = 0 

+

464 ##wait for a return to baseline 

+

465 while cursor >= midpoint and i < (data.shape[0]-1): 

+

466 #advance the sample 

+

467 i+=1 

+

468 ##get cursor value 

+

469 cursor = data[i] 

+

470 ##reset the clock 

+

471 clock = 0 

+

472 elif clock >= timeout_samps: 

+

473 ##advance the samples for the timeout duration 

+

474 i+= timeout_pause_samps 

+

475 num_miss += 1 

+

476 ##reset the clock 

+

477 clock = 0 

+

478 else: 

+

479 ##if nothing else, advance the clock and the sample 

+

480 i+= 1 

+

481 clock+=1 

+

482 return num_t1, num_t2, num_miss 

+

483 

+

484def ismore_sim_bmi(baseline_data, decoder, targets_matrix=None, session_length=0.): 

+

485 import ismore.invasive.bmi_ismoretasks as bmi_ismoretasks 

+

486 from riglib import experiment 

+

487 from features.hdf_features import SaveHDF 

+

488 from ismore.brainamp_features import SimBrainAmpData 

+

489 import datetime 

+

490 import numpy as np 

+

491 import matplotlib.pyplot as plt 

+

492 import multiprocessing as mp 

+

493 from features.blackrock_features import BlackrockBMI 

+

494 from ismore.exo_3D_visualization import Exo3DVisualizationInvasive 

+

495 

+

496 targets = bmi_ismoretasks.SimBMIControlReplayFile.sleep_gen(length=100) 

+

497 plant_type = 'IsMore' 

+

498 kwargs=dict(session_length=session_length, replay_neural_features=baseline_data, decoder=decoder) 

+

499 

+

500 if targets_matrix is not None: 

+

501 kwargs['targets_matrix']=targets_matrix 

+

502 

+

503 Task = experiment.make(bmi_ismoretasks.SimBMIControlReplayFile, [SaveHDF])#, Exo3DVisualizationInvasive]) 

+

504 task = Task(targets, plant_type=plant_type, **kwargs) 

+

505 task.run_sync() 

+

506 pnm = save_dec_enc(task) 

+

507 return pnm 

+

508 

+

509def plot_cursor_func(t2, mid, t1, min_freq, max_freq): 

+

510 f, ax2 = plt.subplots() 

+

511 x = np.linspace(t2-1, t1+1, 1000) 

+

512 func = map_to_freq(t2, mid, t1, min_freq, max_freq) 

+

513 #fig, ax = plt.subplots() 

+

514 ax2.plot(t2, min_freq, 'o', color = 'r') 

+

515 ax2.plot(mid, np.floor((max_freq-min_freq)/2), 'o', color = 'r') 

+

516 ax2.plot(t1, max_freq, 'o', color = 'r') 

+

517 ax2.plot(x, func(x), '-', color = 'g') 

+

518 ax2.annotate('T1: ('+str(round(t1, 3))+')', xy=(t1, max_freq), xytext=(-20, 20), 

+

519 textcoords='offset points', ha='center', va='bottom', 

+

520 bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3), 

+

521 arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.5', 

+

522 color='red')) 

+

523 ax2.annotate('T2: ('+str(round(t2, 3))+')', xy=(t2, min_freq), xytext=(-20,20), 

+

524 textcoords='offset points', ha='center', va='bottom', 

+

525 bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3), 

+

526 arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.5', 

+

527 color='red')) 

+

528 ax2.annotate('Base: ('+str(round(mid, 3))+')', xy=(mid, np.floor((max_freq-min_freq)/2)), xytext=(-20,20), 

+

529 textcoords='offset points', ha='center', va='bottom', 

+

530 bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3), 

+

531 arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0.5', 

+

532 color='red')) 

+

533 ax2.set_ylabel("Feedback frequency") 

+

534 ax2.set_xlabel("Cursor value (E1-E2)") 

+

535 ax2.set_title("Cursor-frequency map", fontsize = 18) 

+

536 

+

537def save_dec_enc(task, pref='sleep_sim_'): 

+

538 ''' 

+

539 Summary: method to save encoder / decoder and hdf file information from task in sim_data folder 

+

540 Input param: task: task, output from arm_assist_main, or generally task object 

+

541 Input param: pref: prefix to saved file names (defaults to 'enc' for encoder) 

+

542 Output param: pkl file name used to save encoder/decoder 

+

543 ''' 

+

544 #enc = task.encoder 

+

545 task.decoder.save() 

+

546 #enc.corresp_dec = task.decoder 

+

547 

+

548 #Save task info 

+

549 import pickle 

+

550 ct = datetime.datetime.now() 

+

551 #pnm = '/Users/preeyakhanna/ismore/ismore_tests/sim_data/'+pref + ct.strftime("%Y%m%d_%H_%M_%S") + '.pkl' 

+

552 pnm = '/home/tecnalia/code/ismore/ismore_tests/sim_data/'+pref + ct.strftime("%m%d%y_%H%M") + '.pkl' 

+

553 pnm2 = '/Users/preeyakhanna/code/ismore/ismore_tests/sim_data/'+pref + ct.strftime("%m%d%y_%H%M") + '.pkl' 

+

554 

+

555 try: 

+

556 pickle.dump(dict(), open(pnm,'wb')) 

+

557 except: 

+

558 pickle.dump(dict(), open(pnm2, 'wb')) 

+

559 pnm = pnm2 

+

560 

+

561 #Save HDF file 

+

562 new_hdf = pnm[:-4]+'.hdf' 

+

563 import shutil 

+

564 f = open(task.h5file.name) 

+

565 f.close() 

+

566 

+

567 #Wait  

+

568 import time 

+

569 time.sleep(1.) 

+

570 

+

571 #Wait after HDF cleaned up 

+

572 task.cleanup_hdf() 

+

573 import time 

+

574 time.sleep(1.) 

+

575 

+

576 #Copy temp file to actual desired location 

+

577 shutil.copy(task.h5file.name, new_hdf) 

+

578 f = open(new_hdf) 

+

579 f.close() 

+

580 

+

581 #Return filename 

+

582 return pnm 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_robot_arms_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_robot_arms_py.html new file mode 100644 index 00000000..519219db --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_robot_arms_py.html @@ -0,0 +1,785 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\robot_arms.py: 32% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Classes implementing various kinematic chains. This module is perhaps mis-located 

+

3as it does not have a direct BMI role but rather contains code which is useful in 

+

4supporting BMI control of kinematic chains. 

+

5 

+

6This code depends on the 'robot' module (https://github.com/sgowda/robotics_toolbox) 

+

7''' 

+

8import numpy as np 

+

9try: 

+

10 import robot 

+

11except ImportError: 

+

12 import warnings 

+

13 warnings.warn("The 'robot' module cannot be found! See https://github.com/sgowda/robotics_toolbox") 

+

14 

+

15import matplotlib.pyplot as plt 

+

16from collections import OrderedDict 

+

17 

+

18 

+

19import time 

+

20 

+

21pi = np.pi 

+

22 

+

23class KinematicChain(object): 

+

24 ''' 

+

25 Arbitrary kinematic chain (i.e. spherical joint at the beginning of  

+

26 each joint) 

+

27 ''' 

+

28 def __init__(self, link_lengths=[10., 10.], name='', base_loc=np.array([0., 0., 0.]), rotation_convention=1): 

+

29 ''' 

+

30 Docstring  

+

31 

+

32 Parameters 

+

33 ---------- 

+

34 link_lengths: iterable 

+

35 Lengths of all the distances between joints 

+

36 base_loc: np.array of shape (3,), default=np.array([0, 0, 0]) 

+

37 Location of the base of the kinematic chain in an "absolute" reference frame 

+

38 ''' 

+

39 self.n_links = len(link_lengths) 

+

40 self.link_lengths = link_lengths 

+

41 self.base_loc = base_loc 

+

42 

+

43 assert rotation_convention in [-1, 1] 

+

44 self.rotation_convention = rotation_convention 

+

45 

+

46 # Create the robot object. Override for child classes with different types of joints 

+

47 self._init_serial_link() 

+

48 self.robot.name = name 

+

49 

+

50 def _init_serial_link(self): 

+

51 links = [] 

+

52 for link_length in self.link_lengths: 

+

53 link1 = robot.Link(alpha=-pi/2) 

+

54 link2 = robot.Link(alpha=pi/2) 

+

55 link3 = robot.Link(d=-link_length) 

+

56 links += [link1, link2, link3] 

+

57 

+

58 # By convention, we start the arm in the XY-plane 

+

59 links[1].offset = -pi/2 

+

60 

+

61 self.robot = robot.SerialLink(links) 

+

62 

+

63 def calc_full_joint_angles(self, joint_angles): 

+

64 ''' 

+

65 Override in child classes to perform static transforms on joint angle inputs. If some  

+

66 joints are always static (e.g., if the chain only operates in a plane) 

+

67 this can avoid unclutter joint angle specifications. 

+

68 ''' 

+

69 return self.rotation_convention * joint_angles 

+

70 

+

71 def full_angles_to_subset(self, joint_angles): 

+

72 ''' 

+

73 Docstring  

+

74  

+

75 Parameters 

+

76 ---------- 

+

77  

+

78 Returns 

+

79 ------- 

+

80 ''' 

+

81 return joint_angles 

+

82 

+

83 def plot(self, joint_angles): 

+

84 ''' 

+

85 Docstring  

+

86  

+

87 Parameters 

+

88 ---------- 

+

89  

+

90 Returns 

+

91 ------- 

+

92 ''' 

+

93 

+

94 joint_angles = self.calc_full_joint_angles(joint_angles) 

+

95 self.robot.plot(joint_angles) 

+

96 

+

97 def forward_kinematics(self, joint_angles, **kwargs): 

+

98 ''' 

+

99 Calculate forward kinematics using D-H parameter convention 

+

100 

+

101 Parameters 

+

102 ---------- 

+

103  

+

104 Returns 

+

105 -------  

+

106 ''' 

+

107 joint_angles = self.calc_full_joint_angles(joint_angles) 

+

108 t, allt = self.robot.fkine(joint_angles, **kwargs) 

+

109 self.joint_angles = joint_angles 

+

110 self.t = t 

+

111 self.allt = allt 

+

112 return t, allt 

+

113 

+

114 def apply_joint_limits(self, joint_angles): 

+

115 ''' 

+

116 Docstring  

+

117  

+

118 Parameters 

+

119 ---------- 

+

120  

+

121 Returns 

+

122 ------- 

+

123 ''' 

+

124 return joint_angles 

+

125 

+

126 def inverse_kinematics(self, target_pos, q_start=None, method='pso', **kwargs): 

+

127 ''' 

+

128 Docstring  

+

129  

+

130 Parameters 

+

131 ---------- 

+

132  

+

133 Returns 

+

134 ------- 

+

135 ''' 

+

136 if q_start == None: 

+

137 q_start = self.random_sample() 

+

138 return self.inverse_kinematics_pso(target_pos, q_start, **kwargs) 

+

139 # ik_method = getattr(self, 'inverse_kinematics_%s' % method) 

+

140 # return ik_method(q_start, target_pos) 

+

141 

+

142 def inverse_kinematics_grad_descent(self, target_pos, starting_config, n_iter=1000, verbose=False, eps=0.01, return_path=False): 

+

143 ''' 

+

144 Default inverse kinematics method is RRT since for redundant  

+

145 kinematic chains, an infinite number of inverse kinematics solutions  

+

146 exist 

+

147 

+

148 Docstring  

+

149  

+

150 Parameters 

+

151 ---------- 

+

152  

+

153 Returns 

+

154 -------  

+

155 ''' 

+

156 

+

157 q = starting_config 

+

158 start_time = time.time() 

+

159 endpoint_traj = np.zeros([n_iter, 3]) 

+

160 

+

161 joint_limited = np.zeros(len(q)) 

+

162 

+

163 for k in range(n_iter): 

+

164 # print k 

+

165 # calc endpoint position of the manipulator 

+

166 endpoint_traj[k] = self.endpoint_pos(q) 

+

167 

+

168 current_cost = np.linalg.norm(endpoint_traj[k] - target_pos, 2) 

+

169 if current_cost < eps: 

+

170 print("Terminating early") 

+

171 break 

+

172 

+

173 # calculate the jacobian 

+

174 J = self.jacobian(q) 

+

175 J_pos = J[0:3,:] 

+

176 

+

177 # for joints that are at their limit, zero out the jacobian? 

+

178 # J_pos[:, np.nonzero(self.calc_full_joint_angles(joint_limited))] = 0 

+

179 

+

180 # take a step from the current position toward the target pos using the inverse Jacobian 

+

181 J_inv = np.linalg.pinv(J_pos) 

+

182 # J_inv = J_pos.T 

+

183 

+

184 xdot = (target_pos - endpoint_traj[k])#/np.linalg.norm(endpoint_traj[k] - target_pos)  

+

185 

+

186 # if current_cost < 3 or k > 10: 

+

187 # stepsize = 0.001 

+

188 # else: 

+

189 # stepsize = 0.01 

+

190 

+

191 

+

192 xdot = (target_pos - endpoint_traj[k])#/np.linalg.norm(endpoint_traj[k] - target_pos) 

+

193 # xdot = (endpoint_traj[k] - target_pos)/np.linalg.norm(endpoint_traj[k] - target_pos) 

+

194 qdot = 0.001*np.dot(J_inv, xdot) 

+

195 qdot = self.full_angles_to_subset(np.array(qdot).ravel()) 

+

196 

+

197 q += qdot 

+

198 

+

199 # apply joint limits 

+

200 q, joint_limited = self.apply_joint_limits(q) 

+

201 

+

202 end_time = time.time() 

+

203 runtime = end_time - start_time 

+

204 if verbose: 

+

205 print("Runtime: %g" % runtime) 

+

206 print("# of iterations: %g" % k) 

+

207 

+

208 if return_path: 

+

209 return q, endpoint_traj[:k] 

+

210 else: 

+

211 return q 

+

212 

+

213 def jacobian(self, joint_angles): 

+

214 ''' 

+

215 Return the full jacobian  

+

216 

+

217 Docstring  

+

218  

+

219 Parameters 

+

220 ---------- 

+

221  

+

222 Returns 

+

223 -------  

+

224 ''' 

+

225 joint_angles = self.calc_full_joint_angles(joint_angles) 

+

226 J = self.robot.jacobn(joint_angles) 

+

227 return J 

+

228 

+

229 def endpoint_pos(self, joint_angles): 

+

230 ''' 

+

231 Docstring  

+

232  

+

233 Parameters 

+

234 ---------- 

+

235  

+

236 Returns 

+

237 ------- 

+

238 ''' 

+

239 

+

240 t, allt = self.forward_kinematics(joint_angles) 

+

241 pos_rel_to_base = np.array(t[0:3,-1]).ravel() 

+

242 return pos_rel_to_base + self.base_loc 

+

243 

+

244 def ik_cost(self, q, q_start, target_pos, weight=100): 

+

245 ''' 

+

246 Docstring  

+

247  

+

248 Parameters 

+

249 ---------- 

+

250  

+

251 Returns 

+

252 ------- 

+

253 ''' 

+

254 

+

255 q_diff = q - q_start 

+

256 return np.linalg.norm(q_diff[0:2]) + weight*np.linalg.norm(self.endpoint_pos(q) - target_pos) 

+

257 

+

258 def inverse_kinematics_pso(self, target_pos, q_start, time_limit=np.inf, verbose=False, eps=0.5, n_particles=10, n_iter=10): 

+

259 ''' 

+

260 Docstring  

+

261  

+

262 Parameters 

+

263 ---------- 

+

264  

+

265 Returns 

+

266 ------- 

+

267 ''' 

+

268 

+

269 # Initialize the particles;  

+

270 n_joints = self.n_joints 

+

271 

+

272 particles_q = np.tile(q_start, [n_particles, 1]) 

+

273 

+

274 # if 0: 

+

275 # # initialize the velocities to be biased around the direction the jacobian tells you is correct 

+

276 # current_pos = self.endpoint_pos(q_start) 

+

277 # int_displ = target_pos - current_pos 

+

278 # print int_displ, target_pos 

+

279 # J = self.jacobian(q_start) 

+

280 # endpoint_vel = np.random.randn(n_particles, 3)# + int_displ 

+

281 # particles_v = np.dot(J[0:3,1::3].T, endpoint_vel.T).T 

+

282 # else: 

+

283 # # initialize particle velocities randomly 

+

284 

+

285 

+

286 particles_v = np.random.randn(n_particles, n_joints) #/ np.array([1., 1., 1, 1]) #np.array(self.link_lengths) 

+

287 

+

288 cost_fn = lambda q: self.ik_cost(q, q_start, target_pos) 

+

289 

+

290 gbest = particles_q.copy() 

+

291 gbestcost = np.array(list(map(cost_fn, gbest))) 

+

292 pbest = gbest[np.argmin(gbestcost)] 

+

293 pbestcost = cost_fn(pbest) 

+

294 

+

295 min_limits = np.array([x[0] for x in self.joint_limits]) 

+

296 max_limits = np.array([x[1] for x in self.joint_limits]) 

+

297 min_limits = np.tile(min_limits, [n_particles, 1]) 

+

298 max_limits = np.tile(max_limits, [n_particles, 1]) 

+

299 

+

300 start_time = time.time() 

+

301 for k in range(n_iter): 

+

302 if time.time() - start_time > time_limit: 

+

303 break 

+

304 

+

305 # update positions of particles 

+

306 particles_q += particles_v 

+

307 

+

308 # apply joint limits 

+

309 min_viol = particles_q < min_limits 

+

310 max_viol = particles_q > max_limits 

+

311 particles_q[min_viol] = min_limits[min_viol] 

+

312 particles_q[max_viol] = max_limits[max_viol] 

+

313 

+

314 # update the costs 

+

315 costs = np.array(list(map(cost_fn, particles_q))) 

+

316 

+

317 # update the 'bests' 

+

318 gbest[gbestcost > costs] = particles_q[gbestcost > costs] 

+

319 gbestcost[gbestcost > costs] = costs[gbestcost > costs] 

+

320 

+

321 idx = np.argmin(gbestcost) 

+

322 pbest = gbest[idx] 

+

323 pbestcost = gbestcost[idx] 

+

324 

+

325 # update the velocity 

+

326 phi1 = 1#np.random.rand() 

+

327 phi2 = 1#np.random.rand() 

+

328 w=0.25 

+

329 c1=0.5 

+

330 c2=0.25 

+

331 particles_v = w*particles_v + c1*phi1*(pbest - particles_q) + c2*phi2*(gbest - particles_q) 

+

332 

+

333 error = np.linalg.norm(self.endpoint_pos(pbest) - target_pos) 

+

334 if error < eps: 

+

335 break 

+

336 

+

337 end_time = time.time() 

+

338 if verbose: print("Runtime = %g, error = %g, n_iter=%d" % (end_time-start_time, error, k)) 

+

339 

+

340 return pbest 

+

341 

+

342 def spatial_positions_of_joints(self, joint_angles): 

+

343 ''' 

+

344 Docstring  

+

345  

+

346 Parameters 

+

347 ---------- 

+

348  

+

349 Returns 

+

350 ------- 

+

351 ''' 

+

352 

+

353 _, allt = self.forward_kinematics(joint_angles, return_allt=True) 

+

354 pos = (allt[0:3, -1,:].T + self.base_loc).T 

+

355 # pos = np.hstack([np.zeros([3,1]), pos]) 

+

356 return pos 

+

357 

+

358 

+

359class PlanarXZKinematicChain(KinematicChain): 

+

360 ''' 

+

361 Kinematic chain restricted to movement in the XZ-plane 

+

362 ''' 

+

363 def _init_serial_link(self): 

+

364 base = robot.Link(alpha=pi/2, d=0, a=0) 

+

365 links = [base] 

+

366 for link_length in self.link_lengths: 

+

367 link1 = robot.Link(alpha=0, d=0, a=link_length) 

+

368 links.append(link1) 

+

369 

+

370 # link2 = robot.Link(alpha=pi/2) 

+

371 # link3 = robot.Link(d=-link_length) 

+

372 # links += [link1, link2, link3] 

+

373 

+

374 # By convention, we start the arm in the XY-plane 

+

375 # links[1].offset = -pi/2  

+

376 

+

377 self.robot = robot.SerialLink(links) 

+

378 

+

379 def calc_full_joint_angles(self, joint_angles): 

+

380 ''' 

+

381 only some joints rotate in the planar kinematic chain 

+

382 

+

383 Parameters 

+

384 ---------- 

+

385 joint_angles : np.ndarray of shape (self.n_links) 

+

386 Joint angles without the angle for the base link, which is fixed at 0 

+

387  

+

388 Returns 

+

389 ------- 

+

390 joint_angles_full : np.ndarray of shape (self.n_links+1) 

+

391 Add on the 0 at the proximal end for the base link angle 

+

392 ''' 

+

393 if not len(joint_angles) == self.n_links: 

+

394 raise ValueError("Incorrect number of joint angles specified!") 

+

395 

+

396 # # There are really 3 angles per joint to allow 3D rotation at each joint 

+

397 # joint_angles_full = np.zeros(self.n_links * 3)  

+

398 # joint_angles_full[1::3] = joint_angles 

+

399 

+

400 joint_angles_full = np.hstack([0, joint_angles]) 

+

401 return self.rotation_convention * joint_angles_full 

+

402 

+

403 def random_sample(self): 

+

404 ''' 

+

405 Sample the joint configuration space within the limits of each joint 

+

406  

+

407 Parameters 

+

408 ---------- 

+

409 None 

+

410  

+

411 Returns 

+

412 ------- 

+

413 None 

+

414 ''' 

+

415 if hasattr(self, 'joint_limits'): 

+

416 joint_limits = self.joint_limits 

+

417 else: 

+

418 joint_limits = [(-np.pi, np.pi)] * self.n_links 

+

419 

+

420 q_start = [] 

+

421 for lim_min, lim_max in joint_limits: 

+

422 q_start.append(np.random.uniform(lim_min, lim_max)) 

+

423 return np.array(q_start) 

+

424 

+

425 def full_angles_to_subset(self, joint_angles): 

+

426 ''' 

+

427 Docstring  

+

428  

+

429 Parameters 

+

430 ---------- 

+

431  

+

432 Returns 

+

433 ------- 

+

434 ''' 

+

435 # return joint_angles[1::3] 

+

436 return joint_angles[1:] 

+

437 

+

438 def apply_joint_limits(self, joint_angles): 

+

439 ''' 

+

440 Docstring  

+

441  

+

442 Parameters 

+

443 ---------- 

+

444  

+

445 Returns 

+

446 ------- 

+

447 ''' 

+

448 if not hasattr(self, 'joint_limits'): 

+

449 return joint_angles 

+

450 else: 

+

451 angles = [] 

+

452 limit_hit = [] 

+

453 for angle, (lim_min, lim_max) in zip(joint_angles, self.joint_limits): 

+

454 limit_hit.append(angle < lim_min or angle > lim_max) 

+

455 angle = max(lim_min, angle) 

+

456 angle = min(angle, lim_max) 

+

457 angles.append(angle) 

+

458 

+

459 return np.array(angles), np.array(limit_hit) 

+

460 

+

461 @property 

+

462 def n_joints(self): 

+

463 ''' 

+

464 In a planar arm, the number of joints equals the number of links 

+

465 ''' 

+

466 return len(self.link_lengths) 

+

467 

+

468 def spatial_positions_of_joints(self, *args, **kwargs): 

+

469 ''' 

+

470 Docstring  

+

471  

+

472 Parameters 

+

473 ---------- 

+

474  

+

475 Returns 

+

476 ------- 

+

477 ''' 

+

478 pos_all_joints = super(PlanarXZKinematicChain, self).spatial_positions_of_joints(*args, **kwargs) 

+

479 return pos_all_joints #(pos_all_joints[:,::3].T + self.base_loc).T 

+

480 

+

481 def create_ik_subchains(self): 

+

482 ''' 

+

483 Docstring  

+

484  

+

485 Parameters 

+

486 ---------- 

+

487  

+

488 Returns 

+

489 ------- 

+

490 ''' 

+

491 proximal_link_lengths = self.link_lengths[:2] 

+

492 distal_link_lengths = self.link_lengths[2:] 

+

493 self.proximal_chain = PlanarXZKinematicChain2Link(proximal_link_lengths) 

+

494 if len(self.link_lengths) > 2: 

+

495 self.distal_chain = PlanarXZKinematicChain(distal_link_lengths) 

+

496 else: 

+

497 self.distal_chain = None 

+

498 

+

499 def inverse_kinematics(self, target_pos, **kwargs): 

+

500 ''' 

+

501 Docstring  

+

502  

+

503 Parameters 

+

504 ---------- 

+

505  

+

506 Returns 

+

507 ------- 

+

508 ''' 

+

509 target_pos = target_pos.copy() 

+

510 target_pos -= self.base_loc 

+

511 if not hasattr(self, 'proximal_chain') or not hasattr(self, 'distal_chain'): 

+

512 self.create_ik_subchains() 

+

513 

+

514 if len(self.link_lengths) > 2: 

+

515 distal_angles = kwargs.pop('distal_angles', None) 

+

516 

+

517 if distal_angles is None: 

+

518 # Sample randomly from the joint limits (-pi, pi) if not specified 

+

519 if not hasattr(self, 'joint_limits') or len(self.joint_limits) < len(self.link_lengths): 

+

520 joint_limits = [(-pi, pi)] * len(self.distal_chain.link_lengths) 

+

521 else: 

+

522 joint_limits = self.joint_limits[2:] 

+

523 distal_angles = np.array([np.random.uniform(*limits) for limits in joint_limits]) 

+

524 

+

525 distal_displ = self.distal_chain.endpoint_pos(distal_angles) 

+

526 proximal_endpoint_pos = target_pos - distal_displ 

+

527 proximal_angles = self.proximal_chain.inverse_kinematics(proximal_endpoint_pos).ravel() 

+

528 angles = distal_angles.copy() 

+

529 joint_angles = proximal_angles.tolist() 

+

530 angles[0] -= np.sum(proximal_angles) 

+

531 ik_angles = np.hstack([proximal_angles, angles]) 

+

532 ik_angles = np.array([np.arctan2(np.sin(angle), np.cos(angle)) for angle in ik_angles]) 

+

533 return ik_angles 

+

534 else: 

+

535 return self.proximal_chain.inverse_kinematics(target_pos).ravel() 

+

536 

+

537 def jacobian(self, theta, old=False): 

+

538 ''' 

+

539 Returns the first derivative of the forward kinematics function for x and z endpoint positions:  

+

540 [[dx/dtheta_1, ..., dx/dtheta_N] 

+

541 [dz/dtheta_1, ..., dz/dtheta_N]] 

+

542  

+

543 Parameters 

+

544 ---------- 

+

545 theta : np.ndarray of shape (N,) 

+

546 Valid configuration for the arm (the jacobian calculations are specific to the configuration of the arm) 

+

547  

+

548 Returns 

+

549 ------- 

+

550 J : np.ndarray of shape (2, N) 

+

551 Manipulator jacobian in the format above 

+

552 ''' 

+

553 if old: 

+

554 # Calculate jacobian based on hand calculation specific to this type of chain 

+

555 l = self.link_lengths 

+

556 N = len(theta) 

+

557 J = np.zeros([2, len(l)]) 

+

558 for m in range(N): 

+

559 for i in range(m, N): 

+

560 J[0, m] += -l[i]*np.sin(sum(self.rotation_convention*theta[:i+1])) 

+

561 J[1, m] += l[i]*np.cos(sum(self.rotation_convention*theta[:i+1])) 

+

562 return J 

+

563 else: 

+

564 # Use the robotics toolbox and the generic D-H convention jacobian 

+

565 J = self.robot.jacob0(self.calc_full_joint_angles(theta)) 

+

566 return np.array(J[[0,2], 1:]) 

+

567 

+

568 def endpoint_potent_null_split(self, q, vel, return_J=False): 

+

569 ''' 

+

570 (Approximately) split joint velocities into an endpoint potent component,  

+

571 which moves the endpoint, and an endpoint null component which only causes self-motion 

+

572 ''' 

+

573 J = self.jacobian(q) 

+

574 J_pinv = np.linalg.pinv(J) 

+

575 J_task = np.dot(J_pinv, J) 

+

576 J_null = np.eye(self.n_joints) - J_task 

+

577 

+

578 vel_task = np.dot(J_task, vel) 

+

579 vel_null = np.dot(J_null, vel) 

+

580 if return_J: 

+

581 return vel_task, vel_null, J, J_pinv 

+

582 else: 

+

583 return vel_task, vel_null 

+

584 

+

585 def config_change_nullspace_workspace(self, config1, config2): 

+

586 ''' 

+

587 For two configurations, determine how much joint displacement is in the "null" space and how much is in the "task" space 

+

588 

+

589 Docstring  

+

590  

+

591 Parameters 

+

592 ---------- 

+

593  

+

594 Returns 

+

595 ------- 

+

596 ''' 

+

597 config = config1 

+

598 vel = config2 - config1 

+

599 endpt1 = self.endpoint_pos(config1) 

+

600 endpt2 = self.endpoint_pos(config2) 

+

601 task_displ = np.linalg.norm(endpt1 - endpt2) 

+

602 

+

603 # compute total displ of individual joints 

+

604 total_joint_displ = 0 

+

605 

+

606 n_joints = len(config1) 

+

607 for k in range(n_joints): 

+

608 jnt_k_vel = np.zeros(n_joints) 

+

609 jnt_k_vel[k] = vel[k] 

+

610 single_joint_displ_pos = self.endpoint_pos(config + jnt_k_vel) 

+

611 total_joint_displ += np.linalg.norm(endpt1 - single_joint_displ_pos) 

+

612 

+

613 return task_displ, total_joint_displ 

+

614 

+

615 def detect_collision(self, theta, obstacle_pos): 

+

616 ''' 

+

617 Detect a collision between the chain and a circular object 

+

618 ''' 

+

619 spatial_joint_pos = self.spatial_positions_of_joints(theta).T + self.base_loc 

+

620 plant_segments = [(x, y) for x, y in zip(spatial_joint_pos[:-1], spatial_joint_pos[1:])] 

+

621 dist_to_object = np.zeros(len(plant_segments)) 

+

622 for k, segment in enumerate(plant_segments): 

+

623 dist_to_object[k] = point_to_line_segment_distance(obstacle_pos, segment) 

+

624 return dist_to_object 

+

625 

+

626 def plot_joint_pos(self, joint_pos, ax=None, flip=False, **kwargs): 

+

627 if ax == None: 

+

628 plt.figure() 

+

629 ax = plt.subplot(111) 

+

630 

+

631 if isinstance(joint_pos, dict): 

+

632 joint_pos = np.vstack(list(joint_pos.values())) 

+

633 elif isinstance(joint_pos, np.ndarray) and np.ndim(joint_pos) == 1: 

+

634 joint_pos = joint_pos.reshape(1, -1) 

+

635 elif isinstance(joint_pos, tuple): 

+

636 joint_pos = np.array(joint_pos).reshape(1, -1) 

+

637 

+

638 for pos in joint_pos: 

+

639 spatial_pos = self.spatial_positions_of_joints(pos).T 

+

640 

+

641 shoulder_anchor = np.array([2., 0., -15.]) 

+

642 spatial_pos = spatial_pos# + shoulder_anchor 

+

643 if flip: 

+

644 ax.plot(-spatial_pos[:,0], spatial_pos[:,2], **kwargs) 

+

645 else: 

+

646 ax.plot(spatial_pos[:,0], spatial_pos[:,2], **kwargs) 

+

647 

+

648 return ax 

+

649 

+

650def point_to_line_segment_distance(point, segment): 

+

651 ''' 

+

652 Determine the distance between a point and a line segment. Used to determine collisions between robot arm links and virtual obstacles. 

+

653 Adapted from http://stackoverflow.com/questions/849211/shortest-distance-between-a-point-and-a-line-segment 

+

654 ''' 

+

655 v, w = segment 

+

656 l2 = np.sum(np.abs(v - w)**2) 

+

657 if l2 == 0: 

+

658 return np.linalg.norm(v - point) 

+

659 

+

660 t = np.dot(point - v, w - v)/l2 

+

661 if t < 0: 

+

662 return np.linalg.norm(point - v) 

+

663 elif t > 1: 

+

664 return np.linalg.norm(point - w) 

+

665 else: 

+

666 projection = v + t*(w-v) 

+

667 return np.linalg.norm(projection - point) 

+

668 

+

669 

+

670class PlanarXZKinematicChain2Link(PlanarXZKinematicChain): 

+

671 ''' Docstring ''' 

+

672 def __init__(self, link_lengths, *args, **kwargs): 

+

673 ''' 

+

674 Docstring  

+

675  

+

676 Parameters 

+

677 ---------- 

+

678  

+

679 Returns 

+

680 ------- 

+

681 ''' 

+

682 if not len(link_lengths) == 2: 

+

683 raise ValueError("Can't instantiate a 2-link arm with > 2 links!") 

+

684 

+

685 super(PlanarXZKinematicChain2Link, self).__init__(link_lengths, *args, **kwargs) 

+

686 

+

687 def inverse_kinematics(self, pos, **kwargs): 

+

688 ''' 

+

689 Inverse kinematics for a two-link kinematic chain. These equations can be solved 

+

690 deterministically.  

+

691 

+

692 Docstring  

+

693  

+

694 Parameters 

+

695 ---------- 

+

696 pos : np.ndarray of shape (3,) 

+

697 Desired endpoint position where the coordinate system origin is the base of the arm. y coordinate must be 0 

+

698  

+

699 Returns 

+

700 ------- 

+

701 np.ndarray of shape (2,) 

+

702 Joint angles which yield the endpoint position with the forward kinematics of this manipulator 

+

703 ''' 

+

704 pos -= self.base_loc 

+

705 l_upperarm, l_forearm = self.link_lengths 

+

706 

+

707 if np.ndim(pos) == 1: 

+

708 pos = pos.reshape(1,-1) 

+

709 

+

710 # require the y-coordinate to be 0, i.e. flat on the screen 

+

711 x, y, z = pos[:,0], pos[:,1], pos[:,2] 

+

712 assert np.all(np.abs(np.array(y)) < 1e-10) 

+

713 

+

714 L = np.sqrt(x**2 + z**2) 

+

715 cos_el_pflex = (L**2 - l_forearm**2 - l_upperarm**2) / (2*l_forearm*l_upperarm) 

+

716 

+

717 cos_el_pflex[ (cos_el_pflex > 1) & (cos_el_pflex < 1 + 1e-9)] = 1 

+

718 el_pflex = np.arccos(cos_el_pflex) 

+

719 

+

720 sh_pabd = np.arctan2(z, x) - np.arcsin(l_forearm * np.sin(np.pi - el_pflex) / L) 

+

721 return np.array([sh_pabd, el_pflex]) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_sim_neurons_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_sim_neurons_py.html new file mode 100644 index 00000000..0e9689d4 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_sim_neurons_py.html @@ -0,0 +1,921 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\sim_neurons.py: 12% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1#!/usr/bin/python 

+

2""" 

+

3Classes to simulate neural activity (spike firing rates) by various methods. 

+

4""" 

+

5 

+

6import os 

+

7import numpy as np 

+

8 

+

9from scipy.io import loadmat 

+

10 

+

11import numpy as np 

+

12from numpy.random import poisson, rand 

+

13from scipy.io import loadmat, savemat 

+

14 

+

15 

+

16 

+

17from scipy.integrate import trapz, simps 

+

18 

+

19ts_dtype = [('ts', float), ('chan', np.int32), ('unit', np.int32)] 

+

20ts_dtype_new = [('ts', float), ('chan', np.int32), ('unit', np.int32), ('arrival_ts', np.float64)] 

+

21 

+

22############################ 

+

23##### Gaussian encoder ##### 

+

24############################ 

+

25class KalmanEncoder(object): 

+

26 ''' 

+

27 Models a BMI user as someone who, given an intended state x, 

+

28 generates a vector of neural features y according to the KF observation 

+

29 model equation: y = Cx + q. 

+

30 ''' 

+

31 def __init__(self, ssm, n_features, int_neural_features=False, scale_noise=1.): 

+

32 self.ssm = ssm 

+

33 self.n_features = n_features 

+

34 self.int_neural_features = int_neural_features 

+

35 self.scale_noise = scale_noise 

+

36 

+

37 drives_neurons = ssm.drives_obs 

+

38 nX = ssm.n_states 

+

39 

+

40 C = 3*np.random.standard_normal([n_features, nX]) 

+

41 C[:, ~drives_neurons] = 0 

+

42 Q = np.identity(n_features) 

+

43 

+

44 self.C = C 

+

45 self.Q = Q 

+

46 

+

47 def __call__(self, intended_state, **kwargs): 

+

48 q = np.random.multivariate_normal(np.zeros(self.Q.shape[0]), self.Q).reshape(-1, 1) 

+

49 neural_features = np.dot(self.C, intended_state.reshape(-1,1)) + self.scale_noise*q 

+

50 if self.int_neural_features: 

+

51 nf = np.round(neural_features) 

+

52 nf[nf <=0] = 0 

+

53 nf[nf > 10] = 10 

+

54 return nf 

+

55 else: 

+

56 return neural_features 

+

57 

+

58 def get_units(self): 

+

59 ''' 

+

60 Return fake indices corresponding to the simulated units, e.g., (1, 1) represents sig001a in the plexon system 

+

61 ''' 

+

62 return np.array([(k,1) for k in range(self.n_features)]) 

+

63 

+

64 

+

65########################### 

+

66##### Poisson encoder ##### 

+

67########################### 

+

68class GenericCosEnc(object): 

+

69 ''' 

+

70 Simulate neurons where the firing rate is a linear function of covariates and the rate parameter goes through a Poisson 

+

71 ''' 

+

72 def __init__(self, C, ssm, return_ts=False, DT=0.1, call_ds_rate=6): 

+

73 ''' 

+

74 Constructor for GenericCosEnc 

+

75 

+

76 Parameters 

+

77 ---------- 

+

78 C : np.ndarray of shape (N, K) 

+

79 N is the number of simulated neurons, K is the number of covariates driving neuronal activity.  

+

80 The product of C and the hidden state vector x should give the intended spike rates in Hz 

+

81 ssm : state_space_models.StateSpace instance 

+

82 ARG_DESCR 

+

83 return_ts : bool, optional, default=False 

+

84 If True, fake timestamps are returned for each spike event in the same format  

+

85 as real spike data would be delivered over the network during a real experiment.  

+

86 If False, a vector of counts is returned instead. Specify True or False depending on  

+

87 which type of feature extractor you're using for your simulated task.  

+

88 DT : float, optional, default=0.1 

+

89 Sampling interval to come up with new spike processes 

+

90 call_ds_rate : int, optional, default=6 

+

91 Calculating DT / call_ds_rate gives the interval between ticks of the main event loop 

+

92 

+

93 Returns 

+

94 ------- 

+

95 GenericCosEnc instance 

+

96 ''' 

+

97 self.C = C 

+

98 self.ssm = ssm 

+

99 self.n_neurons = C.shape[0] 

+

100 self.call_count = 0 

+

101 self.call_ds_rate = call_ds_rate 

+

102 self.return_ts = return_ts 

+

103 self.DT = DT 

+

104 self.unit_inds = np.arange(1, self.n_neurons+1) 

+

105 

+

106 #self.unit_inds = nlen(.hstack((np.array([np.arange(1, self.n_neurons+1)]).T, np.ones(self.n_neurons, 1))) 

+

107 

+

108 def get_units(self): 

+

109 ''' 

+

110 Retrieive the identities of the units in the encoder. Only used because units in real experiments have "names" 

+

111 ''' 

+

112 # Just pretend that each unit is the 'a' unit on a separate electrode 

+

113 return np.array([(ind, 1) for ind in self.unit_inds]) 

+

114 

+

115 def gen_spikes(self, next_state, mode=None): 

+

116 """ 

+

117 Simulate the spikes  

+

118  

+

119 Parameters 

+

120 ---------- 

+

121 next_state : np.array of shape (N, 1) 

+

122 The "next state" to be encoded by this population of neurons 

+

123  

+

124 Returns 

+

125 ------- 

+

126 time stamps or counts 

+

127 Either spike time stamps or a vector of unit spike counts is returned, depending on whether the 'return_ts' attribute is True 

+

128 

+

129 """ 

+

130 

+

131 rates = np.dot(self.C, next_state) 

+

132 return self.return_spikes(rates, mode=mode) 

+

133 

+

134 def return_spikes(self, rates, mode=None): 

+

135 rates[rates < 0] = 0 # Floor firing rates at 0 Hz 

+

136 counts = poisson(rates * self.DT) 

+

137 

+

138 if np.logical_or(mode=='ts', np.logical_and(mode is None, self.return_ts)): 

+

139 ts = [] 

+

140 n_neurons = self.n_neurons 

+

141 for k, ind in enumerate(self.unit_inds): 

+

142 # separate spike counts into individual time-stamps 

+

143 n_spikes = int(counts[k]) 

+

144 fake_time = (self.call_count + 0.5)* 1./60 

+

145 if n_spikes > 0: 

+

146 #spike_data = [(fake_time, int(ind/4)+1, ind % 4) for m in range(n_spikes)]  

+

147 spike_data = [(fake_time, ind, 1) for m in range(n_spikes)] 

+

148 ts += (spike_data) 

+

149 

+

150 ts = np.array(ts, dtype=ts_dtype) 

+

151 return ts 

+

152 

+

153 elif np.logical_or(mode=='counts', np.logical_and(mode is None, self.return_ts is False)): 

+

154 return counts 

+

155 

+

156 def __call__(self, next_state, mode=None): 

+

157 ''' 

+

158 See CosEnc.__call__ for docs 

+

159 ''' 

+

160 if self.call_count % self.call_ds_rate == 0: 

+

161 ts_data = self.gen_spikes(next_state, mode=mode) 

+

162 

+

163 else: 

+

164 if self.return_ts: 

+

165 # return an empty list of time stamps 

+

166 ts_data = np.array([]) 

+

167 else: 

+

168 # return a vector of 0's 

+

169 ts_data = np.zeros(self.n_neurons) 

+

170 

+

171 self.call_count += 1 

+

172 return ts_data 

+

173 

+

174 

+

175class FACosEnc(GenericCosEnc): 

+

176 ''' 

+

177 Simulate neurons where rate is linear function of underlying factor modulation, rate param through Poisson 

+

178 ''' 

+

179 def __init__(self, C, ssm, DT=0.1, call_ds_rate=6,return_ts=False,max_FR=20, **kwargs): 

+

180 

+

181 super(FACosEnc, self).__init__(C, ssm, DT=DT, call_ds_rate=call_ds_rate, return_ts=return_ts) 

+

182 

+

183 #self.input_type = ['priv_unt', 'priv_tun', 'shar_unt', 'shar_tun'] 

+

184 

+

185 #Parse kwargs:  

+

186 self.n_neurons = kwargs.pop('n_neurons', self.n_neurons) 

+

187 self.unit_inds = np.arange(1, self.n_neurons+1) 

+

188 self.lambda_spk_update = 1 

+

189 self.wt_sources = kwargs.pop('wt_sources', None) 

+

190 self.n_facts = kwargs.pop('n_facts', [3, 3]) 

+

191 self.state_vel_var = kwargs.pop('state_vel_var', 2.7) 

+

192 self.r2 = 1./self.state_vel_var 

+

193 

+

194 # Establish number factors for tuned / untuned input sources:  

+

195 self.n_tun_factors = self.n_facts[0] #1 

+

196 self.n_unt_factors = self.n_facts[1] 

+

197 

+

198 self.eps = 1e-15 

+

199 

+

200 if 'psi_tun' in kwargs: 

+

201 print('using kwargs psi tun') 

+

202 self.psi_tun = kwargs['psi_tun'] 

+

203 self.psi_unt_std = kwargs['psi_unt_std'] 

+

204 

+

205 

+

206 else: 

+

207 #Establish mapping from kinematics to factors:  

+

208 self.psi_unt = np.zeros((self.n_neurons, 1)) #517 

+

209 self.psi_unt_std = np.sqrt(7.) + np.zeros((self.n_neurons, )) 

+

210 

+

211 #Matched to fit KF data -- unit vectors:  

+

212 self.psi_tun = np.random.normal(0, 1, (self.n_neurons, ssm.n_states)) 

+

213 self.psi_tun[:, [0, 1, 2, 4, 6]] = 0 

+

214 self.psi_tun = self.psi_tun / np.tile(np.linalg.norm(self.psi_tun, axis=1)[:, np.newaxis], [1, ssm.n_states]) 

+

215 self.psi_tun = self.psi_tun/np.sqrt(2) #Due to 2 active states contributing to tuning 

+

216 

+

217 

+

218 self.v_ = 2*(np.random.random_sample(self.n_tun_factors)-0.5) 

+

219 

+

220 self.V = np.zeros((self.n_tun_factors, ssm.n_states)) 

+

221 self.V[:,3] = self.v_ 

+

222 self.V[:,5] = 2.*(np.random.random_sample((self.n_tun_factors))-0.5) 

+

223 

+

224 self.V = self.V / np.tile(np.linalg.norm(self.V, axis=1)[:, np.newaxis], [1, ssm.n_states]) 

+

225 self.V = self.V / np.sqrt(2.) #Due to 2 active states contributing to tuning 

+

226 

+

227 self.U = 2.*(np.random.random_sample((self.n_neurons, self.n_tun_factors))-0.5) 

+

228 self.U = self.U / np.tile(np.linalg.norm(self.U, axis=1)[:, np.newaxis], [1, self.n_tun_factors]) 

+

229 

+

230 self.W = 2.*(np.random.random_sample((self.n_neurons, self.n_unt_factors))-0.5) #517 

+

231 self.W = self.W / np.tile(np.linalg.norm(self.W, axis=1)[:, np.newaxis], [1, self.n_unt_factors]) 

+

232 self.W = self.W / np.sqrt(2) 

+

233 

+

234 #REMEMBER -- MEAN IS FOR 0.1 SEC, so 20/10:  

+

235 #self.mu = 2*(np.random.random_sample((self.n_neurons, ))+1) 

+

236 self.mu = np.random.exponential(.75, size=(self.n_neurons, )) 

+

237 

+

238 #Results prior to 7-30-16 used a mean of 2spks / 100 ms:  

+

239 #self.mu = 2*(np.random.random_sample((self.n_neurons, ))+1) 

+

240 self.mu = np.random.exponential(.75, size=(self.n_neurons, )) 

+

241 self.bin_step_count = -1 

+

242 

+

243 def _gen_state(self): 

+

244 s = np.random.normal(0, 7, (7, 1)) 

+

245 s[[1, 4], :] = 0 

+

246 s[-1, 0] = 1 

+

247 return s 

+

248 

+

249 def gen_spikes(self, next_state, mode=None): 

+

250 self.ns_pk = next_state 

+

251 

+

252 self.priv_tun_bins = np.random.poisson(self.lambda_spk_update, self.n_neurons) 

+

253 self.priv_unt_bins = np.random.poisson(1, self.n_neurons) 

+

254 

+

255 self.shar_tun_bins = np.random.poisson(1, self.n_tun_factors, ) 

+

256 self.shar_unt_bins = np.random.poisson(1, self.n_unt_factors, ) 

+

257 

+

258 if len(next_state.shape) == 1: 

+

259 next_state = np.array([next_state]).T 

+

260 

+

261 # Private: 

+

262 priv_unt = [] 

+

263 priv_tun = [] 

+

264 

+

265 for n in range(self.n_neurons): 

+

266 

+

267 #Private untuned: 

+

268 #If Poisson draw = True:  

+

269 if self.priv_unt_bins[n] > 0: 

+

270 cnt = [] 

+

271 for z in range(self.priv_unt_bins[n]): 

+

272 #psi_unt = np.max([np.random.normal(self.psi_unt[n], self.psi_unt_std), 0]) 

+

273 psi_unt = np.random.normal(0, self.psi_unt_std[n]) #517 

+

274 cnt.append(psi_unt) 

+

275 priv_unt.append(np.sum(cnt)) 

+

276 else: 

+

277 priv_unt.append(0.) 

+

278 

+

279 #Private tuned:  

+

280 if self.priv_tun_bins[n] > 0: 

+

281 cnt = [] 

+

282 for z in range(self.priv_tun_bins[n]): 

+

283 #psi_tun = np.max([0, np.dot(self.psi_tun[n, :], next_state)]) 

+

284 psi_tun = np.dot(self.psi_tun[n, :], next_state) #517 

+

285 cnt.append(psi_tun) 

+

286 priv_tun.append(np.sum(cnt)) 

+

287 

+

288 else: 

+

289 priv_tun.append(0.) 

+

290 

+

291 self.priv_tun = np.hstack((priv_tun)) 

+

292 self.priv_unt = np.hstack((priv_unt)) 

+

293 

+

294 #Shared tuned:  

+

295 t_tun = np.zeros((self.n_neurons,)) 

+

296 for zi in range(self.n_tun_factors): 

+

297 if self.shar_tun_bins[zi] > 0: 

+

298 for z in range(self.shar_tun_bins[zi]): 

+

299 #print next_state.shape, self.U.shape, self.V.shape, zi, type(self.U), type(self.V), type(next_state) 

+

300 ns = np.array(next_state) 

+

301 if len(ns.shape) < 2: 

+

302 ns = ns[:, np.newaxis] 

+

303 #print 'ns: ', ns.shape, type(ns) 

+

304 #tmp2 = self.U[:,zi]*np.dot(self.V[zi,:], ns) 

+

305 tmp2 = self.U[:,zi]*np.dot(self.V[zi,:], ns) 

+

306 #print 'tmp2: ', tmp2.shape, t_tun.shape 

+

307 t_tun += tmp2 

+

308 #np.dot(self.U[:, zi], np.dot(self.V[zi, :] , next_state)) 

+

309 

+

310 self.shar_tun = t_tun 

+

311 

+

312 #Shared Untuned 

+

313 self.unt_fact = np.random.normal(0, np.sqrt(7.), (self.n_unt_factors, )) 

+

314 t_unt = np.zeros((self.n_neurons,)) 

+

315 for zi in range(self.n_unt_factors): #517 

+

316 if self.shar_unt_bins[zi] > 0: 

+

317 for z in range(self.shar_unt_bins[zi]): 

+

318 t_unt += self.W[:, zi] * self.unt_fact[zi] 

+

319 

+

320 self.shar_unt = t_unt 

+

321 

+

322 #Now weight everything together: 

+

323 if self.wt_sources is None: # if mp wt_sources, equally weight the sources 

+

324 w = np.array([1,1,1,1]) / 4 

+

325 else: 

+

326 w = self.wt_sources 

+

327 

+

328 

+

329 

+

330 counts = np.squeeze(np.array(w[0]*self.priv_unt + w[1]*self.priv_tun + w[2]*self.shar_unt + w[3]*self.shar_tun)) 

+

331 

+

332 #Adding back the mean FR 

+

333 counts += self.mu 

+

334 

+

335 if np.logical_or(mode=='ts', np.logical_and(mode is None, self.return_ts)): 

+

336 ts = [] 

+

337 n_neurons = self.n_neurons 

+

338 for k, ind in enumerate(self.unit_inds): 

+

339 # separate spike counts into individual time-stamps 

+

340 n_spikes = int(counts[k]) 

+

341 fake_time = (self.call_count + 0.5)* 1./60 

+

342 if n_spikes > 0: 

+

343 #spike_data = [(fake_time, int(ind/4)+1, ind % 4) for m in range(n_spikes)]  

+

344 spike_data = [(fake_time, ind, 1) for m in range(n_spikes)] 

+

345 ts += (spike_data) 

+

346 

+

347 ts = np.array(ts, dtype=ts_dtype) 

+

348 return ts 

+

349 

+

350 elif np.logical_or(mode=='counts', np.logical_and(mode is None, self.return_ts is False)): 

+

351 return counts 

+

352 

+

353 def mod_poisson(self, x, dt=0.1): 

+

354 x[x<0] = 0 

+

355 return poisson(x*dt) 

+

356 

+

357 def y2_eq_r2_min_x2(self, x_arr, r2): 

+

358 y = [] 

+

359 for x in x_arr: 

+

360 if np.random.random_sample() > 0.5: 

+

361 y.append(np.sqrt(r2 - x**2)) 

+

362 else: 

+

363 y.append(-1*np.sqrt(r2 - x**2)) 

+

364 return np.array(y) 

+

365 

+

366class NormalizedCosEnc(GenericCosEnc): 

+

367 

+

368 def __init__(self, bounds, *args, **kwargs): 

+

369 self.min = np.array([bounds[0], bounds[2], bounds[4]]) 

+

370 self.range = np.array([bounds[1] - bounds[0], bounds[3] - bounds[2], bounds[5] - bounds[4]]) 

+

371 self.range[self.range == 0] = 1 

+

372 self.gain = 100 

+

373 super(NormalizedCosEnc, self).__init__(*args, **kwargs) 

+

374 

+

375 def gen_spikes(self, next_state, mode=None): 

+

376 """ 

+

377 Simulate the spikes  

+

378  

+

379 Parameters 

+

380 ---------- 

+

381 next_state : np.array of shape (N, 1) 

+

382 The "next state" to be encoded by this population of neurons 

+

383  

+

384 Returns 

+

385 ------- 

+

386 time stamps or counts 

+

387 Either spike time stamps or a vector of unit spike counts is returned, depending on whether the 'return_ts' attribute is True 

+

388 

+

389 """ 

+

390 norm_state = np.divide(np.subtract(np.squeeze(next_state), self.min), self.range) 

+

391 rates = np.dot(self.C, norm_state) * self.gain 

+

392 return self.return_spikes(rates, mode=mode) 

+

393 

+

394def from_file_to_FACosEnc(plot=False): 

+

395 from riglib.bmi import state_space_models as ssm 

+

396 import pickle 

+

397 import os 

+

398 import matplotlib.pyplot as plt 

+

399 

+

400 dat = pickle.load(open(os.path.expandvars('/home/lab/preeya/fa_analysis/grom_data/co_obs_SNR_w_coefficients.pkl'))) 

+

401 SSM = ssm.StateSpaceEndptVel2D() 

+

402 

+

403 snr = {} 

+

404 eps = 10**-10 

+

405 

+

406 if plot: 

+

407 f, ax = plt.subplots(nrows=3, ncols=3) 

+

408 

+

409 for j, i in enumerate(np.sort(list(dat.keys()))): 

+

410 

+

411 snr[i] = [] 

+

412 d = dat[i] 

+

413 kwargs = {} 

+

414 kwargs['n_neurons'] = len(list(d.keys())) 

+

415 C = np.random.rand(kwargs['n_neurons'], SSM.n_states) 

+

416 kwargs['wt_sources'] = [1, 1, 0, 0] 

+

417 enc = FACosEnc(C, SSM, return_ts =True, **kwargs) 

+

418 

+

419 for n in range(len(list(d.keys()))): 

+

420 #For individual units:  

+

421 enc.psi_tun[n, [3, 5, 6]] = d[n][3][0, :] #Terrible construction.  

+

422 enc.mu[n] = 0 

+

423 #Now set the standard deviation: Draw from VFB distribution of commands  

+

424 

+

425 data, enc = sim_enc(enc) 

+

426 U = np.vstack((data['unt'])) 

+

427 T = np.vstack((data['tun'])) 

+

428 spk = U+T 

+

429 vel = np.hstack((data['ctl']))[[3, 5], :].T 

+

430 vel = np.hstack((np.array(vel), np.ones((len(vel), 1)))) 

+

431 

+

432 # Fit encoder:  

+

433 n_units = spk.shape[1] 

+

434 snr_act = [] 

+

435 for n in range(n_units): 

+

436 snr_des = d[n][2] 

+

437 if np.isnan(snr_des): 

+

438 snr_des = .3 

+

439 print('sucdess') 

+

440 snr_act.append(snr_des) 

+

441 s2 = spk[:, n] #Spikes:  

+

442 x = np.linalg.lstsq(vel , s2[:, np.newaxis]) #Regress Spikes against Velocities 

+

443 qnoise = np.var(s2[:, np.newaxis] - vel*np.mat(x[0])) #Residuals 

+

444 #Explained Variance vs. Residual Variance:  

+

445 qsig = np.var(vel*np.mat(x[0])) 

+

446 k = qsig/(snr_des) 

+

447 enc.psi_unt_std[n] = np.sqrt(k)# + eps 

+

448 

+

449 #Fit simulation:  

+

450 data, enc = sim_enc(enc) 

+

451 U = np.vstack((data['unt'])) 

+

452 T = np.vstack((data['tun'])) 

+

453 spk = U+T 

+

454 vel = np.hstack((data['ctl']))[[3, 5], :].T 

+

455 vel = np.hstack((np.array(vel), np.ones((len(vel), 1)))) 

+

456 snr_sim = [] 

+

457 for n in range(n_units): 

+

458 s2 = spk[:, n] #Spikes:  

+

459 x = np.linalg.lstsq(vel, s2[:, np.newaxis]) #Regress Spikes against Velocities 

+

460 qnoise = np.var(s2[:, np.newaxis] - vel*np.mat(x[0])) #Residuals 

+

461 #Explained Variance vs. Residual Variance:  

+

462 qsig = np.var(vel*np.mat(x[0])) 

+

463 snr_sim.append(qsig/qnoise) 

+

464 

+

465 if plot: 

+

466 axi = ax[j/3, j%3] 

+

467 axi.plot(snr_sim, snr_act, '.') 

+

468 axi.set_title(list(dat.keys())[j]) 

+

469 

+

470 #kwargs['psi_unt_std'] = psi_unt_std 

+

471 #kwargs['psi_tun'] = psi_tun 

+

472 pickle.dump(enc, open(os.path.expandvars('$FA_GROM_DATA/sims/test_obs_vs_co_overlap/encoder_param_matched_'+str(i)+'.pkl'), 'wb')) 

+

473 

+

474def test_sim_enc(): 

+

475 int_enc_names = [26, 28, 31, 39, 40, 41, 42, 43, 44] 

+

476 import pickle, os 

+

477 import matplotlib.pyplot as plt 

+

478 real_data = pickle.load(open(os.path.expandvars('/home/lab/preeya/fa_analysis/grom_data/co_obs_SNR_w_coefficients.pkl'))) 

+

479 f, ax = plt.subplots(nrows=3, ncols=3) 

+

480 for ie, e in enumerate(int_enc_names): 

+

481 axi = ax[ie/3, ie%3] 

+

482 match_data = real_data[e] 

+

483 enc = pickle.load(open(os.path.expandvars('$FA_GROM_DATA/sims/test_obs_vs_co_overlap/encoder_param_matched_'+str(e)+'.pkl'))) 

+

484 data, enc = sim_enc(enc) 

+

485 U = np.vstack((data['unt'])) 

+

486 T = np.vstack((data['tun'])) 

+

487 

+

488 spk = U+T 

+

489 vel = np.hstack((data['ctl']))[[3, 5], :].T 

+

490 n_units = spk.shape[1] 

+

491 for n in range(n_units): 

+

492 s2 = spk[:, n] #Spikes:  

+

493 x = np.linalg.lstsq(s2[:, np.newaxis], vel) #Regress Spikes against Velocities 

+

494 q = s2[:, np.newaxis] - vel*np.mat(x[0].T) #Residuals 

+

495 #Explained Variance vs. Residual Variance:  

+

496 ev = np.var(vel*np.mat(x[0].T)) 

+

497 snr = ev/np.var(q) 

+

498 axi.plot(snr, match_data[n][2], '.') 

+

499 axi.set_xlim([0, 1.5]) 

+

500 axi.set_ylim([0, 1.5]) 

+

501 

+

502def sim_enc(enc): 

+

503 enc.call_ds_rate = 1 

+

504 from tasks import sim_fa_decoding 

+

505 dat = {} 

+

506 dat['tun'] = [] 

+

507 dat['unt'] = [] 

+

508 dat['ctl'] = [] 

+

509 

+

510 assister = sim_fa_decoding.SuperSimpleEndPtAssister() 

+

511 

+

512 for it in range(2000): 

+

513 current_state = (np.random.rand(7, 1) - 0.5)*30 

+

514 current_state[[1, 4], 0] = 0 

+

515 target_state = np.zeros((7, 1)) 

+

516 ang = (np.random.permutation(8)[0])/8*np.pi*2 

+

517 target_state[[0, 2], 0] = np.array([10*np.cos(ang), 10*np.sin(ang)]) 

+

518 

+

519 ctrl = assister.calc_next_state(current_state, target_state) 

+

520 ts_data = enc(ctrl) 

+

521 dat['ctl'].append(ctrl) 

+

522 dat['tun'].append(enc.priv_tun) 

+

523 dat['unt'].append(enc.priv_unt) 

+

524 enc.call_ds_rate = 6 

+

525 return dat, enc 

+

526 

+

527 

+

528def make_FACosEnc(num): 

+

529 from riglib.bmi import state_space_models as ssm 

+

530 import pickle 

+

531 num_neurons = 20; 

+

532 SSM = ssm.StateSpaceEndptVel2D() 

+

533 

+

534 for n in range(num): 

+

535 kwargs = {} 

+

536 kwargs['n_neurons'] = num_neurons 

+

537 C = np.random.rand(num_neurons, SSM.n_states) 

+

538 enc = FACosEnc(C, SSM, return_ts=True, **kwargs) 

+

539 enc.psi_unt_std[:4] /= 40 

+

540 pickle.dump(enc, open('/storage/preeya/grom_data/sims/test_obs_vs_co_overlap/encoder_param_matched_'+str(n)+'.pkl', 'wb')) 

+

541 

+

542 

+

543 

+

544class CursorVelCosEnc(GenericCosEnc): 

+

545 ''' 

+

546 Cosine encoder tuned to the X-Z velocity of a cursor. Corresponds to the StateSpaceEndptVel2D state-space model 

+

547 ''' 

+

548 def __init__(self, n_neurons=25, mod_depth=14./0.2, baselines=10, **kwargs): 

+

549 C = np.zeros([n_neurons, 7]) 

+

550 C[:,-1] = baselines 

+

551 

+

552 angles = np.linspace(0, 2 * np.pi, n_neurons) 

+

553 C[:,3] = mod_depth * np.cos(angles) 

+

554 C[:,5] = mod_depth * np.sin(angles) 

+

555 

+

556 ssm = None 

+

557 super(CLDASimCosEnc, self).__init__(C, ssm=None, **kwargs) 

+

558 

+

559 

+

560 

+

561################################# 

+

562##### Point-process encoder ##### 

+

563################################# 

+

564class PointProcess(object): 

+

565 ''' 

+

566 Simulate a single point process. Implemented by Suraj Gowda and Maryam Shanechi. 

+

567 ''' 

+

568 def __init__(self, beta, dt, tau_samples=[], K=0, eps=1e-3): 

+

569 ''' 

+

570 Docstring  

+

571  

+

572 Parameters 

+

573 ---------- 

+

574  

+

575 Returns 

+

576 ------- 

+

577 ''' 

+

578 self.beta = beta.reshape(-1, 1) 

+

579 self.dt = dt 

+

580 self.tau_samples = tau_samples 

+

581 self.K = K 

+

582 self.eps = eps 

+

583 self.i = 0 

+

584 self.j = self.i + 1 

+

585 self.X = np.zeros([0, len(beta)]) 

+

586 self._reset_res() 

+

587 self.tau = np.inf 

+

588 self.rate = np.nan 

+

589 

+

590 def _exp_sample(self): 

+

591 ''' 

+

592 Docstring  

+

593  

+

594 Parameters 

+

595 ---------- 

+

596  

+

597 Returns 

+

598 ------- 

+

599 ''' 

+

600 if len(self.tau_samples) > 0: 

+

601 self.tau = self.tau_samples.pop(0) 

+

602 else: 

+

603 u = np.random.rand() 

+

604 self.tau = np.log(1 - u); 

+

605 

+

606 def _reset_res(self): 

+

607 ''' 

+

608 Docstring  

+

609  

+

610 Parameters 

+

611 ---------- 

+

612  

+

613 Returns 

+

614 ------- 

+

615 ''' 

+

616 self.resold = 1000 

+

617 self.resnew = np.nan 

+

618 

+

619 def _integrate_rate(self): 

+

620 ''' 

+

621 Docstring  

+

622  

+

623 Parameters 

+

624 ---------- 

+

625  

+

626 Returns 

+

627 ------- 

+

628 ''' 

+

629 # integrate rate 

+

630 loglambda = np.dot(self.X[self.last_spike_ind:self.j+1, :], self.beta) #log of lambda delta 

+

631 # import pdb; pdb.set_trace() 

+

632 self.rate = np.ravel(np.exp(loglambda)/self.dt) 

+

633 

+

634 if len(self.rate) > 2: 

+

635 self.resnew = self.tau + simps(self.rate, dx=self.dt, even='first') 

+

636 else: 

+

637 self.resnew = self.tau + trapz(self.rate, dx=self.dt) 

+

638 

+

639 def _decide(self): 

+

640 ''' 

+

641 Docstring  

+

642  

+

643 Parameters 

+

644 ---------- 

+

645  

+

646 Returns 

+

647 ------- 

+

648 ''' 

+

649 if (self.resold > 0) and (self.resnew > self.resold): 

+

650 return True 

+

651 else: 

+

652 #self.j = self.j + 1; 

+

653 self.resold = self.resnew; 

+

654 return False 

+

655 

+

656 def _push(self, x_t): 

+

657 ''' 

+

658 Docstring  

+

659  

+

660 Parameters 

+

661 ---------- 

+

662  

+

663 Returns 

+

664 ------- 

+

665 ''' 

+

666 self.X = np.vstack([self.X, x_t]) 

+

667 

+

668 def __call__(self, x_t): 

+

669 ''' 

+

670 Simulate whether the cell should fire at time t based on new stimulus x_t and previous stimuli (saved) 

+

671  

+

672 Parameters 

+

673 ---------- 

+

674 x_t : np.ndarray of size (N,) 

+

675 Current stimulus that the firing rate of the cell depends on. 

+

676 N should match the  

+

677  

+

678 Returns 

+

679 ------- 

+

680 spiking_bin : bool 

+

681 True or false depending on whether the cell has fired after the present stimulus. 

+

682 ''' 

+

683 self._push(x_t) 

+

684 if np.abs(self.resold) < self.eps: 

+

685 spiking_bin = True 

+

686 else: 

+

687 self._integrate_rate() 

+

688 spiking_bin = self._decide() 

+

689 

+

690 # Handle the spike 

+

691 if spiking_bin: 

+

692 self.last_spike_ind = self.j - 1 

+

693 self._reset_res() 

+

694 self._exp_sample() 

+

695 self._integrate_rate() 

+

696 self.resold = self.resnew; 

+

697 

+

698 self.j += 1 

+

699 return spiking_bin 

+

700 

+

701 def _init_sampling(self, x_t): 

+

702 ''' 

+

703 Docstring  

+

704  

+

705 Parameters 

+

706 ---------- 

+

707  

+

708 Returns 

+

709 ------- 

+

710 ''' 

+

711 self._push(x_t) # initialize the observed extrinsic covariates 

+

712 self._reset_res() 

+

713 self._exp_sample() 

+

714 self.j = 1 

+

715 self.last_spike_ind = 0 # initialization 

+

716 

+

717 def sim_batch(self, X, verbose=False): 

+

718 ''' 

+

719 Docstring  

+

720  

+

721 Parameters 

+

722 ---------- 

+

723  

+

724 Returns 

+

725 ------- 

+

726 ''' 

+

727 framelength = X.shape[0] 

+

728 spikes = np.zeros(framelength); 

+

729 

+

730 self._init_sampling(X[0,:]) 

+

731 

+

732 while self.j < framelength: 

+

733 #spiking_bin = self(X[self.j, :]) 

+

734 spiking_bin = self.__call__(X[self.j, :]) 

+

735 if self.j < framelength and spiking_bin: 

+

736 spikes[self.last_spike_ind] = 1; 

+

737 

+

738 return spikes 

+

739 

+

740class PointProcessEnsemble(object): 

+

741 ''' 

+

742 Simulate an ensemble of point processes 

+

743 ''' 

+

744 def __init__(self, beta, dt, init_state=None, tau_samples=None, eps=1e-3, 

+

745 hist_len=0, units=None): 

+

746 ''' 

+

747 Constructor for PointProcessEnsemble 

+

748  

+

749 Docstring  

+

750  

+

751 Parameters 

+

752 ---------- 

+

753 beta : np.array of shape (n_units, n_covariates) 

+

754 Each row of the matrix specifies the relationship between a single point process in the ensemble and the common "stimuli" 

+

755 dt : float 

+

756 Sampling interval to integrate piont process likelihood over 

+

757 init_state : np.array, optional, default=[np.zeros(n_covariates-1), 1] 

+

758 Initial state of the common stimuli 

+

759 tau_samples : np.iterable, optional, default=None 

+

760 ARG_DESCR 

+

761 eps : DATA_TYPE, optional, default=0.001 

+

762 ARG_DESCR 

+

763 hist_len : DATA_TYPE, optional, default=0 

+

764 ARG_DESCR 

+

765 units : list of tuples, optional, default=None 

+

766 Identifiers for each element of the ensemble. One is automatically generated if none is provided 

+

767  

+

768 Returns 

+

769 ------- 

+

770 PointProcessEnsemble instance 

+

771  

+

772 ''' 

+

773 self.n_neurons, n_covariates = beta.shape 

+

774 if init_state == None: 

+

775 init_state = np.hstack([np.zeros(n_covariates - 1), 1]) 

+

776 if tau_samples == None: 

+

777 tau_samples = [[]]*self.n_neurons 

+

778 point_process_units = [] 

+

779 

+

780 self.beta = beta 

+

781 

+

782 for k in range(self.n_neurons): 

+

783 point_proc = PointProcess(beta[k,:], dt, tau_samples=tau_samples[k]) 

+

784 point_proc._init_sampling(init_state) 

+

785 point_process_units.append(point_proc) 

+

786 

+

787 self.point_process_units = point_process_units 

+

788 

+

789 if units == None: 

+

790 self.units = np.vstack([(x, 1) for x in range(self.n_neurons)]) 

+

791 else: 

+

792 self.units = units 

+

793 

+

794 def get_units(self): 

+

795 ''' 

+

796 Docstring  

+

797  

+

798 Parameters 

+

799 ---------- 

+

800  

+

801 Returns 

+

802 ------- 

+

803 ''' 

+

804 return self.units 

+

805 

+

806 def __call__(self, x_t): 

+

807 ''' 

+

808 Docstring  

+

809  

+

810 Parameters 

+

811 ---------- 

+

812  

+

813 Returns 

+

814 ------- 

+

815 ''' 

+

816 

+

817 # x_t = np.hstack([x_t, 1]) 

+

818 x_t = np.array(x_t).ravel() 

+

819 counts = np.array([unit(x_t) for unit in self.point_process_units]).astype(int) 

+

820 return counts 

+

821 

+

822class CLDASimPointProcessEnsemble(PointProcessEnsemble): 

+

823 ''' 

+

824 PointProcessEnsemble intended to be called at 60 Hz and return simulated 

+

825 spike timestamps at 180 Hz 

+

826 ''' 

+

827 def __init__(self, *args, **kwargs): 

+

828 ''' 

+

829 see PointProcessEnsemble.__init__ 

+

830 ''' 

+

831 super(CLDASimPointProcessEnsemble, self).__init__(*args, **kwargs) 

+

832 self.call_count = -1 

+

833 

+

834 def __call__(self, x_t): 

+

835 ''' 

+

836 Ensemble is called at 60 Hz but expects the timestamps to reflect spike 

+

837 bins determined at 180 Hz 

+

838 

+

839 Parameters 

+

840 ---------- 

+

841 x_t : np.ndarray 

+

842 

+

843  

+

844 Returns 

+

845 ------- 

+

846 ''' 

+

847 ts_data = [] 

+

848 for k in range(3): 

+

849 counts = super(CLDASimPointProcessEnsemble, self).__call__(x_t) 

+

850 nonzero_units, = np.nonzero(counts) 

+

851 fake_time = self.call_count * 1./60 + (k + 0.5)*1./180 

+

852 for unit_ind in nonzero_units: 

+

853 ts = (fake_time, self.units[unit_ind, 0], self.units[unit_ind, 1], fake_time) 

+

854 ts_data.append(ts) 

+

855 

+

856 self.call_count += 1 

+

857 return np.array(ts_data, dtype=ts_dtype_new) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_sskfdecoder_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_sskfdecoder_py.html new file mode 100644 index 00000000..15c26c07 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_sskfdecoder_py.html @@ -0,0 +1,241 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\sskfdecoder.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Classes for BMI decoding using the Steady-state Kalman filter. Based 

+

3heavily on the kfdecoder module.  

+

4''' 

+

5import numpy as np 

+

6from scipy.io import loadmat 

+

7 

+

8from . import bmi 

+

9from . import train 

+

10import pickle 

+

11 

+

12from . import kfdecoder 

+

13 

+

14class SteadyStateKalmanFilter(bmi.GaussianStateHMM): 

+

15 """ 

+

16 Low-level KF in steady-state 

+

17 

+

18 Model:  

+

19 x_{t+1} = Ax_t + w_t; w_t ~ N(0, W) 

+

20 y_t = Cx_t + q_t; q_t ~ N(0, Q) 

+

21 """ 

+

22 model_attrs = ['F', 'K'] 

+

23 

+

24 def __init__(self, *args, **kwargs): 

+

25 '''Docstring 

+

26 

+

27 Parameters 

+

28 ---------- 

+

29 

+

30 Returns 

+

31 ------- 

+

32 

+

33 ''' 

+

34 if len(list(kwargs.keys())) == 0: 

+

35 ## This condition should only be true in the unpickling phase 

+

36 pass 

+

37 else: 

+

38 if 'A' in kwargs and 'C' in kwargs and 'W' in kwargs and 'Q' in kwargs: 

+

39 A = kwargs.pop('A') 

+

40 C = kwargs.pop('C') 

+

41 W = kwargs.pop('W') 

+

42 Q = kwargs.pop('Q') 

+

43 kf = kfdecoder.KalmanFilter(A, W, C, Q, **kwargs) 

+

44 F, K = kf.get_sskf() 

+

45 elif 'F' in kwargs and 'K' in kwargs: 

+

46 F = kwargs['F'] 

+

47 K = kwargs['K'] 

+

48 self.F = F 

+

49 self.K = K 

+

50 self._pickle_init() 

+

51 

+

52 

+

53 def _init_state(self, init_state=None, init_cov=None): 

+

54 """  

+

55 Initialize the state of the filter with a mean and covariance (uncertainty) 

+

56 Docstring 

+

57 

+

58 Parameters 

+

59 ---------- 

+

60 

+

61 Returns 

+

62 ------- 

+

63 """ 

+

64 ## Initialize the BMI state, assuming  

+

65 nS = self.n_states 

+

66 if init_state == None: 

+

67 init_state = np.mat( np.zeros([nS, 1]) ) 

+

68 if self.include_offset: init_state[-1,0] = 1 

+

69 if init_cov == None: 

+

70 init_cov = np.mat( np.zeros([nS, nS]) ) 

+

71 self.state = bmi.GaussianState(init_state, init_cov) 

+

72 

+

73 def _pickle_init(self): 

+

74 '''Docstring 

+

75 

+

76 Parameters 

+

77 ---------- 

+

78 

+

79 Returns 

+

80 ------- 

+

81 ''' 

+

82 nS = self.F.shape[0] 

+

83 self.I = np.mat(np.eye(nS)) 

+

84 

+

85 def get_sskf(self): 

+

86 ''' 

+

87 Docstring 

+

88 

+

89 Parameters 

+

90 ---------- 

+

91 

+

92 Returns 

+

93 ------- 

+

94 ''' 

+

95 return self.F, self.K 

+

96 

+

97 def _forward_infer(self, st, obs_t, Bu=None, u=None, target_state=None, 

+

98 obs_is_control_independent=True, bias_comp=False, **kwargs): 

+

99 ''' 

+

100 Estimate p(x_t | ..., y_{t-1}, y_t) 

+

101 Docstring 

+

102 

+

103 Parameters 

+

104 ---------- 

+

105 

+

106 Returns 

+

107 ------- 

+

108 ''' 

+

109 F, K = self.F, self.K 

+

110 if Bu is not None: 

+

111 post_state_mean = F*st.mean + K*obs_t + Bu 

+

112 else: 

+

113 post_state_mean = F*st.mean + K*obs_t 

+

114 

+

115 I = self.I 

+

116 post_state = I*st # Force memory reallocation for the Gaussian 

+

117 post_state.mean = post_state_mean 

+

118 return post_state 

+

119 

+

120 def __getstate__(self): 

+

121 ''' 

+

122 Pickle only the F and the K matrices 

+

123 Docstring 

+

124 

+

125 Parameters 

+

126 ---------- 

+

127 

+

128 Returns 

+

129 ------- 

+

130 ''' 

+

131 return dict(F=self.F, K=self.K) 

+

132 

+

133 @property 

+

134 def n_states(self): 

+

135 ''' 

+

136 Docstring 

+

137 

+

138 Parameters 

+

139 ---------- 

+

140 

+

141 Returns 

+

142 ------- 

+

143 ''' 

+

144 return self.F.shape[0] 

+

145 

+

146 @property 

+

147 def include_offset(self): 

+

148 ''' 

+

149 Docstring 

+

150 

+

151 Parameters 

+

152 ---------- 

+

153 

+

154 Returns 

+

155 ------- 

+

156 ''' 

+

157 return np.all(self.F[-1, :-1] == 0) and (self.F[-1, -1] == 1) 

+

158 

+

159 def get_K_null(self): 

+

160 ''' 

+

161 $$y_{null} = K_{null} * y_t$$ gives the "null" component of the spike inputs, i.e. $$K_t*y_{null} = 0_{N\times 1}$$ 

+

162 Docstring 

+

163 

+

164 Parameters 

+

165 ---------- 

+

166 

+

167 Returns 

+

168 ------- 

+

169 ''' 

+

170 K = np.mat(self.K) 

+

171 n_neurons = K.shape[1] 

+

172 K_null = np.eye(n_neurons) - np.linalg.pinv(K) * K 

+

173 return K_null 

+

174 

+

175class SSKFDecoder(bmi.Decoder, bmi.BMI): 

+

176 ''' Docstring ''' 

+

177 pass 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_state_space_models_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_state_space_models_py.html new file mode 100644 index 00000000..131265d4 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_state_space_models_py.html @@ -0,0 +1,594 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\state_space_models.py: 30% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1#!/usr/bin/python 

+

2""" 

+

3State-space models for different types of BMIs (i.e. representations of various  

+

4BMI "plants"), and methods to manipulate the parameterizations of such models  

+

5e.g. time resampling of a linear discrete-time representation of a continuous-time model).  

+

6""" 

+

7import numpy as np 

+

8 

+

9############ 

+

10## Constants 

+

11############ 

+

12pi = np.pi 

+

13 

+

14 

+

15class State(object): 

+

16 ''' 

+

17 A 1D component of a state-space, e.g., vertical velocity 

+

18 ''' 

+

19 def __init__(self, name, stochastic=False, drives_obs=False, min_val=np.nan, max_val=np.nan, order=-1, aux=False): 

+

20 ''' 

+

21 Constructor for State 

+

22 

+

23 Parameters 

+

24 ---------- 

+

25 name : string 

+

26 Name of the state 

+

27 stochastic : bool, optional 

+

28 Specify whether the state is stochastic (estimated from observation)  

+

29 or deterministic (updated entirely by model). Default is 'deterministic' 

+

30 drives_obs : bool, optional 

+

31 Specify whether the state will be reflected in the observations if it is 

+

32 used as a 'hidden' state. By default, the state and any observations will not be directly related 

+

33 min_val : float, optional 

+

34 Hard (nonlinear) constraint on the minimum value of the state. By default (np.nan), no constraint is applied. 

+

35 max_val : float, optional 

+

36 Hard (nonlinear) constraint on the maximum value of the state. By default (np.nan), no constraint is applied. 

+

37 order : int 

+

38 Specification of the 'order' that this state would contribute to a differential equation,  

+

39 e.g., integrated position is order -1, position states are order 0, velocity states are order 1, acceleration is order 2, etc. 

+

40 Constant states should be order NaN 

+

41 

+

42 Returns 

+

43 ------- 

+

44 State instance 

+

45 ''' 

+

46 assert not name == 'q', "'q' is a reserved keyword (symbol for generalized robot coordinates) and cannot be used as a state name" 

+

47 self.name = name 

+

48 self.stochastic = stochastic 

+

49 self.drives_obs = drives_obs 

+

50 self.min_val = min_val 

+

51 self.max_val = max_val 

+

52 self.order = order 

+

53 self.aux = aux 

+

54 self._eq_comp_excl = [] 

+

55 

+

56 def __repr__(self): 

+

57 return str(self.name) 

+

58 

+

59 def __eq__(self, other): 

+

60 ''' 

+

61 State instances are equal if all their attributes (name, stochastic, etc.) are equal 

+

62 ''' 

+

63 if not isinstance(other, State): 

+

64 return False 

+

65 else: 

+

66 for x in self.__dict__: 

+

67 if x in other.__dict__: 

+

68 if not (x == '_eq_comp_excl') and not (x in self._eq_comp_excl) and not (x in other._eq_comp_excl): 

+

69 if not (self.__dict__[x] == other.__dict__[x] or (np.isnan(self.__dict__[x]) and np.isnan(other.__dict__[x]))): 

+

70 print(self, other, x) 

+

71 import pdb; pdb.set_trace() 

+

72 return False 

+

73 return True 

+

74 

+

75 def __setstate__(self, data): 

+

76 self.__dict__ = data 

+

77 if '_eq_comp_excl' not in self.__dict__: 

+

78 self._eq_comp_excl = [] 

+

79 

+

80 

+

81class StateSpace(object): 

+

82 ''' 

+

83 A collection of multiple 'State' instances forms a StateSpace 

+

84 ''' 

+

85 def __init__(self, *states, **kwargs): 

+

86 ''' 

+

87 Constructor for StateSpace 

+

88 

+

89 Parameters 

+

90 ---------- 

+

91 states : packed tuple 

+

92 State instances specified in comma-separated arguments 

+

93 

+

94 Returns 

+

95 ------- 

+

96 StateSpace instance 

+

97 ''' 

+

98 if 'statelist' in kwargs: 

+

99 self.states = kwargs['statelist'] 

+

100 else: 

+

101 self.states = list(states) 

+

102 

+

103 def __len__(self): 

+

104 return len(self.states) 

+

105 

+

106 def __repr__(self): 

+

107 return 'State space: ' + str(self.state_names) 

+

108 

+

109 @property 

+

110 def is_stochastic(self): 

+

111 ''' 

+

112 An array of booleans specifying each state as stochastic 

+

113 ''' 

+

114 return np.array([x.stochastic for x in self.states]) 

+

115 

+

116 @property 

+

117 def drives_obs(self): 

+

118 ''' 

+

119 An array of booleans specifying each state as observation-driving 

+

120 ''' 

+

121 return np.array([x.drives_obs for x in self.states]) 

+

122 

+

123 @property 

+

124 def is_aux_state(self): 

+

125 return np.array([x.aux for x in self.states]) 

+

126 

+

127 @property 

+

128 def state_names(self): 

+

129 ''' 

+

130 A list of string names for each state 

+

131 ''' 

+

132 return [x.name for x in self.states] 

+

133 

+

134 @property 

+

135 def bounding_box(self): 

+

136 ''' 

+

137 A tuple of min values and max values for each state 

+

138 ''' 

+

139 min_bounds = np.array([x for x in [x.min_val for x in self.states] if x is not np.nan]) 

+

140 max_bounds = np.array([x for x in [x.max_val for x in self.states] if x is not np.nan]) 

+

141 return (min_bounds, max_bounds) 

+

142 

+

143 @property 

+

144 def states_to_bound(self): 

+

145 ''' 

+

146 A list of the names of all the states which have limits on the values they can take. 

+

147 ''' 

+

148 return [x.name for x in [x for x in self.states if x.min_val is not np.nan]] 

+

149 

+

150 @property 

+

151 def n_states(self): 

+

152 ''' 

+

153 Number of states in the space 

+

154 ''' 

+

155 return len(self.states) 

+

156 

+

157 @property 

+

158 def train_inds(self): 

+

159 ''' 

+

160 An array of  

+

161 ''' 

+

162 return [k for k in range(self.n_states) if self.states[k].stochastic] 

+

163 

+

164 @property 

+

165 def drives_obs_inds(self): 

+

166 ''' 

+

167 A list of the indices of the states which are related to observations when  

+

168 used as a hidden state-space. Used when seeding Decoders 

+

169 ''' 

+

170 return [k for k in range(self.n_states) if self.states[k].drives_obs] 

+

171 

+

172 @property 

+

173 def state_order(self): 

+

174 ''' 

+

175 An array listing the 'order' of each state (see State.__init__ for description of 'order') 

+

176 ''' 

+

177 return np.array([x.order for x in self.states]) 

+

178 

+

179 def get_ssm_matrices(self, *args, **kwargs): 

+

180 ''' 

+

181 Returns the parameters of the composite state-space models for use in Decoders.  

+

182 Must be overridden in child classes as there is no way to specify this generically. 

+

183 ''' 

+

184 raise NotImplementedError 

+

185 

+

186 def __eq__(self, other): 

+

187 ''' 

+

188 State spaces are equal if the all the states are equal and all the states are listed in the same order 

+

189 ''' 

+

190 if not isinstance(other, StateSpace): 

+

191 return False 

+

192 else: 

+

193 return self.states == other.states 

+

194 

+

195 

+

196class LinearVelocityStateSpace(StateSpace): 

+

197 def __init__(self, states, vel_decay=0.8, w=7, Delta=0.1): 

+

198 self.states = states 

+

199 self.vel_decay = vel_decay 

+

200 self.w = w 

+

201 self.Delta = Delta 

+

202 

+

203 # check that there are an equal number of pos and vel states 

+

204 assert len(np.nonzero(self.state_order == 0)[0]) == len(np.nonzero(self.state_order == 1)[0]) 

+

205 

+

206 def __setstate__(self, state): 

+

207 self.__dict__ = state 

+

208 if not hasattr(self, 'Delta'): 

+

209 self.Delta = 0.1 

+

210 

+

211 if not hasattr(self, 'vel_decay'): 

+

212 self.vel_decay = 0.8 

+

213 

+

214 if not hasattr(self, 'w'): 

+

215 self.w = 7 

+

216 

+

217 def get_ssm_matrices(self, update_rate=0.1): 

+

218 ''' 

+

219 For the linear stochastic state-space model  

+

220 x_{t+1} = Ax_{t} + Bu_t + w_t; w_t ~ N(0, W), 

+

221 this function specifies the matrices A, B and W 

+

222 

+

223 A = [I_N \Delta I_N 0 

+

224 0_N a*I_N 0 

+

225 0 0 1] 

+

226 

+

227 W = [0_N 0_N 0 

+

228 0_N w*I_N 0 

+

229 0 0 0] 

+

230 

+

231 B = [0_N 

+

232 1000\Delta I_N 

+

233 0] 

+

234 

+

235 Parameters 

+

236 ---------- 

+

237 update_rate : float, optional 

+

238 Time between iterations of the discrete-time model. Default is 0.1 sec. 

+

239 

+

240 Returns 

+

241 ------- 

+

242 tuple of 3 np.mat matrices 

+

243 A, B and W as specified in the mathematical model above 

+

244 ''' 

+

245 if not (update_rate is None): 

+

246 a_resamp, w_resamp = resample_scalar_ssm(self.vel_decay, self.w, Delta_old=self.Delta, Delta_new=update_rate) 

+

247 Delta = update_rate 

+

248 else: 

+

249 a_resamp = self.vel_decay 

+

250 w_resamp = self.w 

+

251 Delta = self.Delta 

+

252 

+

253 ndim = len(np.nonzero(self.state_order == 1)[0]) 

+

254 A = _gen_A(1, Delta, 0, a_resamp, 1, ndim=ndim) 

+

255 W = _gen_A(0, 0, 0, w_resamp, 0, ndim=ndim) 

+

256 

+

257 # Control input matrix for SSM for control inputs 

+

258 I = np.mat(np.eye(ndim)) 

+

259 B = np.vstack([0*I, Delta*1000 * I, np.zeros([1, ndim])]) 

+

260 

+

261 # account for offset state 

+

262 has_offset = self.states[-1] == offset_state 

+

263 # has_offset = np.isnan(self.states[-1].order) 

+

264 if not has_offset: 

+

265 A = A[:-1, :-1] 

+

266 W = W[:-1, :-1] 

+

267 B = B[:-1, :] 

+

268 

+

269 return A, B, W 

+

270 

+

271 def __eq__(self, other): 

+

272 states_equal = super(LinearVelocityStateSpace, self).__eq__(other) 

+

273 A1, B1, W1 = self.get_ssm_matrices() 

+

274 A2, B2, W2 = other.get_ssm_matrices() 

+

275 # import pdb; pdb.set_trace() 

+

276 return states_equal and np.array_equal(A1, A2) and np.array_equal(B1, B2) and np.array_equal(W1, W2) 

+

277 

+

278 

+

279####################################################################### 

+

280##### Specific StateSpace types for particular experiments/plants ##### 

+

281####################################################################### 

+

282# These class declarations may not actually be best placed in this class,  

+

283# but moving them now would cause problems with unpickling older decoder  

+

284# objects, which are saved with these state space model types. So put new  

+

285# state-space models elsewhere! 

+

286offset_state = State('offset', stochastic=False, drives_obs=True, order=np.nan) 

+

287offset_state._eq_comp_excl.append('order') 

+

288 

+

289class StateSpaceNLinkPlanarChain(LinearVelocityStateSpace): 

+

290 ''' 

+

291 State-space model for an N-link kinematic chain 

+

292 ''' 

+

293 def __init__(self, n_links=2, **kwargs): 

+

294 self.n_links = n_links 

+

295 pos_states = [] 

+

296 vel_states = [] 

+

297 

+

298 for k in range(n_links): 

+

299 pos_state_k = State('theta_%d' % k, stochastic=False, drives_obs=False, min_val=-pi, max_val=0, order=0) 

+

300 vel_state_k = State('omega_%d' % k, stochastic=True, drives_obs=True, order=1) 

+

301 pos_states.append(pos_state_k) 

+

302 vel_states.append(vel_state_k) 

+

303 

+

304 states = pos_states + vel_states + [offset_state] 

+

305 super(StateSpaceNLinkPlanarChain, self).__init__(states, **kwargs) 

+

306 

+

307 def __setstate__(self, state): 

+

308 self.__dict__ = state 

+

309 if not hasattr(self, 'Delta'): 

+

310 self.Delta = 0.1 

+

311 

+

312 if not hasattr(self, 'vel_decay'): 

+

313 self.vel_decay = 0.8 

+

314 

+

315 if not hasattr(self, 'w'): 

+

316 self.w = 0.01 

+

317 

+

318class StateSpaceEndptVel2D(LinearVelocityStateSpace): 

+

319 ''' 

+

320 StateSpace with 2D velocity in the X-Z plane 

+

321 ''' 

+

322 def __init__(self, **kwargs): 

+

323 states = [ 

+

324 State('hand_px', stochastic=False, drives_obs=False, min_val=-25., max_val=25., order=0), 

+

325 State('hand_py', stochastic=False, drives_obs=False, order=0), 

+

326 State('hand_pz', stochastic=False, drives_obs=False, min_val=-14., max_val=14., order=0), 

+

327 State('hand_vx', stochastic=True, drives_obs=True, order=1), 

+

328 State('hand_vy', stochastic=False, drives_obs=False, order=1), 

+

329 State('hand_vz', stochastic=True, drives_obs=True, order=1), 

+

330 offset_state] 

+

331 super(StateSpaceEndptVel2D, self).__init__(states, **kwargs) 

+

332 

+

333 def __setstate__(self, state): 

+

334 self.__dict__ = state 

+

335 if not hasattr(self, 'Delta'): 

+

336 self.Delta = 0.1 

+

337 

+

338 if not hasattr(self, 'vel_decay'): 

+

339 self.vel_decay = 0.8 

+

340 

+

341 if not hasattr(self, 'w'): 

+

342 self.w = 7 

+

343 

+

344class StateSpaceEndptVel3D(LinearVelocityStateSpace): 

+

345 def __init__(self, **kwargs): 

+

346 states = [ 

+

347 State('hand_px', stochastic=False, drives_obs=False, min_val=-25., max_val=25., order=0), 

+

348 State('hand_py', stochastic=False, drives_obs=False, order=0), 

+

349 State('hand_pz', stochastic=False, drives_obs=False, min_val=-14., max_val=14., order=0), 

+

350 State('hand_vx', stochastic=True, drives_obs=True, order=1), 

+

351 State('hand_vy', stochastic=True, drives_obs=True, order=1), 

+

352 State('hand_vz', stochastic=True, drives_obs=True, order=1), 

+

353 offset_state] 

+

354 super(StateSpaceEndptVel3D, self).__init__(states, **kwargs) 

+

355 def __setstate__(self, state): 

+

356 self.__dict__ = state 

+

357 if not hasattr(self, 'Delta'): 

+

358 self.Delta = 0.1 

+

359 

+

360 if not hasattr(self, 'vel_decay'): 

+

361 self.vel_decay = 0.8 

+

362 

+

363 if not hasattr(self, 'w'): 

+

364 self.w = 7 

+

365 

+

366class StateSpaceEndptPos1D(StateSpace): 

+

367 ''' StateSpace for 1D pos control (e.g. RatBMI)''' 

+

368 def __init__(self, **kwargs): 

+

369 states = State('cursor_p', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0) 

+

370 

+

371 super(StateSpaceEndptPos1D, self).__init__(states, **kwargs) 

+

372 

+

373 def __setstate__(self, state): 

+

374 self.__dict__ = state 

+

375 if not hasattr(self, 'Delta'): 

+

376 self.Delta = 0.1 

+

377 

+

378 if not hasattr(self, 'vel_decay'): 

+

379 self.vel_decay = 0.8 

+

380 

+

381 if not hasattr(self, 'w'): 

+

382 self.w = 7 

+

383 

+

384class StateSpaceEndptPos3D(StateSpace): 

+

385 ''' StateSpace for 3D pos control''' 

+

386 def __init__(self, **kwargs): 

+

387 self.states = [ 

+

388 State('hand_px', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0), 

+

389 State('hand_py', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0), 

+

390 State('hand_pz', stochastic=False, drives_obs=True, min_val=-10e6, max_val=10e6, order=0) 

+

391 ] 

+

392 

+

393 def __setstate__(self, state): 

+

394 self.__dict__ = state 

+

395 if not hasattr(self, 'Delta'): 

+

396 self.Delta = 0.1 

+

397 

+

398 if not hasattr(self, 'vel_decay'): 

+

399 self.vel_decay = 0.8 

+

400 

+

401 if not hasattr(self, 'w'): 

+

402 self.w = 7 

+

403 

+

404############################ 

+

405##### Helper functions ##### 

+

406############################ 

+

407def resample_ssm(A, W, Delta_old=0.1, Delta_new=0.005, include_offset=True): 

+

408 ''' 

+

409 Change the effective sampling rate of a linear random-walk discrete-time state-space model 

+

410 That is, state-space models of the form  

+

411 

+

412 x_{t+1} = Ax_t + w_t, w_t ~ N(0, W) 

+

413 

+

414 Parameters 

+

415 ---------- 

+

416 A : np.mat of shape (K, K) 

+

417 State transition model 

+

418 W : np.mat of shape (K, K) 

+

419 Noise covariance estimate 

+

420 Delta_old : float, optional, default=0.1 

+

421 Old sampling rate 

+

422 Delta_new : float, optional, default=0.005 

+

423 New sampling rate 

+

424 include_offset : bool, optional, default=True 

+

425 Indicates whether the state-space matrices  

+

426 

+

427 Returns 

+

428 ------- 

+

429 A_new, W_new 

+

430 New state-space model parameters at the new sampling rate. 

+

431 

+

432 ''' 

+

433 A = A.copy() 

+

434 W = W.copy() 

+

435 if include_offset: 

+

436 orig_nS = A.shape[0] 

+

437 A = A[:-1, :-1] 

+

438 W = W[:-1, :-1] 

+

439 

+

440 loop_ratio = Delta_new/Delta_old 

+

441 N = 1./loop_ratio 

+

442 A_new = A**loop_ratio 

+

443 nS = A.shape[0] 

+

444 I = np.mat(np.eye(nS)) 

+

445 W_new = W * ( (I - A_new**N) * (I - A_new).I - I).I 

+

446 if include_offset: 

+

447 A_expand = np.mat(np.zeros([orig_nS, orig_nS])) 

+

448 A_expand[:-1,:-1] = A_new 

+

449 A_expand[-1,-1] = 1 

+

450 W_expand = np.mat(np.zeros([orig_nS, orig_nS])) 

+

451 W_expand[:-1,:-1] = W_new 

+

452 return A_expand, W_expand 

+

453 else: 

+

454 return A_new, W_new 

+

455 

+

456def resample_scalar_ssm(a, w, Delta_old=0.1, Delta_new=0.005): 

+

457 ''' 

+

458 Similar to resample_ssm, but for a scalar (1-d) state-space model,  

+

459 where the problem can be solved without complicated matrix roots 

+

460 

+

461 Parameters 

+

462 ---------- 

+

463 a : float 

+

464 State transition model 

+

465 w : float  

+

466 Noise variance estimate 

+

467 Delta_old : float, optional, default=0.1 

+

468 Old sampling rate 

+

469 Delta_new : float, optional, default=0.005 

+

470 New sampling rate 

+

471 

+

472 Returns 

+

473 ------- 

+

474 a_new, w_new 

+

475 New state-space model parameters at the new sampling rate. 

+

476 

+

477 ''' 

+

478 loop_ratio = Delta_new/Delta_old 

+

479 a_delta_new = a**loop_ratio 

+

480 w_delta_new = w / ((1-a_delta_new**(2*(1./loop_ratio)))/(1- a_delta_new**2)) 

+

481 

+

482 mu = 1 

+

483 sigma = 0 

+

484 for k in range(int(1./loop_ratio)): 

+

485 mu = a_delta_new*mu 

+

486 sigma = a_delta_new * sigma * a_delta_new + w_delta_new 

+

487 return a_delta_new, w_delta_new 

+

488 

+

489def _gen_A(t, s, m, n, off, ndim=3): 

+

490 """ 

+

491 Utility function for generating block-diagonal matrices 

+

492 used by the KF 

+

493  

+

494 [t*I, s*I, 0 

+

495 m*I, n*I, 0 

+

496 0, 0, off] 

+

497 

+

498 Parameters 

+

499 ---------- 

+

500 t : float  

+

501 See matrix equation above 

+

502 s : float  

+

503 See matrix equation above 

+

504 m : float  

+

505 See matrix equation above 

+

506 n : float  

+

507 See matrix equation above 

+

508 off : float  

+

509 See matrix equation above 

+

510 ndim : int, optional, default = 3  

+

511 Number of states in each block, e.g. 3-states for (x,y,z) position 

+

512 

+

513 Returns 

+

514 ------- 

+

515 np.mat of shape (N, N); N = 2*ndim + 1 

+

516 """ 

+

517 A = np.zeros([2*ndim+1, 2*ndim+1]) 

+

518 A_lower_dim = np.array([[t, s], [m, n]]) 

+

519 A[0:2*ndim, 0:2*ndim] = np.kron(A_lower_dim, np.eye(ndim)) 

+

520 A[-1,-1] = off 

+

521 return np.mat(A) 

+

522 

+

523 

+

524if __name__ == '__main__': 

+

525 a_10hz = 0.8 

+

526 w_10hz = 0.0007 

+

527 

+

528 Delta_old = 0.1 

+

529 Delta_new = 1./60 

+

530 a_60hz, w_60hz = resample_scalar_ssm(a_10hz, w_10hz, Delta_old=Delta_old, Delta_new=Delta_new) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_train_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_train_py.html new file mode 100644 index 00000000..aaab8f84 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_bmi_train_py.html @@ -0,0 +1,1476 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\bmi\train.py: 7% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Methods to create various types of Decoder objects from data(files) 

+

3''' 

+

4import re 

+

5import pickle 

+

6import sys 

+

7 

+

8import numpy as np 

+

9from scipy.io import loadmat 

+

10from riglib.dio import parse 

+

11 

+

12import tables 

+

13from . import kfdecoder, ppfdecoder 

+

14import pdb 

+

15from . import state_space_models 

+

16 

+

17 

+

18import stat 

+

19import os 

+

20import subprocess 

+

21 

+

22############ 

+

23## Constants 

+

24############ 

+

25pi = np.pi 

+

26 

+

27################################################################################ 

+

28## Functions to synchronize task-generated HDF files and neural recording files 

+

29################################################################################ 

+

30def sys_eq(sys1, sys2): 

+

31 ''' 

+

32 Determine if two system strings match. A separate function is required because sometimes 

+

33 the NIDAQ card doesn't properly transmit the first character of the name of the system. 

+

34 

+

35 Parameters 

+

36 ---------- 

+

37 sys1: string 

+

38 Name of system from the neural file 

+

39 sys2: string 

+

40 Name of system to match 

+

41 

+

42 Returns 

+

43 ------- 

+

44 Boolean indicating whether sys1 and sys2 match 

+

45 ''' 

+

46 if sys2 == 'task': 

+

47 if sys1 in ['TAS\x00TASK', 'btqassskh', 'btqassskkkh', 'tasktasktask', 'task\x00task\x00task']: 

+

48 return True 

+

49 elif sys1[:4] in ['tqas', 'tacs','ttua', 'bttu', 'tttu']: 

+

50 return True 

+

51 

+

52 return sys1 in [sys2, sys2[1:], sys2.upper()] 

+

53 

+

54 

+

55#FAKE_BLACKROCK_TMASK = True 

+

56FAKE_BLACKROCK_TMASK = False 

+

57 

+

58######################################## 

+

59## Neural data synchronization functions 

+

60######################################## 

+

61def _get_tmask(files, tslice, sys_name='task'): 

+

62 if 'plexon' in files: 

+

63 fn = _get_tmask_plexon 

+

64 fname = str(files['plexon']) 

+

65 elif 'blackrock' in files: 

+

66 if FAKE_BLACKROCK_TMASK: 

+

67 fn = _get_tmask_blackrock_fake 

+

68 fname = files['hdf'] 

+

69 else: 

+

70 fn = _get_tmask_blackrock 

+

71 fname = [str(name) for name in files['blackrock'] if '.nev' in name][0] # only one of them 

+

72 else: 

+

73 raise Exception("Neural data file(s) not found!") 

+

74 

+

75 return fn(fname, tslice, sys_name=sys_name) 

+

76 

+

77def _get_tmask_plexon(plx, tslice, sys_name='task'): 

+

78 ''' 

+

79 Find the rows of the plx file to use for training the decoder 

+

80 

+

81 Parameters 

+

82 ---------- 

+

83 plx : plexfile instance 

+

84 The plexon file to sync 

+

85 tslice : list of length 2 

+

86 Specify the start and end time to examine the file, in seconds 

+

87 sys_name : string, optional 

+

88 The "system" being synchronized. When the task is running, each data source  

+

89 (i.e., each HDF table) is allowed to be asynchronous and thus is independently  

+

90 synchronized with the neural recording system. 

+

91 

+

92 Returns 

+

93 ------- 

+

94 tmask: np.ndarray of shape (N, ) of booleans 

+

95 Specifies which entries of "rows" (see below) are within the time bounds 

+

96 rows: np.ndarray of shape (N, ) of integers 

+

97 The times at which rows of the specified HDF table were recieved in the neural recording box 

+

98 ''' 

+

99 # Open plx file 

+

100 from plexon import plexfile 

+

101 if isinstance(plx, str) or isinstance(plx, str): 

+

102 plx = plexfile.openFile(plx) 

+

103 

+

104 # Get the list of all the systems registered in the neural data file 

+

105 events = plx.events[:].data 

+

106 reg = parse.registrations(events) 

+

107 

+

108 if len(list(reg.keys())) > 0: 

+

109 # find the key for the specified system data 

+

110 syskey = None 

+

111 for key, system in list(reg.items()): 

+

112 if sys_eq(system[0], sys_name): 

+

113 syskey = key 

+

114 break 

+

115 

+

116 if syskey is None: 

+

117 print((list(reg.items()))) 

+

118 raise Exception('riglib.bmi.train._get_tmask: Training data source not found in neural data file!') 

+

119 elif len(list(reg.keys())) == 0: 

+

120 # try to find how many systems' rowbytes were in the HDF file 

+

121 rowbyte_data = parse.rowbyte(events) 

+

122 if len(list(rowbyte_data.keys())) == 1: 

+

123 print("No systems registered, but only one system registered with rowbytes! Using it anyway instead of throwing an error") 

+

124 syskey = list(rowbyte_data.keys())[0] 

+

125 else: 

+

126 raise Exception("No systems registered and I don't know which sys to use to train!") 

+

127 

+

128 # get the corresponding hdf rows 

+

129 rows = parse.rowbyte(events)[syskey][:,0] 

+

130 

+

131 # Determine which rows are within the time bounds 

+

132 lower, upper = 0 < rows, rows < rows.max() + 1 

+

133 l, u = tslice 

+

134 if l is not None: 

+

135 lower = l < rows 

+

136 if u is not None: 

+

137 upper = rows < u 

+

138 tmask = np.logical_and(lower, upper) 

+

139 return tmask, rows 

+

140 

+

141def _get_tmask_blackrock(nev_fname, tslice, sys_name='task'): 

+

142 ''' Find the rows of the nev file to use for training the decoder.''' 

+

143 

+

144 if nev_fname[-4:] != '.hdf': 

+

145 nev_hdf_fname = nev_fname + '.hdf' 

+

146 

+

147 if not os.path.isfile(nev_hdf_fname): 

+

148 # convert .nev file to hdf file using our own blackrock_parse_files: 

+

149 from db.tracker import models 

+

150 task_entry = int(nev_fname[-8:-4]) 

+

151 _, _ = models.parse_blackrock_file(nev_fname, 0, task_entry) 

+

152 else: 

+

153 nev_hdf_fname = nev_fname 

+

154 

+

155 #import h5py 

+

156 #nev_hdf = h5py.File(nev_hdf_fname, 'r') 

+

157 nev_hdf = tables.openFile(nev_hdf_fname) 

+

158 

+

159 #path = 'channel/digital00001/digital_set' 

+

160 #ts = nev_hdf.get(path).value['TimeStamp'] 

+

161 #msgs = nev_hdf.get(path).value['Value'] 

+

162 

+

163 ts = nev_hdf.root.channel.digital0001.digital_set[:]['TimeStamp'] 

+

164 msgs = nev_hdf.root.channel.digital0001.digital_set[:]['Value'] + 2**16 

+

165 

+

166 msgtype = np.right_shift(np.bitwise_and(msgs, parse.msgtype_mask), 8).astype(np.uint8) 

+

167 # auxdata = np.right_shift(np.bitwise_and(msgs, auxdata_mask), 8).astype(np.uint8) 

+

168 auxdata = np.right_shift(np.bitwise_and(msgs, parse.auxdata_mask), 8+3).astype(np.uint8) 

+

169 rawdata = np.bitwise_and(msgs, parse.rawdata_mask) 

+

170 

+

171 # data is an N x 4 matrix that will be the argument to parse.registrations() 

+

172 data = np.vstack([ts, msgtype, auxdata, rawdata]).T 

+

173 

+

174 # get system registrations 

+

175 reg = parse.registrations(data) 

+

176 syskey = None 

+

177 

+

178 for key, system in list(reg.items()): 

+

179 if sys_eq(system[0], sys_name): 

+

180 syskey = key 

+

181 break 

+

182 

+

183 if syskey is None: 

+

184 raise Exception('No source registration saved in the file!') 

+

185 

+

186 # get the corresponding hdf rows 

+

187 rows = parse.rowbyte(data)[syskey][:,0] 

+

188 

+

189 rows = rows / 30000. 

+

190 

+

191 lower, upper = 0 < rows, rows < rows.max() + 1 

+

192 if tslice is None: 

+

193 l = None; 

+

194 u = None; 

+

195 else: 

+

196 l, u = tslice 

+

197 

+

198 if l is not None: 

+

199 lower = l < rows 

+

200 if u is not None: 

+

201 upper = rows < u 

+

202 tmask = np.logical_and(lower, upper) 

+

203 

+

204 return tmask, rows 

+

205 

+

206def _get_tmask_blackrock_fake(hdf_fname, tslice, **kwargs): 

+

207 # need to create fake "rows" and "tmask" variables 

+

208 

+

209 print('WARNING: Using _get_tmask_blackrock_fake function!!') 

+

210 

+

211 binlen = 0.1 

+

212 strobe_rate = 10 

+

213 hdf = tables.openFile(hdf_fname) 

+

214 

+

215 n_rows = hdf.root.task[:]['plant_pos'].shape[0] 

+

216 first_ts = binlen 

+

217 rows = np.linspace(first_ts, first_ts + (n_rows-1)*(1./strobe_rate), num=n_rows) 

+

218 lower, upper = 0 < rows, rows < rows.max() + 1 

+

219 l, u = tslice 

+

220 if l is not None: 

+

221 lower = l < rows 

+

222 if u is not None: 

+

223 upper = rows < u 

+

224 tmask = np.logical_and(lower, upper) 

+

225 

+

226 return tmask, rows 

+

227 

+

228################################################################################ 

+

229## Feature extraction 

+

230################################################################################ 

+

231def _get_neural_features_plx(files, binlen, extractor_fn, extractor_kwargs, tslice=None, units=None, source='task', strobe_rate=60.): 

+

232 ''' 

+

233 Extract the neural features used to train the decoder 

+

234 

+

235 Parameters 

+

236 ---------- 

+

237 files: dict 

+

238 keys of the dictionary are file-types (e.g., hdf, plexon, etc.), values are file names 

+

239 binlen: float 

+

240 Specifies the temporal resolution of the feature extraction 

+

241 extractor_fn: callable 

+

242 Function must have the call signature  

+

243 neural_features, units, extractor_kwargs = extractor_fn(plx, neurows, binlen, units, extractor_kwargs) 

+

244 extractor_kwargs: dictionary 

+

245 Additional keyword arguments to the extractor_fn (specific to each feature extractor) 

+

246 

+

247 Returns 

+

248 ------- 

+

249 neural_features: np.ndarrya of shape (n_features, n_timepoints) 

+

250 Values of each feature to be used in training the decoder parameters 

+

251 units: np.ndarray of shape (N, -1) 

+

252 Specifies identty of each neural feature 

+

253 extractor_kwargs: dictionary 

+

254 Keyword arguments used to construct the feature extractor used online 

+

255 ''' 

+

256 

+

257 hdf = tables.openFile(files['hdf']) 

+

258 

+

259 plx_fname = str(files['plexon']) 

+

260 from plexon import plexfile 

+

261 try: 

+

262 plx = plexfile.openFile(plx_fname) 

+

263 except IOError: 

+

264 raise Exception("Could not open .plx file: %s" % plx_fname) 

+

265 

+

266 # Use all of the units if none are specified 

+

267 if units is None: 

+

268 units = np.array(plx.units).astype(np.int32) 

+

269 

+

270 if tslice is None: 

+

271 tslice = (1., plx.length-1) 

+

272 

+

273 tmask, rows = _get_tmask_plexon(plx, tslice, sys_name=source) 

+

274 neurows = rows[tmask] 

+

275 

+

276 neural_features, units, extractor_kwargs = extractor_fn(files, neurows, binlen, units, extractor_kwargs) 

+

277 

+

278 return neural_features, units, extractor_kwargs 

+

279 

+

280def _get_neural_features_blackrock(files, binlen, extractor_fn, extractor_kwargs, tslice=None, units=None, source='task', strobe_rate=20.): 

+

281 if units is None: 

+

282 raise Exception('"units" variable is None in preprocess_files!') 

+

283 

+

284 # Note: blackrock units are actually 0-based, but the units to be used for training 

+

285 # (which comes from web interface) are 1-based; to account for this, add 1 

+

286 # to unit numbers when reading from .nev file 

+

287 

+

288 # notes: 

+

289 # tmask --> logical vector of same length as rows that is True for rows inside the tslice 

+

290 # rows --> times (in units of s, measured on neural system) that correspond to each row of the task in hdf file 

+

291 # kin --> every 6th row of kinematics within the tslice boundaries 

+

292 # neurows --> the rows inside the tslice 

+

293 

+

294 if FAKE_BLACKROCK_TMASK: 

+

295 tmask, rows = _get_tmask_blackrock_fake(files['hdf'], tslice) 

+

296 else: 

+

297 nev_fname = [name for name in files['blackrock'] if '.nev' in name][0] # only one of them 

+

298 

+

299 #tmask, rows = _get_tmask_blackrock(nev_fname, tslice, syskey_fn=lambda x: x[0] in [source, source[1:]])  

+

300 

+

301 tmask, rows = _get_tmask_blackrock(nev_fname, tslice, sys_name=source) 

+

302 neurows = rows[tmask] 

+

303 

+

304 neural_features, units, extractor_kwargs = extractor_fn(files, neurows, binlen, units, extractor_kwargs, strobe_rate=strobe_rate) 

+

305 

+

306 return neural_features, units, extractor_kwargs 

+

307 

+

308def _get_neural_features_tdt(files, binlen, extractor_fn, extractor_kwargs, tslice=None, units=None, source='task', strobe_rate=10.): 

+

309 raise NotImplementedError 

+

310 

+

311def get_neural_features(files, binlen, extractor_fn, extractor_kwargs, units=None, tslice=None, source='task', strobe_rate=60): 

+

312 ''' 

+

313 Docstring 

+

314 

+

315 Parameters 

+

316 ---------- 

+

317 

+

318 Returns 

+

319 ------- 

+

320 ''' 

+

321 

+

322 hdf = tables.openFile(files['hdf']) 

+

323 

+

324 if 'plexon' in files: 

+

325 fn = _get_neural_features_plx 

+

326 elif 'blackrock' in files: 

+

327 fn = _get_neural_features_blackrock 

+

328 strobe_rate = 20. 

+

329 elif 'tdt' in files: 

+

330 fn = _get_neural_features_tdt 

+

331 else: 

+

332 raise Exception('Could not find any recognized neural data files!') 

+

333 

+

334 neural_features, units, extractor_kwargs = fn(files, binlen, extractor_fn, extractor_kwargs, tslice=tslice, units=units, source=source, strobe_rate=strobe_rate) 

+

335 

+

336 return neural_features, units, extractor_kwargs 

+

337 

+

338################################################################################ 

+

339## Kinematic data retrieval 

+

340################################################################################ 

+

341def null_kin_extractor(files, binlen, tmask, update_rate_hz=60., pos_key='cursor', vel_key=None): 

+

342 hdf = tables.openFile(files['hdf']) 

+

343 kin = np.squeeze(hdf.root.task[:][pos_key]) 

+

344 

+

345 inds, = np.nonzero(tmask) 

+

346 step_fl = binlen/(1./update_rate_hz) 

+

347 if step_fl < 1: # more than one spike bin per kinematic obs 

+

348 kin = np.hstack([kin, velocity]) 

+

349 

+

350 n_repeats = int((1./update_rate_hz)/binlen) 

+

351 inds = np.sort(np.hstack([inds]*n_repeats)) 

+

352 kin = kin[inds] 

+

353 else: 

+

354 step = int(binlen/(1./update_rate_hz)) 

+

355 inds = inds[::step] 

+

356 kin = kin[inds] 

+

357 

+

358 print(("kin.shape", kin.shape)) 

+

359 return kin 

+

360 

+

361 

+

362def get_plant_pos_vel(files, binlen, tmask, update_rate_hz=60., pos_key='cursor', vel_key=None): 

+

363 ''' 

+

364 Get positions and velocity from 'task' table of HDF file 

+

365 

+

366 Parameters 

+

367 ---------- 

+

368 

+

369 Returns 

+

370 ------- 

+

371 ''' 

+

372 if pos_key == 'plant_pos': # used for ibmi tasks 

+

373 vel_key = 'plant_vel' 

+

374 

+

375 hdf = tables.openFile(files['hdf']) 

+

376 kin = hdf.root.task[:][pos_key] 

+

377 

+

378 inds, = np.nonzero(tmask) 

+

379 step_fl = binlen/(1./update_rate_hz) 

+

380 if step_fl < 1: # more than one spike bin per kinematic obs 

+

381 if vel_key is not None: 

+

382 velocity = hdf.root.task[:][vel_key] 

+

383 else: 

+

384 velocity = np.diff(kin, axis=0) * update_rate_hz 

+

385 velocity = np.vstack([np.zeros(kin.shape[1]), velocity]) 

+

386 kin = np.hstack([kin, velocity]) 

+

387 

+

388 n_repeats = int((1./update_rate_hz)/binlen) 

+

389 inds = np.sort(np.hstack([inds]*n_repeats)) 

+

390 kin = kin[inds] 

+

391 else: 

+

392 step = int(binlen/(1./update_rate_hz)) 

+

393 inds = inds[::step] 

+

394 try: 

+

395 kin = kin[inds] 

+

396 if vel_key is not None: 

+

397 velocity = hdf.root.task[inds][vel_key] 

+

398 else: 

+

399 velocity = np.diff(kin, axis=0) * 1./binlen 

+

400 velocity = np.vstack([np.zeros(kin.shape[1]), velocity]) 

+

401 

+

402 except: 

+

403 kin2 = np.zeros((len(inds), kin.shape[1])) 

+

404 vel2 = np.zeros((len(inds), kin.shape[1])) 

+

405 

+

406 ix = np.nonzero(inds < len(kin))[0] 

+

407 kin2[ix, :] = kin[inds[ix], :] 

+

408 kin = kin2.copy() 

+

409 

+

410 if vel_key is not None: 

+

411 vel2[ix, :] = hdf.root.task[inds[ix]][vel_key] 

+

412 else: 

+

413 vel2 = np.diff(kin, axis=0) * 1./binlen 

+

414 vel2 = np.vstack([np.zeros(kin.shape[1]), vel2]) 

+

415 

+

416 velocity = vel2.copy() 

+

417 

+

418 kin = np.hstack([kin, velocity]) 

+

419 

+

420 return kin 

+

421 

+

422 

+

423################################################################################ 

+

424## Main training functions 

+

425################################################################################ 

+

426def create_onedimLFP(files, extractor_cls, extractor_kwargs, kin_extractor, ssm, units, update_rate=0.1, tslice=None, kin_source='task', 

+

427 pos_key='cursor', vel_key=None, zscore=False): 

+

428 ## get neural features 

+

429 from . import extractor 

+

430 f_extractor = extractor.LFPMTMPowerExtractor(None, **extractor_kwargs) 

+

431 from . import onedim_lfp_decoder as old 

+

432 return old.create_decoder(units, ssm, extractor_cls, f_extractor.extractor_kwargs) 

+

433 

+

434def test_ratBMIdecoder(te_id=None, update_rate=0.1, tslice=None, kin_source='task', pos_key='cursor', vel_key=None, **kwargs): 

+

435 from db import dbfunctions as dbfn 

+

436 te = dbfn.TaskEntry(te_id) 

+

437 files = dict(hdf=te.hdf_filename, plexon=te.plx_filename) 

+

438 

+

439 entry = te.id 

+

440 from . import extractor 

+

441 extractor_cls = extractor.BinnedSpikeCountsExtractor 

+

442 

+

443 neural_features, units, extractor_kwargs = get_neural_features(files, 0.1, extractor_cls.extract_from_file, dict(), tslice=None) 

+

444 extractor_kwargs['units'] = units 

+

445 

+

446 from . import rat_bmi_decoder 

+

447 nsteps = kwargs.pop('nsteps', 10) 

+

448 prob_t1 = kwargs.pop('prob_t1', 0.985) 

+

449 prob_t2 = kwargs.pop('prob_t2', 0.015) 

+

450 timeout = kwargs.pop('timeout', 30.) 

+

451 timeout_pause = kwargs.pop('timeout_pause', 3.) 

+

452 freq_lim = kwargs.pop('freq_lim', (1000., 20000.)) 

+

453 e1_inds = kwargs.pop('e1_inds', None) 

+

454 e2_inds = kwargs.pop('e2_inds', None) 

+

455 

+

456 e1_inds, e2_inds, FR_to_freq_fn, units, t1, t2, mid = rat_bmi_decoder.calc_decoder_from_baseline_file(neural_features, units, nsteps, prob_t1, prob_t2, timeout, 

+

457 timeout_pause, freq_lim, e1_inds, e2_inds) 

+

458 

+

459 task_params = dict(nsteps=nsteps, prob_t1=prob_t1, prob_t2=prob_t2, timeout_pause=timeout_pause, timeout=timeout, freq_lim=freq_lim, 

+

460 e1_inds=e1_inds, e2_inds=e2_inds, te_name=te.name, FR_to_freq_fn=FR_to_freq_fn, units=units, te_id=te_id, t1=t1, t2=t2, mid=mid, 

+

461 extractor_kwargs=extractor_kwargs) 

+

462 

+

463 return task_params 

+

464 

+

465def test_IsmoreSleepDecoder(te_id, e1_units, e2_units, nsteps=1, prob_t1 = 0.985, prob_t2 = 0.015, timeout = 15., 

+

466 timeout_pause=0., freq_lim = [-1, 1], targets_matrix=None, session_length=0, saturate_perc=90, 

+

467 skip_sim=False): 

+

468 

+

469 from db import dbfunctions as dbfn 

+

470 te = dbfn.TaskEntry(te_id) 

+

471 files = dict(hdf=te.hdf_filename, blackrock=te.blackrock_filenames) 

+

472 entry = te.id 

+

473 from . import extractor 

+

474 extractor_cls = extractor.BinnedSpikeCountsExtractor 

+

475 

+

476 units = np.vstack((e1_units, e2_units)) 

+

477 argsort = np.argsort(units[:, 0]) 

+

478 units = units[argsort, :] 

+

479 

+

480 unit_ids = np.hstack((['e1']*len(e1_units) + ['e2']*len(e2_units))) 

+

481 sorted_unit_ids = unit_ids[argsort] 

+

482 

+

483 e1_inds = np.nonzero(sorted_unit_ids=='e1')[0] 

+

484 e2_inds = np.nonzero(sorted_unit_ids=='e2')[0] 

+

485 

+

486 neural_features, units, extractor_kwargs = get_neural_features(files, 0.1, extractor_cls.extract_from_file, 

+

487 dict(), tslice=None, units=units) 

+

488 

+

489 neural_features_unbinned, units, extractor_kwargs = get_neural_features(files, 0.05, extractor_cls.extract_from_file, 

+

490 dict(), tslice=None, units=units) 

+

491 

+

492 import riglib.bmi.rat_bmi_decoder 

+

493 

+

494 kwargs = dict(targets_matrix=targets_matrix, session_length=session_length, 

+

495 saturate_perc=saturate_perc, skip_sim=skip_sim) 

+

496 

+

497 decoder, nrewards = riglib.bmi.rat_bmi_decoder.calc_decoder_from_baseline_file(neural_features, 

+

498 neural_features_unbinned, units, nsteps, prob_t1, prob_t2, timeout, timeout_pause, freq_lim, 

+

499 e1_inds, e2_inds, sim_fcn='ismore', **kwargs) 

+

500 

+

501 decoder.extractor_cls = extractor_cls 

+

502 decoder.extractor_kwargs = extractor_kwargs 

+

503 pickle.dump(decoder, open('/storage/decoders/sleep_from_te'+str(te_id)+'.pkl', 'wb')) 

+

504 from db.tracker import dbq 

+

505 dbq.save_bmi('sleep_from_te'+str(te_id), te_id, '/storage/decoders/sleep_from_te'+str(te_id)+'.pkl') 

+

506 return decoder, nrewards 

+

507 

+

508def create_ratBMIdecoder(task_params): 

+

509 from . import extractor 

+

510 task_params['extractor_cls'] = extractor.BinnedSpikeCountsExtractor 

+

511 from . import rat_bmi_decoder 

+

512 from . import state_space_models 

+

513 rat_decoder= rat_bmi_decoder.create_decoder(state_space_models.StateSpaceEndptPos1D(), task_params) 

+

514 rat_decoder.extractor_kwargs = task_params['extractor_kwargs'] 

+

515 import tempfile 

+

516 import pickle 

+

517 from db.tracker import dbq 

+

518 

+

519 rat_decoder.te_id = task_params['te_id'] 

+

520 tf = tempfile.NamedTemporaryFile('wb') 

+

521 pickle.dump(rat_decoder, tf, 2) 

+

522 tf.flush() 

+

523 

+

524 name = task_params['te_name'] + '_rat_bmi_decoder' 

+

525 dbq.save_bmi(name, int(task_params['te_id']), tf.name) 

+

526 

+

527 

+

528def add_fa_dict_to_decoder(decoder_training_te, dec_ix, fa_te): 

+

529 #First make sure we're training from the correct task entry: spike counts n_units == BMI units 

+

530 from db import dbfunctions as dbfn 

+

531 te = dbfn.TaskEntry(fa_te) 

+

532 hdf = te.hdf 

+

533 sc_n_units = hdf.root.task[0]['spike_counts'].shape[0] 

+

534 

+

535 

+

536 from db.tracker import models 

+

537 te_arr = models.Decoder.objects.filter(entry=decoder_training_te) 

+

538 search_flag = 1 

+

539 for te in te_arr: 

+

540 ix = te.path.find('_') 

+

541 if search_flag: 

+

542 if int(te.path[ix+1:ix+3]) == dec_ix: 

+

543 decoder_old = te 

+

544 search_flag = 0 

+

545 

+

546 if search_flag: 

+

547 raise Exception('No decoder from ', str(decoder_training_te), ' and matching index: ', str(dec_ix)) 

+

548 

+

549 from tasks.factor_analysis_tasks import FactorBMIBase 

+

550 FA_dict = FactorBMIBase.generate_FA_matrices(fa_te) 

+

551 

+

552 import pickle 

+

553 dec = pickle.load(open(decoder_old.filename)) 

+

554 dec.trained_fa_dict = FA_dict 

+

555 dec_n_units = dec.n_units 

+

556 

+

557 if dec_n_units != sc_n_units: 

+

558 raise Exception('Cant use TE for BMI training and FA training -- n_units mismatch') 

+

559 

+

560 from db import trainbmi 

+

561 trainbmi.save_new_decoder_from_existing(dec, decoder_old, suffix='_w_fa_dict_from_'+str(fa_te)) 

+

562 

+

563def train_FADecoder_from_KF(FA_nfactors, FA_te_id, decoder, use_scaled=True, use_main=True): 

+

564 

+

565 from tasks.factor_analysis_tasks import FactorBMIBase 

+

566 FA_dict = FactorBMIBase.generate_FA_matrices(FA_nfactors, FA_te_id) 

+

567 

+

568 # #Now, retrain:  

+

569 binlen = decoder.binlen 

+

570 

+

571 from db import dbfunctions as dbfn 

+

572 te_id = dbfn.TaskEntry(decoder.te_id) 

+

573 files = dict(plexon=te_id.plx_filename, hdf = te_id.hdf_filename) 

+

574 extractor_cls = decoder.extractor_cls 

+

575 extractor_kwargs = decoder.extractor_kwargs 

+

576 kin_extractor = get_plant_pos_vel 

+

577 ssm = decoder.ssm 

+

578 update_rate = decoder.binlen 

+

579 units = decoder.units 

+

580 tslice = (0., te_id.length) 

+

581 

+

582 ## get kinematic data 

+

583 kin_source = 'task' 

+

584 tmask, rows = _get_tmask(files, tslice, sys_name=kin_source) 

+

585 kin = kin_extractor(files, binlen, tmask, pos_key='cursor', vel_key=None) 

+

586 

+

587 ## get neural features 

+

588 neural_features, units, extractor_kwargs = get_neural_features(files, binlen, extractor_cls.extract_from_file, extractor_kwargs, tslice=tslice, units=units, source=kin_source) 

+

589 

+

590 #Get shared input:  

+

591 T = neural_features.shape[0] 

+

592 demean = neural_features.T - np.tile(FA_dict['fa_mu'], [1, T]) 

+

593 

+

594 if use_main: 

+

595 main_shar = (FA_dict['fa_main_shared'] * demean) 

+

596 main_priv = (demean - main_shar) 

+

597 FA = FA_dict['FA_model'] 

+

598 

+

599 else: 

+

600 shar = (FA_dict['fa_sharL']* demean) 

+

601 shar_sc = np.multiply(shar, np.tile(FA_dict['fa_shar_var_sc'], [1, T])) + np.tile(FA_dict['fa_mu'], [1, T]) 

+

602 shar_unsc = shar + np.tile(FA_dict['fa_mu'], [1, T]) 

+

603 if use_scaled: 

+

604 neural_features = shar_sc[:,:-1] 

+

605 else: 

+

606 neural_features = shar_unsc[:,:-1] 

+

607 

+

608 # Remove 1st kinematic sample and last neural features sample to align the  

+

609 # velocity with the neural features 

+

610 kin = kin[1:].T 

+

611 

+

612 decoder2 = train_KFDecoder_abstract(ssm, kin, neural_features, units, update_rate, tslice=tslice) 

+

613 decoder2.extractor_cls = extractor_cls 

+

614 decoder2.extractor_kwargs = extractor_kwargs 

+

615 decoder2.te_id = decoder.te_id 

+

616 decoder2.trained_fa_dict = FA_dict 

+

617 

+

618 import datetime 

+

619 now = datetime.datetime.now() 

+

620 tp = now.isoformat() 

+

621 import pickle 

+

622 fname = os.path.expandvars('$FA_GROM_DATA/decoder_')+tp+'.pkl' 

+

623 f = open(fname, 'w') 

+

624 pickle.dump(decoder2, f) 

+

625 f.close() 

+

626 return decoder2, fname 

+

627 

+

628def conv_KF_to_splitFA_dec(decoder_training_te, dec_ix, fa_te, search_suffix = 'w_fa_dict_from_', use_shar_z=False, tslice=None): 

+

629 

+

630 from db import dbfunctions as dbfn 

+

631 te = dbfn.TaskEntry(fa_te) 

+

632 hdf = te.hdf 

+

633 sc_n_units = hdf.root.task[0]['spike_counts'].shape[0] 

+

634 

+

635 from db.tracker import models 

+

636 te_arr = models.Decoder.objects.filter(entry=decoder_training_te) 

+

637 search_flag = 1 

+

638 for te in te_arr: 

+

639 ix = te.path.find('_') 

+

640 if search_flag: 

+

641 if int(te.path[ix+1:ix+3]) == dec_ix: 

+

642 decoder = pickle.load(open(te.filename)) 

+

643 if hasattr(decoder, 'trained_fa_dict'): 

+

644 ix = te.path.find('w_fa_dict_from_') 

+

645 if ix > 1: 

+

646 fa_te_train = te.path[ix+len(search_suffix):ix+len(search_suffix)+4] 

+

647 if int(fa_te_train) == fa_te: 

+

648 decoder_old = te 

+

649 #search_flag = 0 

+

650 

+

651 # if search_flag: 

+

652 # raise Exception('No decoder from ', str(decoder_training_te), ' and matching index: ', str(dec_ix), ' with FA training from: ',str(fa_te)) 

+

653 # else: 

+

654 print(('Using old decoder: ', decoder_old.path)) 

+

655 

+

656 decoder = pickle.load(open(decoder_old.filename)) 

+

657 if hasattr(decoder, 'trained_fa_dict'): 

+

658 FA_dict = decoder.trained_fa_dict 

+

659 else: 

+

660 raise Exception('Make an FA dict decoder first, then re-train that') 

+

661 

+

662 from db import dbfunctions as dbfn 

+

663 te_id = dbfn.TaskEntry(fa_te) 

+

664 

+

665 files = dict(plexon=te_id.plx_filename, hdf = te_id.hdf_filename) 

+

666 extractor_cls = decoder.extractor_cls 

+

667 extractor_kwargs = decoder.extractor_kwargs 

+

668 extractor_kwargs['discard_zero_units'] = False 

+

669 kin_extractor = get_plant_pos_vel 

+

670 ssm = decoder.ssm 

+

671 update_rate = binlen = decoder.binlen 

+

672 units = decoder.units 

+

673 if tslice is None: 

+

674 tslice = (0., te_id.length) 

+

675 

+

676 ## get kinematic data 

+

677 kin_source = 'task' 

+

678 tmask, rows =_get_tmask(files, tslice, sys_name=kin_source) 

+

679 kin = kin_extractor(files, binlen, tmask, pos_key='cursor', vel_key=None) 

+

680 

+

681 ## get neural features 

+

682 neural_features, units, extractor_kwargs = get_neural_features(files, binlen, extractor_cls.extract_from_file, extractor_kwargs, tslice=tslice, units=units, source=kin_source) 

+

683 

+

684 #Get main shared input: 

+

685 T = neural_features.shape[0] 

+

686 demean = neural_features.T - np.tile(FA_dict['fa_mu'], [1, T]) 

+

687 

+

688 #Neural features in time x spikes:  

+

689 FA = FA_dict['FA_model'] 

+

690 z = FA.transform(demean.T) 

+

691 z = z.T 

+

692 z = z[:FA_dict['fa_main_shar_n_dim'], :] 

+

693 

+

694 #z = FA_dict['u_svd'].T*FA_dict['uut_psi_inv']*demean 

+

695 

+

696 shar_z = FA_dict['fa_main_shared'] * demean 

+

697 priv = demean - shar_z 

+

698 

+

699 #Time by features: 

+

700 if use_shar_z: 

+

701 neural_features2 = np.vstack((z, priv)) 

+

702 suffx = '_split_shar_z' 

+

703 else: 

+

704 neural_features2 = np.vstack((z, priv)) 

+

705 suffx = '_split_z' 

+

706 decoder_split = train_KFDecoder_abstract(ssm, kin.T, neural_features2, units, update_rate, tslice=tslice) 

+

707 decoder_split.n_features = len(units) 

+

708 decoder_split.trained_fa_dict = FA_dict 

+

709 

+

710 decoder_split.extractor_cls = extractor_cls 

+

711 decoder_split.extractor_kwargs = extractor_kwargs 

+

712 

+

713 from db import trainbmi 

+

714 trainbmi.save_new_decoder_from_existing(decoder_split, decoder_old, suffix=suffx) 

+

715 

+

716def train_KFDecoder(files, extractor_cls, extractor_kwargs, kin_extractor, ssm, units, update_rate=0.1, tslice=None, 

+

717 kin_source='task', pos_key='cursor', vel_key=None, zscore=False, filter_kin=True, simple_lin_reg=False, 

+

718 use_data_kwargs=None, **kwargs): 

+

719 ''' 

+

720 Create a new KFDecoder using maximum-likelihood, from kinematic observations and neural observations 

+

721 

+

722 Parameters 

+

723 ----------  

+

724 files : dict 

+

725 Dictionary of files which contain training data. Keys are file tyes, values are file names. 

+

726 Kinematic data is assumed to be stored in an 'hdf' file and neural data assumed to be in 'plx' or 'nev' files 

+

727 extractor_cls : class 

+

728 Class of feature extractor to instantiate 

+

729 extractor_kwargs : dict  

+

730 Parameters to specify for feature extractor to instantiate it to specification 

+

731 kin_extractor : callable 

+

732 Function to extract kinematics from the HDF file. 

+

733 ssm : state_space_models.StateSpace instance 

+

734 State space model for the Decoder object being created. 

+

735 units : np.iterable  

+

736 Spiking units are specified as tuples of (electrode channe, electrode unit) 

+

737 update_rate : float, optional 

+

738 Time in seconds between decoder updates. default=0.1 

+

739 tslice : iterable of length 2, optional 

+

740 Start and end times in seconds to specify the portion of the training data to use for ML estimation. By default, the whole dataset will be used 

+

741 kin_source : string, optional 

+

742 Table from the HDF file to grab kinematic data. Default is the 'task' table. 

+

743 pos_key : string, optional 

+

744 Column of HDF table to use for position data. Default is 'cursor', recognized options are {'cursor', 'joint_angles', 'plant_pos'} 

+

745 vel_key : string 

+

746 Column of HDF table to use for velocity data. Default is None; velocity is computed by single-step numerical differencing (or alternate method ) 

+

747 zscore : Bool  

+

748 Determines whether to zscore neural_data or not 

+

749 kwargs: 

+

750 mFR: mean firing rate to use to zscore units 

+

751 sdFR: standard dev. to use to zscore units 

+

752 

+

753 Returns 

+

754 ------- 

+

755 KFDecoder instance 

+

756 ''' 

+

757 import sys 

+

758 print(files) 

+

759 # sys.stdout.write(files) 

+

760 # sys.stdout.write(extractor_cls) 

+

761 # sys.stdout.write(extractor_kwargs.keys()) 

+

762 # sys.stdout.write(units) 

+

763 binlen = update_rate 

+

764 

+

765 from config import config 

+

766 

+

767 ## get kinematic data 

+

768 tmask, rows = _get_tmask(files, tslice, sys_name=kin_source) 

+

769 kin = kin_extractor(files, binlen, tmask, pos_key=pos_key, vel_key=vel_key, update_rate_hz=config.hdf_update_rate_hz) 

+

770 

+

771 ## get neural features 

+

772 if 'blackrock' in list(files.keys()): 

+

773 strobe_rate = 20. 

+

774 else: 

+

775 strobe_rate = 60. 

+

776 

+

777 neural_features, units, extractor_kwargs = get_neural_features(files, binlen, extractor_cls.extract_from_file, 

+

778 extractor_kwargs, tslice=tslice, units=units, source=kin_source, strobe_rate=strobe_rate) 

+

779 

+

780 # Remove 1st kinematic sample and last neural features sample to align the  

+

781 # velocity with the neural features 

+

782 kin = kin[1:].T 

+

783 neural_features = neural_features[:-1].T 

+

784 

+

785 if filter_kin: 

+

786 filts = get_filterbank(fs=1./update_rate) 

+

787 kin_filt = np.zeros_like(kin) 

+

788 for chan in range(14): 

+

789 for filt in filts[chan]: 

+

790 kin_filt[chan, :] = filt(kin[chan, :]) 

+

791 else: 

+

792 kin_filt = kin.copy() 

+

793 

+

794 if simple_lin_reg: 

+

795 from sklearn.linear_model import Ridge 

+

796 decoder = Ridge(1000.0, fit_intercept=True, normalize=False) 

+

797 

+

798 if use_data_kwargs is not None: 

+

799 

+

800 # HDF rows to use in training 

+

801 X = [] 

+

802 Y = [] 

+

803 

+

804 for pair in use_data_kwargs['pairs']: 

+

805 X.append(neural_features[:, pair[0]]) 

+

806 Y.append(kin_filt[:, pair[1]]) 

+

807 

+

808 # Convert these hdf rows to  

+

809 decoder.fit(np.vstack((X)), np.vstack((Y))) 

+

810 

+

811 else: 

+

812 decoder = train_KFDecoder_abstract(ssm, kin_filt, neural_features, units, update_rate, tslice=tslice, zscore=zscore, **kwargs) 

+

813 decoder.extractor_cls = extractor_cls 

+

814 decoder.extractor_kwargs = extractor_kwargs 

+

815 

+

816 return decoder, neural_features, kin_filt 

+

817 

+

818def get_filterbank(n_channels=14, fs=1000.): 

+

819 from ismore.filter import Filter 

+

820 from scipy.signal import butter 

+

821 band = [.001, 1] # Hz 

+

822 nyq = 0.5 * fs 

+

823 low = band[0] / nyq 

+

824 high = band[1] / nyq 

+

825 high = np.min([high, 0.99]) 

+

826 bpf_coeffs = butter(4, [low, high], btype='band') 

+

827 

+

828 channel_filterbank = [None]*n_channels 

+

829 for k in range(n_channels): 

+

830 filts = [Filter(bpf_coeffs[0], bpf_coeffs[1])] 

+

831 channel_filterbank[k] = filts 

+

832 return channel_filterbank 

+

833 

+

834 

+

835def train_KFDecoderDrift(files, extractor_cls, extractor_kwargs, kin_extractor, ssm, units, update_rate=0.1, tslice=None, 

+

836 kin_source='task', pos_key='cursor', vel_key=None, zscore=False, **kwargs): 

+

837 ''' 

+

838 Create a new KFDecoder using maximum-likelihood, from kinematic observations and neural observations 

+

839 

+

840 Parameters 

+

841 ----------  

+

842 files : dict 

+

843 Dictionary of files which contain training data. Keys are file tyes, values are file names. 

+

844 Kinematic data is assumed to be stored in an 'hdf' file and neural data assumed to be in 'plx' or 'nev' files 

+

845 extractor_cls : class 

+

846 Class of feature extractor to instantiate 

+

847 extractor_kwargs : dict  

+

848 Parameters to specify for feature extractor to instantiate it to specification 

+

849 kin_extractor : callable 

+

850 Function to extract kinematics from the HDF file. 

+

851 ssm : state_space_models.StateSpace instance 

+

852 State space model for the Decoder object being created. 

+

853 units : np.iterable  

+

854 Spiking units are specified as tuples of (electrode channe, electrode unit) 

+

855 update_rate : float, optional 

+

856 Time in seconds between decoder updates. default=0.1 

+

857 tslice : iterable of length 2, optional 

+

858 Start and end times in seconds to specify the portion of the training data to use for ML estimation. By default, the whole dataset will be used 

+

859 kin_source : string, optional 

+

860 Table from the HDF file to grab kinematic data. Default is the 'task' table. 

+

861 pos_key : string, optional 

+

862 Column of HDF table to use for position data. Default is 'cursor', recognized options are {'cursor', 'joint_angles', 'plant_pos'} 

+

863 vel_key : string 

+

864 Column of HDF table to use for velocity data. Default is None; velocity is computed by single-step numerical differencing (or alternate method ) 

+

865 zscore : Bool  

+

866 Determines whether to zscore neural_data or not 

+

867 kwargs: 

+

868 mFR: mean firing rate to use to zscore units 

+

869 sdFR: standard dev. to use to zscore units 

+

870 

+

871 Returns 

+

872 ------- 

+

873 KFDecoder instance 

+

874 ''' 

+

875 import sys 

+

876 print(files) 

+

877 # sys.stdout.write(files) 

+

878 # sys.stdout.write(extractor_cls) 

+

879 # sys.stdout.write(extractor_kwargs.keys()) 

+

880 # sys.stdout.write(units) 

+

881 binlen = update_rate 

+

882 

+

883 from config import config 

+

884 

+

885 ## get kinematic data 

+

886 tmask, rows = _get_tmask(files, tslice, sys_name=kin_source) 

+

887 kin = kin_extractor(files, binlen, tmask, pos_key=pos_key, vel_key=vel_key, update_rate_hz=config.hdf_update_rate_hz) 

+

888 

+

889 ## get neural features 

+

890 neural_features, units, extractor_kwargs = get_neural_features(files, binlen, extractor_cls.extract_from_file, 

+

891 extractor_kwargs, tslice=tslice, units=units, source=kin_source) 

+

892 

+

893 # Remove 1st kinematic sample and last neural features sample to align the  

+

894 # velocity with the neural features 

+

895 kin = kin[1:].T 

+

896 neural_features = neural_features[:-1].T 

+

897 

+

898 kwargs['driftKF'] = True 

+

899 decoder = train_KFDecoder_abstract(ssm, kin, neural_features, units, update_rate, 

+

900 tslice=tslice, zscore=zscore, **kwargs) 

+

901 

+

902 decoder.extractor_cls = extractor_cls 

+

903 decoder.extractor_kwargs = extractor_kwargs 

+

904 

+

905 return decoder 

+

906 

+

907def train_KFDecoder_abstract(ssm, kin, neural_features, units, update_rate, tslice=None, regularizer=0., 

+

908 zscore=False, **kwargs): 

+

909 print(kwargs) 

+

910 print('end of kwargs') 

+

911 

+

912 #### Train the actual KF decoder matrices #### 

+

913 if type(zscore) is bool: 

+

914 pass 

+

915 else: 

+

916 if zscore == 'True': 

+

917 zscore = True 

+

918 elif zscore == 'False': 

+

919 zscore = False 

+

920 else: 

+

921 raise Exception 

+

922 

+

923 print(('zscore value: ', zscore, type(zscore))) 

+

924 

+

925 if zscore: 

+

926 if 'mFR' in kwargs and 'sdFR' in kwargs: 

+

927 print('using kwargs mFR, sdFR to zscore') 

+

928 mFR = kwargs['mFR'] 

+

929 sdFR = kwargs['sdFR'] 

+

930 else: 

+

931 print('computing own mFR, sdFR to zscore') 

+

932 mFR = np.mean(neural_features, axis=1) 

+

933 sdFR = np.std(neural_features, axis=1) 

+

934 if hasattr(kwargs, 'zscore_set_std_to_one'): 

+

935 sdFR = np.ones_like(mFR) 

+

936 neural_features = (neural_features - mFR[:, np.newaxis])*(1./sdFR[:, np.newaxis]) 

+

937 

+

938 else: 

+

939 mFR = np.squeeze(np.mean(neural_features, axis=1)) 

+

940 sdFR = np.squeeze(np.std(neural_features, axis=1)) 

+

941 

+

942 if 'noise_rej' in kwargs: 

+

943 if kwargs['noise_rej']: 

+

944 sum_pop = np.sum(neural_features, axis = 0) 

+

945 bins_noisy = np.nonzero(sum_pop > kwargs['noise_rej_cutoff'])[0] 

+

946 print(('replacing %d noisy bins of total %d bins w/ mFR for decoder training!' % (len(bins_noisy), len(sum_pop)))) 

+

947 neural_features[:, bins_noisy] = mFR[:, np.newaxis] 

+

948 else: 

+

949 kwargs['noise_rej'] = False 

+

950 kwargs['noise_rej_cutoff'] = -1. 

+

951 

+

952 n_features = len(mFR) 

+

953 

+

954 # C should be trained on all of the stochastic state variables, excluding the offset terms 

+

955 C = np.zeros((n_features, ssm.n_states)) 

+

956 C[:, ssm.drives_obs_inds], Q = kfdecoder.KalmanFilter.MLE_obs_model(kin[ssm.train_inds, :], neural_features, regularizer=regularizer) 

+

957 

+

958 

+

959 # Set state space model 

+

960 A, B, W = ssm.get_ssm_matrices(update_rate=update_rate) 

+

961 

+

962 # instantiate KFdecoder 

+

963 driftKF = kwargs.pop('driftKF', False) 

+

964 if driftKF: 

+

965 print(('Training Drift Decoder. Noise Rejection? ', kwargs['noise_rej'])) 

+

966 kf = kfdecoder.KalmanFilterDriftCorrection(A, W, C, Q, is_stochastic=ssm.is_stochastic) 

+

967 else: 

+

968 kf = kfdecoder.KalmanFilter(A, W, C, Q, is_stochastic=ssm.is_stochastic) 

+

969 

+

970 decoder = kfdecoder.KFDecoder(kf, units, ssm, binlen=update_rate, tslice=tslice) 

+

971 

+

972 if zscore: 

+

973 decoder.init_zscore(mFR, sdFR) 

+

974 print('zscore init') 

+

975 else: 

+

976 print('no init_zscore') 

+

977 

+

978 

+

979 # Compute sufficient stats for C and Q matrices (used for RML CLDA) 

+

980 from .clda import KFRML 

+

981 n_features, n_states = C.shape 

+

982 R = np.mat(np.zeros([n_states, n_states])) 

+

983 S = np.mat(np.zeros([n_features, n_states])) 

+

984 R_small, S_small, T, ESS = KFRML.compute_suff_stats(kin[ssm.train_inds, :], neural_features) 

+

985 

+

986 R[np.ix_(ssm.drives_obs_inds, ssm.drives_obs_inds)] = R_small 

+

987 S[:,ssm.drives_obs_inds] = S_small 

+

988 

+

989 decoder.filt.R = R 

+

990 decoder.filt.S = S 

+

991 decoder.filt.T = T 

+

992 decoder.filt.ESS = ESS 

+

993 decoder.n_features = n_features 

+

994 

+

995 decoder.filt.noise_rej = kwargs['noise_rej'] 

+

996 decoder.filt.noise_rej_cutoff = kwargs['noise_rej_cutoff'] 

+

997 decoder.filt.noise_rej_mFR = mFR 

+

998 # decoder.extractor_cls = extractor_cls 

+

999 # decoder.extractor_kwargs = extractor_kwargs 

+

1000 

+

1001 return decoder 

+

1002 

+

1003def train_PPFDecoder(files, extractor_cls, extractor_kwargs, kin_extractor, ssm, units, update_rate=0.1, tslice=None, kin_source='task', 

+

1004 pos_key='cursor', vel_key=None, zscore=False): 

+

1005 ''' 

+

1006 Create a new PPFDecoder using maximum-likelihood, from kinematic observations and neural observations 

+

1007 

+

1008 Parameters 

+

1009 ----------  

+

1010 files : dict 

+

1011 Dictionary of files which contain training data. Keys are file tyes, values are file names. 

+

1012 Kinematic data is assumed to be stored in an 'hdf' file and neural data assumed to be in 'plx' or 'nev' files 

+

1013 extractor_cls : class 

+

1014 Class of feature extractor to instantiate 

+

1015 extractor_kwargs : dict  

+

1016 Parameters to specify for feature extractor to instantiate it to specification 

+

1017 kin_extractor : callable 

+

1018 Function to extract kinematics from the HDF file. 

+

1019 ssm : state_space_models.StateSpace instance 

+

1020 State space model for the Decoder object being created. 

+

1021 units : np.iterable  

+

1022 Spiking units are specified as tuples of (electrode channe, electrode unit) 

+

1023 update_rate : float, optional 

+

1024 Time in seconds between decoder updates. default=0.1 

+

1025 tslice : iterable of length 2, optional 

+

1026 Start and end times in seconds to specify the portion of the training data to use for ML estimation. By default, the whole dataset will be used 

+

1027 kin_source : string, optional 

+

1028 Table from the HDF file to grab kinematic data. Default is the 'task' table. 

+

1029 pos_key : string, optional 

+

1030 Column of HDF table to use for position data. Default is 'cursor', recognized options are {'cursor', 'joint_angles', 'plant_pos'} 

+

1031 vel_key : string 

+

1032 Column of HDF table to use for velocity data. Default is None; velocity is computed by single-step numerical differencing (or alternate method ) 

+

1033 

+

1034 Returns 

+

1035 ------- 

+

1036 PPFDecoder instance 

+

1037 ''' 

+

1038 binlen = 1./180 #update_rate 

+

1039 

+

1040 ## get kinematic data 

+

1041 tmask, rows = _get_tmask(files, tslice, sys_name=kin_source) 

+

1042 kin = kin_extractor(files, binlen, tmask, pos_key=pos_key, vel_key=vel_key) 

+

1043 

+

1044 ## get neural features 

+

1045 neural_features, units, extractor_kwargs = get_neural_features(files, binlen, extractor_cls.extract_from_file, extractor_kwargs, tslice=tslice, units=units, source=kin_source) 

+

1046 

+

1047 # Remove 1st kinematic sample and last neural features sample to align the  

+

1048 # velocity with the neural features 

+

1049 kin = kin[1:].T 

+

1050 neural_features = neural_features[:-1].T 

+

1051 

+

1052 decoder = train_PPFDecoder_abstract(ssm, kin, neural_features, units, update_rate, tslice=tslice) 

+

1053 

+

1054 decoder.extractor_cls = extractor_cls 

+

1055 decoder.extractor_kwargs = extractor_kwargs 

+

1056 

+

1057 return decoder 

+

1058 

+

1059def train_PPFDecoder_abstract(ssm, kin, neural_features, units, update_rate, tslice=None): 

+

1060 binlen = 1./180 #update_rate 

+

1061 # squash any spike counts greater than 1 (doesn't work with PPF model) 

+

1062 neural_features[neural_features > 1] = 1 

+

1063 

+

1064 #### Train the PPF decoder matrices #### 

+

1065 n_features = neural_features.shape[0] # number of neural features 

+

1066 

+

1067 # C should be trained on all of the stochastic state variables, excluding the offset terms 

+

1068 C = np.zeros([n_features, ssm.n_states]) 

+

1069 C[:, ssm.drives_obs_inds], pvals = ppfdecoder.PointProcessFilter.MLE_obs_model(kin[ssm.train_inds, :], neural_features) 

+

1070 

+

1071 # Set state space model 

+

1072 A, B, W = ssm.get_ssm_matrices(update_rate=update_rate) 

+

1073 

+

1074 # instantiate Decoder 

+

1075 ppf = ppfdecoder.PointProcessFilter(A, W, C, B=B, dt=update_rate, is_stochastic=ssm.is_stochastic) 

+

1076 decoder = ppfdecoder.PPFDecoder(ppf, units, ssm, binlen=binlen, tslice=tslice) 

+

1077 

+

1078 # Compute sufficient stats for C matrix (used for RML CLDA) 

+

1079 from .clda import KFRML 

+

1080 n_features, n_states = C.shape 

+

1081 S = np.mat(np.zeros([n_features, n_states])) 

+

1082 S_small, = decoder.compute_suff_stats(kin[ssm.train_inds, :], neural_features) 

+

1083 

+

1084 S[:,ssm.drives_obs_inds] = S_small 

+

1085 

+

1086 decoder.filt.S = S 

+

1087 decoder.n_features = n_features 

+

1088 

+

1089 return decoder 

+

1090 

+

1091################### 

+

1092## Helper functions 

+

1093################### 

+

1094def unit_conv(starting_unit, ending_unit): 

+

1095 '''  

+

1096 Convert between units, e.g. cm to m 

+

1097 Lookup table for conversion factors between units; this function exists 

+

1098 only to avoid hard-coded constants in most of the code 

+

1099 

+

1100 Parameters 

+

1101 ---------- 

+

1102 starting_unit : string 

+

1103 Name of current unit for the quantity, e.g., 'cm' 

+

1104 ending_unit : string 

+

1105 Name of desired unit for the quantity, e.g., 'm' 

+

1106 

+

1107 Returns 

+

1108 ------- 

+

1109 float 

+

1110 Multiplicative scale factor to convert a scalar in the 'starting_unit' to the 'ending_unit' 

+

1111 ''' 

+

1112 

+

1113 if starting_unit == ending_unit: 

+

1114 return 1 

+

1115 elif (starting_unit, ending_unit) == ('cm', 'm'): 

+

1116 return 0.01 

+

1117 elif (starting_unit, ending_unit) == ('m', 'cm'): 

+

1118 return 100 

+

1119 else: 

+

1120 raise ValueError("Unrecognized starting/ending unit") 

+

1121 

+

1122def lookup_cells(cells): 

+

1123 '''  

+

1124 Convert string names of units to 'machine' format. 

+

1125 Take a list of neural units specified as a list of strings and convert  

+

1126 to the 2D array format used to specify neural units to train decoders 

+

1127 

+

1128 Parameters 

+

1129 ---------- 

+

1130 cells : string 

+

1131 String of cell names to parse, e.g., '1a, 2b' 

+

1132 

+

1133 Returns 

+

1134 ------- 

+

1135 list of 2-tuples 

+

1136 Each element of the list is a tuple of (channel, unit), e.g., [(1, 1), (2, 2)] 

+

1137 ''' 

+

1138 cellname = re.compile(r'(\d{1,3})\s*(\w{1})') 

+

1139 cells = [ (int(c), ord(u) - 96) for c, u in cellname.findall(cells)] 

+

1140 return cells 

+

1141 

+

1142def inflate(A, current_states, full_state_ls, axis=0): 

+

1143 ''' 

+

1144 'Inflate' a matrix by filling in rows/columns with zeros 

+

1145 

+

1146 Docstring 

+

1147 

+

1148 Parameters 

+

1149 ---------- 

+

1150 

+

1151 Returns 

+

1152 ------- 

+

1153 

+

1154 ''' 

+

1155 try: 

+

1156 nS = len(full_state_ls) 

+

1157 except: 

+

1158 nS = full_state_ls.n_states 

+

1159 

+

1160 if axis == 0: 

+

1161 A_new = np.zeros([nS, A.shape[1]]) 

+

1162 elif axis == 1: 

+

1163 A_new = np.zeros([A.shape[0], nS]) 

+

1164 

+

1165 try: 

+

1166 new_inds = [full_state_ls.index(x) for x in current_states] 

+

1167 except: 

+

1168 new_inds = [full_state_ls.state_names.index(x) for x in current_states] 

+

1169 if axis == 0: 

+

1170 A_new[new_inds, :] = A 

+

1171 elif axis == 1: 

+

1172 A_new[:, new_inds] = A 

+

1173 

+

1174 return A_new 

+

1175 

+

1176 

+

1177####################### 

+

1178## Simulation functions 

+

1179####################### 

+

1180def _train_PPFDecoder_2D_sim(stochastic_states, neuron_driving_states, units, 

+

1181 bounding_box, states_to_bound, include_y=True, dt=0.1, v=0.4): 

+

1182 ''' 

+

1183 Train a simulation PPFDecoder 

+

1184 

+

1185 Docstring 

+

1186 

+

1187 Parameters 

+

1188 ---------- 

+

1189 

+

1190 Returns 

+

1191 ------- 

+

1192 ''' 

+

1193 raise NotImplementedError 

+

1194 

+

1195def rand_KFDecoder(ssm, units, dt=0.1): 

+

1196 ''' 

+

1197 Make a KFDecoder with the observation model initialized randomly 

+

1198 

+

1199 Parameters 

+

1200 ---------- 

+

1201 ssm : state_space_models.StateSpace instance 

+

1202 State-space model for the KFDecoder. Should specify the A and W matrices 

+

1203 units : np.array of shape (N, 2) 

+

1204 Unit labels to assign to each row of the C matrix 

+

1205 

+

1206 Returns 

+

1207 ------- 

+

1208 KFDecoder instance 

+

1209 ''' 

+

1210 n_neurons = units.shape[0] 

+

1211 binlen = dt 

+

1212 

+

1213 A, B, W = ssm.get_ssm_matrices(update_rate=dt) 

+

1214 drives_neurons = ssm.drives_obs 

+

1215 is_stochastic = ssm.is_stochastic 

+

1216 nX = ssm.n_states 

+

1217 

+

1218 C = np.random.standard_normal([n_neurons, nX]) 

+

1219 C[:, ~drives_neurons] = 0 

+

1220 Q = 10 * np.identity(n_neurons) 

+

1221 

+

1222 kf = kfdecoder.KalmanFilter(A, W, C, Q, is_stochastic=is_stochastic) 

+

1223 

+

1224 mFR = 0 

+

1225 sdFR = 1 

+

1226 decoder = kfdecoder.KFDecoder(kf, units, ssm, mFR=mFR, sdFR=sdFR, binlen=binlen) 

+

1227 

+

1228 decoder.kf.R = np.mat(np.identity(decoder.kf.C.shape[1])) 

+

1229 decoder.kf.S = decoder.kf.C 

+

1230 decoder.kf.T = decoder.kf.Q + decoder.kf.S*decoder.kf.S.T 

+

1231 decoder.kf.ESS = 3000. 

+

1232 

+

1233 decoder.ssm = ssm 

+

1234 decoder.n_features = n_neurons 

+

1235 

+

1236 # decoder.bounder = make_rect_bounder_from_ssm(ssm) 

+

1237 

+

1238 return decoder 

+

1239 

+

1240_train_KFDecoder_2D_sim_2 = rand_KFDecoder 

+

1241 

+

1242def load_from_mat_file(decoder_fname, bounding_box=None, 

+

1243 states=['p_x', 'p_y', 'v_x', 'v_y', 'off'], states_to_bound=[]): 

+

1244 """ 

+

1245 Create KFDecoder from MATLAB decoder file used in a Dexterit-based 

+

1246 BMI 

+

1247 

+

1248 Docstring 

+

1249 

+

1250 Parameters 

+

1251 ---------- 

+

1252 

+

1253 Returns 

+

1254 ------- 

+

1255 """ 

+

1256 decoder_data = loadmat(decoder_fname)['decoder'] 

+

1257 A = decoder_data['A'][0,0] 

+

1258 W = decoder_data['W'][0,0] 

+

1259 H = decoder_data['H'][0,0] 

+

1260 Q = decoder_data['Q'][0,0] 

+

1261 mFR = decoder_data['mFR'][0,0].ravel() 

+

1262 sdFR = decoder_data['sdFR'][0,0].ravel() 

+

1263 

+

1264 pred_sigs = [str(x[0]) for x in decoder_data['predSig'][0,0].ravel()] 

+

1265 unit_lut = {'a':1, 'b':2, 'c':3, 'd':4} 

+

1266 units = [(int(sig[3:6]), unit_lut[sig[-1]]) for sig in pred_sigs] 

+

1267 

+

1268 drives_neurons = np.array([False, False, True, True, True]) 

+

1269 

+

1270 kf = kfdecoder.KalmanFilter(A, W, H, Q) 

+

1271 dec = kfdecoder.KFDecoder(kf, mFR, sdFR, units, bounding_box, states, drives_neurons, states_to_bound) 

+

1272 

+

1273 # Load bounder for position state 

+

1274 from state_bounders import RectangularBounder 

+

1275 bounding_box_data = loadmat('/Users/sgowda/bmi/workspace/decoder_switching/jeev_center_out_bmi_targets.mat') 

+

1276 center_pos = bounding_box_data['centerPos'].ravel() 

+

1277 px_min, py_min = center_pos - 0.09 

+

1278 px_max, py_max = center_pos + 0.09 

+

1279 bounding_box = [(px_min, px_max), (py_min, py_max)] 

+

1280 bounder = RectangularBounder([px_min, py_min], [px_max, py_max], ['p_x', 'p_y']) 

+

1281 dec.bounder = bounder 

+

1282 

+

1283 return dec 

+

1284 

+

1285def rescale_KFDecoder_units(dec, scale_factor=10): 

+

1286 ''' 

+

1287 Convert the units of a KFDecoder, e.g. from mm to cm 

+

1288 

+

1289 C and W matrices of KalmanFilter must be updated for the new units.  

+

1290 A and Q are unitless and thus remain the same 

+

1291 

+

1292 Parameters 

+

1293 ---------- 

+

1294 dec : KFDecoder instance 

+

1295 KFDecoder object 

+

1296 scale_factor : numerical 

+

1297 defines how much bigger the new unit is than the old one 

+

1298 ''' 

+

1299 inds = np.nonzero((np.diag(dec.kf.W) > 0) * dec.drives_neurons)[0] 

+

1300 nS = dec.kf.W.shape[0] 

+

1301 S_diag = np.ones(nS) 

+

1302 S_diag[inds] = scale_factor 

+

1303 S = np.mat(np.diag(S_diag)) 

+

1304 #S = np.mat(np.diag([1., 1, 1, 10, 10, 10, 1])) 

+

1305 dec.kf.C *= S 

+

1306 dec.kf.W *= S.I * S.I 

+

1307 try: 

+

1308 dec.kf.C_xpose_Q_inv_C = S.T * dec.kf.C_xpose_Q_inv_C * S 

+

1309 dec.kf.C_xpose_Q_inv = S.T * dec.kf.C_xpose_Q_inv 

+

1310 except: 

+

1311 pass 

+

1312 dec.bounding_box = tuple([x / scale_factor for x in dec.bounding_box]) 

+

1313 return dec 

+

1314 

+

1315def _train_PPFDecoder_sim_known_beta(beta, units, dt=0.005, dist_units='m'): 

+

1316 ''' 

+

1317 Create a PPFDecoder object to decode 2D velocity from a known 'beta' matrix 

+

1318 

+

1319 Docstring 

+

1320 

+

1321 Parameters 

+

1322 ---------- 

+

1323 

+

1324 Returns 

+

1325 ------- 

+

1326 ''' 

+

1327 units_mult_lut = dict(m=1., cm=0.01) 

+

1328 units_mult = units_mult_lut[dist_units] 

+

1329 

+

1330 ssm = state_space_models.StateSpaceEndptVel2D() 

+

1331 A, _, W = ssm.get_ssm_matrices(update_rate=dt) 

+

1332 

+

1333 # rescale beta for units 

+

1334 beta[:,3:6] *= units_mult 

+

1335 

+

1336 # Control input matrix for SSM for control inputs 

+

1337 I = np.mat(np.eye(3)) 

+

1338 B = np.vstack([0*I, dt*1000 * I, np.zeros([1,3])]) 

+

1339 

+

1340 # instantiate Decoder 

+

1341 ppf = ppfdecoder.PointProcessFilter(A, W, beta, dt=dt, is_stochastic=ssm.is_stochastic, B=B) 

+

1342 dec = ppfdecoder.PPFDecoder(ppf, units, ssm, binlen=dt) 

+

1343 

+

1344 n_stoch_states = len(np.nonzero(ssm.drives_obs)[0]) 

+

1345 n_units = len(units) 

+

1346 dec.H = np.dstack([np.eye(3)*100] * n_units).transpose(2, 0, 1) 

+

1347 dec.M = np.mat(np.ones([n_units, n_stoch_states])) * np.exp(-1.6) 

+

1348 dec.S = np.mat(np.ones([n_units, n_stoch_states])) * np.exp(-1.6) 

+

1349 

+

1350 # Force decoder to run at max 60 Hz 

+

1351 dec.bminum = 0 

+

1352 return dec 

+

1353 

+

1354def load_PPFDecoder_from_mat_file(fname, state_units='cm'): 

+

1355 ''' 

+

1356 Docstring 

+

1357 

+

1358 Parameters 

+

1359 ---------- 

+

1360 

+

1361 Returns 

+

1362 ------- 

+

1363 ''' 

+

1364 data = loadmat(fname) 

+

1365 a = data['A'][2,2] 

+

1366 w = data['W'][0,0] 

+

1367 

+

1368 if 'T_loop' in data: 

+

1369 dt = data['T_loop'][0,0] 

+

1370 else: 

+

1371 dt = 0.005 

+

1372 

+

1373 spike_rate_dt = 0.001 # This is hardcoded b/c the value in the MATLAB file is probably wrong. 

+

1374 A = state_space_models._gen_A(1, dt, 0, a, 1, ndim=3) 

+

1375 W = state_space_models._gen_A(0, 0, 0, w, 0, ndim=3) 

+

1376 

+

1377 if 'beta_hat' in data: 

+

1378 beta = data['beta_hat'][:,:,0] 

+

1379 else: 

+

1380 beta = data['beta'] 

+

1381 

+

1382 beta = ppfdecoder.PointProcessFilter.frommlab(beta) 

+

1383 beta[:,:-1] /= unit_conv('m', state_units) 

+

1384 #beta_full = inflate(beta, states_explaining_neural_activity_2D_vel_decoding, states_3D_endpt, axis=1) 

+

1385 #states = states_3D_endpt#['hand_px', 'hand_py', 'hand_pz', 'hand_vx', 'hand_vy', 'hand_vz', 'offset'] 

+

1386 #states = ['hand_px', 'hand_py', 'hand_pz', 'hand_vx', 'hand_vy', 'hand_vz', 'offset'] 

+

1387 states = state_space_models.StateSpaceEndptVel2D() 

+

1388 neuron_driving_states = ['hand_vx', 'hand_vz', 'offset'] 

+

1389 beta_full = inflate(beta, neuron_driving_states, states, axis=1) 

+

1390 

+

1391 stochastic_states = ['hand_vx', 'hand_vz'] 

+

1392 try: 

+

1393 is_stochastic = [x in stochastic_states for x in states] 

+

1394 except: 

+

1395 is_stochastic = [x in stochastic_states for x in states.state_names] 

+

1396 

+

1397 unit_names = [str(x[0]) for x in data['decoder']['predSig'][0,0][0]] 

+

1398 units = [(int(x[3:6]), ord(x[-1]) - (ord('a') - 1)) for x in unit_names] 

+

1399 units = np.vstack(units) 

+

1400 

+

1401 ppf = ppfdecoder.PointProcessFilter(A, W, beta_full, dt, is_stochastic=is_stochastic) 

+

1402 ppf.spike_rate_dt = spike_rate_dt 

+

1403 

+

1404 try: 

+

1405 drives_neurons = np.array([x in neuron_driving_states for x in states]) 

+

1406 except: 

+

1407 drives_neurons = np.array([x in neuron_driving_states for x in states.state_names]) 

+

1408 dec = ppfdecoder.PPFDecoder(ppf, units, states, binlen=dt) 

+

1409 

+

1410 if state_units == 'cm': 

+

1411 dec.filt.W[3:6, 3:6] *= unit_conv('m', state_units)**2 

+

1412 return dec 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_button_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_button_py.html new file mode 100644 index 00000000..a4ce6e00 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_button_py.html @@ -0,0 +1,166 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\button.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Interact with a peripheral 'button' device hooked up to an FTDI chip 

+

3''' 

+

4 

+

5import threading 

+

6import queue 

+

7import ftdi 

+

8import time 

+

9 

+

10class Button(threading.Thread): 

+

11 ''' Docstring ''' 

+

12 def __init__(self): 

+

13 ''' 

+

14 Docstring 

+

15 

+

16 Parameters 

+

17 ---------- 

+

18 

+

19 Returns 

+

20 ------- 

+

21 ''' 

+

22 super(Button, self).__init__() 

+

23 self.port = None 

+

24 port = ftdi.ftdi_new() 

+

25 usb_open = ftdi.ftdi_usb_open_string(port, "s:0x403:0x6001:2eb80091") 

+

26 assert usb_open == 0, ftdi.ftdi_get_error_string(port) 

+

27 

+

28 ftdi.ftdi_set_bitmode(port, 0xFF, ftdi.BITMODE_BITBANG) 

+

29 self.port = port 

+

30 self.queue = queue.Queue() 

+

31 self.daemon = True 

+

32 self.start() 

+

33 

+

34 def run(self): 

+

35 ''' 

+

36 Docstring 

+

37 

+

38 Parameters 

+

39 ---------- 

+

40 

+

41 Returns 

+

42 ------- 

+

43 ''' 

+

44 last = None 

+

45 while True: 

+

46 k = self._check() 

+

47 if last == 0 and k != 0: 

+

48 self.queue.put(k) 

+

49 last = k 

+

50 time.sleep(0.01) 

+

51 

+

52 def _check(self): 

+

53 ''' 

+

54 Docstring 

+

55 

+

56 Parameters 

+

57 ---------- 

+

58 

+

59 Returns 

+

60 ------- 

+

61 ''' 

+

62 test = ' ' 

+

63 ftdi.ftdi_read_pins(self.port, test) 

+

64 return ord(test) 

+

65 

+

66 def pressed(self): 

+

67 ''' 

+

68 Docstring 

+

69 

+

70 Parameters 

+

71 ---------- 

+

72 

+

73 Returns 

+

74 ------- 

+

75 ''' 

+

76 try: 

+

77 return self.queue.get_nowait() 

+

78 except: 

+

79 return None 

+

80 

+

81 def __del__(self): 

+

82 ''' 

+

83 Docstring 

+

84 

+

85 Parameters 

+

86 ---------- 

+

87 

+

88 Returns 

+

89 ------- 

+

90 ''' 

+

91 if self.port is not None: 

+

92 ftdi.ftdi_disable_bitbang(self.port) 

+

93 ftdi.ftdi_usb_close(self.port) 

+

94 ftdi.ftdi_deinit(self.port) 

+

95 

+

96if __name__ == "__main__": 

+

97 import time 

+

98 btn = Button() 

+

99 while True: 

+

100 k = btn.pressed() 

+

101 if k is not None: 

+

102 print(k) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_calibrations_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_calibrations_py.html new file mode 100644 index 00000000..0399b285 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_calibrations_py.html @@ -0,0 +1,371 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\calibrations.py: 22% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Calibration for the EyeLink eyetracker 

+

3''' 

+

4 

+

5import numpy as np 

+

6 

+

7class Profile(object): 

+

8 ''' 

+

9 Docstring 

+

10 

+

11 Parameters 

+

12 ---------- 

+

13 

+

14 Returns 

+

15 ------- 

+

16 ''' 

+

17 def __init__(self, data, actual, system=None, **kwargs): 

+

18 ''' 

+

19 Docstring 

+

20 

+

21 Parameters 

+

22 ---------- 

+

23 

+

24 Returns 

+

25 ------- 

+

26 ''' 

+

27 self.data = np.array(data) 

+

28 self.actual = np.array(actual) 

+

29 self.system = system 

+

30 self.kwargs = kwargs 

+

31 self._init() 

+

32 

+

33 def _init(self): 

+

34 ''' 

+

35 Docstring 

+

36 

+

37 Parameters 

+

38 ---------- 

+

39 

+

40 Returns 

+

41 ------- 

+

42 ''' 

+

43 #Sanitize the data, clearing out entries which are invalid 

+

44 valid = ~np.isnan(self.data).any(1) 

+

45 self.data = self.data[valid,:] 

+

46 self.actual = self.actual[valid,:] 

+

47 

+

48 def performance(self, blocks=5): 

+

49 '''Perform cross-validation to check the performance of this decoder. 

+

50  

+

51 This function holds out data, trains new decoders using only the training data 

+

52 to check the actual performance of the current decoder. 

+

53 

+

54 Docstring 

+

55 

+

56 Parameters 

+

57 ---------- 

+

58 

+

59 Returns 

+

60 ------- 

+

61 ''' 

+

62 valid = ~np.isnan(self.data).any(1) 

+

63 data = self.data[valid] 

+

64 actual = self.actual[valid] 

+

65 

+

66 nd = self.data.shape[1] 

+

67 dim = tuple(range(nd)), tuple(range(nd, 2*nd)) 

+

68 

+

69 order = np.random.permutation(len(self.data)) 

+

70 idx = set(order) 

+

71 bedge = len(order) / float(blocks) 

+

72 

+

73 ccs = np.zeros((blocks,)) 

+

74 for b in range(blocks): 

+

75 val = order[int(b*bedge):int((b+1)*bedge)] 

+

76 trn = np.array(list(idx - set(val))) 

+

77 

+

78 cal = self.__class__(data[trn], actual[trn], **self.kwargs) 

+

79 corr = np.corrcoef(cal(data[val]).T, actual[val].T) 

+

80 ccs[b] = corr[dim].mean() 

+

81 

+

82 return ccs 

+

83 

+

84 def __call__(self, data): 

+

85 ''' 

+

86 Docstring 

+

87 

+

88 Parameters 

+

89 ---------- 

+

90 

+

91 Returns 

+

92 ------- 

+

93 ''' 

+

94 raise NotImplementedError 

+

95 

+

96class EyeProfile(Profile): 

+

97 ''' 

+

98 Docstring 

+

99 

+

100 Parameters 

+

101 ---------- 

+

102 

+

103 Returns 

+

104 ------- 

+

105 ''' 

+

106 def __init__(self, data, actual, **kwargs): 

+

107 ''' 

+

108 Docstring 

+

109 

+

110 Parameters 

+

111 ---------- 

+

112 

+

113 Returns 

+

114 ------- 

+

115 ''' 

+

116 super(EyeProfile, self).__init__(data, actual, system="eyetracker", **kwargs) 

+

117 

+

118 def _init(self): 

+

119 ''' 

+

120 Docstring 

+

121 

+

122 Parameters 

+

123 ---------- 

+

124 

+

125 Returns 

+

126 ------- 

+

127 ''' 

+

128 super(EyeProfile, self)._init() 

+

129 valid = -(self.data == (-32768, -32768)).all(1) 

+

130 self.data = self.data[valid,:] 

+

131 self.actual = self.actual[valid,:] 

+

132 

+

133class ThinPlate(Profile): 

+

134 '''Interpolates arbitrary input dimensions into arbitrary output dimensions using thin plate splines''' 

+

135 def __init__(self, data, actual, smooth=0, **kwargs): 

+

136 ''' 

+

137 Docstring 

+

138 

+

139 Parameters 

+

140 ---------- 

+

141 

+

142 Returns 

+

143 ------- 

+

144 ''' 

+

145 self.smooth = smooth 

+

146 super(ThinPlate, self).__init__(data, actual, **kwargs) 

+

147 

+

148 def _init(self): 

+

149 ''' 

+

150 Docstring 

+

151 

+

152 Parameters 

+

153 ---------- 

+

154 

+

155 Returns 

+

156 ------- 

+

157 ''' 

+

158 super(ThinPlate, self)._init() 

+

159 self.funcs = [] 

+

160 from scipy.interpolate import Rbf 

+

161 for a in self.actual.T: 

+

162 f = Rbf(*np.vstack([self.data.T, a]), function='thin_plate', smooth=self.smooth) 

+

163 self.funcs.append(f) 

+

164 

+

165 def __call__(self, data): 

+

166 ''' 

+

167 Docstring 

+

168 

+

169 Parameters 

+

170 ---------- 

+

171 

+

172 Returns 

+

173 ------- 

+

174 ''' 

+

175 raw = np.atleast_2d(data).T 

+

176 return np.array([func(*raw) for func in self.funcs]).T 

+

177 

+

178 def __getstate__(self): 

+

179 ''' 

+

180 Docstring 

+

181 

+

182 Parameters 

+

183 ---------- 

+

184 

+

185 Returns 

+

186 ------- 

+

187 ''' 

+

188 state = self.__dict__.copy() 

+

189 del state['funcs'] 

+

190 return state 

+

191 

+

192 def __setstate__(self, state): 

+

193 ''' 

+

194 Docstring 

+

195 

+

196 Parameters 

+

197 ---------- 

+

198 

+

199 Returns 

+

200 ------- 

+

201 ''' 

+

202 super(ThinPlate, self).__setstate__(state) 

+

203 self._init() 

+

204 

+

205class ThinPlateEye(ThinPlate, EyeProfile): 

+

206 ''' 

+

207 Docstring 

+

208 

+

209 Parameters 

+

210 ---------- 

+

211 

+

212 Returns 

+

213 ------- 

+

214 ''' 

+

215 pass 

+

216 

+

217def crossval(cls, data, actual, proportion=0.7, parameter="smooth", xval_range=np.linspace(0,10,20)**2): 

+

218 ''' 

+

219 Docstring 

+

220 

+

221 Parameters 

+

222 ---------- 

+

223 

+

224 Returns 

+

225 ------- 

+

226 ''' 

+

227 actual = np.array(actual) 

+

228 data = np.array(data) 

+

229 

+

230 ccs = np.zeros(len(xval_range)) 

+

231 for i, smooth in enumerate(xval_range): 

+

232 cal = cls(data, actual, **{parameter:smooth}) 

+

233 ccs[i] = cal.performance().mean() 

+

234 

+

235 best = xval_range[ccs.argmax()] 

+

236 return cls(data, actual, **{parameter:best}), best, ccs 

+

237 

+

238class Affine(Profile): 

+

239 '''Runs a linear affine interpolation between data and actual''' 

+

240 def __init__(self, data, actual): 

+

241 ''' 

+

242 Docstring 

+

243 

+

244 Parameters 

+

245 ---------- 

+

246 

+

247 Returns 

+

248 ------- 

+

249 ''' 

+

250 self.data = data 

+

251 self.actual = actual 

+

252 #self.xfm = np.linalg.lstsq() 

+

253 

+

254 

+

255 

+

256 

+

257class AutoAlign(object): 

+

258 '''Runs the autoalignment filter to center everything into the chair coordinates''' 

+

259 def __init__(self, reference): 

+

260 ''' 

+

261 Docstring 

+

262 

+

263 Parameters 

+

264 ---------- 

+

265 

+

266 Returns 

+

267 ------- 

+

268 ''' 

+

269 print("Making autoaligner from reference %s"%reference) 

+

270 from riglib.stereo_opengl import xfm 

+

271 self._quat = xfm.Quaternion 

+

272 self.ref = np.load(reference)['reference'] 

+

273 self.xfm = xfm.Quaternion() 

+

274 self.off1 = np.array([0,0,0]) 

+

275 self.off2 = np.array([0,0,0]) 

+

276 

+

277 def __call__(self, data): 

+

278 ''' 

+

279 Docstring 

+

280 

+

281 Parameters 

+

282 ---------- 

+

283 

+

284 Returns 

+

285 ------- 

+

286 ''' 

+

287 mdata = data.mean(0)[:, :3] 

+

288 avail = (data[:,-6:, -1] > 0).all(0) 

+

289 if avail[:3].all(): 

+

290 #ABC reference 

+

291 cdata = mdata[-6:-3] - mdata[-6] 

+

292 self.off1 = mdata[-6] 

+

293 self.off2 = self.ref[0] 

+

294 rot1 = self._quat.rotate_vecs(cdata[1], self.ref[1] - self.ref[0]) 

+

295 rot2 = self._quat.rotate_vecs((rot1*cdata[2]), self.ref[2] - self.ref[0]) 

+

296 self.xfm = rot2*rot1 

+

297 elif avail[3:].all(): 

+

298 #DEF reference 

+

299 cdata = mdata[-3:] - mdata[-3] 

+

300 self.off1 = mdata[-3] 

+

301 self.off2 = self.ref[3] 

+

302 rot1 = self._quat.rotate_vecs(cdata[1], self.ref[4] - self.ref[3]) 

+

303 rot2 = self._quat.rotate_vecs((rot1*cdata[2]), self.ref[5] - self.ref[3]) 

+

304 self.xfm = rot2*rot1 

+

305 rdata = self.xfm*(mdata[:-6] - self.off1) + self.off2 

+

306 rdata[(data[:,:-6,-1] < 0).any(0)] = np.nan 

+

307 return np.hstack([rdata, np.ones((len(rdata),1))])[np.newaxis] 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio___init___py.html new file mode 100644 index 00000000..20d20d95 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio___init___py.html @@ -0,0 +1,65 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\dio\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio_nidaq___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio_nidaq___init___py.html new file mode 100644 index 00000000..a9fa4f72 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio_nidaq___init___py.html @@ -0,0 +1,230 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\dio\nidaq\__init__.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2NIDAQ Digital I/O interface code. Higher-level Python wrapper for pcidio  

+

3''' 

+

4 

+

5from riglib.dio import parse 

+

6try: 

+

7 import pcidio 

+

8except: 

+

9 pass 

+

10 

+

11class SendAll(object): 

+

12 ''' 

+

13 Interface for sending all the task-generated data through the NIDAQ interface card 

+

14 ''' 

+

15 def __init__(self, device="/dev/comedi0"): 

+

16 ''' 

+

17 Constructor for SendAll 

+

18 

+

19 Parameters 

+

20 ---------- 

+

21 device : string, optional 

+

22 comedi device to open and use for data sending; only tested for NI PCI-6503 

+

23 

+

24 Returns 

+

25 ------- 

+

26 SendAll instance 

+

27 ''' 

+

28 self.systems = dict() 

+

29 

+

30 try: 

+

31 pcidio 

+

32 except: 

+

33 raise Exception('Cannot import pcidio. Did you run its build script?') 

+

34 if pcidio.init(device) != 0: 

+

35 raise ValueError("Could not initialize comedi system") 

+

36 

+

37 def close(self): 

+

38 ''' 

+

39 Release access to the NIDAQ card so that other processes can access the device later. 

+

40 ''' 

+

41 if pcidio.closeall() != 0: 

+

42 raise ValueError("Unable to close comedi system") 

+

43 

+

44 def register(self, system, dtype): 

+

45 ''' 

+

46 Send information about the registration system (name and datatype) in string form, one byte at a time. 

+

47 

+

48 Parameters 

+

49 ---------- 

+

50 system : string 

+

51 Name of the system being registered 

+

52 dtype : np.dtype instance 

+

53 Datatype of incoming data, for later decoding of the binary data during analysis 

+

54 

+

55 Returns 

+

56 ------- 

+

57 None 

+

58 ''' 

+

59 print("nidaq register %s" % system) 

+

60 

+

61 # Save the index of the system being registered (arbitrary number corresponding to the order in which systems were registered) 

+

62 self.systems[system] = pcidio.register_sys(system, str(dtype.descr)) 

+

63 

+

64 def sendMsg(self, msg): 

+

65 ''' 

+

66 Send a string mesasge to the recording system, e.g., as related to the task_msgs HDF table 

+

67 

+

68 Parameters 

+

69 ---------- 

+

70 msg : string 

+

71 Message to send 

+

72 

+

73 Returns 

+

74 ------- 

+

75 None 

+

76 ''' 

+

77 pcidio.sendMsg(str(msg)) 

+

78 

+

79 def sendRow(self, system, idx): 

+

80 ''' 

+

81 The function pcidio.sendRow tries to send an array through the card one byte at a time.  

+

82 It's unclear how the underlying code is supposed to get the 'idx' pointer that is required to use this function... 

+

83 .. and this function appears to be unused and probably should be removed.  

+

84 Abandoned in place in case there is a future unknown use for this functionality 

+

85 ''' 

+

86 if system in self.systems: 

+

87 pcidio.sendRow(self.systems[system], idx) 

+

88 

+

89 def rstart(self, state): 

+

90 ''' 

+

91 Remotely start recording from the plexon system 

+

92 

+

93 Parameters 

+

94 ---------- 

+

95 state : int 

+

96 0 or 1 depending on whether you want the system to start or stop 

+

97 For the plexon system, this is actually not used, and instead the comedi python bindings generate the required pulse. 

+

98 

+

99 Returns 

+

100 ------- 

+

101 None 

+

102 ''' 

+

103 print("Sending rstart command") 

+

104 pcidio.rstart(state) 

+

105 

+

106 def send(self, system, data): 

+

107 ''' 

+

108 Send data through the DIO device 

+

109 

+

110 Parameters 

+

111 ---------- 

+

112 system : string 

+

113 Name of system where the data originated 

+

114 data : object 

+

115 Data to send. Must have a '.tostring()' attribute 

+

116 

+

117 Returns 

+

118 ------- 

+

119 None 

+

120 ''' 

+

121 if system in self.systems: 

+

122 pcidio.sendData(self.systems[system], data.tostring()) 

+

123 

+

124class SendRow(SendAll): 

+

125 ''' 

+

126 Send the number of rows to the plexon system. Used by riglib.hdfwriter.PlexRelayWriter, which is never actually used anywhere.... 

+

127 ''' 

+

128 def send(self, system, data): 

+

129 ''' 

+

130 Send data to a registered system 

+

131 

+

132 Parameters 

+

133 ---------- 

+

134 system : string 

+

135 Name of system where the data originated 

+

136 data : object 

+

137 Argument is ignored, since only the count is sent and not the actual data 

+

138 

+

139 Returns 

+

140 ------- 

+

141 None 

+

142 ''' 

+

143 if system in self.systems: 

+

144 pcidio.sendRowCount(self.systems[system]) 

+

145 

+

146class SendRowByte(SendAll): 

+

147 ''' 

+

148 Send only an 8-bit data word corresponding to the 8 lower bits of the current row number of the HDF table 

+

149 ''' 

+

150 def send(self, system, data): 

+

151 ''' 

+

152 Send data to a registered system 

+

153 

+

154 Parameters 

+

155 ---------- 

+

156 system : string 

+

157 Name of system where the data originated 

+

158 data : object 

+

159 Argument is ignored, since only the count is sent and not the actual data 

+

160 

+

161 Returns 

+

162 ------- 

+

163 None 

+

164 ''' 

+

165 if system in self.systems: 

+

166 pcidio.sendRowByte(self.systems[system]) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio_parse_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio_parse_py.html new file mode 100644 index 00000000..dc2214fa --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_dio_parse_py.html @@ -0,0 +1,235 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\dio\parse.py: 20% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Parse digital data from neural recording system into task data/messages/synchronization pulses 

+

3''' 

+

4 

+

5import numpy as np 

+

6 

+

7msgtype_mask = 0b0000111 << 8 

+

8auxdata_mask = 0b1111000 << 8 

+

9rawdata_mask = 0b11111111 

+

10 

+

11MSG_TYPE_DATA = 0 

+

12MSG_TYPE_MESSAGE = 1 

+

13MSG_TYPE_REGISTER = 2 

+

14MSG_TYPE_REGISTER_SHAPE = 3 

+

15MSG_TYPE_ROW = 4 

+

16MSG_TYPE_ROWBYTE = 5 

+

17 

+

18def parse_data(strobe_data): 

+

19 ''' 

+

20 Parse out 'strobe' digital data into header/registrations + actual data 

+

21 ''' 

+

22 reg_parsed_data = registrations(strobe_data) 

+

23 msg_parsed_data = messages(strobe_data) 

+

24 rowbyte_parsed_data = rowbyte(strobe_data) 

+

25 

+

26 data = dict(messages=msg_parsed_data) 

+

27 for key in rowbyte_parsed_data: 

+

28 if key in reg_parsed_data: 

+

29 sys_name = reg_parsed_data[key][0] 

+

30 try: 

+

31 sys_dtype = np.dtype(eval(reg_parsed_data[key][1])) 

+

32 except: 

+

33 sys_dtype = reg_parsed_data[key][1] 

+

34 data[sys_name] = dict(row_ts=rowbyte_parsed_data[key]) 

+

35 else: 

+

36 data[key] = dict(row_ts=rowbyte_parsed_data[key]) 

+

37 

+

38 return data 

+

39 

+

40def _split(data, flip=False): 

+

41 ''' 

+

42 Helper function to take the 16-bit integer saved in the neural data file  

+

43 and map it back to the three fields of the message type (see docs on  

+

44 communication protocol for details) 

+

45 

+

46 Parameters 

+

47 ---------- 

+

48 data : np.ndarray  

+

49 Integer data and timestamps as stored in the neural data file when messages were sent during experiment 

+

50 

+

51 Returns 

+

52 ------- 

+

53 np.ndarray 

+

54 Raw message data split into the fields (type, aux, "payload") 

+

55 ''' 

+

56 # If the data is a 1D array, extract the timestamps and the raw event codes 

+

57 if len(data.shape) < 2: 

+

58 data = np.array(data[data['chan'] == 257][['ts', 'unit']].tolist()) 

+

59 msgs = data[:,1].astype(np.int16) 

+

60 

+

61 if not flip: 

+

62 msgs = ~msgs # bit-flip the binary messages 

+

63 msgtype = np.right_shift(np.bitwise_and(msgs, msgtype_mask), 8).astype(np.uint8) 

+

64 auxdata = np.right_shift(np.bitwise_and(msgs, auxdata_mask), 8).astype(np.uint8) 

+

65 rawdata = np.bitwise_and(msgs, rawdata_mask) 

+

66 return np.vstack([data[:,0], msgtype, auxdata, rawdata]).T 

+

67 

+

68def registrations(data, map_system=False): 

+

69 ''' 

+

70 Parse the DIO data from the neural recording system to determine which  

+

71 data sources were registered by the experiment software 

+

72  

+

73 Parameters 

+

74 --------- 

+

75 data: np.array 

+

76 Digital bit data sent to the neural recording box. This data can either 

+

77 be a 1D record array with fields ('chan', 'unit', 'ts') or an N x 4 

+

78 regular array where the four columns are (timestamp, message type, 

+

79 auxiliary data, raw data) 

+

80  

+

81 Returns 

+

82 ------- 

+

83 systems : dict 

+

84 In the dictionary, keys are the ID # of the system (assigned sequentially 

+

85 during registration time when the task is initializing). Values are 

+

86 tuples of (name, dtype) 

+

87 ''' 

+

88 if data.ndim < 2 or data.shape[1] != 4: 

+

89 if map_system: 

+

90 data = _split(data, flip=True) 

+

91 else: 

+

92 data = _split(data) 

+

93 

+

94 ts, msgtype, auxdata, rawdata = data[:,0], data[:,1], data[:,2], data[:,3].astype(np.uint8) 

+

95 idx = msgtype == MSG_TYPE_REGISTER #data[:,1] == MSG_TYPE_REGISTER  

+

96 shape_idx = msgtype == MSG_TYPE_REGISTER_SHAPE #data[:,1] == MSG_TYPE_REGISTER_SHAPE  

+

97 

+

98 regsysid = auxdata[idx] #data[idx][:,2] #should have more than 

+

99 #one value for more than one registration 

+

100 

+

101 

+

102 regshapeid = auxdata[shape_idx] #data[shape_idx][:,2] 

+

103 names = rawdata[idx] #data[idx][:,3].astype(np.uint8) 

+

104 

+

105 dtype_data = rawdata[shape_idx] 

+

106 

+

107 systems = dict() 

+

108 for sys in np.unique(regsysid): 

+

109 name = names[regsysid == sys].tostring() 

+

110 name = name[:-1] # Remove null terminator 

+

111 dtype = dtype_data[regshapeid == sys].tostring() #data[shape_idx][regshapeid == sys][:,3].astype(np.uint8).tostring() 

+

112 dtype = dtype[:-1] # Remove null terminator 

+

113 systems[sys] = name, dtype #name[:-1], dtype[:-1] 

+

114 #import pdb; pdb.set_trace()  

+

115 return systems 

+

116 

+

117def rowbyte(data, **kwargs): 

+

118 ''' 

+

119 Parameters 

+

120 ---------- 

+

121 data: np.array 

+

122 see docs for registrations for shape/dtype 

+

123 kwargs : dict 

+

124 see Docs for _split to see which kwargs are allowed 

+

125 

+

126 Returns 

+

127 ------- 

+

128 systems : dict 

+

129 In the dictionary, keys are the ID # of the system (assigned sequentially 

+

130 during registration time when the task is initializing). Values are 

+

131 np.array of shape (N, 2) where the columns are (timestamp, rawdata) 

+

132 ''' 

+

133 if data.ndim < 2 or data.shape[1] != 4: 

+

134 data = _split(data, **kwargs) 

+

135 

+

136 msgs = data[data[:,1] == MSG_TYPE_ROWBYTE] 

+

137 systems = dict() 

+

138 for i in np.unique(msgs[:,2]): 

+

139 systems[i] = msgs[msgs[:,2] == i][:,[0,-1]] 

+

140 return systems 

+

141 

+

142def messages(data, **kwargs): 

+

143 ''' 

+

144 Parse out any string messages sent byte-by-byte to the neural recording system 

+

145 

+

146 Parameters 

+

147 ---------- 

+

148 data : np.ndarray  

+

149 Integer data and timestamps as stored in the neural data file when messages were sent during experiment 

+

150 OR, the result of the _split function 

+

151 

+

152 Returns 

+

153 ------- 

+

154 record array 

+

155 fields of 'time' and 'state' (message) 

+

156 ''' 

+

157 if data.ndim < 2 or data.shape[1] != 4: 

+

158 data = _split(data, **kwargs) 

+

159 

+

160 times = data[data[:,1] == 1, 0] 

+

161 names = data[data[:,1] == 1,-1].astype(np.uint8) 

+

162 

+

163 tidx, = np.nonzero(names == 0) 

+

164 tidx = np.hstack([0, tidx+1]) 

+

165 

+

166 msgs = [] 

+

167 for s, e in np.vstack([tidx[:-1], tidx[1:]-1]).T: 

+

168 msgs.append((times[s], names[s:e].tostring())) 

+

169 

+

170 return np.array(msgs, dtype=[('time', np.float), ('state', 'S256')]) 

+

171 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_Pygame_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_Pygame_py.html new file mode 100644 index 00000000..5bb2ba39 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_Pygame_py.html @@ -0,0 +1,158 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\experiment\Pygame.py: 48% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2General-purpose "pygame" experiment class. Mostly unused 

+

3''' 

+

4 

+

5import os 

+

6os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 

+

7import pygame 

+

8 

+

9from . import traits 

+

10from .experiment import LogExperiment 

+

11 

+

12class Pygame(LogExperiment): 

+

13 ''' 

+

14 'Window' used by the older tasks (dots, eyemove, rds, redgreen, sensorymapping) 

+

15 ''' 

+

16 background = (0,0,0) 

+

17 fps = 60 

+

18 

+

19 timeout_time = traits.Float(4., desc="Timeout (in seconds) during a trial while waiting for a response") 

+

20 penalty_time = traits.Float(5., desc="Length of penalty (in seconds) for incorrect and premature responses") 

+

21 reward_time = traits.Float(5, desc='Time of reward in seconds') 

+

22 

+

23 def screen_init(self): 

+

24 ''' 

+

25 Initialize the pygame display 

+

26 ''' 

+

27 os.environ['SDL_VIDEO_WINDOW_POS'] = "0,0" 

+

28 os.environ['SDL_VIDEO_X11_WMCLASS'] = "monkey_experiment" 

+

29 

+

30 pygame.init() 

+

31 flags = pygame.DOUBLEBUF | pygame.HWSURFACE | pygame.NOFRAME 

+

32 pygame.display.set_mode((3840,1080), flags) 

+

33 

+

34 self.surf = pygame.display.get_surface() 

+

35 self.clock = pygame.time.Clock() 

+

36 self.event = None 

+

37 

+

38 def draw_frame(self): 

+

39 raise NotImplementedError 

+

40 

+

41 def clear_screen(self): 

+

42 self.surf.fill(self.background) 

+

43 pygame.display.flip() 

+

44 

+

45 def _get_event(self): 

+

46 for e in pygame.event.get(pygame.KEYDOWN): 

+

47 return (e.key, e.type) 

+

48 

+

49 def flip_wait(self): 

+

50 pygame.display.flip() 

+

51 self.event = self._get_event() 

+

52 self.clock.tick(self.fps) 

+

53 

+

54 def _while_wait(self): 

+

55 self.surf.fill(self.background) 

+

56 self.flip_wait() 

+

57 

+

58 def _while_trial(self): 

+

59 self.draw_frame() 

+

60 self.flip_wait() 

+

61 

+

62 def _while_reward(self): 

+

63 self.clear_screen() 

+

64 

+

65 def _while_penalty(self): 

+

66 self.surf.fill((181,0,45)) 

+

67 self.flip_wait() 

+

68 

+

69 def _test_start_trial(self, ts): 

+

70 return self.event is not None 

+

71 

+

72 def _test_correct(self, ts): 

+

73 raise NotImplementedError 

+

74 

+

75 def _test_incorrect(self, ts): 

+

76 raise NotImplementedError 

+

77 

+

78 def _test_timeout(self, ts): 

+

79 return ts > self.timeout_time 

+

80 

+

81 def _test_post_reward(self, ts): 

+

82 return ts > self.reward_time 

+

83 

+

84 def _test_post_penalty(self, ts): 

+

85 return ts > self.penalty_time 

+

86 

+

87 def _start_None(self): 

+

88 pygame.display.quit() 

+

89 

+

90 def _start_reward(self): 

+

91 pass 

+

92 

+

93 def _start_wait(self): 

+

94 pass 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment___init___py.html new file mode 100644 index 00000000..c9905e6f --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment___init___py.html @@ -0,0 +1,162 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\experiment\__init__.py: 61% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Experiment constructors. 'Experiment' instances are the combination of  

+

3a task and a list of features. Rather than have a separate class for 

+

4all the possible combinations of tasks and features, a custom class for 

+

5the experiment is created programmatically using 'type'. The created class  

+

6has methods of the base task as well as all the selected features.  

+

7''' 

+

8import numpy as np 

+

9 

+

10################## 

+

11##### Traits ##### 

+

12################## 

+

13try: 

+

14 import traits.api as traits 

+

15except ImportError: 

+

16 import enthought.traits.api as traits 

+

17 

+

18 

+

19class InstanceFromDB(traits.Instance): 

+

20 def __init__(self, *args, **kwargs): 

+

21 if 'bmi3d_db_model' in kwargs: 

+

22 self.bmi3d_db_model = kwargs['bmi3d_db_model'] 

+

23 else: 

+

24 raise ValueError("If using trait 'InstanceFromDB', must specify bmi3d_db_model!") 

+

25 

+

26 # save the arguments for the database 

+

27 #self.bmi3d_query_kwargs = kwargs.pop('bmi3d_query_kwargs', dict()) 

+

28 if 'bmi3d_query_kwargs' in kwargs: 

+

29 self.bmi3d_query_kwargs = kwargs['bmi3d_query_kwargs'] 

+

30 else: 

+

31 self.bmi3d_query_kwargs = dict() 

+

32 

+

33 super(InstanceFromDB, self).__init__(*args, **kwargs) 

+

34 

+

35 

+

36class DataFile(InstanceFromDB): 

+

37 def __init__(self, *args, **kwargs): 

+

38 kwargs['bmi3d_db_model'] = 'DataFile' 

+

39 super(DataFile, self).__init__(*args, **kwargs) 

+

40 

+

41 

+

42class OptionsList(traits.Enum): 

+

43 def __init__(self, *args, **kwargs): 

+

44 if 'bmi3d_input_options' not in kwargs: 

+

45 kwargs['bmi3d_input_options'] = args[0] 

+

46 

+

47 super(OptionsList, self).__init__(*args, **kwargs) 

+

48 

+

49 

+

50traits.InstanceFromDB = InstanceFromDB 

+

51traits.DataFile = DataFile 

+

52traits.OptionsList = OptionsList 

+

53 

+

54 

+

55 

+

56from . import experiment 

+

57from . import generate 

+

58from . import report 

+

59from .experiment import Experiment, LogExperiment, Sequence, TrialTypes, FSMTable, StateTransitions 

+

60 

+

61try: 

+

62 from .Pygame import Pygame 

+

63except: 

+

64 import warnings 

+

65 warnings.warn('riglib/experiment/__init__.py: could not import Pygame (note capital P)') 

+

66 Pygame = object 

+

67 

+

68def make(exp_class, feats=()): 

+

69 ''' 

+

70 Creates a class which inherits from a base experiment class as well as a set of optional features. 

+

71 This function is a *metafunction* as it returns a custom class construction. 

+

72 

+

73 Parameters 

+

74 ---------- 

+

75 exp_class : class 

+

76 Base class containing the finite state machine of the task 

+

77 feats : iterable of classes 

+

78 Additional classes from which to also inherit 

+

79 

+

80 Returns 

+

81 ------- 

+

82 class 

+

83 New class which inherits from the base 'exp_class' and the selected 'feats' 

+

84 ''' 

+

85 if len(feats) == 0: 

+

86 return exp_class 

+

87 else: 

+

88 # construct the class list to define inheritance order for the custom task 

+

89 # inherit from the features first, then the base class 

+

90 clslist = tuple(feats) + (exp_class,) 

+

91 

+

92 print("metaclass constructor") 

+

93 print(clslist) 

+

94 print(feats) 

+

95 

+

96 # return custom class 

+

97 return type(exp_class.__name__, clslist, dict()) 

+

98 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_experiment_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_experiment_py.html new file mode 100644 index 00000000..345efc9b --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_experiment_py.html @@ -0,0 +1,707 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\experiment\experiment.py: 29% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Experimental task base classes, contains mostly code to run the generic  

+

3finite state machine representing different phases of the task 

+

4''' 

+

5 

+

6import time 

+

7import random 

+

8import traceback 

+

9import collections 

+

10import re 

+

11import os 

+

12import tables 

+

13import traceback 

+

14import io 

+

15import numpy as np 

+

16 

+

17from . import traits 

+

18from riglib.fsm import FSMTable, StateTransitions, ThreadedFSM 

+

19 

+

20try: 

+

21 os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 

+

22 import pygame 

+

23except ImportError: 

+

24 import warnings 

+

25 warnings.warn("experiment.py: Cannot import 'pygame'") 

+

26 

+

27from collections import OrderedDict 

+

28 

+

29min_per_hour = 60. 

+

30sec_per_min = 60. 

+

31 

+

32 

+

33def _get_trait_default(trait): 

+

34 ''' 

+

35 Function which tries to determine the default value for a trait in the class declaration 

+

36 ''' 

+

37 _, default = trait.default_value() 

+

38 if isinstance(default, tuple) and len(default) > 0: 

+

39 try: 

+

40 func, args, _ = default 

+

41 default = func(*args) 

+

42 except: 

+

43 pass 

+

44 return default 

+

45 

+

46 

+

47class Experiment(ThreadedFSM, traits.HasTraits): 

+

48 ''' 

+

49 Common ancestor of all task/experiment classes 

+

50 ''' 

+

51 status = dict( 

+

52 wait = dict(start_trial="trial", premature="penalty", stop=None), 

+

53 trial = dict(correct="reward", incorrect="penalty", timeout="penalty"), 

+

54 reward = dict(post_reward="wait"), 

+

55 penalty = dict(post_penalty="wait"), 

+

56 ) 

+

57 

+

58 # For analysis purposes, it's useful to declare which task states are "terminal" states and signify the end of a trial 

+

59 trial_end_states = [] 

+

60 

+

61 # Set the initial state to 'wait'. The 'wait' state has special behavior for the Sequence class (see below) 

+

62 state = "wait" 

+

63 

+

64 # Flag to indicate that the task object has not been constructed or initialized 

+

65 _task_init_complete = False 

+

66 

+

67 # Flag to set in order to stop the FSM gracefully 

+

68 stop = False 

+

69 

+

70 # Rate at which FSM is called. Set to 60 Hz by default to match the typical monitor update rate 

+

71 fps = 60 # Hz 

+

72 

+

73 cycle_count = 0 

+

74 

+

75 # set this flag to true if certain things should only happen in debugging mode 

+

76 debug = False 

+

77 terminated_in_error = False 

+

78 

+

79 ## GUI/database-related attributes 

+

80 # Flag to specify if you want to be able to create a BMI Decoder object from the web interface 

+

81 is_bmi_seed = False 

+

82 

+

83 # Trait GUI manipulation 

+

84 exclude_parent_traits = [] # List of possible parent traits that you don't want to be set from the web interface 

+

85 ordered_traits = [] # Traits in this list appear in order at the top of the web interface parameters 

+

86 hidden_traits = [] # These traits are hidden on the web interface, and can be displayed by clicking the 'Show' radiobutton on the web interface 

+

87 

+

88 # Runtime settable traits 

+

89 session_length = traits.Float(0, desc="Time until task automatically stops. Length of 0 means no auto stop.") 

+

90 

+

91 # Initialization functions ----------------------------------------------- 

+

92 @classmethod 

+

93 def pre_init(cls, **kwargs): 

+

94 ''' 

+

95 Jobs to do before creating the task object go here (or this method should be overridden in child classes). 

+

96 Examples might include sending a trigger to start a recording device (e.g., neural system), since you might want 

+

97 recording to be guaranteed to start before any task event loop activity occurs.  

+

98 ''' 

+

99 print('running experiment.Experiment.pre_init') 

+

100 pass 

+

101 

+

102 def __init__(self, verbose=True, **kwargs): 

+

103 ''' 

+

104 Constructor for Experiment. This is the standard python object constructor 

+

105 

+

106 Parameters 

+

107 ---------- 

+

108 kwargs: optional keyword-arguments 

+

109 Any user-specified parameters for experiment traits, to be passed to the traits.HasTraits parent.  

+

110 

+

111 Returns 

+

112 ------- 

+

113 Experiment instance 

+

114 ''' 

+

115 traits.HasTraits.__init__(self, **kwargs) 

+

116 ThreadedFSM.__init__(self) 

+

117 self.verbose = verbose 

+

118 self.task_start_time = self.get_time() 

+

119 self.reportstats = collections.OrderedDict() 

+

120 self.reportstats['State'] = None #State stat is automatically updated for all experiment classes 

+

121 self.reportstats['Runtime'] = '' #Runtime stat is automatically updated for all experiment classes 

+

122 self.reportstats['Trial #'] = 0 #Trial # stat must be updated by individual experiment classes 

+

123 self.reportstats['Reward #'] = 0 #Rewards stat is updated automatically for all experiment classes 

+

124 

+

125 # If the FSM is set up in the old style (explicit dictionaries instead of wrapper data types), convert to the newer FSMTable 

+

126 if isinstance(self.status, dict): 

+

127 self.status = FSMTable.construct_from_dict(self.status) 

+

128 

+

129 # Attribute for task entry dtype, used to create a numpy record array which is updated every iteration of the FSM 

+

130 # See http://docs.scipy.org/doc/numpy/user/basics.rec.html for details on how to create a record array dtype 

+

131 self.dtype = [] 

+

132 

+

133 self.cycle_count = 0 

+

134 self.clock = pygame.time.Clock() 

+

135 

+

136 self.pause = False 

+

137 

+

138 

+

139 ## Figure out which traits to not save to the HDF file 

+

140 ## Large/complex python objects cannot be saved as HDF file attributes 

+

141 ctraits = self.class_traits() 

+

142 self.object_trait_names = [ctr for ctr in list(ctraits.keys()) if ctraits[ctr].trait_type.__class__.__name__ in ['Instance', 'InstanceFromDB', 'DataFile']] 

+

143 

+

144 if self.verbose: print("finished executing Experiment.__init__") 

+

145 

+

146 def init(self): 

+

147 ''' 

+

148 Initialization method to run *after* object construction (see self.start).  

+

149 This may be necessary in some cases where features are used with multiple inheritance to extend tasks  

+

150 (this is the standard way of creating custom base experiment + features classes through the browser interface).  

+

151 With multiple inheritance, it's difficult/annoying to make guarantees about the order of operations for  

+

152 each of the individual __init__ functions from each of the parents. Instead, this function runs after all the  

+

153 __init__ functions have finished running if any subsequent initialization is necessary before the main event loop  

+

154 can execute properly. Examples include initialization of the decoder state/parameters.  

+

155 ''' 

+

156 # Timestamp for rough loop timing 

+

157 self.last_time = self.get_time() 

+

158 self.cycle_count = 0 

+

159 

+

160 # Create task_data record array 

+

161 # NOTE: all data variables MUST be declared prior to this point. So child classes overriding the 'init' method must 

+

162 # declare their variables using the 'add_dtype' function BEFORE calling the 'super' method. 

+

163 try: 

+

164 self.dtype = np.dtype(self.dtype) 

+

165 self.task_data = np.zeros((1,), dtype=self.dtype) 

+

166 except: 

+

167 print("Error creating the task_data record array") 

+

168 traceback.print_exc() 

+

169 print(self.dtype) 

+

170 self.task_data = None 

+

171 

+

172 # Register the "task" source with the sinks 

+

173 if not hasattr(self, 'sinks'): # this attribute might be set in one of the other 'init' functions from other inherited classes 

+

174 from riglib import sink 

+

175 self.sinks = sink.sinks 

+

176 

+

177 try: 

+

178 self.sinks.register("task", self.dtype) 

+

179 except: 

+

180 traceback.print_exc() 

+

181 raise Exception("Error registering task source") 

+

182 

+

183 self._task_init_complete = True 

+

184 

+

185 def add_dtype(self, name, dtype, shape): 

+

186 ''' 

+

187 Add to the dtype of the task. The task's dtype attribute is used to determine  

+

188 which attributes to save to file.  

+

189 ''' 

+

190 new_field = (name, dtype, shape) 

+

191 existing_field_names = [x[0] for x in self.dtype] 

+

192 if name in existing_field_names: 

+

193 raise Exception("Duplicate add_dtype functionc call for task data field: %s" % name) 

+

194 else: 

+

195 self.dtype.append(new_field) 

+

196 

+

197 def screen_init(self): 

+

198 ''' 

+

199 This method is implemented by the riglib.stereo_opengl.Window class, which is not used by all tasks. However,  

+

200 since Experiment is the ancestor of all tasks, a stub function is here so that any children 

+

201 using the window can safely use 'super'.  

+

202 ''' 

+

203 pass 

+

204 

+

205 # Trait functions -------------------------------------------------------- 

+

206 @classmethod 

+

207 def class_editable_traits(cls): 

+

208 ''' 

+

209 Class method to retrieve the list of editable traits for the given experiment.  

+

210 The default behavior for an experiment class is to make all traits editable except for those 

+

211 listed in the attribute 'exclude_parent_traits'.  

+

212 

+

213 Parameters 

+

214 ---------- 

+

215 None 

+

216 

+

217 Returns 

+

218 ------- 

+

219 editable_traits: list of strings 

+

220 Names of traits which are designated to be runtime-editable 

+

221 ''' 

+

222 # traits = super(Experiment, cls).class_editable_traits() 

+

223 from traits.trait_base import not_event, not_false 

+

224 traits = cls.class_trait_names(type=not_event, editable=not_false) 

+

225 editable_traits = [x for x in traits if x not in cls.exclude_parent_traits] 

+

226 return editable_traits 

+

227 

+

228 @classmethod 

+

229 def get_trait_info(cls, trait_name, ctraits=None): 

+

230 """Get dictionary of information on a given trait""" 

+

231 if ctraits is None: 

+

232 ctraits = cls.class_traits() 

+

233 

+

234 trait_params = dict() 

+

235 trait_params['type'] = ctraits[trait_name].trait_type.__class__.__name__ 

+

236 trait_params['default'] = _get_trait_default(ctraits[trait_name]) 

+

237 trait_params['desc'] = ctraits[trait_name].desc 

+

238 trait_params['hidden'] = 'hidden' if cls.is_hidden(trait_name) else 'visible' 

+

239 if hasattr(ctraits[trait_name], 'label'): 

+

240 trait_params['label'] = ctraits[trait_name].label 

+

241 else: 

+

242 trait_params['label'] = trait_name 

+

243 

+

244 if trait_params['type'] == "InstanceFromDB": 

+

245 # a database instance. pass back the model and the query parameters and let the db  

+

246 # handle the rest 

+

247 trait_params['options'] = (mdl_name, ctraits[trait_name].bmi3d_query_kwargs) 

+

248 

+

249 elif trait_params['type'] == 'Instance': 

+

250 raise ValueError("You should use the 'InstanceFromDB' trait instead of the 'Instance' trait!") 

+

251 

+

252 elif trait_params['type'] == "Enum": 

+

253 raise ValueError("You should use the 'OptionsList' trait instead of the 'Enum' trait!") 

+

254 

+

255 elif trait_params['type'] == "OptionsList": 

+

256 trait_params['options'] = ctraits[trait_name].bmi3d_input_options 

+

257 

+

258 elif trait_params['type'] == "DataFile": 

+

259 trait_params['options'] = ("DataFile", ctraits[trait_name].bmi3d_query_kwargs) 

+

260 

+

261 return trait_params 

+

262 

+

263 @classmethod 

+

264 def get_params(cls): 

+

265 # Use an ordered dict so that params actually stay in the order they're added, instead of random (hash) order 

+

266 params = OrderedDict() 

+

267 

+

268 ctraits = cls.class_traits() 

+

269 

+

270 # add all the traits that are explicitly instructed to appear at the top of the menu 

+

271 ordered_traits = cls.ordered_traits 

+

272 for trait in ordered_traits: 

+

273 if trait in cls.class_editable_traits(): 

+

274 params[trait] = cls.get_trait_info(trait, ctraits=ctraits) 

+

275 

+

276 # add all the remaining non-hidden traits 

+

277 for trait in cls.class_editable_traits(): 

+

278 if trait not in params and not cls.is_hidden(trait): 

+

279 params[trait] = cls.get_trait_info(trait, ctraits=ctraits) 

+

280 

+

281 # add any hidden traits 

+

282 for trait in cls.class_editable_traits(): 

+

283 if trait not in params: 

+

284 params[trait] = cls.get_trait_info(trait, ctraits=ctraits) 

+

285 return params 

+

286 

+

287 def get_trait_values(self): 

+

288 ''' 

+

289 Retrieve all the values of the 'trait' objects 

+

290 ''' 

+

291 trait_values = dict() 

+

292 for trait in self.class_editable_traits(): 

+

293 trait_values[trait] = getattr(self, trait) 

+

294 return trait_values 

+

295 

+

296 @classmethod 

+

297 def is_hidden(cls, trait): 

+

298 ''' 

+

299 Return true if the given trait is not meant to be shown on the GUI by default, i.e. hidden  

+

300 

+

301 Parameters 

+

302 ---------- 

+

303 trait: string 

+

304 Name of trait to check 

+

305 

+

306 Returns 

+

307 ------- 

+

308 bool 

+

309 ''' 

+

310 return trait in cls.hidden_traits 

+

311 

+

312 @classmethod 

+

313 def get_desc(cls, params, report): 

+

314 return "An experiment!" 

+

315 

+

316 # FSM functions ---------------------------------------------------------- 

+

317 def run(self): 

+

318 ''' 

+

319 Method to run the finite state machine of the task. Code that needs to execute  

+

320 imediately before the task starts running in child classes should be of the form: 

+

321 

+

322 def run(self): 

+

323 do stuff 

+

324 try: 

+

325 super(class_name, self).run() 

+

326 finally: 

+

327 clean up stuff 

+

328 

+

329 The try block may or may not be necessary. For example, if you're opening a UDP port, you may want to always 

+

330 close the socket whether or not the main loop executes properly so that you don't loose the  

+

331 reference to the socket.  

+

332 ''' 

+

333 

+

334 ## Initialize the FSM before the loop 

+

335 self.screen_init() 

+

336 self.reportstats['State'] = self.state 

+

337 super(Experiment, self).run() 

+

338 

+

339 def _cycle(self): 

+

340 ''' 

+

341 Code that needs to run every task loop iteration goes here 

+

342 ''' 

+

343 super(Experiment, self)._cycle() 

+

344 

+

345 # Send task data to any registered sinks 

+

346 if self.task_data is not None: 

+

347 self.sinks.send("task", self.task_data) 

+

348 

+

349 def _test_stop(self, ts): 

+

350 '''  

+

351 FSM 'test' function. Returns the 'stop' attribute of the task 

+

352 ''' 

+

353 if self.session_length > 0 and (self.get_time() - self.task_start_time) > self.session_length: 

+

354 self.end_task() 

+

355 return self.stop 

+

356 

+

357 def _test_time_expired(self, ts): 

+

358 ''' 

+

359 Generic function to test if time has expired. For a state 'wait', the function looks up the  

+

360 variable 'wait_time' and uses that as a time. 

+

361 ''' 

+

362 state_time_var_name = self.state + '_time' 

+

363 try: 

+

364 state_time = getattr(self, state_time_var_name) 

+

365 except AttributeError: 

+

366 raise AttributeError("Cannot find attribute %s; may not be able to use generic time_expired event for state %s" % (state_time_var_name, self.state)) 

+

367 

+

368 assert isinstance(state_time, (float, int)) 

+

369 return ts > state_time 

+

370 

+

371 # UI interaction functions ----------------------------------------------- 

+

372 @classmethod 

+

373 def _time_to_string(cls, sec): 

+

374 ''' 

+

375 Convert a time in seconds to a string of format hh:mm:ss. 

+

376 ''' 

+

377 nhours = int(sec/(min_per_hour * sec_per_min)) 

+

378 nmins = int((sec-nhours*min_per_hour*sec_per_min)/sec_per_min) 

+

379 nsecs = int(sec - nhours*min_per_hour*sec_per_min - nmins*sec_per_min) 

+

380 return str(nhours).zfill(2) + ':' + str(nmins).zfill(2) + ':' + str(nsecs).zfill(2) 

+

381 

+

382 def update_report_stats(self): 

+

383 ''' 

+

384 Function to update any relevant report stats for the task. Values are saved in self.reportstats, 

+

385 an ordered dictionary. Keys are strings that will be displayed as the label for the stat in the web interface, 

+

386 values can be numbers or strings. Called every time task state changes. 

+

387 ''' 

+

388 self.reportstats['Runtime'] = self._time_to_string(self.get_time() - self.task_start_time) 

+

389 

+

390 @classmethod 

+

391 def offline_report(self, event_log): 

+

392 '''Returns an ordered dict with report stats to be displayed when past session of this task is selected 

+

393 in the web interface. Not called while task is running, only offline, so stats must come from information 

+

394 available in a sessions event log. Inputs are task object and event_log.''' 

+

395 offline_report = collections.OrderedDict() 

+

396 if len(event_log) == 0: 

+

397 explength = 0 

+

398 else: 

+

399 explength = event_log[-1][-1] - event_log[0][-1] 

+

400 offline_report['Runtime'] = self._time_to_string(explength) 

+

401 n_trials = 0 

+

402 n_success_trials = 0 

+

403 n_error_trials = 0 

+

404 for k, (state, event, t) in enumerate(event_log): 

+

405 if state == "reward": 

+

406 n_trials += 1 

+

407 n_success_trials += 1 

+

408 elif re.match('.*?_penalty', state): 

+

409 n_trials += 1 

+

410 n_error_trials += 1 

+

411 offline_report['Total trials'] = n_trials 

+

412 offline_report['Total rewards'] = n_success_trials 

+

413 try: 

+

414 offline_report['Rewards/min'] = np.round((n_success_trials/explength) * 60, decimals=2) 

+

415 except: 

+

416 offline_report['Rewards/min'] = 0 

+

417 if n_trials == 0: 

+

418 offline_report['Success rate'] = None 

+

419 else: 

+

420 offline_report['Success rate'] = str(np.round(float(n_success_trials)/n_trials*100,decimals=2)) + '%' 

+

421 return offline_report 

+

422 

+

423 def record_annotation(self, msg): 

+

424 """ Record a user-input annotation """ 

+

425 pass 

+

426 

+

427 # UI cleanup functions --------------------------------------------------- 

+

428 def cleanup(self, database, saveid, **kwargs): 

+

429 ''' 

+

430 Commands to execute at the end of a task. 

+

431 

+

432 Parameters 

+

433 ---------- 

+

434 database : object 

+

435 Needs to have the methods save_bmi, save_data, etc. For instance, the db.tracker.dbq module or an RPC representation of the database 

+

436 saveid : int 

+

437 TaskEntry database record id to link files/data to 

+

438 kwargs : optional dict arguments 

+

439 Optional arguments to dbq methods. NOTE: kwargs cannot be used when 'database' is an RPC object. 

+

440 

+

441 Returns 

+

442 ------- 

+

443 None 

+

444 ''' 

+

445 if self.verbose: print("experimient.Experiment.cleanup executing") 

+

446 

+

447 def cleanup_hdf(self): 

+

448 '''  

+

449 Method for adding data to hdf file after hdf sink is closed by  

+

450 system at end of task. The HDF file is re-opened and any extra task  

+

451 data kept in RAM is written 

+

452 ''' 

+

453 if hasattr(self, "h5file"): 

+

454 traits = self.class_editable_traits() 

+

455 

+

456 if hasattr(tables, 'open_file'): # function name depends on version 

+

457 h5file = tables.open_file(self.h5file.name, mode='a') 

+

458 else: 

+

459 h5file = tables.openFile(self.h5file.name, mode='a') 

+

460 

+

461 for trait in traits: 

+

462 if (trait not in self.object_trait_names): # don't save traits which are complicated python objects to the HDF file # and (trait not in ['bmi', 'decoder', 'ref_trajectories']): 

+

463 h5file.root.task.attrs[trait] = getattr(self, trait) 

+

464 h5file.close() 

+

465 

+

466 def terminate(self): 

+

467 ''' 

+

468 Cleanup commands for tasks executed using the "test" button 

+

469 ''' 

+

470 pass 

+

471 

+

472 

+

473class LogExperiment(Experiment): 

+

474 ''' 

+

475 Extension of the experiment class which logs state transitions 

+

476 ''' 

+

477 trial_end_states = [] 

+

478 

+

479 def cleanup(self, database, saveid, **kwargs): 

+

480 ''' 

+

481 Commands to execute at the end of a task.  

+

482 Save the task event log to the database 

+

483 

+

484 see riglib.Experiment.cleanup for argument descriptions 

+

485 ''' 

+

486 if self.verbose: print("experiment.LogExperiment.cleanup") 

+

487 super(LogExperiment, self).cleanup(database, saveid, **kwargs) 

+

488 dbname = kwargs['dbname'] if 'dbname' in kwargs else 'default' 

+

489 if dbname == 'default': 

+

490 database.save_log(saveid, self.event_log) 

+

491 else: 

+

492 database.save_log(saveid, self.event_log, dbname=dbname) 

+

493 

+

494 ########################################################## 

+

495 ##### Functions to calculate statistics from the log ##### 

+

496 ########################################################## 

+

497 def calc_state_occurrences(self, state_name): 

+

498 ''' 

+

499 Calculate the number of times the task enters a particular state 

+

500 

+

501 Parameters 

+

502 ---------- 

+

503 state_name: string 

+

504 Name of state to track 

+

505 

+

506 Returns 

+

507 ------- 

+

508 Counts of state occurrences  

+

509 ''' 

+

510 times = np.array([state[1] for state in self.state_log if state[0] == state_name]) 

+

511 return len(times) 

+

512 

+

513 def calc_trial_num(self): 

+

514 ''' 

+

515 Counts the number of trials which have finished. 

+

516 ''' 

+

517 trialtimes = [state[1] for state in self.state_log if state[0] in self.trial_end_states] 

+

518 return len(trialtimes) 

+

519 

+

520 def calc_events_per_min(self, event_name, window): 

+

521 ''' 

+

522 Calculates the rate of event_name, per minute 

+

523 

+

524 Parameters 

+

525 ---------- 

+

526 event_name: string 

+

527 Name of state representing "event" 

+

528 window: float 

+

529 Number of seconds into the past to look to calculate the current event rate estimate. 

+

530 

+

531 Returns 

+

532 ------- 

+

533 rate : float 

+

534 Rate of specified event, per minute 

+

535 ''' 

+

536 rewardtimes = np.array([state[1] for state in self.state_log if state[0]==event_name]) 

+

537 if (self.get_time() - self.task_start_time) < window: 

+

538 divideby = (self.get_time() - self.task_start_time)/sec_per_min 

+

539 else: 

+

540 divideby = window/sec_per_min 

+

541 return np.sum(rewardtimes >= (self.get_time() - window))/divideby 

+

542 

+

543 

+

544class Sequence(LogExperiment): 

+

545 ''' 

+

546 Task where the targets or other information relevant to the start of each trial 

+

547 are presented by a Python generator 

+

548 ''' 

+

549 

+

550 # List of staticmethods of the class which can be used to create a sequence of targets for each trial 

+

551 sequence_generators = [] 

+

552 

+

553 @classmethod 

+

554 def get_default_seq_generator(cls): 

+

555 ''' 

+

556 Define a default sequence generator as the first one listed in the 'sequence_generators' attribute 

+

557 ''' 

+

558 return getattr(cls, cls.sequence_generators[0]) 

+

559 

+

560 def __init__(self, gen=None, *args, **kwargs): 

+

561 ''' 

+

562 Constructor for Sequence 

+

563 

+

564 Parameters 

+

565 ---------- 

+

566 gen : Python generator 

+

567 Object with a 'next' attribute used in the special "wait" state to get the target sequence for the next trial. 

+

568 kwargs: optonal keyword-arguments 

+

569 Passed to the super constructor 

+

570 

+

571 Returns 

+

572 ------- 

+

573 Sequence instance 

+

574 ''' 

+

575 if gen is None: 

+

576 raise ValueError("Experiment classes which inherit from Sequence must specify a target generator!") 

+

577 

+

578 if np.iterable(gen): 

+

579 from .generate import runseq 

+

580 self.gen = runseq(self, seq=gen) 

+

581 elif hasattr(gen, '__next__'): # python 3 renamed 'next' to '__next__' 

+

582 self.gen = gen 

+

583 else: 

+

584 raise ValueError("Input argument to Sequence 'gen' must be of 'generator' type!") 

+

585 

+

586 super(Sequence, self).__init__(*args, **kwargs) 

+

587 

+

588 def _start_wait(self): 

+

589 ''' 

+

590 At the start of the wait state, the generator (self.gen) is querried for  

+

591 new information needed to start the trial. If the generator runs out, the task stops.  

+

592 ''' 

+

593 if self.debug: 

+

594 print("_start_wait") 

+

595 

+

596 try: 

+

597 self.next_trial = next(self.gen) 

+

598 except StopIteration: 

+

599 self.end_task() 

+

600 

+

601 self._parse_next_trial() 

+

602 

+

603 def _parse_next_trial(self): 

+

604 ''' 

+

605 Interpret the data coming from the generator. If the generator yields a dictionary,  

+

606 then the keys of the dictionary automatically get set as attributes. 

+

607 

+

608 Over-ride or add additional code in child classes if different behavior is desired. 

+

609 ''' 

+

610 if isinstance(self.next_trial, dict): 

+

611 for key in self.next_trial: 

+

612 setattr(self, '_gen_%s' % key, self.next_trial[key]) 

+

613 

+

614 

+

615class TrialTypes(Sequence): 

+

616 ''' 

+

617 This module is deprecated, used by some older tasks (dots, rds) 

+

618 ''' 

+

619 trial_types = [] 

+

620 

+

621 status = dict( 

+

622 wait = dict(start_trial="picktrial", premature="penalty", stop=None), 

+

623 reward = dict(post_reward="wait"), 

+

624 penalty = dict(post_penalty="wait"), 

+

625 ) 

+

626 

+

627 def __init__(self, gen, **kwargs): 

+

628 super(TrialTypes, self).__init__(gen, **kwargs) 

+

629 assert len(self.trial_types) > 0 

+

630 

+

631 for ttype in self.trial_types: 

+

632 self.status[ttype] = { 

+

633 "%s_correct"%ttype :"reward", 

+

634 "%s_incorrect"%ttype :"incorrect", 

+

635 "timeout":"incorrect" } 

+

636 #Associate all trial type endings to the end_trial function defined by Sequence 

+

637 #setattr(self, "_end_%s"%ttype, self._end_trial) 

+

638 

+

639 def _start_picktrial(self): 

+

640 self.set_state(self.next_trial) 

+

641 

+

642 def _start_incorrect(self): 

+

643 self.set_state("penalty") 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_generate_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_generate_py.html new file mode 100644 index 00000000..4abf347d --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_generate_py.html @@ -0,0 +1,203 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\experiment\generate.py: 21% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Various generic generators to combine with tasks. These appear to be mostly deprecated at this point 

+

3''' 

+

4 

+

5import random 

+

6import itertools 

+

7 

+

8import numpy as np 

+

9 

+

10from .experiment import TrialTypes, Sequence 

+

11 

+

12 

+

13def block_random(*args, **kwargs): 

+

14 ''' 

+

15 A generic block randomizer.  

+

16 

+

17 Parameters 

+

18 ---------- 

+

19 

+

20 Returns 

+

21 ------- 

+

22 seq: list 

+

23 Block-random sequence of items where the length of each block is the product of the length of each the parameters being varied 

+

24 ''' 

+

25 n_blocks = kwargs.pop('nblocks') 

+

26 inds = [np.arange(len(arg)) for arg in args] 

+

27 from itertools import product 

+

28 items = [] 

+

29 for x in product(*inds): 

+

30 item = [arg[i] for arg,i in zip(args, x)] 

+

31 items.append(item) 

+

32 

+

33 n_items = len(items) 

+

34 seq = [] 

+

35 for k in range(n_blocks): 

+

36 inds = np.arange(n_items) 

+

37 np.random.shuffle(inds) 

+

38 for i in inds: 

+

39 seq.append(items[i]) 

+

40 

+

41 return seq 

+

42 

+

43def runseq(exp, seq=None, reps=1): 

+

44 ''' 

+

45 Turns a sequence into a Python generator by iterating through the sequence and yielding each sequence element 

+

46 

+

47 Parameters 

+

48 ---------- 

+

49 exp: Class object for task 

+

50 Used only if the experiment has 'trial_types' instead of a target sequence list 

+

51 seq: iterable 

+

52 Target sequence yielded to the task during the 'wait' state 

+

53 reps: int, optional, default=1 

+

54 Number of times to repeat the sequence 

+

55 

+

56 Returns 

+

57 ------- 

+

58 Generator object corresponding to sequence 

+

59 ''' 

+

60 if hasattr(exp, "trial_types"): 

+

61 assert max(seq)+1 == len(exp.trial_types) 

+

62 for _ in range(reps): 

+

63 for s in seq: 

+

64 yield exp.trial_types[s] 

+

65 else: 

+

66 print("runseq") 

+

67 for _ in range(reps): 

+

68 for s in seq: 

+

69 yield s 

+

70 

+

71################################################## 

+

72##### Old functions, for use with TrialTypes ##### 

+

73################################################## 

+

74def endless(exp, probs=None): 

+

75 ''' 

+

76 Deprecated 

+

77 ''' 

+

78 if probs is None: 

+

79 while True: 

+

80 yield random.choice(exp.trial_types) 

+

81 else: 

+

82 assert len(probs) == len(exp.trial_types) 

+

83 probs = np.insert(np.cumsum(_fix_missing(probs)), 0, 0) 

+

84 assert probs[-1] == 1, "Probabilities do not add up to 1!" 

+

85 while True: 

+

86 rand = random.random() 

+

87 p = np.nonzero(rand < probs)[0].min() 

+

88 yield exp.trial_types[p-1] 

+

89 

+

90def sequence(length, probs=2): 

+

91 ''' 

+

92 Deprecated 

+

93 ''' 

+

94 try: 

+

95 opts = len(probs) 

+

96 probs = _fix_missing(probs) 

+

97 except TypeError: 

+

98 opts = probs 

+

99 probs = [1 / float(opts)] * opts 

+

100 return np.random.permutation([i for i, p in enumerate(probs) for _ in range(int(length*p))]) 

+

101 

+

102def _fix_missing(probs): 

+

103 ''' 

+

104 Deprecated 

+

105 ''' 

+

106 total, n = list(map(sum, list(zip(*((i, 1) for i in probs if i is not None))))) 

+

107 if n < len(probs): 

+

108 p = (1 - total) / (len(probs) - n) 

+

109 probs = [i or p for i in probs] 

+

110 return probs 

+

111 

+

112class AdaptiveTrials(object): 

+

113 ''' 

+

114 Deprecated 

+

115 ''' 

+

116 def __init__(self, exp, blocklen=8): 

+

117 assert issubclass(exp, TrialTypes) 

+

118 self.blocklen = blocklen 

+

119 self.trial_types = exp.trial_types 

+

120 self.new_block() 

+

121 

+

122 def new_block(self): 

+

123 perblock = self.blocklen / len(self.trial_types) 

+

124 block = [[t]*perblock for t in self.trial_types] 

+

125 self.block = list(itertools.chain(*block)) 

+

126 random.shuffle(self.block) 

+

127 

+

128 def __next__(self): 

+

129 if len(self.block) < 1: 

+

130 self.new_block() 

+

131 return self.block[0] 

+

132 

+

133 def correct(self): 

+

134 self.block.pop(0) 

+

135 

+

136 def incorrect(self): 

+

137 ondeck = self.block.pop(0) 

+

138 self.block.append(ondeck) 

+

139 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_mocks_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_mocks_py.html new file mode 100644 index 00000000..df53bff3 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_mocks_py.html @@ -0,0 +1,225 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\experiment\mocks.py: 67% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' Mock experiment classes for testing ''' 

+

2from .experiment import LogExperiment, FSMTable, StateTransitions, Sequence 

+

3 

+

4event1to2 = [False, True, False, False, False, False, False, False] 

+

5event1to3 = [False, False, False, False, True, False, False, False] 

+

6event2to3 = [False, False, True, False, False, False, False, False] 

+

7event2to1 = [False, False, False, False, False, False, False, False] 

+

8event3to2 = [False, False, False, False, False, True, False, False] 

+

9event3to1 = [False, False, False, True, False, False, False, False] 

+

10 

+

11class MockLogExperiment(LogExperiment): 

+

12 status = FSMTable( 

+

13 state1=StateTransitions(event1to2='state2', event1to3='state3'), 

+

14 state2=StateTransitions(event2to3='state3', event2to1='state1'), 

+

15 state3=StateTransitions(event3to2='state2', event3to1='state1'), 

+

16 ) 

+

17 state = 'state1' 

+

18 

+

19 def __init__(self, *args, **kwargs): 

+

20 self.iter_idx = 0 

+

21 super(MockLogExperiment, self).__init__(*args, **kwargs) 

+

22 

+

23 def _cycle(self): 

+

24 self.iter_idx += 1 

+

25 super(MockLogExperiment, self)._cycle() 

+

26 

+

27 def _start_state3(self): pass 

+

28 def _while_state3(self): pass 

+

29 def _end_state3(self): pass 

+

30 def _start_state2(self): pass 

+

31 def _while_state2(self): pass 

+

32 def _end_state2(self): pass 

+

33 def _start_state1(self): pass 

+

34 def _while_state1(self): pass 

+

35 def _end_state1(self): pass 

+

36 ################## State trnasition test functions ################## 

+

37 def _test_event3to1(self, time_in_state): return event3to1[self.iter_idx] 

+

38 def _test_event3to2(self, time_in_state): return event3to2[self.iter_idx] 

+

39 def _test_event2to3(self, time_in_state): return event2to3[self.iter_idx] 

+

40 def _test_event2to1(self, time_in_state): return event2to1[self.iter_idx] 

+

41 def _test_event1to3(self, time_in_state): return event1to3[self.iter_idx] 

+

42 def _test_event1to2(self, time_in_state): return event1to2[self.iter_idx] 

+

43 def _test_stop(self, time_in_state): 

+

44 return self.iter_idx >= len(event1to2) - 1 

+

45 

+

46 def get_time(self): 

+

47 return self.iter_idx 

+

48 

+

49 

+

50 

+

51class MockSequence(Sequence): 

+

52 event1to2 = [False, True, False, True, False, True, False, True, False, True, False] 

+

53 event1to3 = [False, False, False, False, False, False, False, False, False, False, False] 

+

54 event2to3 = [False, False, False, False, False, False, False, False, False, False, False] 

+

55 event2to1 = [False, False, True, False, True, False, True, False, True, False, False] 

+

56 event3to2 = [False, False, False, False, False, False, False, False, False, False, False] 

+

57 event3to1 = [False, False, False, False, False, False, False, False, False, False, False] 

+

58 

+

59 status = FSMTable( 

+

60 wait=StateTransitions(event1to2='state2', event1to3='state3'), 

+

61 state2=StateTransitions(event2to3='state3', event2to1='wait'), 

+

62 state3=StateTransitions(event3to2='state2', event3to1='wait'), 

+

63 ) 

+

64 state = 'wait' 

+

65 

+

66 def __init__(self, *args, **kwargs): 

+

67 self.iter_idx = 0 

+

68 self.target_history = [] 

+

69 super(MockSequence, self).__init__(*args, **kwargs) 

+

70 

+

71 def _cycle(self): 

+

72 self.iter_idx += 1 

+

73 super(MockSequence, self)._cycle() 

+

74 

+

75 def _start_state3(self): pass 

+

76 def _while_state3(self): pass 

+

77 def _end_state3(self): pass 

+

78 def _start_state2(self): pass 

+

79 def _while_state2(self): pass 

+

80 def _end_state2(self): pass 

+

81 def _start_state1(self): pass 

+

82 def _while_state1(self): pass 

+

83 def _end_state1(self): pass 

+

84 ################## State trnasition test functions ################## 

+

85 def _test_event3to1(self, time_in_state): return self.event3to1[self.iter_idx] 

+

86 def _test_event3to2(self, time_in_state): return self.event3to2[self.iter_idx] 

+

87 def _test_event2to3(self, time_in_state): return self.event2to3[self.iter_idx] 

+

88 def _test_event2to1(self, time_in_state): return self.event2to1[self.iter_idx] 

+

89 def _test_event1to3(self, time_in_state): return self.event1to3[self.iter_idx] 

+

90 def _test_event1to2(self, time_in_state): return self.event1to2[self.iter_idx] 

+

91 def _test_stop(self, time_in_state): 

+

92 return self.iter_idx >= len(event1to2) - 1 

+

93 

+

94 def get_time(self): 

+

95 return self.iter_idx 

+

96 

+

97 def _start_wait(self): 

+

98 super(MockSequence, self)._start_wait() 

+

99 self.target_history.append(self.next_trial) 

+

100 

+

101 

+

102class MockSequenceWithGenerators(Sequence): 

+

103 status = FSMTable( 

+

104 wait=StateTransitions(start_trial="trial"), 

+

105 trial=StateTransitions(target_reached="reward"), 

+

106 reward=StateTransitions(reward_complete="wait"), 

+

107 ) 

+

108 state = 'wait' 

+

109 

+

110 sequence_generators = ['gen_fn1', 'gen_fn2'] 

+

111 

+

112 def __init__(self, *args, **kwargs): 

+

113 self.target_history = [] 

+

114 self.sim_state_seq = [] 

+

115 for k in range(4): 

+

116 self.sim_state_seq += [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2] 

+

117 

+

118 self.current_state = None 

+

119 super().__init__(*args, **kwargs) 

+

120 

+

121 def init(self): 

+

122 self.add_dtype("target_state", int, (1,)) 

+

123 self.add_dtype("current_state", int, (1,)) 

+

124 super().init() 

+

125 

+

126 def get_time(self): 

+

127 return self.cycle_count * 1.0/60 

+

128 

+

129 @staticmethod 

+

130 def gen_fn1(n_targets=4): 

+

131 target_seq = [1, 2] * n_targets 

+

132 return [{"target":x} for x in target_seq] 

+

133 

+

134 @staticmethod 

+

135 def gen_fn2(n_targets=4): 

+

136 target_seq = [1, 2] * n_targets 

+

137 return [{"target":x} for x in target_seq] 

+

138 

+

139 def _test_start_trial(self, ts): 

+

140 return True 

+

141 

+

142 def _test_reward_complete(self, ts): 

+

143 return True 

+

144 

+

145 def _cycle(self): 

+

146 self.current_state = self.sim_state_seq[self.cycle_count % len(self.sim_state_seq)] 

+

147 self.task_data["target_state"] = self._gen_target 

+

148 self.task_data["current_state"] = self.current_state 

+

149 

+

150 if self.cycle_count == 21: 

+

151 self.record_annotation("test annotation") 

+

152 

+

153 super()._cycle() 

+

154 

+

155 def _test_target_reached(self, ts): 

+

156 return self.current_state == self._gen_target 

+

157 

+

158from . import traits 

+

159class MockSequenceWithTraits(MockSequenceWithGenerators): 

+

160 options_trait = traits.OptionsList(["option1", "option2"], desc='options', label="Options") 

+

161 float_trait = traits.Float(15, desc="float", label="Float") 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_report_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_report_py.html new file mode 100644 index 00000000..8d33c71b --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_experiment_report_py.html @@ -0,0 +1,152 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\experiment\report.py: 14% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Generate reporting stats for completed/ongoing tasks 

+

3''' 

+

4 

+

5 

+

6import numpy as np 

+

7from .experiment import LogExperiment, TrialTypes 

+

8import re 

+

9 

+

10def trialtype(exp): 

+

11 ''' 

+

12 Docstring 

+

13 

+

14 Parameters 

+

15 ---------- 

+

16 

+

17 Returns 

+

18 ------- 

+

19 ''' 

+

20 assert isinstance(exp, LogExperiment), "Cannot report on non-logged experiments" 

+

21 ttypes = exp.trial_types if isinstance(exp, TrialTypes) else ["trial"] 

+

22 trials = dict([(t, dict(correct=[], incorrect=[], timeout=[])) for t in ttypes]) 

+

23 report = dict(rewards=0, prematures=[], trials=trials) 

+

24 trial = None 

+

25 

+

26 for state, event, t in exp.event_log: 

+

27 if trial is not None: 

+

28 if "incorrect" in event: 

+

29 report['trials'][state]["incorrect"].append(t - trial) 

+

30 elif "correct" in event: 

+

31 report['trials'][state]["correct"].append(t - trial) 

+

32 elif "timeout" in event: 

+

33 report['trials'][state]["timeout"].append(t - trial) 

+

34 trial = None 

+

35 

+

36 if event == "start_trial": 

+

37 trial = t 

+

38 elif state == "reward": 

+

39 report['rewards'] += 1 

+

40 

+

41 return report 

+

42 

+

43def general(Exp, event_log, repdict, ntrials, nrewards, reward_len): 

+

44 ''' 

+

45 Docstring 

+

46 

+

47 Parameters 

+

48 ---------- 

+

49 

+

50 Returns 

+

51 ------- 

+

52 ''' 

+

53 report = dict(trials = 0, reward_len=[]) 

+

54 

+

55 report['trials'] = ntrials 

+

56 report['reward_len'] = reward_len, nrewards 

+

57 

+

58 if report['trials'] == 0: 

+

59 report['rates'] = (0,0) 

+

60 else: 

+

61 report['rates'] = (float(nrewards)/ntrials, 1-(float(nrewards)/ntrials)) 

+

62 

+

63 return report 

+

64 

+

65def print_report(report): 

+

66 '''Prints a report generated by report(exp) 

+

67 

+

68 Docstring 

+

69 

+

70 Parameters 

+

71 ---------- 

+

72 

+

73 Returns 

+

74 ------- 

+

75 ''' 

+

76 repstr = ["%8s: %d"%("rewards", report['rewards'])] 

+

77 ttrial = 0 

+

78 for tname, tdict in list(report['trials'].items()): 

+

79 total = len(tdict['correct']) + len(tdict['incorrect']) + len(tdict['timeout']) 

+

80 ttrial += total 

+

81 if total == 0: 

+

82 cper = 0 

+

83 else: 

+

84 cper = float(len(tdict['correct'])) / total * 100 

+

85 cRT = np.mean(tdict['correct']) 

+

86 repstr.append("%8s: %g%%, RT=%g"%(tname, cper, cRT)) 

+

87 repstr.insert(0, "%8s: %d"%("total", ttrial)) 

+

88 print("\n".join(repstr)) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_eyetracker_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_eyetracker_py.html new file mode 100644 index 00000000..e3fde1fc --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_eyetracker_py.html @@ -0,0 +1,298 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\eyetracker.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Base code for including the 'eyetracker' features in experiments 

+

3''' 

+

4 

+

5 

+

6import time 

+

7import itertools 

+

8import numpy as np 

+

9try: 

+

10 import pylink 

+

11except ImportError: 

+

12 print("Couldn't find eyetracker module") 

+

13 

+

14class Simulate(object): 

+

15 ''' 

+

16 Feature (task add-on) to simulate the eyetracker. 

+

17 ''' 

+

18 update_freq = 500 

+

19 dtype = np.dtype((np.float, (2,))) 

+

20 

+

21 def __init__(self, fixations=[(0,0), (-0.6,0.3), (0.6,0.3)], isi=500, slen=15): 

+

22 ''' 

+

23 Docstring 

+

24 

+

25 Parameters 

+

26 ---------- 

+

27 

+

28 Returns 

+

29 ------- 

+

30 ''' 

+

31 from scipy.interpolate import interp1d 

+

32 flen = list(range(len(fixations)+1)) 

+

33 t = list(itertools.chain(*[(i*isi + slen*i, (i+1)*isi + slen*i) for i in flen]))[:-1] 

+

34 xy = np.append(np.tile(fixations, (1, 2)).reshape(-1, 2), [fixations[0]], axis=0) 

+

35 self.mod = t[-1] / 1000. 

+

36 self.interp = interp1d(np.array(t)/1000., xy, kind='linear', axis=0) 

+

37 self.fixations = fixations 

+

38 self.isi = isi 

+

39 

+

40 def start(self): 

+

41 

+

42 ''' 

+

43 Docstring 

+

44 

+

45 Parameters 

+

46 ---------- 

+

47 

+

48 Returns 

+

49 ------- 

+

50 ''' 

+

51 print("eyetracker.simulate.start()") 

+

52 self.stime = time.time() 

+

53 

+

54 def retrieve(self, filename): 

+

55 ''' 

+

56 for sim, there is no need to retrieve an file 

+

57 

+

58 Parameters 

+

59 ---------- 

+

60 

+

61 Returns 

+

62 ------- 

+

63 ''' 

+

64 pass 

+

65 

+

66 def get(self): 

+

67 ''' 

+

68 Docstring 

+

69 

+

70 Parameters 

+

71 ---------- 

+

72 

+

73 Returns 

+

74 ------- 

+

75 ''' 

+

76 time.sleep(1./self.update_freq) 

+

77 

+

78 data = self.interp((time.time() - self.stime) % self.mod) + np.random.randn(2)*.01 

+

79 #expand dims 

+

80 data_2 = np.expand_dims(data, axis = 0) 

+

81 return data_2 

+

82 

+

83 def stop(self): 

+

84 ''' 

+

85 Docstring 

+

86 

+

87 Parameters 

+

88 ---------- 

+

89 

+

90 Returns 

+

91 ------- 

+

92 ''' 

+

93 return 

+

94 

+

95 def sendMsg(self, msg): 

+

96 ''' 

+

97 Docstring 

+

98 

+

99 Parameters 

+

100 ---------- 

+

101 

+

102 Returns 

+

103 ------- 

+

104 ''' 

+

105 pass 

+

106 

+

107class System(object): 

+

108 ''' 

+

109 System representing the EyeLink eyetracker. Compatible with riglib.source.DataSource 

+

110 ''' 

+

111 update_freq = 500 

+

112 dtype = np.dtype((np.float, (2,))) 

+

113 

+

114 def __init__(self, address='10.0.0.2'): 

+

115 ''' 

+

116 Constructor for the System representing the EyeLink eyetracker 

+

117 

+

118 Parameters 

+

119 ---------- 

+

120 address: IP address string  

+

121 IP address of the EyeLink host machine 

+

122 

+

123 Returns 

+

124 ------- 

+

125 System instance 

+

126 ''' 

+

127 self.tracker = pylink.EyeLink(address) 

+

128 self.tracker.setOfflineMode() 

+

129 

+

130 def start(self, filename=None): 

+

131 ''' 

+

132 Docstring 

+

133 

+

134 Parameters 

+

135 ---------- 

+

136 

+

137 Returns 

+

138 ------- 

+

139 ''' 

+

140 print("eyetracker.System.start()") 

+

141 self.filename = filename 

+

142 if filename is None: 

+

143 self.filename = "%s.edf"%time.strftime("%Y%m%d") #%Y-%m-%d_%I:%M:%p 

+

144 self.tracker.openDataFile(self.filename) 

+

145 # pylink.beginRealTimeMode(100) 

+

146 print("\n\ntracker.startRecording") 

+

147 self.tracker.startRecording(1,0,1,0) 

+

148 

+

149 def stop(self): 

+

150 ''' 

+

151 Docstring 

+

152 

+

153 Parameters 

+

154 ---------- 

+

155 

+

156 Returns 

+

157 ------- 

+

158 ''' 

+

159 self.tracker.stopRecording() 

+

160 pylink.endRealTimeMode() 

+

161 

+

162 def get(self): 

+

163 ''' 

+

164 Docstring 

+

165 

+

166 Parameters 

+

167 ---------- 

+

168 

+

169 Returns 

+

170 ------- 

+

171 ''' 

+

172 samp = self.tracker.getNextData() 

+

173 while samp != pylink.SAMPLE_TYPE: 

+

174 time.sleep(.001) 

+

175 samp = self.tracker.getNextData() 

+

176 try: 

+

177 data = np.array(self.tracker.getFloatData().getLeftEye().getGaze()) 

+

178 if data.sum() < -1e4: 

+

179 return np.array([np.nan, np.nan]) 

+

180 except: 

+

181 return np.array([np.nan, np.nan]) 

+

182 

+

183 return data 

+

184 

+

185 def set_filter(self, filt): 

+

186 ''' 

+

187 Docstring 

+

188 

+

189 Parameters 

+

190 ---------- 

+

191 

+

192 Returns 

+

193 ------- 

+

194 ''' 

+

195 self.filter = filt 

+

196 

+

197 def retrieve(self, filename): 

+

198 ''' 

+

199 Docstring 

+

200 

+

201 Parameters 

+

202 ---------- 

+

203 

+

204 Returns 

+

205 ------- 

+

206 ''' 

+

207 self.tracker.setOfflineMode() 

+

208 pylink.msecDelay(1) 

+

209 self.tracker.closeDataFile() 

+

210 self.tracker.receiveDataFile(self.filename, filename) 

+

211 

+

212 def sendMsg(self, msg): 

+

213 ''' 

+

214 Docstring 

+

215 

+

216 Parameters 

+

217 ---------- 

+

218 

+

219 Returns 

+

220 ------- 

+

221 ''' 

+

222 self.tracker.sendMessage(msg) 

+

223 

+

224 def __del__(self): 

+

225 ''' 

+

226 Docstring 

+

227 

+

228 Parameters 

+

229 ---------- 

+

230 

+

231 Returns 

+

232 ------- 

+

233 ''' 

+

234 self.tracker.close() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_filter_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_filter_py.html new file mode 100644 index 00000000..41943fc4 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_filter_py.html @@ -0,0 +1,117 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\filter.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Generic class for implementing filters describable by rational z-transforms (ratio of polynomials) 

+

3''' 

+

4from scipy.signal import sigtools, lfilter 

+

5import numpy as np 

+

6 

+

7class Filter(object): 

+

8 def __init__(self, b=[], a=[1.]): 

+

9 ''' 

+

10 Constructor for Filter 

+

11 

+

12 Parameters 

+

13 ---------- 

+

14 b : array_like 

+

15 The numerator coefficient vector in a 1-D sequence. 

+

16 a : array_like 

+

17 The denominator coefficient vector in a 1-D sequence. If ``a[0]`` 

+

18 is not 1, then both `a` and `b` are normalized by ``a[0]``. 

+

19 

+

20 Returns 

+

21 ------- 

+

22 Filter instance 

+

23 ''' 

+

24 self.b = np.array(b) 

+

25 self.a = np.array(a) 

+

26 

+

27 # normalize the constants 

+

28 self.b /= self.a[0] 

+

29 self.a /= self.a[0] 

+

30 

+

31 self.zi = np.zeros(max(len(a), len(b))-1) 

+

32 

+

33 

+

34 def __call__(self, samples): 

+

35 ''' 

+

36 Run the filter parameters on the most recent set of samples 

+

37 

+

38 Parameters 

+

39 ---------- 

+

40 samples : np.ndarray of shape (N,) 

+

41 samples to filter (x input) 

+

42 

+

43 Returns 

+

44 ------- 

+

45 np.ndarray 

+

46 Most recent N outputs of the filter 

+

47 ''' 

+

48 # promote scalars to arrays so the lfilter function doesn't complain 

+

49 if isinstance(samples, (float, int)): 

+

50 samples = np.array([samples]) 

+

51 

+

52 filt_output, self.zi = lfilter(self.b, self.a, samples, zi=self.zi) 

+

53 return filt_output 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm___init___py.html new file mode 100644 index 00000000..3f09f209 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm___init___py.html @@ -0,0 +1,65 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\fsm\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1from .fsm.fsm import FSMTable, StateTransitions, Clock, FSM, ThreadedFSM 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_fsm___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_fsm___init___py.html new file mode 100644 index 00000000..2c4e618f --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_fsm___init___py.html @@ -0,0 +1,65 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\fsm\fsm\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1from .fsm import FSMTable, StateTransitions, Clock, FSM, ThreadedFSM 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_fsm_fsm_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_fsm_fsm_py.html new file mode 100644 index 00000000..2f7fc8e2 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_fsm_fsm_py.html @@ -0,0 +1,428 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\fsm\fsm\fsm.py: 39% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1"""Finite state machine implementation """ 

+

2import time 

+

3import random 

+

4import threading 

+

5import traceback 

+

6import collections 

+

7import re 

+

8import os 

+

9import tables 

+

10import traceback 

+

11import io 

+

12import numpy as np 

+

13 

+

14from collections import OrderedDict 

+

15 

+

16min_per_hour = 60. 

+

17sec_per_min = 60. 

+

18 

+

19 

+

20class FSMTable(object): 

+

21 def __init__(self, **kwargs): 

+

22 self.states = OrderedDict() 

+

23 for state_name, transitions in list(kwargs.items()): 

+

24 self.states[state_name] = transitions 

+

25 

+

26 def __getitem__(self, key): 

+

27 return self.states[key] 

+

28 

+

29 def get_possible_state_transitions(self, current_state): 

+

30 return list(self.states[current_state].items()) 

+

31 

+

32 def _lookup_next_state(self, current_state, transition_event): 

+

33 return self.states[current_state][transition_event] 

+

34 

+

35 def __iter__(self): 

+

36 return list(self.states.keys()).__iter__() 

+

37 

+

38 @staticmethod 

+

39 def construct_from_dict(status): 

+

40 outward_transitions = OrderedDict() 

+

41 for state in status: 

+

42 outward_transitions[state] = StateTransitions(stoppable=False, **status[state]) 

+

43 return FSMTable(**outward_transitions) 

+

44 

+

45 

+

46class StateTransitions(object): 

+

47 def __init__(self, stoppable=True, **kwargs): 

+

48 self.state_transitions = OrderedDict() 

+

49 for event, next_state in list(kwargs.items()): 

+

50 self.state_transitions[event] = next_state 

+

51 

+

52 if stoppable and not ('stop' in self.state_transitions): 

+

53 self.state_transitions['stop'] = None 

+

54 

+

55 def __getitem__(self, key): 

+

56 return self.state_transitions[key] 

+

57 

+

58 def __iter__(self): 

+

59 transition_events = list(self.state_transitions.keys()) 

+

60 return transition_events.__iter__() 

+

61 

+

62 def items(self): 

+

63 return list(self.state_transitions.items()) 

+

64 

+

65 

+

66class Clock(object): 

+

67 def tick(self, fps): 

+

68 import time 

+

69 time.sleep(1.0/fps) 

+

70 

+

71 

+

72class FSM(object): 

+

73 status = FSMTable( 

+

74 wait = StateTransitions(start_trial="trial", premature="penalty", stop=None), 

+

75 trial = StateTransitions(correct="reward", incorrect="penalty", timeout="penalty"), 

+

76 reward = StateTransitions(post_reward="wait"), 

+

77 penalty = StateTransitions(post_penalty="wait"), 

+

78 ) 

+

79 state = "wait" 

+

80 debug = False 

+

81 fps = 60 # frames per second 

+

82 

+

83 log_exclude = set() # List out state/trigger pairs to exclude from logging 

+

84 

+

85 def __init__(self, *args, **kwargs): 

+

86 self.verbose = kwargs.pop('verbose', False) 

+

87 

+

88 # state and event transitions 

+

89 self.state_log = [] 

+

90 self.event_log = [] 

+

91 

+

92 self.clock = Clock() 

+

93 

+

94 # Timestamp for rough loop timing 

+

95 self.last_time = self.get_time() 

+

96 self.cycle_count = 0 

+

97 

+

98 @property 

+

99 def update_rate(self): 

+

100 ''' 

+

101 Attribute for update rate of task. Using @property in case any future modifications 

+

102 decide to change fps on initialization 

+

103 ''' 

+

104 return 1./self.fps 

+

105 

+

106 def print_to_terminal(self, *args): 

+

107 ''' 

+

108 Print to the terminal rather than the websocket if the websocket is being used by the 'Notify' feature (see db.tasktrack) 

+

109 ''' 

+

110 if len(args) == 1: 

+

111 print(args[0]) 

+

112 else: 

+

113 print(args) 

+

114 

+

115 def init(self): 

+

116 '''Interface for child classes to run initialization code after object 

+

117 construction''' 

+

118 pass 

+

119 

+

120 def run(self): 

+

121 ''' 

+

122 Generic method to run the finite state machine of the task. Code that needs to execute  

+

123 imediately before the task starts running in child classes should be of the form: 

+

124 

+

125 def run(self): 

+

126 do stuff 

+

127 try: 

+

128 super(class_name, self).run() 

+

129 finally: 

+

130 clean up stuff 

+

131 

+

132 The try block may or may not be necessary. For example, if you're opening a UDP port, you may want to always 

+

133 close the socket whether or not the main loop executes properly so that you don't loose the  

+

134 reference to the socket.  

+

135 ''' 

+

136 

+

137 ## Initialize the FSM before the loop 

+

138 self.set_state(self.state) 

+

139 

+

140 while self.state is not None: 

+

141 if self.debug: 

+

142 # allow ungraceful termination if in debugging mode so that pdb  

+

143 # can catch the exception in the appropriate place 

+

144 self.fsm_tick() 

+

145 else: 

+

146 # in "production" mode (not debugging), try to capture & log errors gracefully 

+

147 try: 

+

148 self.fsm_tick() 

+

149 except: 

+

150 self.print_to_terminal("Error in FSM tick") 

+

151 self.state = None 

+

152 self.terminated_in_error = True 

+

153 

+

154 self.termination_err = io.StringIO() 

+

155 traceback.print_exc(None, self.termination_err) 

+

156 self.termination_err.seek(0) 

+

157 

+

158 self.print_to_terminal(self.termination_err.read()) 

+

159 self.termination_err.seek(0) 

+

160 if self.verbose: print("end of FSM.run, task state is", self.state) 

+

161 

+

162 def run_sync(self): 

+

163 self.init() 

+

164 self.run() 

+

165 

+

166 ########################################################### 

+

167 ##### Finite state machine (FSM) transition functions ##### 

+

168 ########################################################### 

+

169 def fsm_tick(self): 

+

170 ''' 

+

171 Execute the commands corresponding to a single tick of the event loop 

+

172 ''' 

+

173 # Execute commands 

+

174 self.exec_state_specific_actions(self.state) 

+

175 

+

176 # Execute the commands which must run every loop, independent of the FSM state 

+

177 # (e.g., running the BMI decoder) 

+

178 self._cycle() 

+

179 

+

180 current_state = self.state 

+

181 

+

182 # iterate over the possible events which could move the task out of the current state 

+

183 for event in self.status[current_state]: 

+

184 if self.test_state_transition_event(event): # if the event has occurred 

+

185 # execute commands to end the current state 

+

186 self.end_state(current_state) 

+

187 

+

188 # trigger the transition for the event 

+

189 self.trigger_event(event) 

+

190 

+

191 # stop searching for transition events (transition events must be  

+

192 # mutually exclusive for this FSM to function properly) 

+

193 break 

+

194 

+

195 def test_state_transition_event(self, event): 

+

196 event_test_fn_name = "_test_%s" % event 

+

197 if hasattr(self, event_test_fn_name): 

+

198 event_test_fn = getattr(self, event_test_fn_name) 

+

199 time_since_state_started = self.get_time() - self.start_time 

+

200 return event_test_fn(time_since_state_started) 

+

201 else: 

+

202 return False 

+

203 

+

204 def end_state(self, state): 

+

205 end_state_fn_name = "_end_%s" % state 

+

206 if hasattr(self, end_state_fn_name): 

+

207 end_state_fn = getattr(self, end_state_fn_name) 

+

208 end_state_fn() 

+

209 

+

210 def start_state(self, state): 

+

211 state_start_fn_name = "_start_%s" % state 

+

212 if hasattr(self, state_start_fn_name): 

+

213 state_start_fn = getattr(self, state_start_fn_name) 

+

214 state_start_fn() 

+

215 

+

216 def exec_state_specific_actions(self, state): 

+

217 if hasattr(self, "_while_%s" % state): 

+

218 getattr(self, "_while_%s" % state)() 

+

219 

+

220 def trigger_event(self, event): 

+

221 ''' 

+

222 Transition the task state to a new state, where the next state depends on the current state as well as the trigger event 

+

223 

+

224 Parameters 

+

225 ---------- 

+

226 event: string 

+

227 Based on the current state, a particular event will trigger a particular state transition (Mealy machine) 

+

228 

+

229 Returns 

+

230 ------- 

+

231 None 

+

232 ''' 

+

233 log = (self.state, event) not in self.log_exclude 

+

234 if log: 

+

235 self.event_log.append((self.state, event, self.get_time())) 

+

236 

+

237 fsm_edges = self.status[self.state] 

+

238 next_state = fsm_edges[event] 

+

239 self.set_state(next_state, log=log) 

+

240 

+

241 def set_state(self, condition, log=True): 

+

242 ''' 

+

243 Change the state of the task 

+

244 

+

245 Parameters 

+

246 ---------- 

+

247 condition: string 

+

248 Name of new state. The state name must be a key in the 'status' dictionary attribute of the task 

+

249 

+

250 Returns 

+

251 ------- 

+

252 None 

+

253 ''' 

+

254 # Record the time at which the new state is entered. Used for timed states, e.g., the reward state 

+

255 self.start_time = self.get_time() 

+

256 

+

257 if log: 

+

258 self.state_log.append((condition, self.start_time)) 

+

259 self.state = condition 

+

260 

+

261 self.start_state(condition) 

+

262 

+

263 def get_time(self): 

+

264 ''' 

+

265 Abstraction to get the current time. By default, state transitions are based on wall clock time, not on iteration count. 

+

266 To get simulations to run faster than real time, this function must be overwritten. 

+

267 

+

268 Returns 

+

269 ------- 

+

270 float: The current time in seconds 

+

271 ''' 

+

272 return time.time() 

+

273 

+

274 def _cycle(self): 

+

275 ''' 

+

276 Code that needs to run every task loop iteration goes here 

+

277 ''' 

+

278 self.cycle_count += 1 

+

279 if self.fps > 0: 

+

280 self.clock.tick(self.fps) 

+

281 

+

282 def iter_time(self): 

+

283 ''' 

+

284 Determine the time elapsed since the last time this function was called 

+

285 ''' 

+

286 start_time = self.get_time() 

+

287 loop_time = start_time - self.last_time 

+

288 self.last_time = start_time 

+

289 return loop_time 

+

290 

+

291 @classmethod 

+

292 def parse_fsm(cls): 

+

293 ''' 

+

294 Print out the FSM of the task in a semi-readable form 

+

295 ''' 

+

296 for state in cls.status: 

+

297 print('When in state "%s"' % state) 

+

298 for trigger_event, next_state in list(cls.status[state].items()): 

+

299 print('\tevent "%s" moves the task to state "%s"' % (trigger_event, next_state)) 

+

300 

+

301 @classmethod 

+

302 def auto_gen_fsm_functions(cls): 

+

303 ''' 

+

304 Parse the FSM to write all the _start, _end, _while, and _test functions 

+

305 ''' 

+

306 events_to_test = [] 

+

307 for state in cls.status: 

+

308 # make _start function  

+

309 print('''def _start_%s(self): pass''' % state) 

+

310 

+

311 # make _while function 

+

312 print('''def _while_%s(self): pass''' % state) 

+

313 # make _end function 

+

314 print('''def _end_%s(self): pass''' % state) 

+

315 for event, _ in cls.status.get_possible_state_transitions(state): 

+

316 events_to_test.append(event) 

+

317 

+

318 print("################## State trnasition test functions ##################") 

+

319 

+

320 for event in events_to_test: 

+

321 if event == 'stop': continue 

+

322 print('''def _test_%s(self, time_in_state): return False''' % event) 

+

323 

+

324 def end_task(self): 

+

325 ''' 

+

326 End the FSM gracefully on the next iteration by setting the task's "stop" flag. 

+

327 ''' 

+

328 self.stop = True 

+

329 

+

330 def _test_stop(self, ts): 

+

331 '''  

+

332 FSM 'test' function. Returns the 'stop' attribute of the task 

+

333 ''' 

+

334 return self.stop 

+

335 

+

336 

+

337class ThreadedFSM(FSM, threading.Thread): 

+

338 """ FSM + infrastructure to run FSM in its own thread """ 

+

339 def __init__(self): 

+

340 FSM.__init__(self) 

+

341 threading.Thread.__init__(self) 

+

342 

+

343 def start(self): 

+

344 ''' 

+

345 From the python docs on threading.Thread: 

+

346 Once a thread object is created, its activity must be started by  

+

347 calling the thread's start() method. This invokes the run() method in a  

+

348 separate thread of control. 

+

349 

+

350 Prior to the thread's start method being called, the secondary init function (self.init) is executed. 

+

351 After the threading.Thread.start is executed, the 'run' method is executed automatically in a separate thread. 

+

352 

+

353 Returns 

+

354 ------- 

+

355 None 

+

356 ''' 

+

357 self.init() 

+

358 threading.Thread.start(self) 

+

359 

+

360 def join(self): 

+

361 ''' 

+

362 Code to run before re-joining the FSM thread  

+

363 ''' 

+

364 threading.Thread.join(self) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_setup_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_setup_py.html new file mode 100644 index 00000000..61b7f604 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_fsm_setup_py.html @@ -0,0 +1,82 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\fsm\setup.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import setuptools 

+

2 

+

3with open("README.md", "r") as fh: 

+

4 long_description = fh.read() 

+

5 

+

6setuptools.setup( 

+

7 name="fsm", 

+

8 version="0.1.0", 

+

9 author="Carmena Lab", 

+

10 description="Python Finite State Machine", 

+

11 long_description=long_description, 

+

12 long_description_content_type="text/markdown", 

+

13 packages=setuptools.find_packages(), 

+

14 classifiers=[ 

+

15 "Programming Language :: Python :: 3", 

+

16 "Operating System :: OS Independent", 

+

17 ], 

+

18) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter___init___py.html new file mode 100644 index 00000000..9f8480f9 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter___init___py.html @@ -0,0 +1,68 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\hdfwriter\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1# This is the __init__.py file if you are using the HDFWriter from 

+

2# riglib, without doing its own setup.  

+

3from .hdfwriter.hdfwriter import MsgTable 

+

4from .hdfwriter.hdfwriter import HDFWriter 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_hdfwriter___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_hdfwriter___init___py.html new file mode 100644 index 00000000..6bb5991c --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_hdfwriter___init___py.html @@ -0,0 +1,66 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\hdfwriter\hdfwriter\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1from .hdfwriter import MsgTable 

+

2from .hdfwriter import HDFWriter 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_hdfwriter_hdfwriter_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_hdfwriter_hdfwriter_py.html new file mode 100644 index 00000000..08eae065 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_hdfwriter_hdfwriter_py.html @@ -0,0 +1,202 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\hdfwriter\hdfwriter\hdfwriter.py: 32% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Base code for 'saveHDF' feature in experiments for writing data to an HDF file during experiment 

+

3''' 

+

4 

+

5import tables 

+

6import numpy as np 

+

7 

+

8compfilt = tables.Filters(complevel=5, complib="zlib", shuffle=True) 

+

9 

+

10class MsgTable(tables.IsDescription): 

+

11 ''' 

+

12 Pytables custom table atom type used for the HDF tables named *_msgs 

+

13 ''' 

+

14 time = tables.UIntCol() 

+

15 msg = tables.StringCol(256) 

+

16 

+

17class HDFWriter(object): 

+

18 '''  

+

19 Used by the SaveHDF feature (features.hdf_features.SaveHDF) to save data  

+

20 to an HDF file in "real-time", as the task is running 

+

21 ''' 

+

22 def __init__(self, filename): 

+

23 ''' 

+

24 Constructor for HDFWriter 

+

25 

+

26 Parameters 

+

27 ---------- 

+

28 filename : string 

+

29 Name of file to use to send data 

+

30 

+

31 Returns 

+

32 ------- 

+

33 HDFWriter instance 

+

34 ''' 

+

35 print("HDFWriter: Saving datafile to %s"%filename) 

+

36 self.h5 = tables.open_file(filename, "w") 

+

37 self.data = {} 

+

38 self.msgs = {} 

+

39 self.f = [] 

+

40 

+

41 def register(self, name, dtype, include_msgs=True): 

+

42 ''' 

+

43 Create a table in the HDF file corresponding to the specified source name and data type 

+

44 

+

45 Parameters 

+

46 ---------- 

+

47 system : string 

+

48 Name of the system being registered 

+

49 dtype : np.dtype instance 

+

50 Datatype of incoming data, for later decoding of the binary data during analysis 

+

51 include_msgs : boolean, optional, default=True 

+

52 Flag to indicated whether a table should be created for "msgs" from the current source (default True) 

+

53 

+

54 Returns 

+

55 ------- 

+

56 None 

+

57 ''' 

+

58 print("HDFWriter registered %r" % name) 

+

59 print(dtype) 

+

60 if dtype.subdtype is not None: 

+

61 #just a simple dtype with a shape 

+

62 dtype, sliceshape = dtype.subdtype 

+

63 arr = self.h5.create_earray("/", name, tables.Atom.from_dtype(dtype), 

+

64 shape=(0,)+sliceshape, filters=compfilt) 

+

65 else: 

+

66 arr = self.h5.create_table("/", name, dtype, filters=compfilt) 

+

67 

+

68 self.data[name] = arr 

+

69 if include_msgs: 

+

70 msg = self.h5.create_table("/", name+"_msgs", MsgTable, filters=compfilt) 

+

71 self.msgs[name] = msg 

+

72 

+

73 def send(self, system, data): 

+

74 ''' 

+

75 Add a new row to the HDF table for 'system' and fill it with the 'data' values 

+

76 

+

77 Parameters 

+

78 ---------- 

+

79 system : string 

+

80 Name of system where the data originated 

+

81 data : object 

+

82 Data to send. Must have a '.tostring()' attribute 

+

83 

+

84 Returns 

+

85 ------- 

+

86 None 

+

87 ''' 

+

88 if system in self.data: 

+

89 if data is not None: 

+

90 self.data[system].append(data) 

+

91 

+

92 def sendMsg(self, msg): 

+

93 ''' 

+

94 Write a string to the *_msgs table for each system registered with the HDF sink 

+

95 

+

96 Parameters 

+

97 ---------- 

+

98 msg : string 

+

99 Message to link to the current row of the HDF table 

+

100 

+

101 Returns 

+

102 ------- 

+

103 None 

+

104 ''' 

+

105 for system in list(self.msgs.keys()): 

+

106 row = self.msgs[system].row 

+

107 row['time'] = len(self.data[system]) 

+

108 row['msg'] = msg 

+

109 row.append() 

+

110 

+

111 def sendAttr(self, system, attr, value): 

+

112 ''' 

+

113 While the HDF writer process is running, set an attribute of the table 

+

114 (not sure that this has ever been tested..) 

+

115 

+

116 Parameters 

+

117 ---------- 

+

118 system : string 

+

119 Name of the table where the attribute should be set 

+

120 attr : string  

+

121 Name of the attribute 

+

122 value : object 

+

123 Value of the attribute to set 

+

124 

+

125 Returns 

+

126 ------- 

+

127 None 

+

128 ''' 

+

129 if system in self.data: 

+

130 self.data[system].attrs[attr] = value 

+

131 

+

132 def close(self): 

+

133 ''' 

+

134 Close the HDF file so that it saves properly after the process terminates 

+

135 ''' 

+

136 self.h5.close() 

+

137 print("Closed hdf") 

+

138 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_setup_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_setup_py.html new file mode 100644 index 00000000..ecf36001 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_hdfwriter_setup_py.html @@ -0,0 +1,82 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\hdfwriter\setup.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import setuptools 

+

2 

+

3with open("README.md", "r") as fh: 

+

4 long_description = fh.read() 

+

5 

+

6setuptools.setup( 

+

7 name="hdfwriter", 

+

8 version="0.1.0", 

+

9 author="Carmena Lab", 

+

10 description="HDFWriter", 

+

11 long_description=long_description, 

+

12 long_description_content_type="text/markdown", 

+

13 packages=setuptools.find_packages(), 

+

14 classifiers=[ 

+

15 "Programming Language :: Python :: 3", 

+

16 "Operating System :: OS Independent", 

+

17 ], 

+

18) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_kinarmdata_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_kinarmdata_py.html new file mode 100644 index 00000000..297e1520 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_kinarmdata_py.html @@ -0,0 +1,136 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\kinarmdata.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Code for getting data from kinarm 

+

3''' 

+

4 

+

5import time 

+

6import numpy as np 

+

7from riglib.source import DataSourceSystem 

+

8from riglib import kinarmsocket 

+

9 

+

10class Kinarmdata(DataSourceSystem): 

+

11 ''' 

+

12 Client for data streamed from kinarm, compatible with riglib.source.DataSource 

+

13 ''' 

+

14 update_freq = 1000. 

+

15 

+

16 # dtype is the numpy data type of items that will go  

+

17 # into the (multi-channel, in this case) datasource's ringbuffer 

+

18 dtype = np.dtype((np.float, (3, 50))) 

+

19 

+

20 def __init__(self, addr=("192.168.0.8", 9090)): 

+

21 ''' 

+

22 Constructor for Kinarmdata and connect to server 

+

23 

+

24 Parameters 

+

25 ---------- 

+

26 addr : tuple of length 2 

+

27 (client (self) IP address, client UDP port) 

+

28 

+

29 ''' 

+

30 self.conn = kinarmsocket.KinarmSocket(addr) 

+

31 self.conn.connect() 

+

32 

+

33 def start(self): 

+

34 ''' 

+

35 Start receiving data 

+

36 ''' 

+

37 self.data = self.conn.get_data() 

+

38 

+

39 def stop(self): 

+

40 ''' 

+

41 Disconnect from kinarmdata socket 

+

42 ''' 

+

43 self.conn.disconnect() 

+

44 

+

45 def get(self): 

+

46 ''' 

+

47 Get a new kinarm sample 

+

48 ''' 

+

49 # while True: 

+

50 # try: 

+

51 # d = self.data.next() 

+

52 # except: 

+

53 # break 

+

54 

+

55 return next(self.data) 

+

56 

+

57def make(cls=DataSourceSystem, *args, **kwargs): 

+

58 ''' 

+

59 Docstring 

+

60 This ridiculous function dynamically creates a class with a new init function 

+

61 

+

62 Parameters 

+

63 ---------- 

+

64 

+

65 Returns 

+

66 ------- 

+

67 ''' 

+

68 def init(self, *args, **kwargs): 

+

69 super(self.__class__, self).__init__(*args, **kwargs) 

+

70 

+

71 dtype = np.dtype((np.float, (3, 50))) 

+

72 return type(cls.__name__, (cls,), dict(dtype=dtype, __init__=init)) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_kinarmsocket_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_kinarmsocket_py.html new file mode 100644 index 00000000..996c9ef2 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_kinarmsocket_py.html @@ -0,0 +1,146 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\kinarmsocket.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import socket 

+

2import struct 

+

3import time 

+

4import numpy as np 

+

5 

+

6class KinarmSocket(object): 

+

7 ''' Method for receiving UDP packets from Kinarm. The kinarm must be running the Dexterit-E task  

+

8 called 'PK_send_UDP_kinarm', or any task with a 'UDP Send Binary' block which is located in the  

+

9 xPC target tab of the Simulink library. Note: The normal 'UDP Send' input will NOT work -- packets 

+

10 will not be sent since the Dexterit task computer does not receive them (the xPC machine does).  

+

11 

+

12 The socket port that packets are sent to is also configured in this UDP Send Binary block.  

+

13 

+

14 Data format is double --> binary byte packing with Byte alignment = 1 (From Matlab: "The byte alignment  

+

15 field specifies how the data types are aligned. The possible values are: 1, 2, 4, and 8. The byte  

+

16 alignment scheme is simple, and starts each element in the list of signals on a boundary specified  

+

17 by the alignment relative to the start of the vector."") 

+

18 

+

19 Packet received are 3 x 50 matrix. Refer to Dexterit Manual (on desktop of Dexterit task computer) for  

+

20 what the rows / columns refer to (page 71/82).  

+

21 ''' 

+

22 

+

23 def __init__(self, addr=('192.168.0.8', 9090)): 

+

24 ''' Self-IP address is: 192.168.0.8 set manually. Once fully migrating system over to BMI3d,  

+

25 will need to adjust for new IPs''' 

+

26 

+

27 #Set up UDP socket (specificed by socket.SOCK_DGRAM. TCP woudl use socket.SOCK_STREAM) 

+

28 self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 

+

29 

+

30 #Free up port for use if you just ran task: 

+

31 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

+

32 

+

33 #Bind address 

+

34 self.socket.bind(addr) 

+

35 self.data_size = 1200 # 3 x 50 x 8 bytes (double) 

+

36 self.remote_ip = '192.168.0.2' 

+

37 

+

38 def _test_connected(self): 

+

39 try: 

+

40 data, add = self.socket.recvfrom(1200) 

+

41 unpacked_ = struct.unpack('150d', data) 

+

42 assert len(unpacked_) == 150 

+

43 self.connected = True 

+

44 except: 

+

45 print('Make sure Kinarm Task is running - error in recieving packets') 

+

46 self.connected = False 

+

47 

+

48 def connect(self): 

+

49 self._test_connected() 

+

50 

+

51 def get_data(self): 

+

52 ''' 

+

53 A generator which yields packets as they are received 

+

54 ''' 

+

55 

+

56 assert self.connected, "Socket is not connected, cannot get data" 

+

57 

+

58 while self.connected: 

+

59 packet, address = self.socket.recvfrom(self.data_size) 

+

60 

+

61 #Make sure packet is from correct address:  

+

62 if address[0] == self.remote_ip: 

+

63 arrival_ts = time.time() 

+

64 data = np.array(struct.unpack('150d', packet)) 

+

65 

+

66 if data.shape[0] == 150: 

+

67 #reshape data into 3 x 50 

+

68 kindata = data.reshape(50,3).T 

+

69 #kindata = np.hstack((kindata, np.zeros((3, 1)))) 

+

70 #kindata[:, -1] = arrival_ts 

+

71 yield kindata 

+

72 

+

73 def disconnect(self): 

+

74 self.socket.close() 

+

75 self.connected = False 

+

76 

+

77 def __del__(self): 

+

78 self.disconnect() 

+

79 

+

80 

+

81 

+

82 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_master8stimulation_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_master8stimulation_py.html new file mode 100644 index 00000000..92d3df8c --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_master8stimulation_py.html @@ -0,0 +1,174 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\master8stimulation.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Code for the triggering stimulation from the Master-8 voltage-programmable stimulator 

+

3''' 

+

4 

+

5import time 

+

6import tempfile 

+

7import random 

+

8import traceback 

+

9import numpy as np 

+

10import fnmatch 

+

11import os 

+

12import comedi 

+

13 

+

14from riglib import calibrations, bmi 

+

15 

+

16from riglib.bmi import extractor 

+

17 

+

18import os 

+

19import subprocess 

+

20 

+

21import time 

+

22 

+

23###### CONSTANTS 

+

24sec_per_min = 60 

+

25 

+

26""" 

+

27Define one class to send TTL pulse for one cycle of stimulation (StimulationPulse). Define second class to implement this 

+

28at a fixed rate. 

+

29 

+

30Left off at line 79. Need to figure out how we re-trigger starting the stimulus pulse train again. 

+

31""" 

+

32 

+

33 

+

34 

+

35class TTLStimulation(object): 

+

36 '''During the stimulation phase, send a timed TTL pulse to the Master-8 stimulator''' 

+

37 hold_time = float(1) 

+

38 stimulation_pulse_length = float(0.2*1e6) 

+

39 stimulation_frequency = float(3) 

+

40 

+

41 status = dict( 

+

42 pulse = dict(pulse_end="interpulse_period", stop=None), 

+

43 interpulse_period = dict(another_pulse="pulse", pulse_train_end="pulse_off", stop=None), 

+

44 pulse_off= dict(stop=None) 

+

45 ) 

+

46 

+

47 com = comedi.comedi_open('/dev/comedi0') 

+

48 pulse_count = 0 #initializing number of pulses that have occured 

+

49 

+

50 def __init__(self, *args, **kwargs): 

+

51 super(TTLStimulation, self).__init__(*args, **kwargs) 

+

52 number_of_pulses = int(self.hold_time*self.stimulation_frequency) # total pulses during a stimulation pulse train, assumes hold_time is in s 

+

53 

+

54 def init(self): 

+

55 super(TTLStimulation, self).init() 

+

56 

+

57 #### TEST FUNCTIONS #### 

+

58 

+

59 def _test_pulse_end(self, ts): 

+

60 #return true if time has been longer than the specified pulse duration 

+

61 pulse_length = self.stimulation_pulse_length*1e-6 # assumes stimulation_pulse_length is in us 

+

62 return ts>=pulse_length 

+

63 

+

64 def _test_another_pulse(self,ts): 

+

65 interpulse_time = (1/self.stimulationfrequency) - self.stimulation_pulse_length*1e-6 # period minus the duration of a pulse 

+

66 return ts>=interpulse_time 

+

67 

+

68 def _test_pulse_train_end(self,ts): 

+

69 return (self.pulse_count > number_of_pulses) # end train if number of pulses is completed or if the animal ends holding early 

+

70 

+

71 

+

72 #### STATE FUNCTIONS #### 

+

73 

+

74 def _start_pulse(self): 

+

75 ''' 

+

76 At the start of the stimulation state, send TTL pulse 

+

77 ''' 

+

78 

+

79 #super(TTLStimulation, self)._start_pulse() 

+

80 subdevice = 0 

+

81 write_mask = 0x800000 

+

82 val = 0x800000 

+

83 base_channel = 0 

+

84 comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) 

+

85 self.pulse_count = self.pulse_count + 1 

+

86 #self.stimulation_start = self.get_time() - self.start_time 

+

87 

+

88 def _end_pulse(self): 

+

89 subdevice = 0 

+

90 write_mask = 0x800000 

+

91 val = 0x000000 

+

92 base_channel = 0 

+

93 comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) 

+

94 

+

95 def _start_interpulse_period(self): 

+

96 super(TTLStimulation, self)._start_interpulse_period() 

+

97 subdevice = 0 

+

98 write_mask = 0x800000 

+

99 val = 0x000000 

+

100 base_channel = 0 

+

101 comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) 

+

102 

+

103 def _end_interpulse_period(self): 

+

104 super(TTLStimulation, self)._end_interpulse_period() 

+

105 subdevice = 0 

+

106 write_mask = 0x800000 

+

107 val = 0x000000 

+

108 base_channel = 0 

+

109 comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) 

+

110 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_motiontracker_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_motiontracker_py.html new file mode 100644 index 00000000..541b5c90 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_motiontracker_py.html @@ -0,0 +1,364 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\motiontracker.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Base code for 'motiontracker' feature, compatible with PhaseSpace motiontracker 

+

3''' 

+

4 

+

5import os 

+

6import time 

+

7import numpy as np 

+

8 

+

9try: 

+

10 from OWL import * 

+

11except: 

+

12 OWL_MODE2 = False 

+

13 print("Cannot find phasespace driver") 

+

14 

+

15cwd = os.path.split(os.path.abspath(__file__))[0] 

+

16 

+

17class Simulate(object): 

+

18 ''' 

+

19 Docstring 

+

20 

+

21 Parameters 

+

22 ---------- 

+

23 

+

24 Returns 

+

25 ------- 

+

26 ''' 

+

27 update_freq = 240 

+

28 def __init__(self, marker_count=8, radius=(10, 2, 5), offset=(-20,0,0), speed=(5,5,4)): 

+

29 ''' 

+

30 Docstring 

+

31 

+

32 Parameters 

+

33 ---------- 

+

34 

+

35 Returns 

+

36 ------- 

+

37 ''' 

+

38 self.n = marker_count 

+

39 self.radius = radius 

+

40 self.offset = np.array(offset) 

+

41 self.speed = speed 

+

42 

+

43 self.offsets = np.random.rand(self.n)*np.pi 

+

44 

+

45 def start(self): 

+

46 ''' 

+

47 Docstring 

+

48 

+

49 Parameters 

+

50 ---------- 

+

51 

+

52 Returns 

+

53 ------- 

+

54 ''' 

+

55 self.stime = time.time() 

+

56 

+

57 def get(self): 

+

58 ''' 

+

59 Docstring 

+

60 

+

61 Parameters 

+

62 ---------- 

+

63 

+

64 Returns 

+

65 ------- 

+

66 ''' 

+

67 time.sleep(1./self.update_freq) 

+

68 ts = (time.time() - self.stime) 

+

69 data = np.zeros((self.n, 3)) 

+

70 for i, p in enumerate(self.offsets): 

+

71 x = self.radius[0] * np.cos(ts / self.speed[0] * 2*np.pi + p) 

+

72 y = self.radius[1] * np.sin(ts / self.speed[1] * 2*np.pi + p) 

+

73 z = self.radius[2] * np.sin(ts / self.speed[2] * 2*np.pi + p) 

+

74 data[i] = x,y,z 

+

75 

+

76 #expands the dimension for HDFwriter saving 

+

77 data_temp = np.hstack([data + np.random.randn(self.n, 3) * 0.1, np.ones((self.n, 1))]) 

+

78 data_temp_expand = np.expand_dims(data_temp, axis = 0) 

+

79 

+

80 return data_temp_expand 

+

81 

+

82 def stop(self): 

+

83 ''' 

+

84 Docstring 

+

85 

+

86 Parameters 

+

87 ---------- 

+

88 

+

89 Returns 

+

90 ------- 

+

91 ''' 

+

92 return 

+

93 

+

94 

+

95class System(object): 

+

96 ''' 

+

97 Docstring 

+

98 

+

99 Parameters 

+

100 ---------- 

+

101 

+

102 Returns 

+

103 ------- 

+

104 ''' 

+

105 update_freq = 240 

+

106 def __init__(self, marker_count=8, server_name='10.0.0.11', init_flags=OWL_MODE2): 

+

107 ''' 

+

108 Docstring 

+

109 

+

110 Parameters 

+

111 ---------- 

+

112 

+

113 Returns 

+

114 ------- 

+

115 ''' 

+

116 self.marker_count = marker_count 

+

117 if(owlInit(server_name, init_flags) < 0): 

+

118 raise Exception(owl_get_error("init error",owlGetError())) 

+

119 

+

120 # flush requests and check for errors fix 

+

121 if(owlGetStatus() == 0): 

+

122 raise Exception(owl_get_error("error in point tracker setup", owlGetError())) 

+

123 

+

124 # set define frequency 

+

125 owlSetFloat(OWL_FREQUENCY, OWL_MAX_FREQUENCY) 

+

126 

+

127 #create a point tracker 

+

128 self.tracker = 0 

+

129 owlTrackeri(self.tracker, OWL_CREATE, OWL_POINT_TRACKER) 

+

130 self._init_markers() 

+

131 

+

132 def _init_markers(self): 

+

133 ''' 

+

134 Docstring 

+

135 

+

136 Parameters 

+

137 ---------- 

+

138 

+

139 Returns 

+

140 ------- 

+

141 ''' 

+

142 # set markers 

+

143 for i in range(self.marker_count): 

+

144 owlMarkeri(MARKER(self.tracker, i), OWL_SET_LED, i) 

+

145 owlTracker(self.tracker, OWL_ENABLE) 

+

146 self.coords = np.zeros((self.marker_count, 4)) 

+

147 

+

148 def start(self, filename=None): 

+

149 ''' 

+

150 Docstring 

+

151 

+

152 Parameters 

+

153 ---------- 

+

154 

+

155 Returns 

+

156 ------- 

+

157 ''' 

+

158 self.filename = filename 

+

159 if filename is not None: 

+

160 #figure out command to tell phasespace to start a recording 

+

161 pass 

+

162 owlSetInteger(OWL_STREAMING, OWL_ENABLE) 

+

163 #owlSetInteger(OWL_INTERPOLATION, 4) 

+

164 

+

165 def stop(self): 

+

166 ''' 

+

167 Docstring 

+

168 

+

169 Parameters 

+

170 ---------- 

+

171 

+

172 Returns 

+

173 ------- 

+

174 ''' 

+

175 if self.filename is not None: 

+

176 #tell phasespace to stop recording 

+

177 pass 

+

178 owlSetInteger(OWL_STREAMING, OWL_DISABLE) 

+

179 

+

180 def get(self): 

+

181 ''' 

+

182 Docstring 

+

183 

+

184 Parameters 

+

185 ---------- 

+

186 

+

187 Returns 

+

188 ------- 

+

189 ''' 

+

190 markers = [] 

+

191 n = owlGetMarkers(markers, self.marker_count) 

+

192 while n == 0: 

+

193 time.sleep(.001) 

+

194 n = owlGetMarkers(markers, self.marker_count) 

+

195 

+

196 for i, m in enumerate(markers): 

+

197 self.coords[i] = m.x, m.y, m.z, m.cond 

+

198 

+

199 return self.coords 

+

200 

+

201 def __del__(self): 

+

202 ''' 

+

203 Docstring 

+

204 

+

205 Parameters 

+

206 ---------- 

+

207 

+

208 Returns 

+

209 ------- 

+

210 ''' 

+

211 for i in range(self.marker_count): 

+

212 owlMarker(MARKER(self.tracker, i), OWL_CLEAR_MARKER) 

+

213 owlTracker(self.tracker, OWL_DESTROY) 

+

214 owlDone() 

+

215 

+

216class AligningSystem(System): 

+

217 ''' 

+

218 Docstring 

+

219 

+

220 Parameters 

+

221 ---------- 

+

222 

+

223 Returns 

+

224 ------- 

+

225 ''' 

+

226 def _init_markers(self): 

+

227 ''' 

+

228 Docstring 

+

229 

+

230 Parameters 

+

231 ---------- 

+

232 

+

233 Returns 

+

234 ------- 

+

235 ''' 

+

236 MAX = 32 

+

237 for i in range(self.marker_count): 

+

238 owlMarkeri(MARKER(self.tracker, i), OWL_SET_LED, i) 

+

239 for i in range(6): 

+

240 owlMarkeri(MARKER(self.tracker, self.marker_count+i), OWL_SET_LED, MAX+i) 

+

241 self.marker_count += 6 

+

242 owlTracker(self.tracker, OWL_ENABLE) 

+

243 self.coords = np.zeros((self.marker_count, 4)) 

+

244 

+

245def owl_get_error(s, n): 

+

246 """ 

+

247 Print OWL error. 

+

248 Docstring 

+

249 

+

250 Parameters 

+

251 ---------- 

+

252 

+

253 Returns 

+

254 ------- 

+

255 """ 

+

256 if(n < 0): return "%s: %d" % (s, n) 

+

257 elif(n == OWL_NO_ERROR): return "%s: No Error" % s 

+

258 elif(n == OWL_INVALID_VALUE): return "%s: Invalid Value" % s 

+

259 elif(n == OWL_INVALID_ENUM): return "%s: Invalid Enum" % s 

+

260 elif(n == OWL_INVALID_OPERATION): return "%s: Invalid Operation" % s 

+

261 else: return "%s: 0x%x" % (s, n) 

+

262 

+

263 

+

264def make(marker_count, cls=System, **kwargs): 

+

265 """This ridiculous function dynamically creates a class with a new init function 

+

266 Docstring 

+

267 

+

268 Parameters 

+

269 ---------- 

+

270 

+

271 Returns 

+

272 ------- 

+

273 """ 

+

274 def init(self, **kwargs): 

+

275 super(self.__class__, self).__init__(marker_count=marker_count, **kwargs) 

+

276 

+

277 dtype = np.dtype((np.float, (marker_count, 4))) 

+

278 if cls == AligningSystem: 

+

279 dtype = np.dtype((np.float, (marker_count+6, 4))) 

+

280 return type(cls.__name__, (cls,), dict(dtype=dtype, __init__=init)) 

+

281 

+

282 

+

283def make_autoalign_reference(data, filename=os.path.join(cwd, "alignment2.npz")): 

+

284 '''Creates an alignment that can be used with the autoaligner 

+

285 Docstring 

+

286 

+

287 Parameters 

+

288 ---------- 

+

289 

+

290 Returns 

+

291 ------- 

+

292 ''' 

+

293 from .stereo_opengl import xfm 

+

294 assert data.shape[1:] == (6, 3) 

+

295 mdata = np.median(data,0) 

+

296 cdata = mdata - mdata[0] 

+

297 rot1 = xfm.Quaternion.rotate_vecs(np.cross(cdata[2], cdata[1]), [0,1,0]) 

+

298 rdata = rot1*cdata 

+

299 rot2 = xfm.Quaternion.rotate_vecs(rdata[1], [1, 0, 0]) 

+

300 np.savez(filename, data=data, reference=rot2*rot1*cdata) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_mp_calc_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_mp_calc_py.html new file mode 100644 index 00000000..59b288ee --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_mp_calc_py.html @@ -0,0 +1,284 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\mp_calc.py: 17% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import multiprocessing as mp 

+

2import time 

+

3 

+

4import numpy as np 

+

5import queue 

+

6 

+

7class MPCompute(mp.Process): 

+

8 """ 

+

9 Generic class for running computations that occur infrequently 

+

10 but take longer than a single BMI loop iteration 

+

11 """ 

+

12 def __init__(self, work_queue, result_queue, fn): 

+

13 ''' 

+

14 Constructor for MPCompute 

+

15 

+

16 Parameters 

+

17 ---------- 

+

18 work_queue : mp.Queue 

+

19 Jobs start when an entry is found in work_queue 

+

20 result_queue : mp.Queue 

+

21 Results of job are placed back onto result_queue 

+

22 

+

23 Returns 

+

24 ------- 

+

25 MPCompute instance 

+

26 ''' 

+

27 # run base constructor 

+

28 super(MPCompute, self).__init__() 

+

29 

+

30 self.work_queue = work_queue 

+

31 self.result_queue = result_queue 

+

32 self.done = mp.Event() 

+

33 self.fn = fn 

+

34 

+

35 def _check_for_job(self): 

+

36 ''' 

+

37 Non-blocking check to see if data is present in the input queue 

+

38 ''' 

+

39 try: 

+

40 job = self.work_queue.get_nowait() 

+

41 except: 

+

42 job = None 

+

43 return job 

+

44 

+

45 def run(self): 

+

46 ''' 

+

47 The main loop. Starts automatically when the process is spawned. See mp.Process.run for additional docs. 

+

48 Every 0.5 seconds, check for new computation to carry out 

+

49 

+

50 Parameters 

+

51 ---------- 

+

52 None 

+

53 

+

54 Returns 

+

55 ------- 

+

56 None 

+

57 ''' 

+

58 while not self.done.is_set(): 

+

59 job = self._check_for_job() 

+

60 

+

61 # unpack the data 

+

62 if not (job is None): 

+

63 new_params = self.calc(*job[0], **job[1]) 

+

64 self.result_queue.put(new_params) 

+

65 

+

66 # Pause to lower the process's effective priority 

+

67 time.sleep(0.5) 

+

68 

+

69 def calc(self, *args, **kwargs): 

+

70 ''' 

+

71 Run the actual calculation function 

+

72 ''' 

+

73 return self.fn(*args, **kwargs) 

+

74 

+

75 def stop(self): 

+

76 ''' 

+

77 Set the flag to stop the 'while' loop in the 'run' method gracefully 

+

78 ''' 

+

79 self.done.set() 

+

80 

+

81 

+

82class FuncProxy(object): 

+

83 ''' 

+

84 Wrapper for MPCompute computations running in another process 

+

85 ''' 

+

86 def __init__(self, fn, multiproc=False, waiting_resp=None, init_resp=None, verbose=False): 

+

87 self.verbose = verbose 

+

88 self.multiproc = multiproc 

+

89 if self.multiproc: 

+

90 # create the queues 

+

91 self.work_queue = mp.Queue() 

+

92 self.result_queue = mp.Queue() 

+

93 

+

94 # Instantiate the process 

+

95 self.calculator = MPCompute(self.work_queue, self.result_queue, fn) 

+

96 

+

97 # spawn the process 

+

98 self.calculator.start() 

+

99 else: 

+

100 self.fn = fn 

+

101 

+

102 assert waiting_resp in [None, 'prev'], "Unrecognized waiting_resp" 

+

103 self.waiting_resp = waiting_resp 

+

104 

+

105 self.prev_result = (init_resp, 0) 

+

106 self.prev_input = None 

+

107 self.waiting = False 

+

108 

+

109 def reset(self): 

+

110 ''' 

+

111 Docstring 

+

112 

+

113 Parameters 

+

114 ---------- 

+

115 

+

116 Returns 

+

117 ------- 

+

118 ''' 

+

119 self.prev_input = None 

+

120 

+

121 def _stuff(self): 

+

122 ''' 

+

123 Docstring 

+

124 

+

125 Parameters 

+

126 ---------- 

+

127 

+

128 Returns 

+

129 ------- 

+

130 ''' 

+

131 try: 

+

132 output_data = self.result_queue.get_nowait() 

+

133 self.prev_result = output_data 

+

134 self.waiting = False 

+

135 return output_data, True 

+

136 except queue.Empty: 

+

137 if self.waiting_resp == None: 

+

138 return None 

+

139 elif self.waiting_resp == 'prev': 

+

140 return self.prev_result, False 

+

141 except: 

+

142 import traceback 

+

143 traceback.print_exc() 

+

144 

+

145 def input_same(self, stuff): 

+

146 ''' 

+

147 Docstring 

+

148 

+

149 Parameters 

+

150 ---------- 

+

151 

+

152 Returns 

+

153 ------- 

+

154 ''' 

+

155 args, kwargs = stuff 

+

156 if self.prev_input == None: 

+

157 return False 

+

158 

+

159 args_same = True 

+

160 for a1, a2 in zip(args, self.prev_input[0]): 

+

161 try: 

+

162 args_same = args_same and np.all(a1 == a2) 

+

163 except ValueError: 

+

164 args_same = args_same and np.array_equal(a1, a2) 

+

165 

+

166 kwargs_same = list(kwargs.keys()) == list(self.prev_input[1].keys()) 

+

167 

+

168 for key1, key2 in zip(list(kwargs.keys()), list(self.prev_input[1].keys())): 

+

169 k1 = kwargs[key1] 

+

170 k2 = kwargs[key2] 

+

171 if key1 == 'q_start': 

+

172 continue 

+

173 try: 

+

174 kwargs_same = kwargs_same and np.all(k1 == k2) 

+

175 except: 

+

176 kwargs_same = kwargs_same and np.array_equal(k1, k2) 

+

177 

+

178 return args_same and kwargs_same 

+

179 

+

180 

+

181 def __call__(self, *args, **kwargs): 

+

182 ''' 

+

183 Docstring 

+

184 

+

185 Parameters 

+

186 ---------- 

+

187 

+

188 Returns 

+

189 ------- 

+

190 ''' 

+

191 input_data = (args, kwargs) 

+

192 input_same_as_last = self.input_same(input_data) #input_data == self.prev_input  

+

193 if self.multiproc: 

+

194 if input_same_as_last and not self.waiting: 

+

195 return self.prev_result, False 

+

196 

+

197 elif input_same_as_last and self.waiting: 

+

198 # Return the new result if it's available, otherwise the previous result 

+

199 return self._stuff() 

+

200 

+

201 elif not input_same_as_last: 

+

202 if self.verbose: print("queuing job") 

+

203 self.work_queue.put(input_data) 

+

204 self.prev_input = input_data 

+

205 self.waiting = True 

+

206 return self._stuff() 

+

207 else: 

+

208 if input_same_as_last: 

+

209 return self.prev_result, False 

+

210 else: 

+

211 self.prev_input = input_data 

+

212 self.prev_result = self.fn(*args, **kwargs) 

+

213 return self.prev_result, True 

+

214 

+

215 def __del__(self): 

+

216 ''' 

+

217 Stop the child process if one was spawned 

+

218 ''' 

+

219 if self.multiproc: 

+

220 self.calculator.stop() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_mp_proxy_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_mp_proxy_py.html new file mode 100644 index 00000000..53fcbe42 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_mp_proxy_py.html @@ -0,0 +1,118 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\mp_proxy.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import os 

+

2import sys 

+

3import time 

+

4import inspect 

+

5import traceback 

+

6import multiprocessing as mp 

+

7from multiprocessing import sharedctypes as shm 

+

8import ctypes 

+

9 

+

10import numpy as np 

+

11 

+

12 

+

13class FuncProxy(object): 

+

14 ''' 

+

15 Interface for calling functions in remote processes. 

+

16 ''' 

+

17 def __init__(self, name, pipe, event=None): 

+

18 ''' 

+

19 Constructor for FuncProxy 

+

20 

+

21 Parameters 

+

22 ---------- 

+

23 name : string 

+

24 Name of remote function to call 

+

25 pipe : mp.Pipe instance 

+

26 multiprocessing pipe through which to send data (function name, arguments) and receive the result 

+

27 event : mp.Event instance 

+

28 A flag to set which is multiprocessing-compatible (visible to both the current and the remote processes) 

+

29 

+

30 Returns 

+

31 ------- 

+

32 FuncProxy instance 

+

33 ''' 

+

34 self.pipe = pipe 

+

35 self.name = name 

+

36 self.event = event 

+

37 

+

38 def __call__(self, *args, **kwargs): 

+

39 ''' 

+

40 Return the result of the remote function call 

+

41 

+

42 Parameters 

+

43 ---------- 

+

44 *args, **kwargs : positional arguments, keyword arguments 

+

45 To be passed to the remote function associated when the object was created 

+

46 

+

47 Returns 

+

48 ------- 

+

49 function result 

+

50 ''' 

+

51 self.pipe.send((self.name, args, kwargs)) 

+

52 if not self.event is None: 

+

53 self.event.set() 

+

54 return self.pipe.recv() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_NatNetClient_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_NatNetClient_py.html new file mode 100644 index 00000000..fc1d50c3 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_NatNetClient_py.html @@ -0,0 +1,580 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\optitrack_client\NatNetClient.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1#Copyright © 2018 Naturalpoint 

+

2# 

+

3#Licensed under the Apache License, Version 2.0 (the "License"); 

+

4#you may not use this file except in compliance with the License. 

+

5#You may obtain a copy of the License at 

+

6# 

+

7#http://www.apache.org/licenses/LICENSE-2.0 

+

8# 

+

9#Unless required by applicable law or agreed to in writing, software 

+

10#distributed under the License is distributed on an "AS IS" BASIS, 

+

11#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

+

12#See the License for the specific language governing permissions and 

+

13#limitations under the License. 

+

14 

+

15# OptiTrack NatNet direct depacketization library for Python 3.x 

+

16 

+

17import socket 

+

18import struct 

+

19from threading import Thread 

+

20 

+

21def trace( *args ): 

+

22 pass # print( "".join(map(str,args)) ) 

+

23 

+

24# Create structs for reading various object types to speed up parsing. 

+

25Vector3 = struct.Struct( '<fff' ) 

+

26Quaternion = struct.Struct( '<ffff' ) 

+

27FloatValue = struct.Struct( '<f' ) 

+

28DoubleValue = struct.Struct( '<d' ) 

+

29 

+

30class NatNetClient: 

+

31 def __init__( self ): 

+

32 # Change this value to the IP address of the NatNet server. 

+

33 self.serverIPAddress = "10.155.206.1" 

+

34 

+

35 # Change this value to the IP address of your local network interface 

+

36 self.localIPAddress = '10.155.205.164' 

+

37 

+

38 # This should match the multicast address listed in Motive's streaming settings. 

+

39 self.multicastAddress = "239.255.42.99" 

+

40 

+

41 # NatNet Command channel 

+

42 self.commandPort = 1510 

+

43 

+

44 # NatNet Data channel  

+

45 self.dataPort = 1511 

+

46 

+

47 # Set this to a callback method of your choice to receive per-rigid-body data at each frame. 

+

48 self.rigidBodyListener = None 

+

49 

+

50 # NatNet stream version. This will be updated to the actual version the server is using during initialization. 

+

51 self.__natNetStreamVersion = (3,0,0,0) 

+

52 

+

53 # Client/server message ids 

+

54 NAT_PING = 0 

+

55 NAT_PINGRESPONSE = 1 

+

56 NAT_REQUEST = 2 

+

57 NAT_RESPONSE = 3 

+

58 NAT_REQUEST_MODELDEF = 4 

+

59 NAT_MODELDEF = 5 

+

60 NAT_REQUEST_FRAMEOFDATA = 6 

+

61 NAT_FRAMEOFDATA = 7 

+

62 NAT_MESSAGESTRING = 8 

+

63 NAT_DISCONNECT = 9 

+

64 NAT_UNRECOGNIZED_REQUEST = 100 

+

65 

+

66 # Create a data socket to attach to the NatNet stream 

+

67 def __createDataSocket( self, port): 

+

68 result = socket.socket( socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) # UDP 

+

69 result.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

+

70 

+

71 #mreq=socket.inet_aton(self.multicastAddress)+socket.inet_aton(self.localIPAddress) 

+

72 #mreq=socket.inet_aton(self.multicastAddress)+socket.inet_aton(socket.INADDR_ANY) 

+

73 mreq = struct.pack('4sL', socket.inet_aton(self.multicastAddress), socket.INADDR_ANY) 

+

74 result.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) 

+

75 

+

76 #result.bind( (self.localIPAddress, port) ) # bug, why would you bind to yourself if you wanna receive data from a remote server!  

+

77 

+

78 result.bind( ("", port) ) 

+

79 return result 

+

80 

+

81 # Create a command socket to attach to the NatNet stream 

+

82 def __createCommandSocket( self ): 

+

83 result = socket.socket( socket.AF_INET, socket.SOCK_DGRAM ) 

+

84 result.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

+

85 result.bind( ('', 0) ) 

+

86 result.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) 

+

87 

+

88 return result 

+

89 

+

90 # Unpack a rigid body object from a data packet 

+

91 def __unpackRigidBody( self, data ): 

+

92 offset = 0 

+

93 

+

94 # ID (4 bytes) 

+

95 id = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

96 offset += 4 

+

97 trace( "ID:", id ) 

+

98 

+

99 # Position and orientation 

+

100 pos = Vector3.unpack( data[offset:offset+12] ) 

+

101 offset += 12 

+

102 trace( "\tPosition:", pos[0],",", pos[1],",", pos[2] ) 

+

103 rot = Quaternion.unpack( data[offset:offset+16] ) 

+

104 offset += 16 

+

105 trace( "\tOrientation:", rot[0],",", rot[1],",", rot[2],",", rot[3] ) 

+

106 

+

107 # Send information to any listener. 

+

108 if self.rigidBodyListener is not None: 

+

109 self.rigidBodyListener( id, pos, rot ) 

+

110 

+

111 # RB Marker Data ( Before version 3.0. After Version 3.0 Marker data is in description ) 

+

112 if( self.__natNetStreamVersion[0] < 3 and self.__natNetStreamVersion[0] != 0) : 

+

113 # Marker count (4 bytes) 

+

114 markerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

115 offset += 4 

+

116 markerCountRange = range( 0, markerCount ) 

+

117 trace( "\tMarker Count:", markerCount ) 

+

118 

+

119 # Marker positions 

+

120 for i in markerCountRange: 

+

121 pos = Vector3.unpack( data[offset:offset+12] ) 

+

122 offset += 12 

+

123 trace( "\tMarker", i, ":", pos[0],",", pos[1],",", pos[2] ) 

+

124 

+

125 if( self.__natNetStreamVersion[0] >= 2 ): 

+

126 # Marker ID's 

+

127 for i in markerCountRange: 

+

128 id = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

129 offset += 4 

+

130 trace( "\tMarker ID", i, ":", id ) 

+

131 

+

132 # Marker sizes 

+

133 for i in markerCountRange: 

+

134 size = FloatValue.unpack( data[offset:offset+4] ) 

+

135 offset += 4 

+

136 trace( "\tMarker Size", i, ":", size[0] ) 

+

137 

+

138 if( self.__natNetStreamVersion[0] >= 2 ): 

+

139 markerError, = FloatValue.unpack( data[offset:offset+4] ) 

+

140 offset += 4 

+

141 trace( "\tMarker Error:", markerError ) 

+

142 

+

143 # Version 2.6 and later 

+

144 if( ( ( self.__natNetStreamVersion[0] == 2 ) and ( self.__natNetStreamVersion[1] >= 6 ) ) or self.__natNetStreamVersion[0] > 2 or self.__natNetStreamVersion[0] == 0 ): 

+

145 param, = struct.unpack( 'h', data[offset:offset+2] ) 

+

146 trackingValid = ( param & 0x01 ) != 0 

+

147 offset += 2 

+

148 trace( "\tTracking Valid:", 'True' if trackingValid else 'False' ) 

+

149 

+

150 return offset 

+

151 

+

152 # Unpack a skeleton object from a data packet 

+

153 def __unpackSkeleton( self, data ): 

+

154 offset = 0 

+

155 

+

156 id = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

157 offset += 4 

+

158 trace( "ID:", id ) 

+

159 

+

160 rigidBodyCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

161 offset += 4 

+

162 trace( "Rigid Body Count:", rigidBodyCount ) 

+

163 for j in range( 0, rigidBodyCount ): 

+

164 offset += self.__unpackRigidBody( data[offset:] ) 

+

165 

+

166 return offset 

+

167 

+

168 # Unpack data from a motion capture frame message 

+

169 def __unpackMocapData( self, data ): 

+

170 trace( "Begin MoCap Frame\n-----------------\n" ) 

+

171 

+

172 data = memoryview( data ) 

+

173 offset = 0 

+

174 

+

175 # Frame number (4 bytes) 

+

176 frameNumber = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

177 offset += 4 

+

178 trace( "Frame #:", frameNumber ) 

+

179 

+

180 # Marker set count (4 bytes) 

+

181 markerSetCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

182 offset += 4 

+

183 trace( "Marker Set Count:", markerSetCount ) 

+

184 

+

185 for i in range( 0, markerSetCount ): 

+

186 # Model name 

+

187 modelName, separator, remainder = bytes(data[offset:]).partition( b'\0' ) 

+

188 offset += len( modelName ) + 1 

+

189 trace( "Model Name:", modelName.decode( 'utf-8' ) ) 

+

190 

+

191 # Marker count (4 bytes) 

+

192 markerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

193 offset += 4 

+

194 trace( "Marker Count:", markerCount ) 

+

195 

+

196 for j in range( 0, markerCount ): 

+

197 pos = Vector3.unpack( data[offset:offset+12] ) 

+

198 offset += 12 

+

199 #trace( "\tMarker", j, ":", pos[0],",", pos[1],",", pos[2] ) 

+

200 

+

201 # Unlabeled markers count (4 bytes) 

+

202 unlabeledMarkersCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

203 offset += 4 

+

204 trace( "Unlabeled Markers Count:", unlabeledMarkersCount ) 

+

205 

+

206 for i in range( 0, unlabeledMarkersCount ): 

+

207 pos = Vector3.unpack( data[offset:offset+12] ) 

+

208 offset += 12 

+

209 trace( "\tMarker", i, ":", pos[0],",", pos[1],",", pos[2] ) 

+

210 

+

211 # Rigid body count (4 bytes) 

+

212 rigidBodyCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

213 offset += 4 

+

214 trace( "Rigid Body Count:", rigidBodyCount ) 

+

215 

+

216 for i in range( 0, rigidBodyCount ): 

+

217 offset += self.__unpackRigidBody( data[offset:] ) 

+

218 

+

219 # Version 2.1 and later 

+

220 skeletonCount = 0 

+

221 if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] > 0 ) or self.__natNetStreamVersion[0] > 2 ): 

+

222 skeletonCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

223 offset += 4 

+

224 trace( "Skeleton Count:", skeletonCount ) 

+

225 for i in range( 0, skeletonCount ): 

+

226 offset += self.__unpackSkeleton( data[offset:] ) 

+

227 

+

228 # Labeled markers (Version 2.3 and later) 

+

229 labeledMarkerCount = 0 

+

230 if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] > 3 ) or self.__natNetStreamVersion[0] > 2 ): 

+

231 labeledMarkerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

232 offset += 4 

+

233 trace( "Labeled Marker Count:", labeledMarkerCount ) 

+

234 for i in range( 0, labeledMarkerCount ): 

+

235 id = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

236 offset += 4 

+

237 pos = Vector3.unpack( data[offset:offset+12] ) 

+

238 offset += 12 

+

239 size = FloatValue.unpack( data[offset:offset+4] ) 

+

240 offset += 4 

+

241 

+

242 # Version 2.6 and later 

+

243 if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 6 ) or self.__natNetStreamVersion[0] > 2 or major == 0 ): 

+

244 param, = struct.unpack( 'h', data[offset:offset+2] ) 

+

245 offset += 2 

+

246 occluded = ( param & 0x01 ) != 0 

+

247 pointCloudSolved = ( param & 0x02 ) != 0 

+

248 modelSolved = ( param & 0x04 ) != 0 

+

249 

+

250 # Version 3.0 and later 

+

251 if( ( self.__natNetStreamVersion[0] >= 3 ) or major == 0 ): 

+

252 residual, = FloatValue.unpack( data[offset:offset+4] ) 

+

253 offset += 4 

+

254 trace( "Residual:", residual ) 

+

255 

+

256 # Force Plate data (version 2.9 and later) 

+

257 if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 9 ) or self.__natNetStreamVersion[0] > 2 ): 

+

258 forcePlateCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

259 offset += 4 

+

260 trace( "Force Plate Count:", forcePlateCount ) 

+

261 for i in range( 0, forcePlateCount ): 

+

262 # ID 

+

263 forcePlateID = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

264 offset += 4 

+

265 trace( "Force Plate", i, ":", forcePlateID ) 

+

266 

+

267 # Channel Count 

+

268 forcePlateChannelCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

269 offset += 4 

+

270 

+

271 # Channel Data 

+

272 for j in range( 0, forcePlateChannelCount ): 

+

273 trace( "\tChannel", j, ":", forcePlateID ) 

+

274 forcePlateChannelFrameCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

275 offset += 4 

+

276 for k in range( 0, forcePlateChannelFrameCount ): 

+

277 forcePlateChannelVal = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

278 offset += 4 

+

279 trace( "\t\t", forcePlateChannelVal ) 

+

280 

+

281 # Device data (version 2.11 and later) 

+

282 if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 11 ) or self.__natNetStreamVersion[0] > 2 ): 

+

283 deviceCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

284 offset += 4 

+

285 trace( "Device Count:", deviceCount ) 

+

286 for i in range( 0, deviceCount ): 

+

287 # ID 

+

288 deviceID = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

289 offset += 4 

+

290 trace( "Device", i, ":", deviceID ) 

+

291 

+

292 # Channel Count 

+

293 deviceChannelCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

294 offset += 4 

+

295 

+

296 # Channel Data 

+

297 for j in range( 0, deviceChannelCount ): 

+

298 trace( "\tChannel", j, ":", deviceID ) 

+

299 deviceChannelFrameCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

300 offset += 4 

+

301 for k in range( 0, deviceChannelFrameCount ): 

+

302 deviceChannelVal = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

303 offset += 4 

+

304 trace( "\t\t", deviceChannelVal ) 

+

305 

+

306 # Timecode  

+

307 timecode = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

308 offset += 4 

+

309 timecodeSub = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

310 offset += 4 

+

311 

+

312 # Timestamp (increased to double precision in 2.7 and later) 

+

313 if( ( self.__natNetStreamVersion[0] == 2 and self.__natNetStreamVersion[1] >= 7 ) or self.__natNetStreamVersion[0] > 2 ): 

+

314 timestamp, = DoubleValue.unpack( data[offset:offset+8] ) 

+

315 offset += 8 

+

316 else: 

+

317 timestamp, = FloatValue.unpack( data[offset:offset+4] ) 

+

318 offset += 4 

+

319 

+

320 # Hires Timestamp (Version 3.0 and later) 

+

321 if( ( self.__natNetStreamVersion[0] >= 3 ) or major == 0 ): 

+

322 stampCameraExposure = int.from_bytes( data[offset:offset+8], byteorder='little' ) 

+

323 offset += 8 

+

324 stampDataReceived = int.from_bytes( data[offset:offset+8], byteorder='little' ) 

+

325 offset += 8 

+

326 stampTransmit = int.from_bytes( data[offset:offset+8], byteorder='little' ) 

+

327 offset += 8 

+

328 # Frame parameters 

+

329 param, = struct.unpack( 'h', data[offset:offset+2] ) 

+

330 isRecording = ( param & 0x01 ) != 0 

+

331 trackedModelsChanged = ( param & 0x02 ) != 0 

+

332 offset += 2 

+

333 

+

334 # Send information to any listener. 

+

335 if self.newFrameListener is not None: 

+

336 #print(frameNumber) 

+

337 self.newFrameListener( frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, 

+

338 labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ) 

+

339 

+

340 # Unpack a marker set description packet 

+

341 def __unpackMarkerSetDescription( self, data ): 

+

342 offset = 0 

+

343 

+

344 name, separator, remainder = bytes(data[offset:]).partition( b'\0' ) 

+

345 offset += len( name ) + 1 

+

346 trace( "Markerset Name:", name.decode( 'utf-8' ) ) 

+

347 

+

348 markerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

349 offset += 4 

+

350 

+

351 for i in range( 0, markerCount ): 

+

352 name, separator, remainder = bytes(data[offset:]).partition( b'\0' ) 

+

353 offset += len( name ) + 1 

+

354 trace( "\tMarker Name:", name.decode( 'utf-8' ) ) 

+

355 

+

356 return offset 

+

357 

+

358 # Unpack a rigid body description packet 

+

359 def __unpackRigidBodyDescription( self, data ): 

+

360 offset = 0 

+

361 

+

362 # Version 2.0 or higher 

+

363 if( self.__natNetStreamVersion[0] >= 2 ): 

+

364 name, separator, remainder = bytes(data[offset:]).partition( b'\0' ) 

+

365 offset += len( name ) + 1 

+

366 trace( "\tRigidBody Name:", name.decode( 'utf-8' ) ) 

+

367 

+

368 id = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

369 offset += 4 

+

370 

+

371 parentID = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

372 offset += 4 

+

373 

+

374 timestamp = Vector3.unpack( data[offset:offset+12] ) 

+

375 offset += 12 

+

376 

+

377 # Version 3.0 and higher, rigid body marker information contained in description 

+

378 if (self.__natNetStreamVersion[0] >= 3 or self.__natNetStreamVersion[0] == 0 ): 

+

379 markerCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

380 offset += 4 

+

381 trace( "\tRigidBody Marker Count:", markerCount ) 

+

382 

+

383 markerCountRange = range( 0, markerCount ) 

+

384 for marker in markerCountRange: 

+

385 markerOffset = Vector3.unpack(data[offset:offset+12]) 

+

386 offset +=12 

+

387 for marker in markerCountRange: 

+

388 activeLabel = int.from_bytes(data[offset:offset+4],byteorder = 'little') 

+

389 offset += 4 

+

390 

+

391 return offset 

+

392 

+

393 # Unpack a skeleton description packet 

+

394 def __unpackSkeletonDescription( self, data ): 

+

395 offset = 0 

+

396 

+

397 name, separator, remainder = bytes(data[offset:]).partition( b'\0' ) 

+

398 offset += len( name ) + 1 

+

399 trace( "\tMarker Name:", name.decode( 'utf-8' ) ) 

+

400 

+

401 id = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

402 offset += 4 

+

403 

+

404 rigidBodyCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

405 offset += 4 

+

406 

+

407 for i in range( 0, rigidBodyCount ): 

+

408 offset += self.__unpackRigidBodyDescription( data[offset:] ) 

+

409 

+

410 return offset 

+

411 

+

412 # Unpack a data description packet 

+

413 def __unpackDataDescriptions( self, data ): 

+

414 offset = 0 

+

415 datasetCount = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

416 offset += 4 

+

417 

+

418 for i in range( 0, datasetCount ): 

+

419 type = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

420 offset += 4 

+

421 if( type == 0 ): 

+

422 offset += self.__unpackMarkerSetDescription( data[offset:] ) 

+

423 elif( type == 1 ): 

+

424 offset += self.__unpackRigidBodyDescription( data[offset:] ) 

+

425 elif( type == 2 ): 

+

426 offset += self.__unpackSkeletonDescription( data[offset:] ) 

+

427 

+

428 def __dataThreadFunction( self, socket ): 

+

429 while True: 

+

430 # Block for input 

+

431 data, addr = socket.recvfrom(1024) # 32k byte buffer size 

+

432 if( len( data ) > 0 ): 

+

433 self.__processMessage( data ) 

+

434 

+

435 def __processMessage( self, data ): 

+

436 trace( "Begin Packet\n------------\n" ) 

+

437 

+

438 messageID = int.from_bytes( data[0:2], byteorder='little' ) 

+

439 trace( "Message ID:", messageID ) 

+

440 #print(messageID) 

+

441 #print(self.NAT_FRAMEOFDATA) 

+

442 

+

443 packetSize = int.from_bytes( data[2:4], byteorder='little' ) 

+

444 trace( "Packet Size:", packetSize ) 

+

445 

+

446 offset = 4 

+

447 if( messageID == self.NAT_FRAMEOFDATA ): 

+

448 self.__unpackMocapData( data[offset:] ) 

+

449 elif( messageID == self.NAT_MODELDEF ): 

+

450 self.__unpackDataDescriptions( data[offset:] ) 

+

451 elif( messageID == self.NAT_PINGRESPONSE ): 

+

452 offset += 256 # Skip the sending app's Name field 

+

453 offset += 4 # Skip the sending app's Version info 

+

454 self.__natNetStreamVersion = struct.unpack( 'BBBB', data[offset:offset+4] ) 

+

455 offset += 4 

+

456 elif( messageID == self.NAT_RESPONSE ): 

+

457 if( packetSize == 4 ): 

+

458 commandResponse = int.from_bytes( data[offset:offset+4], byteorder='little' ) 

+

459 offset += 4 

+

460 else: 

+

461 message, separator, remainder = bytes(data[offset:]).partition( b'\0' ) 

+

462 offset += len( message ) + 1 

+

463 trace( "Command response:", message.decode( 'utf-8' ) ) 

+

464 elif( messageID == self.NAT_UNRECOGNIZED_REQUEST ): 

+

465 trace( "Received 'Unrecognized request' from server" ) 

+

466 elif( messageID == self.NAT_MESSAGESTRING ): 

+

467 message, separator, remainder = bytes(data[offset:]).partition( b'\0' ) 

+

468 offset += len( message ) + 1 

+

469 trace( "Received message from server:", message.decode( 'utf-8' ) ) 

+

470 else: 

+

471 trace( "ERROR: Unrecognized packet type" ) 

+

472 

+

473 trace( "End Packet\n----------\n" ) 

+

474 #print('Finished Processing') 

+

475 

+

476 def sendCommand( self, command, commandStr, socket, address ): 

+

477 # Compose the message in our known message format 

+

478 if( command == self.NAT_REQUEST_MODELDEF or command == self.NAT_REQUEST_FRAMEOFDATA ): 

+

479 packetSize = 0 

+

480 commandStr = "" 

+

481 elif( command == self.NAT_REQUEST ): 

+

482 packetSize = len( commandStr ) + 1 

+

483 elif( command == self.NAT_PING ): 

+

484 commandStr = "Ping" 

+

485 packetSize = len( commandStr ) + 1 

+

486 

+

487 data = command.to_bytes( 2, byteorder='little' ) 

+

488 data += packetSize.to_bytes( 2, byteorder='little' ) 

+

489 

+

490 data += commandStr.encode( 'utf-8' ) 

+

491 data += b'\0' 

+

492 

+

493 socket.sendto( data, address ) 

+

494 

+

495 def run( self ): 

+

496 # Create the data socket 

+

497 self.dataSocket = self.__createDataSocket( self.dataPort ) 

+

498 if( self.dataSocket is None ): 

+

499 print( "Could not open data channel" ) 

+

500 exit 

+

501 

+

502 # Create the command socket 

+

503 self.commandSocket = self.__createCommandSocket() 

+

504 if( self.commandSocket is None ): 

+

505 print( "Could not open command channel" ) 

+

506 exit 

+

507 

+

508 # Create a separate thread for receiving data packets 

+

509 dataThread = Thread( target = self.__dataThreadFunction, args = (self.dataSocket, )) 

+

510 dataThread.start() 

+

511 

+

512 # Create a separate thread for receiving command packets 

+

513 commandThread = Thread( target = self.__dataThreadFunction, args = (self.commandSocket, )) 

+

514 commandThread.start() 

+

515 

+

516 self.sendCommand( self.NAT_REQUEST_MODELDEF, "", self.commandSocket, (self.serverIPAddress, self.commandPort) ) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_PythonSample_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_PythonSample_py.html new file mode 100644 index 00000000..5fb16079 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_PythonSample_py.html @@ -0,0 +1,105 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\optitrack_client\PythonSample.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1#Copyright © 2018 Naturalpoint 

+

2# 

+

3#Licensed under the Apache License, Version 2.0 (the "License"); 

+

4#you may not use this file except in compliance with the License. 

+

5#You may obtain a copy of the License at 

+

6# 

+

7#http://www.apache.org/licenses/LICENSE-2.0 

+

8# 

+

9#Unless required by applicable law or agreed to in writing, software 

+

10#distributed under the License is distributed on an "AS IS" BASIS, 

+

11#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

+

12#See the License for the specific language governing permissions and 

+

13#limitations under the License. 

+

14 

+

15 

+

16# OptiTrack NatNet direct depacketization sample for Python 3.x 

+

17# 

+

18# Uses the Python NatNetClient.py library to establish a connection (by creating a NatNetClient), 

+

19# and receive data via a NatNet connection and decode it using the NatNetClient library. 

+

20 

+

21from NatNetClient import NatNetClient 

+

22 

+

23# This is a callback function that gets connected to the NatNet client and called once per mocap frame. 

+

24def receiveNewFrame( frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, 

+

25 labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ): 

+

26 print( "Received frame", frameNumber ) 

+

27 

+

28# This is a callback function that gets connected to the NatNet client. It is called once per rigid body per frame 

+

29def receiveRigidBodyFrame( id, position, rotation ): 

+

30 print( "Received frame for rigid body", position ) 

+

31 

+

32# This will create a new NatNet client 

+

33streamingClient = NatNetClient() 

+

34 

+

35# Configure the streaming client to call our rigid body handler on the emulator to send data out. 

+

36streamingClient.newFrameListener = receiveNewFrame 

+

37streamingClient.rigidBodyListener = receiveRigidBodyFrame 

+

38 

+

39# Start up the streaming client now that the callbacks are set up. 

+

40# This will run perpetually, and operate on a separate thread. 

+

41streamingClient.run() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client___init___py.html new file mode 100644 index 00000000..5259c4dd --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client___init___py.html @@ -0,0 +1,64 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\optitrack_client\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_optitrack_direct_pack_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_optitrack_direct_pack_py.html new file mode 100644 index 00000000..25d0e9e2 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_optitrack_direct_pack_py.html @@ -0,0 +1,151 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\optitrack_client\optitrack_direct_pack.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1from riglib.optitrack_client.NatNetClient import NatNetClient as TestClient 

+

2import numpy as np 

+

3from multiprocessing import Process,Lock 

+

4import pickle 

+

5 

+

6mutex = Lock() 

+

7 

+

8class System(object): 

+

9 """ 

+

10 this is is the dataSource interface for getting the mocap at BMI3D's reqeust 

+

11 compatible with DataSourceSystem 

+

12 uses data_array to keep track of the lastest buffer 

+

13 """ 

+

14 rigidBodyCount = 1 

+

15 update_freq = 120 

+

16 dtype = np.dtype((np.float, (rigidBodyCount, 6))) #6 degress of freedo 

+

17 def __init__(self): 

+

18 self.rigid_body_count = 1 #for now,only one rigid body 

+

19 

+

20 self.test_client = TestClient() 

+

21 self.num_length = 10 # slots for buffer 

+

22 self.data_array = [None] * self.num_length 

+

23 self.rotation_buffer = [None] * self.num_length 

+

24 

+

25 # This is a callback function that gets connected to the NatNet client and called once per mocap frame. 

+

26 def receiveNewFrame(self, frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, 

+

27 labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ): 

+

28 #print( "Received frame", frameNumber ) 

+

29 pass 

+

30 

+

31 # This is a callback function that gets connected to the NatNet client. It is called once per rigid body per frame 

+

32 def receiveRigidBodyFrame(self, id, position, rotation ): 

+

33 #print( "Received frame for rigid body", position ) 

+

34 

+

35 #save to the running buffer with a lock 

+

36 with mutex: 

+

37 self.data_array.insert(0,position) 

+

38 self.data_array.pop() 

+

39 #save to rotation buffer list 

+

40 self.rotation_buffer.insert(0,position) 

+

41 self.rotation_buffer.pop() 

+

42 

+

43 

+

44 def start(self): 

+

45 self.test_client.newFrameListener = self.receiveNewFrame 

+

46 self.test_client.rigidBodyListener =self.receiveRigidBodyFrame 

+

47 self.test_client.run() 

+

48 print('Started the interface thread') 

+

49 

+

50 def stop(self): 

+

51 pass 

+

52 

+

53 def get(self): 

+

54 current_value = None 

+

55 rotation_value = None 

+

56 pos_rot = None 

+

57 

+

58 with mutex: 

+

59 current_value = self.data_array[0] 

+

60 rotation_value = self.rotation_buffer[0] 

+

61 

+

62 #return the latest saved data 

+

63 if (not current_value is None) and (not rotation_value is None): 

+

64 pos_rot = np.concatenate((np.asarray(current_value),np.asarray(rotation_value))) 

+

65 

+

66 pos_rot = np.expand_dims(pos_rot, axis = 0) 

+

67 print(pos_rot.shape) 

+

68 return pos_rot #return that (x,y,z, rotation matrix) 

+

69 

+

70class Simulation(System): 

+

71 ''' 

+

72 this class does all the things except when the optitrack is not broadcasting data 

+

73 the get function starts to return random numbers 

+

74 ''' 

+

75 update_freq = 60 #Hz 

+

76 

+

77 def get(self): 

+

78 mag_fac = 10 

+

79 current_value = np.random.rand(self.rigidBodyCount, 6) * mag_fac 

+

80 current_value = np.expand_dims(current_value, axis = 0) 

+

81 return current_value 

+

82 

+

83 

+

84if __name__ == "__main__": 

+

85 s = System() 

+

86 s.start() 

+

87 s.get() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_optitrack_interface_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_optitrack_interface_py.html new file mode 100644 index 00000000..46e3c10c --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_optitrack_interface_py.html @@ -0,0 +1,190 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\optitrack_client\optitrack_interface.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import numpy as np 

+

2import sys, time 

+

3import socket 

+

4 

+

5 

+

6N_TEST_FRAMES = 1 #number of testing frames during start 

+

7class System(object): 

+

8 """ 

+

9 this is is the dataSource interface for getting the mocap at BMI3D's reqeust 

+

10 compatible with DataSourceSystem 

+

11 uses data_array to keep track of the lastest buffer 

+

12 """ 

+

13 port_num = 1230 #same as the optitrack #default to 1230 

+

14 HEADERSIZE = 10 

+

15 rece_byte_size = 512 

+

16 debug = True 

+

17 optitrack_ip_addr = "10.155.206.1" 

+

18 TIME_OUT_TIME = 2 

+

19 

+

20 

+

21 rigidBodyCount = 1 

+

22 update_freq = 120 

+

23 dtype = np.dtype((np.float, (rigidBodyCount, 6))) #6 degress of freedom 

+

24 

+

25 def __init__(self): 

+

26 self.rigid_body_count = 1 #for now,only one rigid body 

+

27 

+

28 

+

29 

+

30 

+

31 def start(self): 

+

32 #start to connect to the client 

+

33 #set up the socket 

+

34 self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 

+

35 self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

+

36 self.s.settimeout(self.TIME_OUT_TIME) 

+

37 

+

38 print("connecting to the c# server") 

+

39 ''' 

+

40 self.s.bind(('', 1230)) #bind to all incoming request 

+

41 self.s.listen() #listen to one clinet 

+

42 ''' 

+

43 try: 

+

44 #clientsocket, address = self.s.accept() 

+

45 self.s.connect((self.optitrack_ip_addr, self.port_num)) 

+

46 except: 

+

47 print("cannot connect to Motive") 

+

48 print("Is the c# server running?") 

+

49 

+

50 #otherwise it works as expected and set client to be a  

+

51 #class property 

+

52 #self.clientsocket = clientsocket 

+

53 print(f"Connection to c# client \ 

+

54 {self.optitrack_ip_addr} has been established.") 

+

55 

+

56 #automatically pull 10 frames 

+

57 # and cal the mean round trip time 

+

58 t1 = time.perf_counter() 

+

59 for i in range(N_TEST_FRAMES): self.get() 

+

60 t2 = time.perf_counter() 

+

61 print(f'time to grab {N_TEST_FRAMES} frames : \ 

+

62 {(t2 - t1)} s ') 

+

63 

+

64 

+

65 def stop(self): 

+

66 msg = "stop" 

+

67 self.send_command(msg) 

+

68 #close the socket 

+

69 #self.s.close() 

+

70 print("socket closed!") 

+

71 

+

72 def get(self): 

+

73 #the property that gets one frame of data 

+

74 # 3 positions and 3 angles 

+

75 #the last element is frame number 

+

76 msg = "get" 

+

77 result_string = self.send_and_receive(msg) 

+

78 motive_frame = np.fromstring(result_string, sep=',') 

+

79 current_value = motive_frame[:6] #only using the motion data 

+

80 current_value.transpose() 

+

81 

+

82 

+

83 #for some weird reason, the string needs to be expanded.. 

+

84 #just send the motion data for now 

+

85 current_value = np.expand_dims(current_value, axis = 0) 

+

86 current_value = np.expand_dims(current_value, axis = 0) 

+

87 return current_value 

+

88 

+

89 

+

90 

+

91 def send_command(self, msg): 

+

92 #get the message in string and encode in bytes and send to the socket 

+

93 msg = f"{len(msg):<{self.HEADERSIZE}}"+msg 

+

94 msg_ascii = msg.encode("ascii") 

+

95 self.s.send(msg_ascii) 

+

96 

+

97 def send_and_receive(self, msg): 

+

98 #this function sends a command 

+

99 #and then wait for a response 

+

100 msg = f"{len(msg):<{self.HEADERSIZE}}"+msg 

+

101 msg_ascii = msg.encode("ascii") 

+

102 self.s.send(msg_ascii) 

+

103 result_in_bytes = self.s.recv(self.rece_byte_size) 

+

104 return str(result_in_bytes,encoding="ASCII") 

+

105 

+

106 

+

107class Simulation(System): 

+

108 ''' 

+

109 this class does all the things except when the optitrack is not broadcasting data 

+

110 the get function starts to return random numbers 

+

111 ''' 

+

112 update_freq = 60 #Hz 

+

113 def get(self): 

+

114 mag_fac = 10 

+

115 current_value = np.random.rand(self.rigidBodyCount, 6) * mag_fac 

+

116 current_value = np.expand_dims(current_value, axis = 0) 

+

117 return current_value 

+

118 

+

119 

+

120if __name__ == "__main__": 

+

121 s = System() 

+

122 s.start() 

+

123 s.send_command("start_rec") 

+

124 time.sleep(5) 

+

125 s.stop() 

+

126 print("finished") 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_NatNetClient_perframe_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_NatNetClient_perframe_py.html new file mode 100644 index 00000000..4b962839 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_NatNetClient_perframe_py.html @@ -0,0 +1,86 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\optitrack_client\test_NatNetClient_perframe.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1from NatNetClient import NatNetClient 

+

2 

+

3# This will create a new NatNet client 

+

4streamingClient = NatNetClient() 

+

5 

+

6streamingClient.dataSocket = streamingClient.__createDataSocket(streamingClient.dataPort) 

+

7if (streamingClient.dataSocket is None): 

+

8 print("Could not open data channel") 

+

9 exit 

+

10 

+

11# Create the command socket 

+

12streamingClient.commandSocket = streamingClient.__createCommandSocket() 

+

13if (streamingClient.commandSocket is None): 

+

14 print("Could not open command channel") 

+

15 exit 

+

16 

+

17 

+

18# receive some data 

+

19 

+

20data, addr = streamingClient.dataSocket.recvfrom(1024) # 32k byte buffer size 

+

21if (len(data) > 0): 

+

22 streamingClient__processMessage(data) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_control_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_control_py.html new file mode 100644 index 00000000..e162001a --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_control_py.html @@ -0,0 +1,87 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\optitrack_client\test_control.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1from NatNetClient import NatNetClient 

+

2 

+

3 

+

4# This is a callback function that gets connected to the NatNet client and called once per mocap frame. 

+

5def receiveNewFrame( frameNumber, markerSetCount, unlabeledMarkersCount, rigidBodyCount, skeletonCount, 

+

6 labeledMarkerCount, timecode, timecodeSub, timestamp, isRecording, trackedModelsChanged ): 

+

7 #print( "Received frame", frameNumber ) 

+

8 pass 

+

9 

+

10# This is a callback function that gets connected to the NatNet client. It is called once per rigid body per frame 

+

11def receiveRigidBodyFrame( id, position, rotation ): 

+

12 #print( "Received frame for rigid body", position ) 

+

13 pass 

+

14 

+

15# This will create a new NatNet client 

+

16test_client = NatNetClient() 

+

17 

+

18# Configure the streaming client to call our rigid body handler on the emulator to send data out. 

+

19test_client.newFrameListener = receiveNewFrame 

+

20test_client.rigidBodyListener = receiveRigidBodyFrame 

+

21 

+

22test_client.sendCommand( test_client.NAT_REQUEST_MODELDEF, "", test_client.commandSocket, 

+

23 (test_client.serverIPAddress, test_client.commandPort) ) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_optitrack_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_optitrack_py.html new file mode 100644 index 00000000..af5a2104 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_optitrack_client_test_optitrack_py.html @@ -0,0 +1,74 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\optitrack_client\test_optitrack.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1from riglib.optitrack_client.optitrack_direct_pack import System 

+

2import time 

+

3 

+

4num_length = 10 

+

5motion_data = System() 

+

6motion_data.start() 

+

7 

+

8while True: 

+

9 print(motion_data.get()) 

+

10 time.sleep(0.05) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_phidgets_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_phidgets_py.html new file mode 100644 index 00000000..ff72d0a8 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_phidgets_py.html @@ -0,0 +1,191 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\phidgets.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Code for interacting with the Phdigets API 

+

3''' 

+

4import time 

+

5import itertools 

+

6import numpy as np 

+

7 

+

8from Phidgets.Devices.InterfaceKit import InterfaceKit 

+

9from .source import DataSourceSystem 

+

10 

+

11class System(DataSourceSystem): 

+

12 ''' 

+

13 Generic DataSourceSystem interface for the Phidgets board: http://www.phidgets.com/products.php?category=0&product_id=1018_2 

+

14 ''' 

+

15 update_freq = 1000 

+

16 

+

17 def __init__(self, n_sensors=2, n_inputs=1): 

+

18 ''' 

+

19 Docstring 

+

20 

+

21 Parameters 

+

22 ---------- 

+

23 

+

24 Returns 

+

25 ------- 

+

26 ''' 

+

27 self.n_sensors = n_sensors 

+

28 self.n_inputs = n_inputs 

+

29 self.interval = 1. / self.update_freq 

+

30 

+

31 self.sensordat = np.zeros((n_sensors,)) 

+

32 self.inputdat = np.zeros((n_inputs,), dtype=np.bool) 

+

33 self.data = np.zeros((1,), dtype=self.dtype) 

+

34 

+

35 self.kit = InterfaceKit() 

+

36 self.kit.openPhidget() 

+

37 self.kit.waitForAttach(2000) 

+

38 

+

39 def start(self): 

+

40 ''' 

+

41 Docstring 

+

42 

+

43 Parameters 

+

44 ---------- 

+

45 

+

46 Returns 

+

47 ------- 

+

48 ''' 

+

49 self.tic = time.time() 

+

50 

+

51 def stop(self): 

+

52 ''' 

+

53 Docstring 

+

54 

+

55 Parameters 

+

56 ---------- 

+

57 

+

58 Returns 

+

59 ------- 

+

60 ''' 

+

61 pass 

+

62 

+

63 def get(self): 

+

64 ''' 

+

65 Docstring 

+

66 

+

67 Parameters 

+

68 ---------- 

+

69 

+

70 Returns 

+

71 ------- 

+

72 ''' 

+

73 toc = time.time() - self.tic 

+

74 if 0 < toc < self.interval: 

+

75 time.sleep(self.interval - toc) 

+

76 try: 

+

77 for i in range(self.n_sensors): 

+

78 self.sensordat[i] = self.kit.getSensorValue(i) / 1000. 

+

79 for i in range(self.n_inputs): 

+

80 self.inputdat[i] = self.kit.getInputState(i) 

+

81 except: 

+

82 print('sensor_error') 

+

83 self.data['sensors'] = self.sensordat 

+

84 self.data['inputs'] = self.inputdat 

+

85 self.tic = time.time() 

+

86 return self.data 

+

87 

+

88 def sendMsg(self, msg): 

+

89 ''' 

+

90 Docstring 

+

91 

+

92 Parameters 

+

93 ---------- 

+

94 

+

95 Returns 

+

96 ------- 

+

97 ''' 

+

98 pass 

+

99 

+

100 def __del__(self): 

+

101 ''' 

+

102 Docstring 

+

103 

+

104 Parameters 

+

105 ---------- 

+

106 

+

107 Returns 

+

108 ------- 

+

109 ''' 

+

110 self.kit.closePhidget() 

+

111 

+

112def make(sensors, inputs, cls=System, **kwargs): 

+

113 ''' 

+

114 Docstring 

+

115 This ridiculous function dynamically creates a class with a new init function 

+

116 

+

117 Parameters 

+

118 ---------- 

+

119 

+

120 Returns 

+

121 ------- 

+

122 ''' 

+

123 def init(self, **kwargs): 

+

124 super(self.__class__, self).__init__(n_sensors=sensors, n_inputs=inputs, **kwargs) 

+

125 

+

126 dtype = np.dtype([('sensors', np.float, (sensors,)), ('inputs', np.bool, (inputs,))]) 

+

127 return type(cls.__name__, (cls,), dict(dtype=dtype, __init__=init)) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plants_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plants_py.html new file mode 100644 index 00000000..a6d47602 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plants_py.html @@ -0,0 +1,624 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\plants.py: 49% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1#!/usr/bin/python 

+

2''' 

+

3Representations of plants (control systems) 

+

4''' 

+

5import os 

+

6os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 

+

7import numpy as np 

+

8from .stereo_opengl.primitives import Cylinder, Sphere, Cone, Cube, Chain 

+

9from .stereo_opengl.models import Group 

+

10from riglib.bmi import robot_arms 

+

11from riglib.stereo_opengl.xfm import Quaternion 

+

12 

+

13import sys 

+

14import time 

+

15import socket 

+

16import select 

+

17import numpy as np 

+

18from collections import namedtuple 

+

19 

+

20from utils.constants import * 

+

21from riglib import source 

+

22import robot 

+

23 

+

24import struct 

+

25from riglib.bmi.robot_arms import KinematicChain 

+

26import pygame 

+

27import math 

+

28 

+

29class RefTrajectories(dict): 

+

30 ''' 

+

31 Generic class to hold trajectories to be replayed by a plant.  

+

32 For now, this class is just a dictionary that has had its type changed 

+

33 ''' 

+

34 pass 

+

35 

+

36 

+

37from riglib.source import DataSourceSystem 

+

38class FeedbackData(DataSourceSystem): 

+

39 ''' 

+

40 Generic class for parsing UDP feedback data from a plant. Meant to be used with  

+

41 riglib.source.DataSource to grab and log data asynchronously.  

+

42 

+

43 See DataSourceSystem for notes on the source interface 

+

44 ''' 

+

45 

+

46 MAX_MSG_LEN = 300 

+

47 sleep_time = 0 

+

48 

+

49 # must define these in subclasses 

+

50 update_freq = None 

+

51 address = None 

+

52 dtype = None 

+

53 

+

54 def __init__(self): 

+

55 self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 

+

56 self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

+

57 self.sock.bind(self.address) 

+

58 

+

59 # self.file_ = open(self.feedback_filename, 'w') 

+

60 

+

61 def start(self): 

+

62 self.listening = True 

+

63 self.data = self.get_feedback_data() 

+

64 

+

65 def stop(self): 

+

66 self.listening = False 

+

67 self.sock.close() 

+

68 # self.file_.close() 

+

69 

+

70 def __del__(self): 

+

71 # The stop commands for the socket should be issued before this object is garbage-collected, but just in case... 

+

72 self.stop() 

+

73 

+

74 def get(self): 

+

75 return next(self.data) 

+

76 

+

77 def get_feedback_data(self): 

+

78 '''Yield received feedback data.''' 

+

79 

+

80 self.last_timestamp = -1 

+

81 

+

82 while self.listening: 

+

83 r, _, _ = select.select([self.sock], [], [], 0) 

+

84 

+

85 if r: # if the list r is not empty 

+

86 feedback = self.sock.recv(self.MAX_MSG_LEN) 

+

87 ts_arrival = time.time() # secs 

+

88 

+

89 # print "feedback:", feedback 

+

90 # self.file_.write(feedback.rstrip('\r') + "\n") 

+

91 

+

92 processed_feedback = self.process_received_feedback(feedback, ts_arrival) 

+

93 

+

94 if processed_feedback['ts'] != self.last_timestamp: 

+

95 yield processed_feedback 

+

96 

+

97 self.last_timestamp = processed_feedback['ts'] 

+

98 

+

99 time.sleep(self.sleep_time) 

+

100 

+

101 def process_received_feedback(self, feedback, ts_arrival): 

+

102 raise NotImplementedError('Implement in subclasses!') 

+

103 

+

104 

+

105 

+

106 

+

107class Plant(object): 

+

108 ''' 

+

109 Generic interface for task-plant interaction 

+

110 ''' 

+

111 hdf_attrs = [] 

+

112 def __init__(self, *args, **kwargs): 

+

113 pass 

+

114 

+

115 def drive(self, decoder): 

+

116 ''' 

+

117 Call this function to 'drive' the plant to the state specified by the decoder 

+

118 

+

119 Parameters 

+

120 ---------- 

+

121 decoder : bmi.Decoder instance  

+

122 Decoder used to estimate the state of/control the plant  

+

123 

+

124 Returns 

+

125 ------- 

+

126 None 

+

127 ''' 

+

128 # Instruct the plant to go to the decoder-specified intrinsic coordinates 

+

129 # decoder['q'] is a special __getitem__ case. See riglib.bmi.Decoder.__getitem__/__setitem__ 

+

130 self.set_intrinsic_coordinates(decoder['q']) 

+

131 

+

132 # Not all intrinsic coordinates will be achievable. So determine where the plant actually went 

+

133 intrinsic_coords = self.get_intrinsic_coordinates() 

+

134 

+

135 # Update the decoder state with the current state of the plant, after the last command 

+

136 if not np.any(np.isnan(intrinsic_coords)): 

+

137 decoder['q'] = self.get_intrinsic_coordinates() 

+

138 

+

139 def get_data_to_save(self): 

+

140 ''' 

+

141 Get data to save regarding the state of the plant on every iteration of the event loop 

+

142 

+

143 Parameters 

+

144 ---------- 

+

145 None 

+

146 

+

147 Returns 

+

148 ------- 

+

149 dict:  

+

150 keys are strings, values are np.ndarray objects of data values 

+

151 ''' 

+

152 return dict() 

+

153 

+

154 def init(self): 

+

155 ''' 

+

156 Secondary initialization after object construction. Does nothing by default 

+

157 ''' 

+

158 pass 

+

159 

+

160 def start(self): 

+

161 ''' 

+

162 Start any auxiliary processes used by the plant 

+

163 ''' 

+

164 pass 

+

165 

+

166 def stop(self): 

+

167 ''' 

+

168 Stop any auxiliary processes used by the plant 

+

169 ''' 

+

170 pass 

+

171 

+

172 def init_decoder(self, decoder): 

+

173 decoder['q'] = self.get_intrinsic_coordinates() 

+

174 

+

175 

+

176class AsynchronousPlant(Plant): 

+

177 def init(self): 

+

178 from riglib import sink 

+

179 sink.sinks.register(self.source) 

+

180 super(AsynchronousPlant, self).init() 

+

181 

+

182 def start(self): 

+

183 ''' 

+

184 Only start the DataSource after it has been registered with  

+

185 The SinkManager singleton (sink.sinks) in the call to init() 

+

186 ''' 

+

187 self.source.start() 

+

188 super(AsynchronousPlant, self).start() 

+

189 

+

190 def stop(self): 

+

191 self.source.stop() 

+

192 super(AsynchronousPlant, self).stop() 

+

193 

+

194################################################### 

+

195##### Virtual plants for specific experiments ##### 

+

196################################################### 

+

197class CursorPlant(Plant): 

+

198 ''' 

+

199 Create a plant which is a 2-D or 3-D cursor on a screen/stereo display 

+

200 ''' 

+

201 hdf_attrs = [('cursor', 'f8', (3,))] 

+

202 def __init__(self, endpt_bounds=None, cursor_radius=0.4, cursor_color=(.5, 0, .5, 1), starting_pos=np.array([0., 0., 0.]), vel_wall=True, **kwargs): 

+

203 self.endpt_bounds = endpt_bounds 

+

204 self.position = starting_pos 

+

205 self.starting_pos = starting_pos 

+

206 self.cursor_radius = cursor_radius 

+

207 self.cursor_color = cursor_color 

+

208 self._pickle_init() 

+

209 self.vel_wall = vel_wall 

+

210 

+

211 def _pickle_init(self): 

+

212 self.cursor = Sphere(radius=self.cursor_radius, color=self.cursor_color) 

+

213 self.cursor.translate(*self.position, reset=True) 

+

214 self.graphics_models = [self.cursor] 

+

215 

+

216 def draw(self): 

+

217 self.cursor.translate(*self.position, reset=True) 

+

218 

+

219 def get_endpoint_pos(self): 

+

220 return self.position 

+

221 

+

222 def set_endpoint_pos(self, pt, **kwargs): 

+

223 self.position = pt 

+

224 self.draw() 

+

225 

+

226 def get_intrinsic_coordinates(self): 

+

227 return self.position 

+

228 

+

229 def set_intrinsic_coordinates(self, pt): 

+

230 self.position = pt 

+

231 self.draw() 

+

232 

+

233 def set_visibility(self, visible): 

+

234 self.visible = visible 

+

235 if visible: 

+

236 self.graphics_models[0].attach() 

+

237 else: 

+

238 self.graphics_models[0].detach() 

+

239 

+

240 def _bound(self, pos, vel): 

+

241 pos = pos.copy() 

+

242 vel = vel.copy() 

+

243 if len(vel) == 0: 

+

244 vel_wall = self.vel_wall # don't worry about vel if it's empty 

+

245 self.vel_wall = False 

+

246 if self.endpt_bounds is not None: 

+

247 if pos[0] < self.endpt_bounds[0]: 

+

248 pos[0] = self.endpt_bounds[0] 

+

249 if self.vel_wall: vel[0] = 0 

+

250 if pos[0] > self.endpt_bounds[1]: 

+

251 pos[0] = self.endpt_bounds[1] 

+

252 if self.vel_wall: vel[0] = 0 

+

253 

+

254 if pos[1] < self.endpt_bounds[2]: 

+

255 pos[1] = self.endpt_bounds[2] 

+

256 if self.vel_wall: vel[1] = 0 

+

257 if pos[1] > self.endpt_bounds[3]: 

+

258 pos[1] = self.endpt_bounds[3] 

+

259 if self.vel_wall: vel[1] = 0 

+

260 

+

261 if pos[2] < self.endpt_bounds[4]: 

+

262 pos[2] = self.endpt_bounds[4] 

+

263 if self.vel_wall: vel[2] = 0 

+

264 if pos[2] > self.endpt_bounds[5]: 

+

265 pos[2] = self.endpt_bounds[5] 

+

266 if self.vel_wall: vel[2] = 0 

+

267 if len(vel) == 0: 

+

268 self.vel_wall = vel_wall # restore previous value 

+

269 return pos, vel 

+

270 

+

271 def drive(self, decoder): 

+

272 pos = decoder['q'].copy() 

+

273 vel = decoder['qdot'].copy() 

+

274 

+

275 pos, vel = self._bound(pos, vel) 

+

276 

+

277 decoder['q'] = pos 

+

278 decoder['qdot'] = vel 

+

279 super(CursorPlant, self).drive(decoder) 

+

280 

+

281 def get_data_to_save(self): 

+

282 return dict(cursor=self.position) 

+

283 

+

284 

+

285class AuditoryCursor(Plant): 

+

286 ''' 

+

287 An auditory cursor that changes frequency accordingly 

+

288 ''' 

+

289 hdf_attrs = [('aud_cursor_freq', 'f8', (1,))] 

+

290 

+

291 def __init__(self, min_freq, max_freq, sound_duration=0.1): 

+

292 self.min_freq = min_freq 

+

293 self.max_freq = max_freq 

+

294 self.bits = 16 

+

295 

+

296 pygame.mixer.pre_init(44100, -self.bits, 2) 

+

297 pygame.init() 

+

298 

+

299 duration = 0.1 # in seconds 

+

300 self.sample_rate = 44100 

+

301 self.n_samples = int(round(duration*self.sample_rate)) 

+

302 self.max_sample = 2**(self.bits - 1) - 1 

+

303 self.freq = 0 

+

304 #setup our numpy array to handle 16 bit ints, which is what we set our mixer to expect with "bits" up above 

+

305 self.buf = np.zeros((self.n_samples, 2), dtype = np.int16) 

+

306 self.buf_ext = np.zeros((10*self.n_samples, 2), dtype=np.int16) 

+

307 self.t = np.arange(self.n_samples)/float(self.n_samples)*duration 

+

308 self.t0 = np.arange(self.n_samples)*0 

+

309 self.t_start = time.time() 

+

310 

+

311 def drive(self, decoder): 

+

312 self.freq = decoder.filt.F 

+

313 

+

314 if np.logical_and(decoder.cnt == 0, decoder.feedback): 

+

315 #Just got reset:  

+

316 if self.freq > self.max_freq: 

+

317 self.freq = self.max_freq 

+

318 elif self.freq < self.min_freq: 

+

319 self.freq = self.min_freq 

+

320 self.play_freq() 

+

321 

+

322 def play_freq(self): 

+

323 self.buf[:,0] = np.round(self.max_sample*np.sin(2*math.pi*self.freq*self.t)).astype(int) 

+

324 self.buf[:,1] = np.round(self.max_sample*np.sin(2*math.pi*self.freq*self.t0)).astype(int) 

+

325 sound = pygame.sndarray.make_sound(self.buf) 

+

326 sound.play() 

+

327 

+

328 def play_white_noise(self): 

+

329 self.buf_ext[:,0] = np.round(self.max_sample*np.random.normal(0, self.max_sample/2., (10*self.n_samples, ))).astype(int) 

+

330 self.buf_ext[:,1] = np.round(self.max_sample*np.zeros((10*self.n_samples, ))).astype(int) 

+

331 sound = pygame.sndarray.make_sound(self.buf_ext) 

+

332 sound.play() 

+

333 

+

334 

+

335 def get_intrinsic_coordinates(self): 

+

336 return self.freq 

+

337 

+

338 

+

339class onedimLFP_CursorPlant(CursorPlant): 

+

340 ''' 

+

341 A square cursor confined to vertical movement 

+

342 ''' 

+

343 hdf_attrs = [('lfp_cursor', 'f8', (3,))] 

+

344 

+

345 def __init__(self, endpt_bounds, *args, **kwargs): 

+

346 self.lfp_cursor_rad = kwargs['lfp_cursor_rad'] 

+

347 self.lfp_cursor_color = kwargs['lfp_cursor_color'] 

+

348 args=[(), kwargs['lfp_cursor_color']] 

+

349 super(onedimLFP_CursorPlant, self).__init__(endpt_bounds, *args, **kwargs) 

+

350 

+

351 

+

352 def _pickle_init(self): 

+

353 self.cursor = Cube(side_len=self.lfp_cursor_rad, color=self.lfp_cursor_color) 

+

354 self.cursor.translate(*self.position, reset=True) 

+

355 self.graphics_models = [self.cursor] 

+

356 

+

357 def drive(self, decoder): 

+

358 pos = decoder.filt.get_mean() 

+

359 pos = [-8, -2.2, pos] 

+

360 

+

361 if self.endpt_bounds is not None: 

+

362 if pos[2] < self.endpt_bounds[4]: 

+

363 pos[2] = self.endpt_bounds[4] 

+

364 

+

365 if pos[2] > self.endpt_bounds[5]: 

+

366 pos[2] = self.endpt_bounds[5] 

+

367 

+

368 self.position = pos 

+

369 self.draw() 

+

370 

+

371 def turn_off(self): 

+

372 self.cursor.detach() 

+

373 

+

374 def turn_on(self): 

+

375 self.cursor.attach() 

+

376 

+

377 def get_data_to_save(self): 

+

378 return dict(lfp_cursor=self.position) 

+

379 

+

380class onedimLFP_CursorPlant_inverted(onedimLFP_CursorPlant): 

+

381 ''' 

+

382 A square cursor confined to vertical movement 

+

383 ''' 

+

384 hdf_attrs = [('lfp_cursor', 'f8', (3,))] 

+

385 

+

386 def drive(self, decoder): 

+

387 std_pos = decoder.filt.get_mean() 

+

388 inv_pos = [-8, -2.2, -1.0*std_pos] 

+

389 

+

390 if self.endpt_bounds is not None: 

+

391 if inv_pos[2] < self.endpt_bounds[4]: 

+

392 inv_pos[2] = self.endpt_bounds[4] 

+

393 

+

394 if inv_pos[2] > self.endpt_bounds[5]: 

+

395 inv_pos[2] = self.endpt_bounds[5] 

+

396 

+

397 self.position = inv_pos 

+

398 self.draw() 

+

399 

+

400class twodimLFP_CursorPlant(onedimLFP_CursorPlant): 

+

401 '''Same as 1d cursor but assumes decoder returns array ''' 

+

402 def drive(self, decoder): 

+

403 #Pos = (Left-Right, 0, Up-Down) 

+

404 pos = decoder.filt.get_mean() 

+

405 pos = [pos[0], -2.2, pos[2]] 

+

406 #pos = [-8, -2.2, pos[2]] 

+

407 

+

408 if self.endpt_bounds is not None: 

+

409 if pos[2] < self.endpt_bounds[4]: 

+

410 pos[2] = self.endpt_bounds[4] 

+

411 

+

412 if pos[2] > self.endpt_bounds[5]: 

+

413 pos[2] = self.endpt_bounds[5] 

+

414 

+

415 self.position = pos 

+

416 self.draw() 

+

417 

+

418 

+

419arm_color = (181/256., 116/256., 96/256., 1) 

+

420arm_radius = 0.6 

+

421pi = np.pi 

+

422class RobotArmGen2D(Plant): 

+

423 ''' 

+

424 Generic virtual plant for creating a kinematic chain of any number of links but confined to the X-Z (vertical) plane 

+

425 ''' 

+

426 def __init__(self, link_radii=arm_radius, joint_radii=arm_radius, link_lengths=[15,15,5,5], joint_colors=arm_color, 

+

427 link_colors=arm_color, base_loc=np.array([2., 0., -15]), joint_limits=[(-pi,pi), (-pi,0), (-pi/2,pi/2), (-pi/2, 10*pi/180)], stay_on_screen=False, **kwargs): 

+

428 ''' 

+

429 Instantiate the graphics and the virtual arm for a planar kinematic chain 

+

430 ''' 

+

431 self.num_joints = num_joints = len(link_lengths) 

+

432 

+

433 self.link_lengths = link_lengths 

+

434 self.curr_vecs = np.zeros([num_joints, 3]) #rows go from proximal to distal links 

+

435 

+

436 # set initial vecs to correct orientations (arm starts out vertical) 

+

437 self.curr_vecs[0,2] = self.link_lengths[0] 

+

438 self.curr_vecs[1:,0] = self.link_lengths[1:] 

+

439 

+

440 # Instantiate the kinematic chain object 

+

441 self.kin_chain = self.kin_chain_class(link_lengths, base_loc=base_loc) 

+

442 self.kin_chain.joint_limits = joint_limits 

+

443 

+

444 self.base_loc = base_loc 

+

445 

+

446 self.chain = Chain(link_radii, joint_radii, link_lengths, joint_colors, link_colors) 

+

447 self.cursor = Sphere(radius=arm_radius/2, color=link_colors) 

+

448 self.graphics_models = [self.chain.link_groups[0], self.cursor] 

+

449 

+

450 self.chain.translate(*self.base_loc, reset=True) 

+

451 

+

452 self.hdf_attrs = [('cursor', 'f8', (3,)), ('joint_angles','f8', (self.num_joints, )), ('arm_visible', 'f8', (1,))] 

+

453 

+

454 self.visible = True # arm is visible when initialized 

+

455 

+

456 self.stay_on_screen = stay_on_screen 

+

457 self.joint_angles = np.zeros(self.num_joints) 

+

458 

+

459 @property 

+

460 def kin_chain_class(self): 

+

461 return robot_arms.PlanarXZKinematicChain 

+

462 

+

463 def get_endpoint_pos(self): 

+

464 ''' 

+

465 Returns the current position of the non-anchored end of the arm. 

+

466 ''' 

+

467 return self.kin_chain.endpoint_pos(self.joint_angles) 

+

468 

+

469 def set_endpoint_pos(self, pos, **kwargs): 

+

470 ''' 

+

471 Positions the arm according to specified endpoint position.  

+

472 ''' 

+

473 if pos is not None: 

+

474 # Run the inverse kinematics 

+

475 angles = self.perform_ik(pos, **kwargs) 

+

476 

+

477 # Update the joint configuration  

+

478 self.set_intrinsic_coordinates(angles) 

+

479 

+

480 def perform_ik(self, pos, **kwargs): 

+

481 angles = self.kin_chain.inverse_kinematics(pos, q_start=self.get_intrinsic_coordinates(), verbose=False, eps=0.008, **kwargs).ravel() 

+

482 return angles 

+

483 

+

484 def calc_joint_angles(self, vecs): 

+

485 return np.arctan2(vecs[:,2], vecs[:,0]) 

+

486 

+

487 def get_intrinsic_coordinates(self): 

+

488 ''' 

+

489 Returns the joint angles of the arm in radians 

+

490 ''' 

+

491 return self.calc_joint_angles(self.curr_vecs) 

+

492 

+

493 def set_intrinsic_coordinates(self, theta): 

+

494 ''' 

+

495 Set the joints by specifying the angles in radians. 

+

496 

+

497 Parameters 

+

498 ---------- 

+

499 theta : np.ndarray 

+

500 Theta is a list of angles. If an element of theta = NaN, angle should remain the same. 

+

501 

+

502 Returns 

+

503 ------- 

+

504 None 

+

505 ''' 

+

506 new_endpt_pos = self.kin_chain.endpoint_pos(theta) 

+

507 if self.stay_on_screen and (new_endpt_pos[0] > 25 or new_endpt_pos[0] < -25 or new_endpt_pos[-1] < -14 or new_endpt_pos[-1] > 14): 

+

508 # ignore the command because it would push the endpoint off the screen  

+

509 return 

+

510 

+

511 if not np.any(np.isnan(theta)): 

+

512 self.joint_angles = theta 

+

513 for i in range(self.num_joints): 

+

514 self.curr_vecs[i] = self.link_lengths[i]*np.array([np.cos(theta[i]), 0, np.sin(theta[i])]) 

+

515 

+

516 self.chain._update_link_graphics(self.curr_vecs) 

+

517 self.cursor.translate(*self.get_endpoint_pos(), reset=True) 

+

518 

+

519 def get_data_to_save(self): 

+

520 return dict(cursor=self.get_endpoint_pos(), joint_angles=self.get_intrinsic_coordinates(), arm_visible=self.visible) 

+

521 

+

522 def set_visibility(self, visible): 

+

523 self.visible = visible 

+

524 if visible: 

+

525 self.graphics_models[0].attach() 

+

526 else: 

+

527 self.graphics_models[0].detach() 

+

528 

+

529class EndptControlled2LArm(RobotArmGen2D): 

+

530 ''' 

+

531 2-link arm controlled in extrinsic coordinates (endpoint position) 

+

532 ''' 

+

533 def __init__(self, *args, **kwargs): 

+

534 super(EndptControlled2LArm, self).__init__(*args, **kwargs) 

+

535 self.hdf_attrs = [('cursor', 'f8', (3,)), ('arm_visible','f8', (1,))] 

+

536 

+

537 def get_intrinsic_coordinates(self): 

+

538 return self.get_endpoint_pos() 

+

539 

+

540 def set_intrinsic_coordinates(self, pos, **kwargs): 

+

541 self.set_endpoint_pos(pos, **kwargs) 

+

542 

+

543 def set_endpoint_pos(self, pos, **kwargs): 

+

544 if pos is not None: 

+

545 # Run the inverse kinematics 

+

546 theta = self.perform_ik(pos, **kwargs) 

+

547 self.joint_angles = theta 

+

548 

+

549 for i in range(self.num_joints): 

+

550 if theta[i] is not None and ~np.isnan(theta[i]): 

+

551 self.curr_vecs[i] = self.link_lengths[i]*np.array([np.cos(theta[i]), 0, np.sin(theta[i])]) 

+

552 

+

553 self.chain._update_link_graphics(self.curr_vecs) 

+

554 

+

555 def get_data_to_save(self): 

+

556 return dict(cursor=self.get_endpoint_pos(), arm_visible=self.visible) 

+

557 

+

558 @property 

+

559 def kin_chain_class(self): 

+

560 return robot_arms.PlanarXZKinematicChain2Link 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon___init___py.html new file mode 100644 index 00000000..c944bf00 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon___init___py.html @@ -0,0 +1,287 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\plexon\__init__.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Base code for 'bmi' feature (both spikes and field potentials) when using the plexon system 

+

3''' 

+

4 

+

5import time 

+

6import numpy as np 

+

7from . import plexnet 

+

8from collections import Counter 

+

9import os 

+

10import array 

+

11 

+

12try: 

+

13 from config import config 

+

14 PL_IP = config.plexon_ip 

+

15 PL_PORT = int(config.plexon_port) 

+

16except: 

+

17 PL_IP = "127.0.0.1" # default to localhost 

+

18 PL_PORT = 6000 

+

19PL_ADDR = (PL_IP, PL_PORT) 

+

20 

+

21PL_SingleWFType = 1 

+

22PL_ExtEventType = 4 

+

23PL_ADDataType = 5 

+

24 

+

25from riglib.source import DataSourceSystem 

+

26 

+

27class Spikes(DataSourceSystem): 

+

28 ''' 

+

29 Client for spike data streamed from plexon system, compatible with riglib.source.DataSource 

+

30 ''' 

+

31 update_freq = 40000 

+

32 dtype = np.dtype([("ts", np.float), ("chan", np.int32), ("unit", np.int32), ("arrival_ts", np.float64)]) 

+

33 

+

34 def __init__(self, addr=PL_ADDR, channels=None): 

+

35 ''' 

+

36 Constructor for plexon.Spikes 

+

37 

+

38 Parameters 

+

39 ---------- 

+

40 addr: tuple of length 2 

+

41 (IP address, UDP port) 

+

42 channels: optional, default = None 

+

43 list of channels (electrodes) from which to receive spike data 

+

44 

+

45 Returns 

+

46 ------- 

+

47 Spikes instance 

+

48 ''' 

+

49 self.conn = plexnet.Connection(*addr) 

+

50 self.conn.connect(256, waveforms=False, analog=False) 

+

51 

+

52 try: 

+

53 self.conn.select_spikes(channels) 

+

54 except: 

+

55 print("Cannot run select_spikes method; old system?") 

+

56 

+

57 def start(self): 

+

58 ''' 

+

59 Connect to the plexon server and start receiving data 

+

60 ''' 

+

61 self.conn.start_data() 

+

62 

+

63 # self.data is a generator (the result of self.conn.get_data() is a 'yield').  

+

64 # Calling 'self.data.next()' in the 'get' function pulls a new spike timestamp 

+

65 self.data = self.conn.get_data() 

+

66 

+

67 def stop(self): 

+

68 ''' 

+

69 Disconnect from the plexon server 

+

70 ''' 

+

71 self.conn.stop_data() 

+

72 self.conn.disconnect() 

+

73 

+

74 def get(self): 

+

75 ''' 

+

76 Return a single spike timestamp/waveform. Must be polled continuously for additional spike data. The polling is automatically taken care of by riglib.source.DataSource 

+

77 ''' 

+

78 d = next(self.data) 

+

79 while d.type != PL_SingleWFType: 

+

80 d = next(self.data) 

+

81 

+

82 return np.array([(d.ts / self.update_freq, d.chan, d.unit, d.arrival_ts)], dtype=self.dtype) 

+

83 

+

84 

+

85class LFP(DataSourceSystem): 

+

86 ''' 

+

87 Client for local field potential data streamed from plexon system, compatible with riglib.source.MultiChanDataSource 

+

88 ''' 

+

89 update_freq = 1000. 

+

90 

+

91 gain_digiamp = 1000. 

+

92 gain_headstage = 1. 

+

93 

+

94 # like the Spikes class, dtype is the numpy data type of items that will go  

+

95 # into the (multi-channel, in this case) datasource's ringbuffer 

+

96 # unlike the Spikes class, the get method below does not return objects of  

+

97 # this type (this has to do with the fact that a potentially variable  

+

98 # amount of LFP data is returned in d.waveform every time 

+

99 # self.data.next() is called 

+

100 dtype = np.dtype('float') 

+

101 

+

102 def __init__(self, addr=PL_ADDR, channels=None, chan_offset=512): 

+

103 ''' 

+

104 Constructor for plexon.LFP 

+

105 

+

106 Parameters 

+

107 ---------- 

+

108 addr : tuple of length 2 

+

109 (IP address, UDP port) 

+

110 channels : optional, default = None 

+

111 list of channels (electrodes) from which to receive spike data 

+

112 chan_offset : int, optional, default=512 

+

113 Indexing offset from the first LFP channel to the indexing system used by the OPX system 

+

114 

+

115 Returns 

+

116 ------- 

+

117 plexon.LFP instance 

+

118 ''' 

+

119 self.conn = plexnet.Connection(*addr) 

+

120 self.conn.connect(256, waveforms=False, analog=True) 

+

121 

+

122 # for OPX system, field potential (FP) channels are numbered 513-768 

+

123 self.chan_offset = chan_offset 

+

124 channels_offset = [c + self.chan_offset for c in channels] 

+

125 try: 

+

126 self.conn.select_continuous(channels_offset) 

+

127 except: 

+

128 print("Cannot run select_continuous method") 

+

129 

+

130 def start(self): 

+

131 ''' 

+

132 Connect to the plexon server and start receiving data 

+

133 ''' 

+

134 self.conn.start_data() 

+

135 self.data = self.conn.get_data() 

+

136 

+

137 def stop(self): 

+

138 ''' 

+

139 Disconnect from the plexon server 

+

140 ''' 

+

141 self.conn.stop_data() 

+

142 self.conn.disconnect() 

+

143 

+

144 def get(self): 

+

145 ''' 

+

146 Get a new LFP sample/block of LFP samples from the  

+

147 ''' 

+

148 d = next(self.data) 

+

149 while d.type != PL_ADDataType: 

+

150 d = next(self.data) 

+

151 

+

152 # values are in currently signed integers in the range [-2048, 2047] 

+

153 # first convert to float 

+

154 waveform = np.array(d.waveform, dtype='float') 

+

155 

+

156 # convert to units of mV 

+

157 waveform = waveform * 16 * (5000. / 2**15) * (1./self.gain_digiamp) * (1./self.gain_headstage) 

+

158 

+

159 return (d.chan-self.chan_offset, waveform) 

+

160 

+

161 

+

162class Aux(DataSourceSystem): 

+

163 ''' 

+

164 Client for auxiliary analog data streamed from plexon system, compatible with riglib.source.MultiChanDataSource 

+

165 ''' 

+

166 update_freq = 1000. 

+

167 

+

168 gain_digiamp = 1. 

+

169 gain_headstage = 1. 

+

170 

+

171 # see comment above 

+

172 dtype = np.dtype('float') 

+

173 

+

174 def __init__(self, addr=PL_ADDR, channels=None, chan_offset=768): 

+

175 ''' 

+

176 Constructor for plexon.Aux 

+

177 

+

178 Parameters 

+

179 ---------- 

+

180 addr : tuple of length 2 

+

181 (IP address, UDP port) 

+

182 channels : optional, default = None 

+

183 list of channels (electrodes) from which to receive spike data 

+

184 chan_offset : int, optional, default=768 

+

185 Indexing offset from the first Aux channel to the indexing system used by the OPX system 

+

186 

+

187 Returns 

+

188 ------- 

+

189 plexon.Aux instance 

+

190 ''' 

+

191 self.conn = plexnet.Connection(*addr) 

+

192 self.conn.connect(256, waveforms=False, analog=True) 

+

193 

+

194 # for OPX system, the 32 auxiliary input (AI) channels are numbered 769-800 

+

195 self.chan_offset = chan_offset 

+

196 

+

197 channels_offset = [c + self.chan_offset for c in channels] 

+

198 try: 

+

199 self.conn.select_continuous(channels_offset) 

+

200 except: 

+

201 print("Cannot run select_continuous method") 

+

202 

+

203 def start(self): 

+

204 self.conn.start_data() 

+

205 self.data = self.conn.get_data() 

+

206 

+

207 def stop(self): 

+

208 self.conn.stop_data() 

+

209 

+

210 def get(self): 

+

211 d = next(self.data) 

+

212 while d.type != PL_ADDataType: 

+

213 d = next(self.data) 

+

214 

+

215 # values are in currently signed integers in the range [-2048, 2047] 

+

216 # first convert to float 

+

217 waveform = np.array(d.waveform, dtype='float') 

+

218 

+

219 # convert to units of mV 

+

220 waveform = waveform * 16 * (5000. / 2**15) * (1./self.gain_digiamp) * (1./self.gain_headstage) 

+

221 

+

222 return (d.chan-self.chan_offset, waveform) 

+

223 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_checkbin_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_checkbin_py.html new file mode 100644 index 00000000..d46101eb --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_checkbin_py.html @@ -0,0 +1,77 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\plexon\checkbin.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import numpy as np 

+

2 

+

3def bin(plx, times, binlen=.1): 

+

4 units = plx.units 

+

5 bins = np.zeros((len(times), len(units))) 

+

6 for i, t in enumerate(times): 

+

7 spikes = plx.spikes[t-binlen:t].data 

+

8 for j, (c, u) in enumerate(units): 

+

9 chan = spikes['chan'] == c 

+

10 unit = spikes['unit'] == u 

+

11 bins[i, j] = sum(np.logical_and(chan, unit)) 

+

12 

+

13 return bins 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_plexnet_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_plexnet_py.html new file mode 100644 index 00000000..a2fbaf89 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_plexnet_py.html @@ -0,0 +1,479 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\plexon\plexnet.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Client-side code to configure and receive neural data from the Plexon PC over 

+

3the network. 

+

4''' 

+

5 

+

6import re 

+

7import math 

+

8import array 

+

9import struct 

+

10import socket 

+

11import time 

+

12from collections import namedtuple 

+

13import os 

+

14import numpy as np 

+

15import matplotlib.pyplot as plt 

+

16 

+

17PACKETSIZE = 512 

+

18 

+

19 

+

20 

+

21WaveData = namedtuple("WaveData", ["type", "ts", "chan", "unit", "waveform", "arrival_ts"]) 

+

22chan_names = re.compile(r'^(\w{2,4})(\d{2,3})(\w)?') 

+

23 

+

24class Connection(object): 

+

25 ''' 

+

26 A wrapper around a UDP socket which sends the Omniplex PC commands and  

+

27 receives data. Must run in a separte process (e.g., through `riglib.source`)  

+

28 if you want to use it as part of a task (e.g., BMI control) 

+

29 ''' 

+

30 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_CONNECT_CLIENT = (10000) 

+

31 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_DISCONNECT_CLIENT = (10999) 

+

32 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_GET_PARAMETERS_MMF = (10100) 

+

33 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_START_DATA_PUMP = (10200) 

+

34 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_STOP_DATA_PUMP = (10300) 

+

35 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_SPIKE_CHANNELS = (10400) 

+

36 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_CONTINUOUS_CHANNELS = (10401) 

+

37 

+

38 PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_MMF_SIZES = (10001) 

+

39 PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_SENDING_SERVER_AREA = (20003) 

+

40 PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_SENDING_DATA = (1) 

+

41 

+

42 SPIKE_CHAN_SORTED_TIMESTAMPS = (0x01) 

+

43 SPIKE_CHAN_SORTED_WAVEFORMS = (0x02) 

+

44 SPIKE_CHAN_UNSORTED_TIMESTAMPS = (0x04) 

+

45 SPIKE_CHAN_UNSORTED_WAVEFORMS = (0x08) 

+

46 

+

47 def __init__(self, addr, port): 

+

48 self.addr = (addr, port) 

+

49 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 

+

50 self.sock.connect(self.addr) 

+

51 

+

52 self.num_server_dropped = 0 

+

53 self.num_mmf_dropped = 0 

+

54 

+

55 self._init = False 

+

56 

+

57 def _recv(self): 

+

58 ''' 

+

59 Receives a single PACKETSIZE chunk from the socket 

+

60 ''' 

+

61 d = '' 

+

62 while len(d) < PACKETSIZE: 

+

63 d += self.sock.recv(PACKETSIZE - len(d)) 

+

64 return d 

+

65 

+

66 def connect(self, channels, waveforms=False, analog=True): 

+

67 '''Establish a connection with the plexnet remote server, then request and set parameters 

+

68 

+

69 Parameters 

+

70 ---------- 

+

71 channels : int 

+

72 Number of channels to initialize through the server 

+

73 waveforms : bool, optional 

+

74 Set to true if you want to stream spike waveforms (not available for MAP system?) 

+

75 analog : bool, optional 

+

76 Set to true if you want to receive data from the analog channels 

+

77 

+

78 Returns 

+

79 ------- 

+

80 None  

+

81 ''' 

+

82 

+

83 packet = array.array('i', '\x00'*PACKETSIZE) 

+

84 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_CONNECT_CLIENT 

+

85 packet[1] = True #timestamp 

+

86 packet[2] = waveforms 

+

87 packet[3] = analog 

+

88 packet[4] = 1 #channels start 

+

89 packet[5] = channels+1 

+

90 

+

91 self.sock.sendall(packet.tostring()) 

+

92 

+

93 resp = array.array('i', self._recv()) 

+

94 

+

95 if resp[0] == self.PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_MMF_SIZES: 

+

96 self.n_cmd = resp[3] 

+

97 if 0 < self.n_cmd < 32: 

+

98 sup_spike = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_SPIKE_CHANNELS 

+

99 sup_cont = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_CONTINUOUS_CHANNELS 

+

100 self.supports_spikes = any([b == sup_spike for b in resp[4:]]) 

+

101 self.supports_cont = any([b == sup_cont for b in resp[4:]]) 

+

102 

+

103 print('supports spikes:', self.supports_spikes) 

+

104 print('supports continuous:', self.supports_cont) 

+

105 

+

106 packet = array.array('i', '\x00'*PACKETSIZE) 

+

107 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_GET_PARAMETERS_MMF 

+

108 self.sock.sendall(packet.tostring()) 

+

109 

+

110 self.params = [] 

+

111 

+

112 gotServerArea = False 

+

113 while not gotServerArea: 

+

114 resp = array.array('i', self._recv()) 

+

115 self.params.append(resp) 

+

116 

+

117 if resp[0] == self.PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_SENDING_SERVER_AREA: 

+

118 self.n_spike = resp[15] 

+

119 self.n_cont = resp[17] 

+

120 print("Spike channels: %d, continuous channels: %d"%(self.n_spike, self.n_cont)) 

+

121 gotServerArea = True 

+

122 

+

123 self._init = True 

+

124 

+

125 

+

126 def select_spikes(self, channels=None, waveforms=True, unsorted=False): 

+

127 ''' 

+

128 Sets the channels from which to receive spikes. This function always requests sorted data 

+

129 

+

130 Parameters 

+

131 ---------- 

+

132 channels : array_like, optional 

+

133 A list of channels which you want to see spikes from 

+

134 waveforms : bool, optional 

+

135 Request spikes from all selected channels 

+

136 

+

137 Returns 

+

138 ------- 

+

139 None 

+

140 ''' 

+

141 if not self._init: 

+

142 raise ValueError("Please initialize the connection first") 

+

143 if not self.supports_spikes: 

+

144 raise ValueError("Server does not support spike streaming!") 

+

145 packet = array.array('i', '\x00'*20) 

+

146 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_SPIKE_CHANNELS 

+

147 packet[2] = 1 

+

148 packet[3] = self.n_spike 

+

149 raw = packet.tostring() 

+

150 

+

151 #always send timestamps, waveforms are optional 

+

152 bitmask = 1 | waveforms << 1 

+

153 if unsorted: 

+

154 bitmask |= 1<<2 | waveforms << 3 

+

155 

+

156 if channels is None: 

+

157 raw += array.array('b', [bitmask]*(PACKETSIZE - 20)).tostring() 

+

158 else: 

+

159 packet = array.array('b', '\x00'*(PACKETSIZE - 20)) 

+

160 for c in channels: 

+

161 packet[c-1] = bitmask 

+

162 raw += packet.tostring() 

+

163 

+

164 self.sock.sendall(raw) 

+

165 

+

166 def select_continuous(self, channels=None): 

+

167 ''' 

+

168 Sets the channels from which to receive continuous neural data (e.g., LFP) 

+

169 

+

170 Parameters 

+

171 ---------- 

+

172 channels : array_like, optional 

+

173 A list of channels which you want to see spikes from 

+

174 

+

175 Returns 

+

176 ------- 

+

177 None  

+

178 ''' 

+

179 if not self._init: 

+

180 raise ValueError("Please initialize the connection first") 

+

181 if not self.supports_cont: 

+

182 raise ValueError("Server does not support continuous data streaming!") 

+

183 

+

184 if channels is None: # select all of them 

+

185 # print 'selecting all continuous channels' 

+

186 chan_selection = array.array('b', [1]*self.n_cont) 

+

187 else: 

+

188 # print 'selecting specified continuous channels' 

+

189 chan_selection = array.array('b', [0]*self.n_cont) 

+

190 for c in channels: 

+

191 # always true unless channels outside the range [1,...,self.n_cont] were specified 

+

192 if c-1 < len(chan_selection): 

+

193 chan_selection[c-1] = 1 

+

194 

+

195 n_packets = int(math.ceil(float(self.n_cont) / PACKETSIZE)) 

+

196 HEADERSIZE = 20 # bytes 

+

197 chan_offset = 0 

+

198 

+

199 # e.g., for 800 continuous channels, 2 "packets" are formed 

+

200 # chan_offset is 0 for the first packet 

+

201 # chan_offset is 492 for the second packet 

+

202 

+

203 raw = '' 

+

204 for packet_num in range(n_packets): 

+

205 # print 'subpacket:', packet_num 

+

206 header = array.array('i', '\x00'*HEADERSIZE) 

+

207 header[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_CONTINUOUS_CHANNELS 

+

208 header[1] = packet_num 

+

209 header[2] = n_packets 

+

210 # header[3] = number of channels that are specified (or not specified) in the selection 

+

211 # bytes that follow 

+

212 header[4] = chan_offset # channel offset in this packet 

+

213 

+

214 if chan_offset + (PACKETSIZE-HEADERSIZE) < len(chan_selection[chan_offset:]): 

+

215 payload = chan_selection[chan_offset:chan_offset+PACKETSIZE-HEADERSIZE] 

+

216 n_selections = len(payload) 

+

217 

+

218 chan_offset = chan_offset + len(payload) 

+

219 else: # there are less than PACKETSIZE - HEADERSIZE channels left to specify 

+

220 payload = chan_selection[chan_offset:] 

+

221 n_selections = len(payload) 

+

222 

+

223 # don't need to worry about incrementing chan_offset (reached end) 

+

224 

+

225 # pad with zeros 

+

226 n_pad = PACKETSIZE - HEADERSIZE - len(payload) 

+

227 payload += array.array('b', [0]*n_pad) 

+

228 

+

229 

+

230 header[3] = n_selections 

+

231 raw += header.tostring() 

+

232 raw += payload.tostring() 

+

233 

+

234 self.sock.sendall(raw) 

+

235 

+

236 def start_data(self): 

+

237 '''Start the data pump from plexnet remote''' 

+

238 if not self._init: 

+

239 raise ValueError("Please initialize the connection first") 

+

240 packet = array.array('i', '\x00'*PACKETSIZE) 

+

241 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_START_DATA_PUMP 

+

242 self.sock.sendall(packet.tostring()) 

+

243 self.streaming = True 

+

244 

+

245 def stop_data(self): 

+

246 '''Stop the data pump from plexnet remote''' 

+

247 if not self._init: 

+

248 raise ValueError("Please initialize the connection first") 

+

249 packet = array.array('i', '\x00'*PACKETSIZE) 

+

250 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_STOP_DATA_PUMP 

+

251 self.sock.sendall(packet.tostring()) 

+

252 self.streaming = False 

+

253 

+

254 def disconnect(self): 

+

255 '''Disconnect from the plexnet remote server and close all network sockets''' 

+

256 if not self._init: 

+

257 raise ValueError("Please initialize the connection first") 

+

258 packet = array.array('i', '\x00'*PACKETSIZE) 

+

259 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_DISCONNECT_CLIENT 

+

260 self.sock.sendall(packet.tostring()) 

+

261 self.sock.close() 

+

262 

+

263 def __del__(self): 

+

264 self.disconnect() 

+

265 

+

266 def get_data(self): 

+

267 ''' 

+

268 A generator which yields packets as they are received 

+

269 ''' 

+

270 

+

271 assert self._init, "Please initialize the connection first" 

+

272 hnames = 'type,Uts,ts,chan,unit,nwave,nword'.split(',') 

+

273 invalid = set([0, -1]) 

+

274 

+

275 while self.streaming: 

+

276 packet = self._recv() 

+

277 

+

278 arrival_ts = time.time() 

+

279 ibuf = struct.unpack('4i', packet[:16]) 

+

280 if ibuf[0] == 1: 

+

281 self.num_server_dropped = ibuf[2] 

+

282 self.num_mmf_dropped = ibuf[3] 

+

283 packet = packet[16:] 

+

284 

+

285 while len(packet) > 16: 

+

286 header = dict(list(zip(hnames, struct.unpack('hHI4h', packet[:16])))) 

+

287 packet = packet[16:] 

+

288 

+

289 if header['type'] not in invalid: 

+

290 wavedat = None 

+

291 if header['nwave'] > 0: 

+

292 l = header['nwave'] * header['nword'] * 2 

+

293 wavedat = array.array('h', packet[:l]) 

+

294 packet = packet[l:] 

+

295 

+

296 chan = header['chan'] 

+

297 # when returning continuous data, plexon reports the channel numbers 

+

298 # as between 0--799 instead of 1--800 (but it doesn't do this 

+

299 # when returning spike data!), so we have add 1 to the channel number 

+

300 if header['type'] == 5: # 5 is PL_ADDataType 

+

301 chan = header['chan'] + 1 

+

302 

+

303 ts = int(header['Uts']) << 32 | int(header['ts']) 

+

304 

+

305 yield WaveData(type=header['type'], chan=chan, 

+

306 unit=header['unit'], ts=ts, waveform=wavedat, 

+

307 arrival_ts=arrival_ts) 

+

308 

+

309if __name__ == "__main__": 

+

310 import csv 

+

311 import time 

+

312 import argparse 

+

313 parser = argparse.ArgumentParser(description="Collects plexnet data for a set amount of time") 

+

314 parser.add_argument("address",help="Server's address") 

+

315 parser.add_argument("--port", type=int, help="Server's port (defaults to 6000)", default=6000) 

+

316 parser.add_argument("--len", type=float, help="Time (in seconds) to record data", default=30.) 

+

317 parser.add_argument("output", help="Output csv file") 

+

318 args = parser.parse_args() 

+

319 

+

320 with open(args.output, "w") as f: 

+

321 csvfile = csv.DictWriter(f, WaveData._fields) 

+

322 csvfile.writeheader() 

+

323 

+

324 #Initialize the connection 

+

325 print('initializing connection') 

+

326 conn = Connection(args.address, args.port) 

+

327 conn.connect(256, analog=True) #Request all 256 channels 

+

328 

+

329 print('selecting spike channels') 

+

330 spike_channels = [] #2, 3, 4] 

+

331 unsorted = False #True 

+

332 conn.select_spikes(spike_channels, unsorted=unsorted) 

+

333 # conn.select_spikes(unsorted=unsorted) 

+

334 

+

335 print('selecting continuous channels') 

+

336 # cont_channels = 512 + np.array([1, 2, 5, 9, 10, 192, 250, 256]) #range(513, 768) #range(512+1, 512+192) #[1, 532, 533, 768, 800] #502, 503, 504, 505] #[85, 86] 

+

337 cont_channels = 512 + np.array([53]) 

+

338 # cont_channels = [1, 532, 533, 768, 800] #502, 503, 504, 505] #[85, 86] 

+

339 conn.select_continuous(cont_channels) 

+

340 # conn.select_continuous() # select all 800 continuous channels 

+

341 

+

342 # for saving to mat file 

+

343 write_to_mat = True 

+

344 n_samp = 2 * 1000*int(args.len) 

+

345 n_chan = len(cont_channels) 

+

346 data = np.zeros((n_chan, 2*n_samp), dtype='int16') 

+

347 idxs = np.zeros(n_chan) 

+

348 chan_to_row = dict() 

+

349 for i, chan in enumerate(cont_channels): 

+

350 chan_to_row[chan] = i 

+

351 

+

352 

+

353 ts = [] 

+

354 arrival_ts = [] 

+

355 t = [] 

+

356 n_samples = 0 

+

357 n_samp = [] 

+

358 got_first = False 

+

359 

+

360 

+

361 print('starting data') 

+

362 conn.start_data() #start the data pump 

+

363 

+

364 waves = conn.get_data() 

+

365 start = time.time() 

+

366 

+

367 

+

368 while (time.time()-start) < args.len: 

+

369 wave = next(waves) 

+

370 if not got_first and wave is not None: 

+

371 print(wave) 

+

372 first_ts = wave.ts 

+

373 first_arrival_ts = wave.arrival_ts 

+

374 got_first = True 

+

375 

+

376 if wave is not None: 

+

377 csvfile.writerow(dict(wave._asdict())) 

+

378 

+

379 if write_to_mat and wave is not None: 

+

380 row = chan_to_row[wave.chan] 

+

381 idx = idxs[row] 

+

382 n_pts = len(wave.waveform) 

+

383 data[row, idx:idx+n_pts] = wave.waveform 

+

384 idxs[row] += n_pts 

+

385 

+

386 if wave is not None and wave.chan == 512+53: 

+

387 ts.append(wave.ts - first_ts) 

+

388 arrival_ts.append(wave.arrival_ts - first_arrival_ts) 

+

389 

+

390 n_samples += len(wave.waveform) 

+

391 t.append(time.time() - start) 

+

392 n_samp.append(n_samples) 

+

393 

+

394 

+

395 #Stop the connection 

+

396 conn.stop_data() 

+

397 conn.disconnect() 

+

398 

+

399 if write_to_mat: 

+

400 save_dict = dict() 

+

401 save_dict['data'] = data 

+

402 save_dict['channels'] = cont_channels 

+

403 

+

404 print('saving data...', end=' ') 

+

405 import scipy.io as sio 

+

406 sio.matlab.savemat('plexnet_data_0222_8pm_1.mat', save_dict) 

+

407 print('done.') 

+

408 

+

409 

+

410 plt.figure() 

+

411 plt.subplot(2,1,1) 

+

412 plt.plot(arrival_ts, ts) 

+

413 plt.subplot(2,1,2) 

+

414 plt.plot(t, n_samp) 

+

415 plt.show() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_plexnet_softserver_oldfiles_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_plexnet_softserver_oldfiles_py.html new file mode 100644 index 00000000..54347380 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_plexnet_softserver_oldfiles_py.html @@ -0,0 +1,437 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\plexon\plexnet_softserver_oldfiles.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' Needs docs''' 

+

2 

+

3import re 

+

4import math 

+

5import array 

+

6import struct 

+

7import socket 

+

8import time 

+

9from collections import namedtuple 

+

10import os 

+

11import matplotlib.pyplot as plt 

+

12 

+

13PACKETSIZE = 512 

+

14 

+

15 

+

16 

+

17WaveData = namedtuple("WaveData", ["type", "ts", "chan", "unit", "waveform", "arrival_ts"]) 

+

18chan_names = re.compile(r'^(\w{2,4})(\d{2,3})(\w)?') 

+

19 

+

20class Connection(object): 

+

21 '''Here's a docstring''' 

+

22 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_CONNECT_CLIENT = (10000) 

+

23 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_DISCONNECT_CLIENT = (10999) 

+

24 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_GET_PARAMETERS_MMF = (10100) 

+

25 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_START_DATA_PUMP = (10200) 

+

26 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_STOP_DATA_PUMP = (10300) 

+

27 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_SPIKE_CHANNELS = (10400) 

+

28 PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_CONTINUOUS_CHANNELS = (10401) 

+

29 

+

30 PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_MMF_SIZES = (10001) 

+

31 PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_SENDING_SERVER_AREA = (20003) 

+

32 PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_SENDING_DATA = (1) 

+

33 

+

34 SPIKE_CHAN_SORTED_TIMESTAMPS = (0x01) 

+

35 SPIKE_CHAN_SORTED_WAVEFORMS = (0x02) 

+

36 SPIKE_CHAN_UNSORTED_TIMESTAMPS = (0x04) 

+

37 SPIKE_CHAN_UNSORTED_WAVEFORMS = (0x08) 

+

38 

+

39 def __init__(self, addr, port): 

+

40 print('USING MODIFIED PLEXNET') 

+

41 self.addr = (addr, port) 

+

42 

+

43 

+

44 self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 

+

45 

+

46 

+

47 self.sock.connect(self.addr) 

+

48 

+

49 

+

50 self.num_server_dropped = 0 

+

51 self.num_mmf_dropped = 0 

+

52 

+

53 self._init = False 

+

54 

+

55 def _recv(self): 

+

56 '''Receives a single PACKETSIZE chunk from the socket''' 

+

57 d = '' 

+

58 # print 'len(d)', len(d) 

+

59 # print 'PACKETSIZE', PACKETSIZE 

+

60 while len(d) < PACKETSIZE: 

+

61 # print 'calling self.sock.recv' 

+

62 d += self.sock.recv(PACKETSIZE - len(d)) 

+

63 return d 

+

64 

+

65 def connect(self, channels, waveforms=False, analog=True): 

+

66 '''Establish a connection with the plexnet remote server, then request and set parameters 

+

67 

+

68 Parameters 

+

69 ---------- 

+

70 channels : int 

+

71 Number of channels to initialize through the server 

+

72 waveforms : bool, optional 

+

73 Request spike waveforms? 

+

74 analog : bool, optional 

+

75 Request analog data? 

+

76 ''' 

+

77 

+

78 packet = array.array('i', '\x00'*PACKETSIZE) 

+

79 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_CONNECT_CLIENT 

+

80 packet[1] = True #timestamp 

+

81 packet[2] = waveforms 

+

82 packet[3] = analog 

+

83 packet[4] = 1 #channels start 

+

84 packet[5] = channels+1 

+

85 

+

86 self.sock.sendall(packet.tostring()) 

+

87 

+

88 resp = array.array('i', self._recv()) 

+

89 

+

90 if resp[0] == self.PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_MMF_SIZES: 

+

91 self.n_cmd = resp[3] 

+

92 if 0 < self.n_cmd < 32: 

+

93 sup_spike = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_SPIKE_CHANNELS 

+

94 sup_cont = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_CONTINUOUS_CHANNELS 

+

95 self.supports_spikes = any([b == sup_spike for b in resp[4:]]) 

+

96 self.supports_cont = any([b == sup_cont for b in resp[4:]]) 

+

97 

+

98 # packet = array.array('i', '\x00'*PACKETSIZE) 

+

99 # packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_GET_PARAMETERS_MMF 

+

100 # self.sock.sendall(packet.tostring()) 

+

101 

+

102 # self.params = [] 

+

103 

+

104 # gotServerArea = False 

+

105 # while not gotServerArea: 

+

106 # resp = array.array('i', self._recv()) 

+

107 # self.params.append(resp) 

+

108 

+

109 # if resp[0] == self.PLEXNET_COMMAND_FROM_SERVER_TO_CLIENT_SENDING_SERVER_AREA: 

+

110 # self.n_spike = resp[15] 

+

111 # self.n_cont = resp[17] 

+

112 # print "Spike channels: %d, continuous channels: %d"%(self.n_spike, self.n_cont) 

+

113 # gotServerArea = True 

+

114 self.n_spike = 0 

+

115 self.n_cont = 192 

+

116 

+

117 self._init = True 

+

118 

+

119 

+

120 def select_spikes(self, channels=None, waveforms=True, unsorted=False): 

+

121 '''Sets the channels from which to receive spikes. This function always requests sorted data 

+

122 

+

123 Parameters 

+

124 ---------- 

+

125 channels : array_like, optional 

+

126 A list of channels which you want to see spikes from 

+

127 waveforms : bool, optional 

+

128 Request spikes from all selected channels 

+

129 ''' 

+

130 if not self._init: 

+

131 raise ValueError("Please initialize the connection first") 

+

132 if not self.supports_spikes: 

+

133 raise ValueError("Server does not support spike streaming!") 

+

134 packet = array.array('i', '\x00'*20) 

+

135 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_SPIKE_CHANNELS 

+

136 packet[2] = 1 

+

137 packet[3] = self.n_spike 

+

138 raw = packet.tostring() 

+

139 

+

140 #always send timestamps, waveforms are optional 

+

141 bitmask = 1 | waveforms << 1 

+

142 if unsorted: 

+

143 bitmask |= 1<<2 | waveforms << 3 

+

144 

+

145 if channels is None: 

+

146 raw += array.array('b', [bitmask]*(PACKETSIZE - 20)).tostring() 

+

147 else: 

+

148 packet = array.array('b', '\x00'*(PACKETSIZE - 20)) 

+

149 for c in channels: 

+

150 packet[c-1] = bitmask 

+

151 raw += packet.tostring() 

+

152 

+

153 self.sock.sendall(raw) 

+

154 

+

155 def select_continuous(self, channels=None): 

+

156 '''Sets the channels from which to receive continuous data''' 

+

157 if not self._init: 

+

158 raise ValueError("Please initialize the connection first") 

+

159 if not self.supports_cont: 

+

160 raise ValueError("Server does not support continuous data streaming!") 

+

161 

+

162 if channels is None: # select all of them 

+

163 # print 'selecting all continuous channels' 

+

164 chan_selection = array.array('b', [1]*self.n_cont) 

+

165 else: 

+

166 # print 'selecting specified continuous channels' 

+

167 chan_selection = array.array('b', [0]*self.n_cont) 

+

168 for c in channels: 

+

169 # always true unless channels outside the range [0,1,...,self.n_cont-1] were specified 

+

170 if c-1 < len(chan_selection): 

+

171 chan_selection[c-1] = 1 

+

172 

+

173 n_packets = int(math.ceil(float(self.n_cont) / PACKETSIZE)) 

+

174 HEADERSIZE = 20 # bytes 

+

175 chan_offset = 0 

+

176 

+

177 raw = '' 

+

178 for packet_num in range(n_packets): 

+

179 # print 'subpacket:', packet_num 

+

180 header = array.array('i', '\x00'*HEADERSIZE) 

+

181 header[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_SELECT_CONTINUOUS_CHANNELS 

+

182 header[1] = packet_num 

+

183 header[2] = n_packets 

+

184 # header[3] = number of channels selection that follow 

+

185 header[4] = chan_offset 

+

186 

+

187 # TODO -- verify that this should be <= and not < 

+

188 if chan_offset + (PACKETSIZE-HEADERSIZE) <= len(chan_selection[chan_offset:]): 

+

189 payload = chan_selection[chan_offset:chan_offset+PACKETSIZE-HEADERSIZE] 

+

190 n_selections = len(payload) 

+

191 

+

192 chan_offset = chan_offset + len(payload) 

+

193 else: # there are less than PACKETSIZE - HEADERSIZE channels left to specify 

+

194 payload = chan_selection[chan_offset:] 

+

195 n_selections = len(payload) 

+

196 

+

197 # don't need to worry about incrementing chan_offset (reached end) 

+

198 

+

199 # pad with zeros 

+

200 n_pad = PACKETSIZE - HEADERSIZE - len(payload) 

+

201 payload += array.array('b', [0]*n_pad) 

+

202 

+

203 

+

204 header[3] = n_selections 

+

205 raw += header.tostring() 

+

206 raw += payload.tostring() 

+

207 

+

208 # print 'len of subpacket:', len(header.tostring() + payload.tostring()) 

+

209 # print 'header:', header 

+

210 # print 'payload:', payload 

+

211 

+

212 self.sock.sendall(raw) 

+

213 

+

214 def start_data(self): 

+

215 '''Start the data pump from plexnet remote''' 

+

216 if not self._init: 

+

217 raise ValueError("Please initialize the connection first") 

+

218 packet = array.array('i', '\x00'*PACKETSIZE) 

+

219 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_START_DATA_PUMP 

+

220 self.sock.sendall(packet.tostring()) 

+

221 self.streaming = True 

+

222 

+

223 def stop_data(self): 

+

224 '''Stop the data pump from plexnet remote''' 

+

225 if not self._init: 

+

226 raise ValueError("Please initialize the connection first") 

+

227 packet = array.array('i', '\x00'*PACKETSIZE) 

+

228 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_STOP_DATA_PUMP 

+

229 self.sock.sendall(packet.tostring()) 

+

230 self.streaming = False 

+

231 

+

232 def disconnect(self): 

+

233 '''Disconnect from the plexnet remote server and close all network sockets''' 

+

234 if not self._init: 

+

235 raise ValueError("Please initialize the connection first") 

+

236 packet = array.array('i', '\x00'*PACKETSIZE) 

+

237 packet[0] = self.PLEXNET_COMMAND_FROM_CLIENT_TO_SERVER_DISCONNECT_CLIENT 

+

238 self.sock.sendall(packet.tostring()) 

+

239 self.sock.close() 

+

240 

+

241 def __del__(self): 

+

242 self.disconnect() 

+

243 

+

244 def get_data(self): 

+

245 '''A generator which yields packets as they are received''' 

+

246 # print 'running get_data' 

+

247 assert self._init, "Please initialize the connection first" 

+

248 hnames = 'type,Uts,ts,chan,unit,nwave,nword'.split(',') 

+

249 invalid = set([0, -1]) 

+

250 # print 'entering while loop' 

+

251 while self.streaming: 

+

252 # print 'calling self._recv()' 

+

253 packet = self._recv() 

+

254 # print 'received packet' 

+

255 arrival_ts = time.time() 

+

256 ibuf = struct.unpack('4i', packet[:16]) 

+

257 if ibuf[0] == 1: 

+

258 self.num_server_dropped = ibuf[2] 

+

259 self.num_mmf_dropped = ibuf[3] 

+

260 packet = packet[16:] 

+

261 

+

262 while len(packet) > 16: 

+

263 header = dict(list(zip(hnames, struct.unpack('hHI4h', packet[:16])))) 

+

264 packet = packet[16:] 

+

265 

+

266 if header['type'] not in invalid: 

+

267 wavedat = None 

+

268 if header['nwave'] > 0: 

+

269 l = header['nwave'] * header['nword'] * 2 

+

270 wavedat = array.array('h', packet[:l]) 

+

271 packet = packet[l:] 

+

272 

+

273 chan = header['chan'] 

+

274 # when returning continuous data, plexon reports the channel numbers 

+

275 # as between 0--799 instead of 1--800 (but it doesn't do this 

+

276 # when returning spike data!), so we have add 1 to the channel number 

+

277 if header['type'] == 5: # 5 is PL_ADDataType 

+

278 chan = header['chan'] + 1 

+

279 

+

280 ts = int(header['Uts']) << 32 | int(header['ts']) 

+

281 # print wavedat 

+

282 yield WaveData(type=header['type'], chan=chan, 

+

283 unit=header['unit'], ts=ts, waveform=wavedat, 

+

284 arrival_ts=arrival_ts) 

+

285 

+

286if __name__ == "__main__": 

+

287 import csv 

+

288 import time 

+

289 import argparse 

+

290 parser = argparse.ArgumentParser(description="Collects plexnet data for a set amount of time") 

+

291 parser.add_argument("address",help="Server's address") 

+

292 parser.add_argument("--port", type=int, help="Server's port (defaults to 6000)", default=6000) 

+

293 parser.add_argument("--discard_len", type=float, help="Amount of time (in secs) of initial data to discard", default=0.) 

+

294 parser.add_argument("--len", type=float, help="Time (in seconds) to record data", default=5.) 

+

295 parser.add_argument("output", help="Output csv file") 

+

296 args = parser.parse_args() 

+

297 

+

298 with open(args.output, "w") as f: 

+

299 csvfile = csv.DictWriter(f, WaveData._fields) 

+

300 csvfile.writeheader() 

+

301 

+

302 #Initialize the connection 

+

303 conn = Connection(args.address, args.port) 

+

304 conn.connect(192, analog=True) 

+

305 

+

306 # print 'selecting spike channels' 

+

307 # spike_channels = [] #2, 3, 4] 

+

308 # unsorted = False #True 

+

309 # conn.select_spikes(spike_channels, unsorted=unsorted) 

+

310 # # conn.select_spikes(unsorted=unsorted) 

+

311 

+

312 print('selecting continuous channels') 

+

313 cont_channels = [65, 66] #502, 503, 504, 505] #[85, 86] 

+

314 conn.select_continuous(cont_channels) 

+

315 # conn.select_continuous() # select all 800 continuous channels 

+

316 

+

317 print('starting data') 

+

318 conn.start_data() #start the data pump 

+

319 

+

320 waves = conn.get_data() 

+

321 start = time.time() 

+

322 

+

323 while (time.time()-start) < args.discard_len: 

+

324 wave = next(waves) 

+

325 

+

326 got_first = False 

+

327 

+

328 n_samples = 0 

+

329 n_packets = 0 

+

330 

+

331 ts = [] 

+

332 arrival_ts = [] 

+

333 

+

334 t = [] 

+

335 n_samp = [] 

+

336 

+

337 start = time.time() 

+

338 while (time.time()-start) < args.len: 

+

339 wave = next(waves) 

+

340 if not got_first and wave is not None: 

+

341 print(wave) 

+

342 first_ts = wave.ts 

+

343 first_arrival_ts = wave.arrival_ts 

+

344 got_first = True 

+

345 

+

346 # if wave is not None: 

+

347 # csvfile.writerow(dict(wave._asdict())) 

+

348 

+

349 if wave is not None and wave.chan == 65: 

+

350 csvfile.writerow(dict(wave._asdict())) 

+

351 n_samples += len(wave.waveform) 

+

352 n_packets += 1 

+

353 

+

354 ts.append(wave.ts - first_ts) 

+

355 arrival_ts.append(wave.arrival_ts - first_arrival_ts) 

+

356 

+

357 t.append(time.time() - start) 

+

358 n_samp.append(n_samples) 

+

359 

+

360 

+

361 #Stop the connection 

+

362 conn.stop_data() 

+

363 conn.disconnect() 

+

364 

+

365 print('n_samples', n_samples) 

+

366 print('n_packets', n_packets) 

+

367 

+

368 plt.figure() 

+

369 plt.subplot(2,1,1) 

+

370 plt.plot(arrival_ts, ts) 

+

371 plt.subplot(2,1,2) 

+

372 plt.plot(t, n_samp) 

+

373 plt.show() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_source_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_source_py.html new file mode 100644 index 00000000..090041d2 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_source_py.html @@ -0,0 +1,97 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\plexon\source.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2This module appears to be deprecated..... 

+

3''' 

+

4 

+

5from . import plexnet 

+

6from riglib.source import DataSource 

+

7 

+

8class _PlexCont(object): 

+

9 def __init__(self, addr=("10.0.0.13", 6000), channels=None): 

+

10 self.conn = plexnet.Connection(*addr) 

+

11 self.conn.connect(256, waveforms=False, analog=True) 

+

12 self.conn.select_continuous(channels) 

+

13 

+

14 def start(self): 

+

15 self.conn.start_data() 

+

16 self.data = self.conn.get_data() 

+

17 

+

18 def stop(self): 

+

19 self.conn.stop_data() 

+

20 

+

21 def get(self): 

+

22 d = next(self.data) 

+

23 while d.type != 5: 

+

24 d = next(self.data) 

+

25 

+

26 return np.array([(d.ts, d.chan, d.unit)], dtype=self.dtype) 

+

27 

+

28class Continuous(DataSource): 

+

29 def __init__(self, channels=None): 

+

30 self.source = _PlexCont(channels) 

+

31 

+

32 def _get(self, system): 

+

33 pass 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_test_plexfile_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_test_plexfile_py.html new file mode 100644 index 00000000..f3806049 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_plexon_test_plexfile_py.html @@ -0,0 +1,80 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\plexon\test_plexfile.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import numpy as np 

+

2from plexon import plexfile 

+

3 

+

4filename = "cart20130620_03.plx" 

+

5def test_continuous_edges(): 

+

6 plx = plexfile.openFile(filename) 

+

7 lfp = plx.lfp[:10] 

+

8 data = lfp.data 

+

9 time = lfp.time 

+

10 assert np.allclose(data[1000:2000], plx.lfp[1:2].data) 

+

11 assert np.allclose(data[1000:2000], plx.lfp[time[1000]:time[2000]].data) 

+

12 assert np.allclose(data[1000:2001], plx.lfp[time[1000]:time[2000]+.00001].data) 

+

13 assert np.allclose(data[480:1080], plx.lfp[time[480]:time[1080]].data) 

+

14 assert np.allclose(data[479:1080], plx.lfp[time[479]:time[1080]].data) 

+

15 assert np.allclose(data[479:1079], plx.lfp[time[479]:time[1079]].data) 

+

16 assert np.allclose(data[480:1079], plx.lfp[time[480]:time[1079]].data) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_positioner___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_positioner___init___py.html new file mode 100644 index 00000000..037894b4 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_positioner___init___py.html @@ -0,0 +1,652 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\positioner\__init__.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1#!/usr/bin/python 

+

2''' 

+

3Code for interacting with the positioner microcontroller 

+

4''' 

+

5import serial 

+

6import time 

+

7import struct 

+

8import numpy as np 

+

9import re 

+

10from riglib.experiment import Experiment, Sequence, FSMTable, StateTransitions 

+

11import random 

+

12 

+

13import socket 

+

14import select 

+

15 

+

16dir_lut = dict(x={0:0, -1:0, 1:1}, 

+

17 y={0:0, -1:0, 1:1}, 

+

18 z={0:1, -1:1, 1:0}, # convention flipped for z-stage 

+

19) 

+

20 

+

21class Positioner(object): 

+

22 def __init__(self, dev='/dev/arduino_positioner'): 

+

23 self.port = serial.Serial(dev, baudrate=115200) 

+

24 self.port.flushInput() 

+

25 

+

26 def _parse_resp(self, resp): 

+

27 resp = resp.rstrip() 

+

28 limits = list(map(int, resp[-6:])) 

+

29 return limits 

+

30 

+

31 def poll_limit_switches(self, N=100): 

+

32 while 1: 

+

33 time.sleep(0.1) 

+

34 self.port.write('\n') 

+

35 raw_resp = self.port.readline() 

+

36 print("limit switches", self._parse_resp(raw_resp)) 

+

37 

+

38 def read_limit_switches(self): 

+

39 self.port.write('\n') 

+

40 raw_resp = self.port.readline() 

+

41 return self._parse_resp(raw_resp) 

+

42 

+

43 def wake_motors(self): 

+

44 self.port.write('w\n') 

+

45 # self.port.readline() 

+

46 

+

47 def sleep_motors(self): 

+

48 print("sleep motors") 

+

49 self.port.write('s\n') 

+

50 # self.port.readline() 

+

51 

+

52 def step_motors(self, step_x, step_y, step_z, dir_x, dir_y, dir_z): 

+

53 cmd_data = 0 

+

54 cmd_step_data = step_x | (step_y << 1) | (step_z << 2) 

+

55 cmd_dir_data = dir_x | (dir_y << 1) | (dir_z << 2) 

+

56 cmd_data = cmd_step_data | (cmd_dir_data << 4) 

+

57 cmd = 'm' + struct.pack('B', cmd_data) + '\n' 

+

58 #print cmd_data, cmd 

+

59 self.port.write(cmd) 

+

60 

+

61 def move(self, n_steps_x, n_steps_y, n_steps_z): 

+

62 self.wake_motors() 

+

63 limits = self._parse_resp(self.port.readline()) 

+

64 time.sleep(1) 

+

65 

+

66 dir_x = dir_lut['x'][np.sign(n_steps_x)] 

+

67 dir_y = dir_lut['y'][np.sign(n_steps_y)] 

+

68 dir_z = dir_lut['z'][np.sign(n_steps_z)] 

+

69 

+

70 n_steps_sent_x = 0 

+

71 n_steps_sent_y = 0 

+

72 n_steps_sent_z = 0 

+

73 

+

74 k = 0 

+

75 while (abs(n_steps_x) > n_steps_sent_x) or (abs(n_steps_y) > n_steps_sent_y) or (abs(n_steps_z) > n_steps_sent_z): 

+

76 if k % 10 == 0: print(k) 

+

77 step_x = int(n_steps_sent_x < abs(n_steps_x)) 

+

78 step_y = int(n_steps_sent_y < abs(n_steps_y)) 

+

79 step_z = int(n_steps_sent_z < abs(n_steps_z)) 

+

80 #print step_x, step_y, step_z, dir_x, dir_y, dir_z 

+

81 self.step_motors(step_x, step_y, step_z, dir_x, dir_y, dir_z) 

+

82 limits = self._parse_resp(self.port.readline()) 

+

83 k += 1 

+

84 

+

85 n_steps_sent_x += step_x 

+

86 n_steps_sent_y += step_y 

+

87 n_steps_sent_z += step_z 

+

88 

+

89 self.sleep_motors() 

+

90 

+

91 def old_move(self, n_steps_x, n_steps_y, n_steps_z): 

+

92 self.wake_motors() 

+

93 try: 

+

94 time.sleep(1) 

+

95 

+

96 dir_x = dir_lut['x'][np.sign(n_steps_x)] 

+

97 dir_y = dir_lut['y'][np.sign(n_steps_y)] 

+

98 dir_z = dir_lut['z'][np.sign(n_steps_z)] 

+

99 

+

100 n_steps_sent_x = 0 

+

101 n_steps_sent_y = 0 

+

102 n_steps_sent_z = 0 

+

103 

+

104 k = 0 

+

105 while (abs(n_steps_x) > n_steps_sent_x) or (abs(n_steps_y) > n_steps_sent_y) or (abs(n_steps_z) > n_steps_sent_z): 

+

106 if k % 10 == 0: print(k) 

+

107 step_x = int(n_steps_sent_x < abs(n_steps_x)) 

+

108 step_y = int(n_steps_sent_y < abs(n_steps_y)) 

+

109 step_z = int(n_steps_sent_z < abs(n_steps_z)) 

+

110 #print step_x, step_y, step_z, dir_x, dir_y, dir_z 

+

111 self.step_motors(step_x, step_y, step_z, dir_x, dir_y, dir_z) 

+

112 limits = self._parse_resp(self.port.readline()) 

+

113 k += 1 

+

114 

+

115 n_steps_sent_x += step_x 

+

116 n_steps_sent_y += step_y 

+

117 n_steps_sent_z += step_z 

+

118 

+

119 except: 

+

120 import traceback 

+

121 traceback.print_exc() 

+

122 finally: 

+

123 self.sleep_motors() 

+

124 

+

125 def move2(self): 

+

126 self.wake_motors() 

+

127 for k in range(200): 

+

128 self.step_motors(1,0,0,0,0,0) 

+

129 limits = self._parse_resp(self.port.readline()) 

+

130 self.sleep_motors() 

+

131 

+

132 def go_to_min(self, verbose=False): 

+

133 can_move = self.read_limit_switches() 

+

134 x_can_decrease = can_move[0] 

+

135 y_can_decrease = can_move[2] 

+

136 z_can_decrease = can_move[4] 

+

137 

+

138 dir_x = dir_lut['x'][-1] 

+

139 dir_y = dir_lut['y'][-1] 

+

140 dir_z = dir_lut['z'][-1] 

+

141 

+

142 n_steps_sent_x = 0 

+

143 n_steps_sent_y = 0 

+

144 n_steps_sent_z = 0 

+

145 if x_can_decrease or y_can_decrease or z_can_decrease: 

+

146 self.wake_motors() 

+

147 

+

148 try: 

+

149 k = 0 

+

150 while x_can_decrease or y_can_decrease or z_can_decrease: 

+

151 step_x = int(x_can_decrease) 

+

152 step_y = int(y_can_decrease) 

+

153 step_z = int(z_can_decrease) 

+

154 if verbose: 

+

155 print(step_x, step_y, step_z) 

+

156 

+

157 self.step_motors(step_x, step_y, step_z, dir_x, dir_y, dir_z) 

+

158 can_move = self._parse_resp(self.port.readline()) 

+

159 

+

160 x_can_decrease = can_move[0] 

+

161 y_can_decrease = can_move[2] 

+

162 z_can_decrease = can_move[4] 

+

163 

+

164 n_steps_sent_x += step_x 

+

165 n_steps_sent_y += step_y 

+

166 n_steps_sent_z += step_z 

+

167 k += 1 

+

168 finally: 

+

169 self.sleep_motors() 

+

170 

+

171 return n_steps_sent_x, n_steps_sent_y, n_steps_sent_z 

+

172 

+

173 def go_to_max(self, verbose=False): 

+

174 can_move = self.read_limit_switches() 

+

175 x_can_increase = can_move[1] 

+

176 y_can_increase = can_move[3] 

+

177 z_can_increase = can_move[5] 

+

178 

+

179 dir_x = dir_lut['x'][1] 

+

180 dir_y = dir_lut['y'][1] 

+

181 dir_z = dir_lut['z'][1] 

+

182 

+

183 n_steps_sent_x = 0 

+

184 n_steps_sent_y = 0 

+

185 n_steps_sent_z = 0 

+

186 if x_can_increase or y_can_increase or z_can_increase: 

+

187 self.wake_motors() 

+

188 

+

189 try: 

+

190 k = 0 

+

191 while x_can_increase or y_can_increase or z_can_increase: 

+

192 step_x = int(x_can_increase) 

+

193 step_y = int(y_can_increase) 

+

194 step_z = int(z_can_increase) 

+

195 if verbose: 

+

196 print(step_x, step_y, step_z) 

+

197 

+

198 self.step_motors(step_x, step_y, step_z, dir_x, dir_y, dir_z) 

+

199 can_move = self._parse_resp(self.port.readline()) 

+

200 

+

201 x_can_increase = can_move[1] 

+

202 y_can_increase = can_move[3] 

+

203 z_can_increase = can_move[5] 

+

204 

+

205 n_steps_sent_x += step_x 

+

206 n_steps_sent_y += step_y 

+

207 n_steps_sent_z += step_z 

+

208 k += 1 

+

209 finally: 

+

210 self.sleep_motors() 

+

211 

+

212 return n_steps_sent_x, n_steps_sent_y, n_steps_sent_z 

+

213 

+

214 def continuous_move(self, n_steps_x, n_steps_y, n_steps_z): 

+

215 self.start_continuous_move(n_steps_x, n_steps_y, n_steps_z) 

+

216 return self.end_continuous_move() 

+

217 

+

218 def start_continuous_move(self, n_steps_x, n_steps_y, n_steps_z): 

+

219 ''' 

+

220 Same as 'continuous_move', but without blocking for a response/movement to finish before the function returns 

+

221 ''' 

+

222 self.wake_motors() 

+

223 msg = 'c' + struct.pack('>hhh', n_steps_x, n_steps_y, n_steps_z) + '\n' 

+

224 self.port.write(msg) 

+

225 

+

226 self.motor_dir = np.array([np.sign(n_steps_x), np.sign(n_steps_y), np.sign(n_steps_z)]) 

+

227 

+

228 def end_continuous_move(self, stiff=False): 

+

229 ''' 

+

230 Cleanup part of 'continuous_move' after 'start_continuous_move' has been called 

+

231 ''' 

+

232 movement_data = self.port.readline() 

+

233 

+

234 try: 

+

235 m = re.match(".*?: (\d+), (\d+), (\d+)", movement_data) 

+

236 n_steps_actuated = list(map(int, [m.group(x) for x in [1,2,3]])) 

+

237 except: 

+

238 import traceback 

+

239 traceback.print_exc() 

+

240 print(movement_data) 

+

241 

+

242 if not stiff: 

+

243 self.sleep_motors() 

+

244 

+

245 return n_steps_actuated 

+

246 

+

247 def calibrate(self, n_runs): 

+

248 ''' 

+

249 Repeatedly go from min to max and back so the number of steps can be counted 

+

250 ''' 

+

251 n_steps_min_to_max = [None]*n_runs 

+

252 n_steps_max_to_min = [None]*n_runs 

+

253 

+

254 self.go_to_min() 

+

255 time.sleep(1) 

+

256 

+

257 for k in range(n_runs): 

+

258 n_steps_min_to_max[k] = self.go_to_max() 

+

259 time.sleep(2) 

+

260 n_steps_max_to_min[k] = self.go_to_min() 

+

261 print("min to max") 

+

262 print(n_steps_min_to_max) 

+

263 print("max to min") 

+

264 print(n_steps_max_to_min) 

+

265 time.sleep(2) 

+

266 

+

267 return n_steps_min_to_max, n_steps_max_to_min 

+

268 

+

269 def data_available(self): 

+

270 return self.port.inWaiting() 

+

271 

+

272# from features.generator_features import Autostart 

+

273class PositionerTaskController(Sequence): 

+

274 ''' 

+

275 Interface between the positioner and the task interface. The positioner should run asynchronously 

+

276 so that the task event loop does not have to wait for a serial port response from the microcontroller. 

+

277 ''' 

+

278 

+

279 status = FSMTable( 

+

280 go_to_origin = StateTransitions(microcontroller_done='wait'), 

+

281 wait = StateTransitions(start_trial='move_target'), 

+

282 move_target = StateTransitions(microcontroller_done='reach', stoppable=False), 

+

283 reach = StateTransitions(time_expired='reward', new_target_set_remotely='move_target'), 

+

284 reward = StateTransitions(time_expired='wait'), 

+

285 ) 

+

286 # status = dict( 

+

287 # go_to_origin = dict(microcontroller_done='wait', stop=None), 

+

288 # wait = dict(start_trial='move_target', stop=None), 

+

289 # move_target = dict(microcontroller_done='reach'), 

+

290 # reach = dict(time_expired='reward', stop=None), 

+

291 # reward = dict(time_expired='wait'), 

+

292 # ) 

+

293 

+

294 

+

295 state = 'go_to_origin' 

+

296 

+

297 sequence_generators = ['random_target_calibration', 'xy_sweep'] 

+

298 reward_time = 1 

+

299 reach_time = 1 

+

300 

+

301 @staticmethod 

+

302 def random_target_calibration(n_blocks=10): 

+

303 # # constants selected approximately from one subject's ROM 

+

304 # targets = [ 

+

305 # (x_min, y_min, z_min), 

+

306 # (x_max, y_min, z_min), 

+

307 # (x_min, y_max, z_min), 

+

308 # (x_min, y_min, z_max), 

+

309 # (x_max, y_max, z_min), 

+

310 # (x_max, y_min, z_max), 

+

311 # (x_min, y_max, z_max), 

+

312 # (x_max, y_max, z_max), 

+

313 # ] 

+

314 

+

315 # trial_target_ls = [] 

+

316 # for k in range(n_blocks): 

+

317 # random.shuffle(targets) 

+

318 # for targ in targets: 

+

319 # trial_target_ls.append(dict(int_target_pos=targ)) 

+

320 

+

321 # # set the last target to be the origin since the purpose of this generator is to measure the drift in # of steps 

+

322 # trial_target_ls.append(dict(int_target_pos=np.zeros(3))) 

+

323 # return trial_target_ls 

+

324 

+

325 # @staticmethod  

+

326 # def calibration_targets(nblocks=1): 

+

327 targets = [ 

+

328 (45, 34, 0), 

+

329 (50, 38, -25), 

+

330 (40, 35, 0), 

+

331 (40, 35, -25), 

+

332 (30, 29, 0), 

+

333 (30, 29, -25), 

+

334 (20, 35, 0), 

+

335 (20, 35, -25), 

+

336 # (10, 38, 0), # reachable? 

+

337 # (10, 38, -25), # reachable? 

+

338 ] 

+

339 trial_target_ls = [] 

+

340 for k in range(n_blocks): 

+

341 random.shuffle(targets) 

+

342 for targ in targets: 

+

343 trial_target_ls.append(dict(int_target_pos=targ)) 

+

344 

+

345 # set the last target to be the origin since the purpose of this generator is to measure the drift in # of steps 

+

346 trial_target_ls.append(dict(int_target_pos=np.zeros(3))) 

+

347 return trial_target_ls 

+

348 

+

349 @staticmethod 

+

350 def xy_sweep(z_min=-25, z_max=0, zpts=6): 

+

351 

+

352 xy_target_locs = np.vstack([ 

+

353 [8.20564516129, 37.6302083333], 

+

354 [9.61693548387, 34.1145833333], 

+

355 [15.1209677419, 31.1848958333], 

+

356 [15.5443548387, 34.5703125], 

+

357 [18.2258064516, 36.5234375], 

+

358 [23.4475806452, 34.7005208333], 

+

359 [22.8830645161, 32.3567708333], 

+

360 [23.0241935484, 29.1666666667], 

+

361 [28.9516129032, 34.8307291667], 

+

362 [28.9516129032, 32.2265625], 

+

363 [29.2338709677, 30.1432291667], 

+

364 [33.3266129032, 35.4166666667], 

+

365 [33.8911290323, 33.1380208333], 

+

366 [30.5040322581, 30.078125], 

+

367 [20.4838709677, 28.1901041667], 

+

368 [35.5846774194, 36.5885416667], 

+

369 [39.2540322581, 33.5286458333], 

+

370 [41.5120967742, 38.5416666667], 

+

371 [47.439516129, 37.6953125], 

+

372 ]) 

+

373 

+

374 trial_target_ls = [] 

+

375 z_range = np.linspace(z_min, z_max, zpts) 

+

376 for zpt in z_range: 

+

377 for xy_targ in xy_target_locs: 

+

378 trial_target_ls.append(dict(int_target_pos=np.hstack([xy_targ, zpt]))) 

+

379 

+

380 return trial_target_ls 

+

381 

+

382 def __init__(self, *args, **kwargs): 

+

383 ''' 

+

384 Constructor for PositionerTaskController 

+

385 

+

386 Parameters 

+

387 ---------- 

+

388 # x_len : float 

+

389 # measured distance the positioner can travel in the x-dimension 

+

390 # y_len : float 

+

391 # measured distance the positioner can travel in the y-dimension 

+

392 # z_len : float 

+

393 # measured distance the positioner can travel in the z-dimension 

+

394 dev : str, optional, default=/dev/ttyACM1 

+

395 Serial port to use to communicate with Arduino controller 

+

396 x_cm_per_rev : int, optional, default=12 

+

397 Number of cm traveled for one full revolution of the stepper motors in the x-dimension 

+

398 y_cm_per_rev : int, optional, default=12 

+

399 Number of cm traveled for one full revolution of the stepper motors in the y-dimension 

+

400 z_cm_per_rev : float, optional, default=7.6 

+

401 Number of cm traveled for one full revolution of the stepper motors in the z-dimension 

+

402 x_step_size : float, optional, default=0.25 

+

403 Microstepping mode in the x-dimension 

+

404 y_step_size : float, optional, default=0.25 

+

405 Microstepping mode in the y-dimension 

+

406 z_step_size : float, optional, default=0.25 

+

407 Microstepping mode in the z-dimension 

+

408  

+

409 Returns 

+

410 ------- 

+

411 PositionerTaskController instance 

+

412 ''' 

+

413 

+

414 # TODO make these input arguments 

+

415 positioner_dev = '/dev/arduino_positioner' 

+

416 

+

417 # cm/rev based on measured data 

+

418 x_cm_per_rev = 12.4 

+

419 y_cm_per_rev = 12.4 

+

420 z_cm_per_rev = 8.0 

+

421 

+

422 x_step_size = 1./4 

+

423 y_step_size = 1./4 

+

424 z_step_size = 1./4 

+

425 

+

426 self.loc = np.ones(3) * np.nan # position of the target relative to origin is unknown until the origin limit switches are hit 

+

427 self.pos_uctrl_iface = Positioner(dev=positioner_dev) 

+

428 # self.pos_uctrl_iface.sleep_motors() 

+

429 

+

430 self.steps_from_origin = np.ones(3) * np.nan # cumulative number of steps taken from the origin. Unknown until the origin limit switches are hit. 

+

431 self.step_size = np.array([x_step_size, y_step_size, z_step_size], dtype=np.float64) 

+

432 self.cm_per_rev = np.array([x_cm_per_rev, y_cm_per_rev, z_cm_per_rev], dtype=np.float64) 

+

433 

+

434 self.full_steps_per_rev = 200. 

+

435 

+

436 super(PositionerTaskController, self).__init__(*args, **kwargs) 

+

437 

+

438 def init(self): 

+

439 self.add_dtype('positioner_loc', np.float64, (3,)) 

+

440 self.add_dtype('positioner_steps_from_origin', np.float64, (3,)) 

+

441 super(PositionerTaskController, self).init() 

+

442 

+

443 # open an rx_socket for reading new commands  

+

444 import socket 

+

445 self.rx_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 

+

446 self.rx_sock.bind(('localhost', 60005)) 

+

447 

+

448 def terminate(self): 

+

449 # close the rx socket used for reading remote commands 

+

450 super(PositionerTaskController, self).terminate() 

+

451 self.rx_sock.close() 

+

452 

+

453 ##### Helper functions ##### 

+

454 def _calc_steps_to_pos(self, target_pos): 

+

455 displ_cm = target_pos - self.loc 

+

456 

+

457 # compute the number of steps needed to travel the desired displacement 

+

458 displ_rev = displ_cm / self.cm_per_rev 

+

459 displ_full_steps = displ_rev * self.full_steps_per_rev 

+

460 displ_microsteps = displ_full_steps / self.step_size 

+

461 displ_microsteps = displ_microsteps.astype(int) 

+

462 return displ_microsteps 

+

463 

+

464 def _steps_to_cm(self, n_steps_actuated): 

+

465 steps_moved = n_steps_actuated * self.step_size 

+

466 self.steps_from_origin += steps_moved 

+

467 return steps_moved / self.full_steps_per_rev * self.cm_per_rev 

+

468 

+

469 def _integrate_steps(self, n_steps_actuated, signs): 

+

470 steps_moved = n_steps_actuated * signs * self.step_size 

+

471 self.steps_from_origin += steps_moved 

+

472 self.loc += steps_moved / self.full_steps_per_rev * self.cm_per_rev 

+

473 

+

474 def _test_microcontroller_done(self, *args, **kwargs): 

+

475 # check if any data has returned from the microcontroller interface 

+

476 # self.print_to_terminal("starting to check for data") 

+

477 bytes_avail = self.pos_uctrl_iface.data_available() 

+

478 

+

479 # self.print_to_terminal(bytes_avail) 

+

480 

+

481 # remember to actually read the data out of the buffer in an '_end' function 

+

482 return bytes_avail > 0 

+

483 

+

484 # def update_report_stats(self, *args, **kwargs): 

+

485 # super(PositionerTaskController, self).update_report_stats() 

+

486 # self.reportstats['resp_bytes'] = self.pos_uctrl_iface.data_available() 

+

487 

+

488 def _test_new_target_set_remotely(self, *args, **kwargs): 

+

489 # print "checking for new target set" 

+

490 socket_list = [self.rx_sock] 

+

491 # Get the list sockets which are readable 

+

492 read_sockets, write_sockets, error_sockets = select.select(socket_list , [], [], 0) 

+

493 if self.rx_sock in read_sockets: 

+

494 raw_data = self.rx_sock.recv(8*3) 

+

495 import struct 

+

496 new_pos = struct.unpack('ddd', raw_data) 

+

497 print("received new position!") 

+

498 print(new_pos) 

+

499 self._gen_int_target_pos = new_pos 

+

500 return True 

+

501 else: 

+

502 return False 

+

503 

+

504 ##### State transition functions ##### 

+

505 def _start_go_to_origin(self): 

+

506 print("_start_go_to_origin") 

+

507 self.pos_uctrl_iface.start_continuous_move(1000, 1000, -1000) 

+

508 

+

509 def _start_go_to_target(self,num_x,num_y,num_z): 

+

510 # AY modification - _start_go_to_origin sends the positioner to a predetermined location. Need to be able to send it 

+

511 # different target locations for the different target positions. Not currently working (also not implemented in tasklilst yet) 

+

512 print("_start_go_to_target") 

+

513 self.pos_uctrl_iface.start_continuous_move(num_x,num_y,num_z) 

+

514 

+

515 def _start_reward(self): 

+

516 pass 

+

517 

+

518 def _end_go_to_origin(self): 

+

519 steps_actuated = self.pos_uctrl_iface.end_continuous_move(stiff=True) 

+

520 

+

521 self.loc = np.zeros(3) 

+

522 self.steps_from_origin = np.zeros(3) 

+

523 

+

524 def _start_move_target(self): 

+

525 # calc number of steps from current pos to target pos 

+

526 displ_microsteps = self._calc_steps_to_pos(self._gen_int_target_pos) 

+

527 

+

528 # send command to initiatem movement  

+

529 self.pos_uctrl_iface.start_continuous_move(*displ_microsteps) 

+

530 

+

531 def _end_move_target(self): 

+

532 # send command to kill motors 

+

533 steps_actuated = self.pos_uctrl_iface.end_continuous_move() 

+

534 self._integrate_steps(steps_actuated, self.pos_uctrl_iface.motor_dir) 

+

535 

+

536 def _cycle(self): 

+

537 self.task_data['positioner_loc'] = self.loc 

+

538 self.task_data['positioner_steps_from_origin'] = self.steps_from_origin 

+

539 super(PositionerTaskController, self)._cycle() 

+

540 

+

541 ### Old functions ### 

+

542 def go_to_origin(self): 

+

543 ''' 

+

544 Tap the origin limit switches so the absolute position of the target can be estimated.  

+

545 Run at initialization and/or periodically to correct for any accumulating errors. 

+

546 ''' 

+

547 steps_moved = np.array(self.pos_uctrl_iface.continuous_move(-10000, -10000, 10000)) 

+

548 step_signs = np.array([-1, -1, 1], dtype=np.float64) * self.step_size 

+

549 

+

550 if not np.any(np.isnan(self.steps_from_origin)): # error accumulation correction 

+

551 self.steps_from_origin += step_signs * steps_moved 

+

552 

+

553 # if no position errors were accumulated, then self.steps_from_origin should all be 0 

+

554 acc_error = self.steps_from_origin 

+

555 print("accumulated step errors") 

+

556 print(acc_error) 

+

557 

+

558 self.loc = np.zeros(3) 

+

559 self.steps_from_origin = np.zeros(3) 

+

560 

+

561 def go_to_position(self, target_pos): 

+

562 if np.any(np.isnan(self.loc)): 

+

563 raise Exception("System must be 'zeroed' before it can go to an absolute location!") 

+

564 

+

565 displ_cm = target_pos - self.loc 

+

566 

+

567 # compute the number of steps needed to travel the desired displacement 

+

568 displ_rev = displ_cm / self.cm_per_rev 

+

569 displ_full_steps = displ_rev * full_steps_per_rev 

+

570 displ_microsteps = displ_full_steps / self.step_size 

+

571 

+

572 steps_moved = np.array(self.pos_uctrl_iface.continuous_move(*displ_microsteps)) 

+

573 steps_moved = steps_moved * np.sign(displ_microsteps) * self.step_size 

+

574 

+

575 self.steps_from_origin += steps_moved 

+

576 self.loc += steps_moved / self.full_steps_per_rev * self.cm_per_rev 

+

577 

+

578 def run(self): 

+

579 ''' 

+

580 Tell the positioner motors to turn off when the task ends 

+

581 ''' 

+

582 try: 

+

583 super(PositionerTaskController, self).run() 

+

584 finally: 

+

585 self.pos_uctrl_iface.sleep_motors() 

+

586 

+

587 

+

588from .calib import * 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_positioner_calib_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_positioner_calib_py.html new file mode 100644 index 00000000..fb284d1f --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_positioner_calib_py.html @@ -0,0 +1,140 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\positioner\calib.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1## ******* Raw Data ****** 

+

2##  

+

3## In [15]: pos.continuous_move(10000, 10000, 10000) 

+

4## c move: 5184, 3359, 4011 

+

5##  

+

6## Limit switches read: 110101 

+

7##  

+

8##  

+

9## In [16]: pos.continuous_move(-10000, -10000, -10000) 

+

10## c move: 5060, 3367, 3985 

+

11##  

+

12## Limit switches read: 111011 

+

13##  

+

14##  

+

15## In [17]: pos.continuous_move(10000, 10000, 10000) 

+

16## c move: 5183, 3359, 3994 

+

17##  

+

18## Limit switches read: 110101 

+

19##  

+

20##  

+

21## In [18]: pos.continuous_move(-10000, -10000, -10000) 

+

22## c move: 5044, 3367, 4000 

+

23##  

+

24## Limit switches read: 111010 

+

25##  

+

26##  

+

27## In [19]: pos.continuous_move(10000, 10000, 10000) 

+

28## c move: 5183, 3359, 3994 

+

29##  

+

30## Limit switches read: 110101 

+

31##  

+

32##  

+

33## In [20]: pos.continuous_move(-10000, -10000, -10000) 

+

34## c move: 5044, 3367, 742 

+

35##  

+

36## Limit switches read: 111011 

+

37##  

+

38##  

+

39## In [21]: pos.continuous_move(-10000, -10000, -10000) 

+

40## c move: 2, 0, 3248 

+

41##  

+

42## Limit switches read: 110111 

+

43##  

+

44##  

+

45## In [22]: pos.continuous_move(10000, 10000, 10000) 

+

46## c move: 5184, 3358, 3995 

+

47##  

+

48## Limit switches read: 111101 

+

49##  

+

50##  

+

51## In [23]: pos.continuous_move(-10000, -10000, -10000) 

+

52## c move: 5044, 3367, 490 

+

53##  

+

54## Limit switches read: 111010 

+

55##  

+

56##  

+

57## In [24]: pos.continuous_move(-10000, -10000, -10000) 

+

58## c move: 4, 0, 3488 

+

59##  

+

60## Limit switches read: 110111 

+

61 

+

62############################## End Raw Data ##################################### 

+

63import numpy as np 

+

64n_steps_min_to_max = np.vstack([ 

+

65 (5184, 3359, 4011), 

+

66 (5183, 3359, 3994), 

+

67 (5183, 3359, 3994), 

+

68 (5184, 3358, 3995), 

+

69]) 

+

70 

+

71n_steps_max_to_min = np.vstack([ 

+

72 (5060, 3367, 3985), 

+

73 (5044, 3367, 4000), 

+

74 (5044, 3367, 742+3248), 

+

75 (5044, 3367, 490+3488) 

+

76]) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_reward_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_reward_py.html new file mode 100644 index 00000000..0d3ac7e3 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_reward_py.html @@ -0,0 +1,437 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\reward.py: 33% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Code for interacting with the Crist reward system(s). Consult the Crist manual for the command protocol 

+

3''' 

+

4 

+

5 

+

6import glob 

+

7import time 

+

8 

+

9import struct 

+

10import binascii 

+

11import threading 

+

12import io 

+

13import traceback 

+

14 

+

15 

+

16import serial 

+

17import time 

+

18 

+

19import numpy 

+

20 

+

21try: 

+

22 import traits.api as traits 

+

23except: 

+

24 import enthought.traits.api as traits 

+

25 

+

26def _xsum(msg): 

+

27 ''' 

+

28 Compute the checksums for the messages which must be sent as part of the packet 

+

29 

+

30 Parameters 

+

31 ---------- 

+

32 msg : string 

+

33 Message to be sent over the serial port 

+

34 

+

35 Returns 

+

36 ------- 

+

37 char 

+

38 The 8-bit checksum of the entire message 

+

39 ''' 

+

40 chrval = [int(''.join(x), 16) for x in zip(*[iter(binascii.b2a_hex(msg))]*2)] 

+

41 return chr(sum(chrval) % 256) 

+

42 

+

43class Basic(object): 

+

44 ''' 

+

45 Bare-bones interface for the Crist reward system. Can give timed reward and drain on/off. 

+

46 This class is sufficient for all the tasks implemented as of Aug. 2014. 

+

47 ''' 

+

48 response_message_length = 7 

+

49 def __init__(self): 

+

50 ''' 

+

51 Constructor for basic reward system interface 

+

52 

+

53 Parameters 

+

54 ---------- 

+

55 None 

+

56 

+

57 Returns 

+

58 ------- 

+

59 Basic instance 

+

60 ''' 

+

61 self.port = serial.Serial('/dev/crist_reward', baudrate=38400) 

+

62 from config import config 

+

63 self.version = int(config.reward_sys['version']) 

+

64 if self.version==1: self.set_beeper_volume(128) 

+

65 time.sleep(.5) 

+

66 self.reset() 

+

67 

+

68 def _write(self, msg): 

+

69 ''' 

+

70 Send an arbitrary message over the serial port 

+

71 

+

72 Parameters 

+

73 ---------- 

+

74 msg : string 

+

75 Message to be sent over the serial port  

+

76 

+

77 Returns 

+

78 ------- 

+

79 msg_out : string 

+

80 Response from crist system after sending command 

+

81 ''' 

+

82 fmsg = msg+_xsum(msg) 

+

83 self.port.flushOutput() 

+

84 self.port.flushInput() 

+

85 self.port.write(fmsg) 

+

86 msg_out = self.port.read(self.port.inWaiting()) 

+

87 return msg_out 

+

88 

+

89 def reward(self, length): 

+

90 ''' 

+

91 Open the solenoid for some length of time 

+

92 

+

93 Parameters 

+

94 ---------- 

+

95 length : float 

+

96 Duration of time the solenoid should be open, in seconds. NOTE: in some versions of the system, there appears to be max of ~5s 

+

97 

+

98 Returns 

+

99 ------- 

+

100 None 

+

101 ''' 

+

102 length /= .1 

+

103 length = int(length) 

+

104 if self.version==0: 

+

105 self._write(struct.pack('<ccxHxx', '@', 'G', length)) 

+

106 elif self.version==1: 

+

107 self._write(struct.pack('<cccHxxx', '@', 'G', '1', length)) 

+

108 else: 

+

109 raise Exception("Unrecognized reward system version!") 

+

110 self.port.read(self.port.inWaiting()) 

+

111 

+

112 def setup_touch_sensor(self): 

+

113 ''' 

+

114 Send the serial command to initialize the Crist touch sensor 

+

115 ''' 

+

116 if self.version==1: #arc system 

+

117 cmd = ['@', 'C', '1' ,'O','%c' % 0b10000000, '%c' % 1, '%c' % 0, 'E'] 

+

118 stuff = ''.join(cmd) 

+

119 self._write(stuff) 

+

120 

+

121 def sensor_reward(self, length): 

+

122 ''' 

+

123 Set the duration of the reward if the subject touches the touch sensor 

+

124 

+

125 Parameters 

+

126 ---------- 

+

127 length : float 

+

128 Duration of time the solenoid should be open, in seconds. NOTE: in some versions of the system, there appears to be max of ~5s 

+

129 

+

130 Returns 

+

131 ------- 

+

132 None 

+

133 ''' 

+

134 if self.version==1: 

+

135 cmd = ['@', 'S', '%c' % 0x02, '%c' % 0x02, '%c' %10, '%c' %0, '%c' %0] 

+

136 stuff = ''.join(cmd) 

+

137 self._write(stuff) 

+

138 

+

139 def set_beeper_volume(self, volume): 

+

140 ''' 

+

141 Send a command to set the sound level of the audio beep paired with the solenoid opening 

+

142 

+

143 Parameters 

+

144 ---------- 

+

145 volume : int in range [0, 255] 

+

146 255 is max possible volume 

+

147 

+

148 Returns 

+

149 ------- 

+

150 string  

+

151 Response message from system 

+

152 ''' 

+

153 if not (volume >= 0 and volume <= 255): 

+

154 raise ValueError("Invalid beeper volume: %g" % volume) 

+

155 return self._write('@CS' + '%c' % volume + 'E' + struct.pack('xxx')) 

+

156 

+

157 def reset(self): 

+

158 ''' 

+

159 Send the system reset command 

+

160 ''' 

+

161 if self.version==0: 

+

162 self._write("@CPSNNN") 

+

163 elif self.version==1: 

+

164 cmd = ['@', 'C', '1', 'P', '%c' % 0b10000000, '%c' % 0, '%c' % 0, 'D'] 

+

165 stuff = ''.join(cmd) 

+

166 self._write(stuff) 

+

167 else: 

+

168 raise Exception("Unrecognized reward system version!") 

+

169 self.last_response = self.port.read(self.port.inWaiting()) 

+

170 

+

171 def drain(self, drain_time=1200): 

+

172 ''' 

+

173 Turns on the reward system drain for specified amount of time (in seconds) 

+

174 

+

175 Parameters 

+

176 ---------- 

+

177 drain_time : float  

+

178 Time to drain the system, in seconds. 

+

179 

+

180 Returns 

+

181 ------- 

+

182 None 

+

183 ''' 

+

184 assert drain_time > 0 

+

185 assert drain_time < 9999 

+

186 if self.version == 0: #have to wait and manually tell it to turn off 

+

187 self._write("@CNSENN") 

+

188 time.sleep(drain_time) 

+

189 self._write("@CNSDNN") 

+

190 elif self.version == 1: 

+

191 self._write('@M1' + struct.pack('H', drain_time) + 'D' + struct.pack('xx')) 

+

192 else: 

+

193 raise Exception("Unrecognized reward system version!") 

+

194 

+

195 def drain_off(self): 

+

196 ''' 

+

197 Turns off drain if currently on 

+

198 ''' 

+

199 if self.version==0: 

+

200 self._write("@CNSDNN") 

+

201 elif self.version==1: 

+

202 self._write('@M1' + struct.pack('H', 0) + 'A' + struct.pack('xx')) 

+

203 else: 

+

204 raise Exception("Unrecognized reward system version!") 

+

205 

+

206 

+

207########################################## 

+

208##### Code below this line is unused ##### 

+

209########################################## 

+

210class _parse_num(object): 

+

211 types = {1:'<B', 2:'<H', 4:'<I'} 

+

212 def __init__(self, length=2, mult=1, unit=""): 

+

213 self.t = self.types[length] 

+

214 self.m = mult 

+

215 self.u = unit 

+

216 def __getitem__(self, msg): 

+

217 if len(msg) == 3: 

+

218 msg += "\x00" 

+

219 i, = struct.unpack(self.t, msg) 

+

220 return i*self.m 

+

221 

+

222class System(traits.HasTraits, threading.Thread): 

+

223 ''' 

+

224 More complete reward system interface. Only tested for "version 0" of the system 

+

225 ''' 

+

226 _running = True 

+

227 port = traits.Instance(serial.Serial) 

+

228 

+

229 reward_mode = traits.Enum("Time", "Volume", "Count") 

+

230 sensor_status = traits.Bool 

+

231 drain_status = traits.Enum("Disabled", "PC enabled", "Switch enabled", "External enabled", "Time out") 

+

232 ctrl_status = traits.Enum("Run", "Halted high", "Halted low", "Offline", "Halted problem", "Override timeout") 

+

233 switch_state = traits.Enum("0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F") 

+

234 reward_freq = traits.Float 

+

235 touches_per_reward = traits.Int 

+

236 programmed_time = traits.Float 

+

237 programmed_volume = traits.Float 

+

238 total_rewards = traits.Int 

+

239 total_touches = traits.Int 

+

240 total_reward_time = traits.Float 

+

241 total_reward_volume = traits.Float 

+

242 sensor_id = traits.Str 

+

243 dispenser_id = traits.Str 

+

244 firmware_version = traits.Float 

+

245 

+

246 def __init__(self, **kwargs): 

+

247 super(System, self).__init__(**kwargs) 

+

248 threading.Thread.__init__(self) 

+

249 self.plock = threading.Lock() 

+

250 self.daemon = True 

+

251 self._messages = dict([ 

+

252 ("&D",(37, "Data", self._parse_status)), 

+

253 ("#E", (4, "Error")), 

+

254 ("#A", (4, "Acknowledge")), 

+

255 ("*Z", (8, "Volume Calibration")) 

+

256 ]) 

+

257 reward_mode = dict(T="Time", V="Volume", C="Count") 

+

258 drain_status = dict( 

+

259 P="PC enabled", 

+

260 S="Switch enabled", 

+

261 X="External enabled", 

+

262 D="Disabled", 

+

263 T="Time out") 

+

264 ctrl_status = dict( 

+

265 R="Run", 

+

266 H="halted, input voltage high", 

+

267 L="halted, input voltage low", 

+

268 D="Controller device off-line", 

+

269 M="Halted, message problem", 

+

270 O="override timeout") 

+

271 enable_disable = dict( 

+

272 D=False, 

+

273 E=True) 

+

274 self._order = [ 

+

275 (1, "reward_mode", reward_mode), 

+

276 (1, "sensor_status", enable_disable), 

+

277 (1, "drain_status", drain_status), 

+

278 (1, "ctrl_status", ctrl_status), 

+

279 (1, "switch_state", None), 

+

280 (2, "reward_freq", _parse_num(unit="s")), 

+

281 (1, "touches_per_reward", _parse_num(1)), 

+

282 (2, "programmed_time", _parse_num(2, 0.1, "ms")), 

+

283 (2, "programmed_volume", _parse_num(2, .001, "ml")), 

+

284 (2, "total_rewards", _parse_num()), 

+

285 (2, "total_touches", _parse_num()), 

+

286 (4, "total_reward_time", _parse_num(4, 0.1, "ms")), 

+

287 (3, "total_reward_volume", _parse_num(4, .001, "ml")), 

+

288 (5, "sensor_ID", None), 

+

289 (5, "dispenser_ID", None), 

+

290 (1, "firmware_version", _parse_num(1, .1)) 

+

291 ] 

+

292 #self.reset_stats() 

+

293 

+

294 def _parse_status(self, msg): 

+

295 msg = io.StringIO(msg) 

+

296 output = {} 

+

297 for length, name, op in self._order: 

+

298 part = msg.read(length) 

+

299 if op is None: 

+

300 output[name] = part 

+

301 else: 

+

302 output[name] = op[part] 

+

303 

+

304 self.set(**output) 

+

305 

+

306 def __del__(self): 

+

307 self._running = False 

+

308 

+

309 def run(self): 

+

310 while self._running: 

+

311 header = self.port.read(2) 

+

312 print("recieved %r"%header) 

+

313 try: 

+

314 self.plock.acquire() 

+

315 msg = self.port.read(self._messages[header][0] - 2) 

+

316 self.plock.release() 

+

317 assert _xsum(header+msg[:-1]) == msg[-1], "Wrong checksum! %s"%msg 

+

318 if len(self._messages[header]) > 2: 

+

319 self._messages[header][-1](msg) 

+

320 else: 

+

321 print(self._messages[header], repr(msg)) 

+

322 except: 

+

323 traceback.print_exc() 

+

324 time.sleep(10) 

+

325 print(repr(msg),repr(self.port.read(self.port.inWaiting()))) 

+

326 

+

327 def reward(self, time=500, volume=None): 

+

328 '''Returns the string used to output a time or volume reward. 

+

329  

+

330 Parameters 

+

331 ---------- 

+

332 time : int 

+

333 Time in milliseconds to turn on the reward 

+

334 volume: int 

+

335 volume in microliters 

+

336 ''' 

+

337 assert (volume is None and time is not None) or \ 

+

338 (volume is not None and time is None) 

+

339 time /= .1 

+

340 self._write(struct.pack('<ccxHxx', '@', 'G', time)) 

+

341 

+

342 def _write(self, msg): 

+

343 fmsg = msg+_xsum(msg) 

+

344 self.plock.acquire() 

+

345 self.port.write(fmsg) 

+

346 self.plock.release() 

+

347 

+

348 def update(self): 

+

349 self._write("@CNSNNN") 

+

350 

+

351 def reset(self): 

+

352 self._write("@CPSNNN") 

+

353 

+

354 def reset_stats(self): 

+

355 self._write("@CRSNNN") 

+

356 

+

357 def drain(self, status=None): 

+

358 mode = ("D", "E")[self.drain_status == "Disabled" if status is None else status] 

+

359 self._write("@CNS%sNN"%mode) 

+

360 

+

361def open(): 

+

362 try: 

+

363 #port = serial.Serial(glob.glob("/dev/ttyUSB*")[0], baudrate=38400) 

+

364 #reward = System(port=port) 

+

365 #reward.start() 

+

366 reward = Basic() 

+

367 return reward 

+

368 except: 

+

369 print("Reward system not found") 

+

370 import traceback 

+

371 import os 

+

372 import builtins 

+

373 traceback.print_exc(file=builtins.open(os.path.expanduser('~/code/bmi3d/log/reward.log'), 'w')) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_serial_dio_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_serial_dio_py.html new file mode 100644 index 00000000..2e4207a4 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_serial_dio_py.html @@ -0,0 +1,208 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\serial_dio.py: 36% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2DIO using a serial port + microcontroller instead of the NIDAQ card 

+

3''' 

+

4import serial 

+

5from collections import defaultdict 

+

6import struct 

+

7from numpy import binary_repr 

+

8from .dio.parse import MSG_TYPE_ROWBYTE, MSG_TYPE_REGISTER 

+

9import time 

+

10 

+

11def construct_word(aux, msg_type, data, n_bits_data=8, n_bits_msg_type=3): 

+

12 word = (aux << (n_bits_data + n_bits_msg_type)) | (msg_type << n_bits_data) | data 

+

13 return word 

+

14 

+

15baudrate = 115200 

+

16 

+

17class SendRowByte(object): 

+

18 ''' 

+

19 Send only an 8-bit data word corresponding to the 8 lower  

+

20 bits of the current row number of the HDF table 

+

21 ''' 

+

22 ''' 

+

23 Interface for sending all the task-generated data through the NIDAQ interface card 

+

24 ''' 

+

25 def __init__(self, device=None): 

+

26 ''' 

+

27 Constructor for SendRowByte 

+

28 

+

29 Parameters 

+

30 ---------- 

+

31 device : string, optional 

+

32 Linux name of the serial port for the Arduino board, defined by setserial 

+

33 

+

34 Returns 

+

35 ------- 

+

36 SendAll instance 

+

37 ''' 

+

38 self.systems = dict() 

+

39 self.port = serial.Serial('/dev/arduino_neurosync', baudrate=baudrate) 

+

40 self.n_systems = 0 

+

41 self.rowcount = defaultdict(int) 

+

42 

+

43 def close(self): 

+

44 ''' 

+

45 Release access to the Arduino serial port 

+

46 ''' 

+

47 # stop recording 

+

48 self.port.write('p') 

+

49 self.port.close() 

+

50 

+

51 def register(self, system, dtype): 

+

52 ''' 

+

53 Send information about the registration system (name and datatype) in string form, one byte at a time. 

+

54 

+

55 Parameters 

+

56 ---------- 

+

57 system : string 

+

58 Name of the system being registered 

+

59 dtype : np.dtype instance 

+

60 Datatype of incoming data, for later decoding of the binary data during analysis 

+

61 

+

62 Returns 

+

63 ------- 

+

64 None 

+

65 ''' 

+

66 # Save the index of the system being registered (arbitrary number corresponding to the order in which systems were registered) 

+

67 self.n_systems += 1 

+

68 self.systems[system] = self.n_systems 

+

69 

+

70 

+

71 # if self.n_systems > 1: 

+

72 # raise Exception("This currently only works for one system!") 

+

73 

+

74 #print "System Register: %s" % system, self.systems[system] 

+

75 #print "Arduino register %s" % system, self.systems[system] 

+

76 

+

77 #if self.n_systems > 1: 

+

78 # raise Exception("This currently only works for one system!") 

+

79 

+

80 print("Arduino register %s" % system, self.systems[system]) 

+

81 

+

82 for sys_name_chr in system: 

+

83 reg_word = construct_word(self.systems[system], MSG_TYPE_REGISTER, ord(sys_name_chr)) 

+

84 self._send_data_word_to_serial_port(reg_word) 

+

85 

+

86 null_term_word = construct_word(self.systems[system], MSG_TYPE_REGISTER, 0) # data payload is 0 for null terminator 

+

87 self._send_data_word_to_serial_port(null_term_word) 

+

88 

+

89 def sendMsg(self, msg): 

+

90 ''' 

+

91 Do nothing. Messages are stored with row numbers in the HDF table, so no need to also send the message over to the recording system. 

+

92 

+

93 Parameters 

+

94 ---------- 

+

95 msg : string 

+

96 Message to send 

+

97 

+

98 Returns 

+

99 ------- 

+

100 None 

+

101 ''' 

+

102 # there's no point in sending a message, since every message is  

+

103 # stored in the HDF table anyway with a row number,  

+

104 # and every row number is automatically synced. 

+

105 pass 

+

106 

+

107 def send(self, system, data): 

+

108 ''' 

+

109 Send the row number for a data word to the neural system 

+

110 

+

111 Parameters 

+

112 ---------- 

+

113 system : string  

+

114 Name of system  

+

115 data : object 

+

116 This is unused. Only used in the parent's version where the actual data, and not just the HDF row number, is sent. 

+

117 

+

118 Returns 

+

119 ------- 

+

120 None 

+

121 ''' 

+

122 

+

123 if not (system in self.systems): 

+

124 # if the system is not registered, do nothing 

+

125 return 

+

126 

+

127 current_sys_rowcount = self.rowcount[system] 

+

128 self.rowcount[system] += 1 

+

129 

+

130 # construct the data packet 

+

131 word = construct_word(self.systems[system], MSG_TYPE_ROWBYTE, current_sys_rowcount % 256) 

+

132 self._send_data_word_to_serial_port(word) 

+

133 

+

134 # if verbose: 

+

135 # print binary_repr(word, 16) 

+

136 # word_str = 'd' + struct.pack('<H', word) 

+

137 # self.port.write(word_str) 

+

138 

+

139 def _send_data_word_to_serial_port(self, word, verbose=False): 

+

140 #self.port.write(word) 

+

141 if verbose: 

+

142 print(binary_repr(word, 16)) 

+

143 word_str = 'd' + struct.pack('<H', word) 

+

144 self.port.write(word_str) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_setup_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_setup_py.html new file mode 100644 index 00000000..bf0751e1 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_setup_py.html @@ -0,0 +1,108 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\setup.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1#!/usr/bin/env python 

+

2''' 

+

3Install the cython code required to open plexon files.  

+

4''' 

+

5 

+

6# System imports 

+

7from distutils.core import setup 

+

8from distutils.extension import Extension 

+

9from Cython.Build import cythonize 

+

10import numpy as np 

+

11 

+

12psth = Extension("plexon.psth", 

+

13 ['plexon/psth.pyx', 'plexon/cpsth/psth.c'], 

+

14 include_dirs= ['.', np.get_include(), 'plexon/', 'plexon/cpsth/'], 

+

15 # define_macros = [('DEBUG', None)], 

+

16 # extra_compile_args=["-g"], 

+

17 # extra_link_args=["-g"], 

+

18) 

+

19 

+

20plexfile = Extension("plexon.plexfile", 

+

21 ['plexon/plexfile.pyx', 

+

22 'plexon/cplexfile/plexfile.c', 

+

23 'plexon/cplexfile/plexread.c', 

+

24 'plexon/cplexfile/dataframe.c', 

+

25 'plexon/cplexfile/inspect.c', 

+

26 'plexon/cpsth/psth.c'], 

+

27 include_dirs= [ '.', 

+

28 np.get_include(), 

+

29 'plexon/', 

+

30 'plexon/cpsth/', 

+

31 'plexon/cplexfile/' 

+

32 ], 

+

33# define_macros = [('DEBUG', None)], 

+

34# extra_compile_args=["-g"], 

+

35# extra_link_args=["-g"], 

+

36) 

+

37 

+

38setup( name = "Plexfile utilities", 

+

39 description = "Utilities for dealing with the Plexon neural streaming interface", 

+

40 author = "James Gao", 

+

41 version = "0.1.0", 

+

42 packages = ['plexon'], 

+

43 ext_modules = cythonize([psth, plexfile], include_path=['.', 'plexon/cython/', np.get_include()]) 

+

44 ) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_sink_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_sink_py.html new file mode 100644 index 00000000..223f0a99 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_sink_py.html @@ -0,0 +1,335 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\sink.py: 30% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Generic data sink. Sinks run in separate processes and interact with the main process through code here 

+

3''' 

+

4 

+

5import os 

+

6import inspect 

+

7import traceback 

+

8import multiprocessing as mp 

+

9 

+

10from . import source 

+

11from .mp_proxy import FuncProxy 

+

12 

+

13class DataSink(mp.Process): 

+

14 ''' 

+

15 Generic single-channel data sink 

+

16 ''' 

+

17 def __init__(self, output, **kwargs): 

+

18 ''' 

+

19 Constructor for DataSink 

+

20  

+

21 Parameters 

+

22 ---------- 

+

23 output : type 

+

24 data sink class to be implemented in the remote process 

+

25 kwargs : optional kwargs 

+

26 kwargs to instantiate the data sink 

+

27  

+

28 Returns 

+

29 ------- 

+

30 DataSink instance 

+

31 ''' 

+

32 super(DataSink, self).__init__() 

+

33 self.output = output 

+

34 self.kwargs = kwargs 

+

35 self.cmd_event = mp.Event() 

+

36 self.cmd_pipe, self._cmd_pipe = mp.Pipe() 

+

37 self.pipe, self._pipe = mp.Pipe() 

+

38 self.status = mp.Value('b', 1) # mp boolean used for terminating the remote process 

+

39 

+

40 self.methods = set(filter(lambda n: inspect.isfunction(getattr(output, n)), dir(output))) 

+

41 # python 2 version: inspect.ismethod doesn't work because the object is not instantiated 

+

42 # self.methods = set(n for n in dir(output) if inspect.ismethod(getattr(output, n))) 

+

43 

+

44 def run(self): 

+

45 ''' 

+

46 Run the sink system in a remote process 

+

47 

+

48 Parameters 

+

49 ---------- 

+

50 None 

+

51 

+

52 Returns 

+

53 ------- 

+

54 None 

+

55 ''' 

+

56 # instantiate the output interface 

+

57 output = self.output(**self.kwargs) 

+

58 

+

59 while self.status.value > 0: 

+

60 if self._pipe.poll(.001): 

+

61 system, data = self._pipe.recv() 

+

62 output.send(system, data) 

+

63 

+

64 if self.cmd_event.is_set(): 

+

65 cmd, args, kwargs = self._cmd_pipe.recv() 

+

66 try: 

+

67 if cmd == "getattr": 

+

68 ret = getattr(output, args[0]) 

+

69 else: 

+

70 ret = getattr(output, cmd)(*args, **kwargs) 

+

71 

+

72 except Exception as e: 

+

73 traceback.print_exc(file=open(os.path.expandvars('$BMI3D/log/data_sink_log'), 'a')) 

+

74 ret = e 

+

75 self.cmd_event.clear() 

+

76 self._cmd_pipe.send(ret) 

+

77 

+

78 # close the sink if the status bit has been set to 0 

+

79 output.close() 

+

80 print("ended datasink") 

+

81 

+

82 def __getattr__(self, attr): 

+

83 ''' 

+

84 Get the specified attribute of the sink in the remote process 

+

85  

+

86 Parameters 

+

87 ---------- 

+

88 attr : string 

+

89 Name of attribute  

+

90  

+

91 Returns 

+

92 ------- 

+

93 object: 

+

94 Value of specified named attribute 

+

95 ''' 

+

96 methods = object.__getattribute__(self, "methods") 

+

97 if attr in methods: 

+

98 return FuncProxy(attr, self.cmd_pipe, self.cmd_event) 

+

99 else: 

+

100 raise AttributeError("Can't get attribute: %s. Remote methods available: %s" % (attr, str(self.methods))) 

+

101 

+

102 def send(self, system, data): 

+

103 ''' 

+

104 Send data to the sink system running in the remote process 

+

105  

+

106 Parameters 

+

107 ---------- 

+

108 system : string 

+

109 Name of system (source) from which the data originated  

+

110 data : object 

+

111 Arbitrary data. The remote sink should know how to handle the data 

+

112  

+

113 Returns 

+

114 ------- 

+

115 None 

+

116 ''' 

+

117 if self.status.value > 0: 

+

118 self.pipe.send((system, data)) 

+

119 

+

120 def stop(self): 

+

121 ''' 

+

122 Instruct the sink to stop gracefully by setting the 'status' boolean 

+

123 

+

124 Parameters 

+

125 ---------- 

+

126 None 

+

127 

+

128 Returns 

+

129 ------- 

+

130 None 

+

131 ''' 

+

132 self.status.value = 0 

+

133 

+

134 def __del__(self): 

+

135 ''' 

+

136 Stop the remote sink when the object is destructed 

+

137 ''' 

+

138 self.stop() 

+

139 

+

140 

+

141class SinkManager(object): 

+

142 ''' Data Sink manager singleton to be used by features ''' 

+

143 def __init__(self): 

+

144 ''' 

+

145 Constructor for SinkManager 

+

146 

+

147 Parameters 

+

148 ---------- 

+

149 None 

+

150 

+

151 Returns 

+

152 ------- 

+

153 None 

+

154 ''' 

+

155 self.sinks = [] 

+

156 self.sources = [] 

+

157 self.registrations = dict() 

+

158 

+

159 def start(self, output, **kwargs): 

+

160 ''' 

+

161 Docstring 

+

162 

+

163 Parameters 

+

164 ---------- 

+

165 output : DATA_TYPE 

+

166 ARG_DESCR 

+

167 kwargs : optional kwargs 

+

168 ARG_DESCR 

+

169 

+

170 Returns 

+

171 ------- 

+

172 ''' 

+

173 print(("sinkmanager start %s"%output)) 

+

174 sink = DataSink(output, **kwargs) 

+

175 sink.start() 

+

176 self.registrations[sink] = set() 

+

177 for source, dtype in self.sources: 

+

178 sink.register(source, dtype) 

+

179 self.registrations[sink].add((source, dtype)) 

+

180 

+

181 self.sinks.append(sink) 

+

182 return sink 

+

183 

+

184 def register(self, system, dtype=None): 

+

185 ''' 

+

186 Register a system with all the known sinks 

+

187 

+

188 Parameters 

+

189 ---------- 

+

190 system : source.DataSource, source.MultiChanDataSource, or string  

+

191 System to register with all the sinks  

+

192 dtype : None (deprecated) 

+

193 Even if specified, this is overwritten in the 'else:' condition below 

+

194 

+

195 Returns 

+

196 ------- 

+

197 None 

+

198 ''' 

+

199 if isinstance(system, source.DataSource): 

+

200 name = system.name 

+

201 dtype = system.source.dtype 

+

202 elif isinstance(system, source.MultiChanDataSource): 

+

203 name = system.name 

+

204 dtype = system.send_to_sinks_dtype 

+

205 elif isinstance(system, str): 

+

206 name = system 

+

207 else: 

+

208 # assume that the system is a class 

+

209 name = system.__module__.split(".")[1] 

+

210 dtype = system.dtype 

+

211 

+

212 self.sources.append((name, dtype)) 

+

213 

+

214 for s in self.sinks: 

+

215 if (name, dtype) not in self.registrations[s]: 

+

216 self.registrations[s].add((name, dtype)) 

+

217 s.register(name, dtype) 

+

218 

+

219 def send(self, system, data): 

+

220 ''' 

+

221 Send data from the specified 'system' to all sinks which have been registered 

+

222 

+

223 Parameters 

+

224 ---------- 

+

225 system: string  

+

226 Name of the system sending the data 

+

227 data: np.array 

+

228 Generic data to be handled by each sink. Can be a record array, e.g., for task data. 

+

229 

+

230 Returns 

+

231 ------- 

+

232 None 

+

233 ''' 

+

234 for s in self.sinks: 

+

235 s.send(system, data) 

+

236 

+

237 def stop(self): 

+

238 ''' 

+

239 Run the 'stop' method of all the registered sinks 

+

240 ''' 

+

241 for s in self.sinks: 

+

242 s.stop() 

+

243 

+

244 def __iter__(self): 

+

245 ''' 

+

246 Returns a python iterator to allow looping over all the  

+

247 registered sinks, e.g., to send them all the same data 

+

248 ''' 

+

249 for s in self.sinks: 

+

250 yield s 

+

251 

+

252# Data Sink manager singleton to be used by features 

+

253sinks = SinkManager() 

+

254 

+

255 

+

256class PrintSink(object): 

+

257 '''A null sink which directly prints the received data''' 

+

258 def __init__(self): 

+

259 print("Starting print sink") 

+

260 

+

261 def register(self, name, dtype): 

+

262 print(("Registered name %s with dtype %r"%(name, dtype))) 

+

263 

+

264 def send(self, system, data): 

+

265 print(("Received %s data: \n%r"%(system, data))) 

+

266 

+

267 def sendMsg(self, msg): 

+

268 print(("### MESSAGE: %s"%msg)) 

+

269 

+

270 def close(self): 

+

271 print("Ended print sink") 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_source_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_source_py.html new file mode 100644 index 00000000..a5fb4b4f --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_source_py.html @@ -0,0 +1,768 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\source.py: 23% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Generic data source module. Sources run in separate processes and continuously collect/save 

+

3data and interact with the main process through the methods here. 

+

4''' 

+

5 

+

6import os 

+

7import sys 

+

8import time 

+

9import inspect 

+

10import traceback 

+

11import multiprocessing as mp 

+

12from multiprocessing import sharedctypes as shm 

+

13import ctypes 

+

14 

+

15import numpy as np 

+

16from .mp_proxy import FuncProxy 

+

17 

+

18from . import sink # this circular import is not ideal.. 

+

19 

+

20class DataSourceSystem(object): 

+

21 ''' 

+

22 Abstract base class for use with the generic DataSource infrastructure. Requirements: 

+

23 1) the class must have an attribute named 'dtype' which represents  

+

24 the data type of the source data. The datatype *cannot* change! 

+

25 2) the class must have an attribute named 'update_freq' which specifies  

+

26 the frequency at which new data samples will be ready for retrieval. 

+

27 3) 'start' method--no arguments 

+

28 4) 'stop' method--no arguments 

+

29 5) 'get' method--should return a single output argument 

+

30 ''' 

+

31 dtype = np.dtype([]) 

+

32 update_freq = 1 

+

33 def start(self): 

+

34 ''' 

+

35 Initialization for the source 

+

36 ''' 

+

37 pass 

+

38 

+

39 def stop(self): 

+

40 ''' 

+

41 Code to run when the data source is to be stopped 

+

42 ''' 

+

43 pass 

+

44 

+

45 def get(self): 

+

46 ''' 

+

47 Retrieve the current data available from the source.  

+

48 ''' 

+

49 pass 

+

50 

+

51class DataSource(mp.Process): 

+

52 ''' 

+

53 Generic single-channel data source 

+

54 ''' 

+

55 def __init__(self, source, bufferlen=10, name=None, send_data_to_sink_manager=True, **kwargs): 

+

56 ''' 

+

57 Parameters 

+

58 ---------- 

+

59 source: class compatible with DataSourceSystem 

+

60 Class to be instantiated as the "system" with changing data values.  

+

61 bufferlen: float 

+

62 Number of seconds long to make the ringbuffer. Seconds are converted to number  

+

63 of samples based on the 'update_freq' attribute of the source 

+

64 name: string, optional, default=None 

+

65 Name of the sink, i.e., HDF table. If one is not provided, it will be inferred based 

+

66 on the name of the source module 

+

67 send_data_to_sink_manager: boolean, optional, default=True 

+

68 Flag to indicate whether data should be saved to a sink (e.g., HDF file) 

+

69 kwargs: optional keyword arguments 

+

70 Passed to the source during object construction if any are specified 

+

71 

+

72 Returns 

+

73 ------- 

+

74 DataSource instance 

+

75 ''' 

+

76 super(DataSource, self).__init__() 

+

77 if name is not None: 

+

78 self.name = name 

+

79 else: 

+

80 self.name = source.__module__.split('.')[-1] 

+

81 self.filter = None 

+

82 self.source = source 

+

83 self.source_kwargs = kwargs 

+

84 self.bufferlen = bufferlen 

+

85 self.max_len = bufferlen * int(self.source.update_freq) 

+

86 self.slice_size = self.source.dtype.itemsize 

+

87 

+

88 self.lock = mp.Lock() 

+

89 self.idx = shm.RawValue('l', 0) 

+

90 self.data = shm.RawArray('c', self.max_len * self.slice_size) 

+

91 self.pipe, self._pipe = mp.Pipe() 

+

92 self.cmd_event = mp.Event() 

+

93 self.status = mp.Value('b', 1) 

+

94 self.stream = mp.Event() 

+

95 self.last_idx = 0 

+

96 

+

97 # self.methods = set(n for n in dir(source) if inspect.ismethod(getattr(source, n))) 

+

98 self.methods = set(filter(lambda n: inspect.isfunction(getattr(source, n)), dir(source))) 

+

99 

+

100 

+

101 # in DataSource.run, there is a call to "self.sinks.send(...)", 

+

102 # but if the DataSource was never registered with the sink manager, 

+

103 # then this line results in unnecessary IPC 

+

104 # so, set send_data_to_sink_manager to False if you want to avoid this 

+

105 self.send_data_to_sink_manager = send_data_to_sink_manager 

+

106 

+

107 def start(self, *args, **kwargs): 

+

108 ''' 

+

109 From Python's docs on the multiprocessing module: 

+

110 Start the process's activity. 

+

111 This must be called at most once per process object. It arranges for the object's run() method to be invoked in a separate process. 

+

112 

+

113 Parameters 

+

114 ---------- 

+

115 None 

+

116 

+

117 Returns 

+

118 ------- 

+

119 None 

+

120 ''' 

+

121 self.sinks = sink.sinks 

+

122 super(DataSource, self).start(*args, **kwargs) 

+

123 

+

124 def run(self): 

+

125 ''' 

+

126 Main function executed by the mp.Process object. This function runs in the *remote* process, not in the main process 

+

127 ''' 

+

128 try: 

+

129 system = self.source(**self.source_kwargs) 

+

130 system.start() 

+

131 except Exception as e: 

+

132 print("source.DataSource.run: unable to start source!") 

+

133 print(e) 

+

134 self.status.value = -1 

+

135 

+

136 streaming = True 

+

137 size = self.slice_size 

+

138 while self.status.value > 0: 

+

139 if self.cmd_event.is_set(): # if a command has been sent from the main task 

+

140 cmd, args, kwargs = self._pipe.recv() 

+

141 self.lock.acquire() 

+

142 try: 

+

143 if cmd == "getattr": 

+

144 ret = getattr(system, args[0]) 

+

145 else: 

+

146 ret = getattr(system, cmd)(*args, **kwargs) 

+

147 except Exception as e: 

+

148 print("source.DataSource.run: unable to process RPC call") 

+

149 traceback.print_exc() 

+

150 ret = e 

+

151 self.lock.release() 

+

152 self._pipe.send(ret) 

+

153 self.cmd_event.clear() 

+

154 

+

155 if self.stream.is_set(): 

+

156 self.stream.clear() 

+

157 streaming = not streaming 

+

158 if streaming: 

+

159 self.idx.value = 0 

+

160 system.start() 

+

161 else: 

+

162 system.stop() 

+

163 

+

164 if streaming: 

+

165 data = system.get() 

+

166 if self.send_data_to_sink_manager: 

+

167 self.sinks.send(self.name, data) 

+

168 if data is not None: 

+

169 # if not isinstance(data, np.ndarray): 

+

170 # raise ValueError("source.DataSource.run: Data returned from \ 

+

171 # source system must be an array to ensure type consistency!") 

+

172 

+

173 try: 

+

174 self.lock.acquire() 

+

175 i = self.idx.value % self.max_len 

+

176 self.data[i*size:(i+1)*size] = np.array(data).tostring() 

+

177 self.idx.value += 1 

+

178 self.lock.release() 

+

179 except Exception as e: 

+

180 print("source.DataSource.run, exception saving data to ring buffer") 

+

181 print(e) 

+

182 else: 

+

183 time.sleep(.001) 

+

184 

+

185 # stop the system once self.status.value has been set to a negative number 

+

186 system.stop() 

+

187 

+

188 def get(self, all=False, **kwargs): 

+

189 ''' 

+

190 Retreive data from the remote process 

+

191 

+

192 Parameters 

+

193 ---------- 

+

194 all : boolean, optional, default=False 

+

195 If true, returns all the data currently available. Since a finite buffer is used,  

+

196 this is NOT the same as all the data observed. (see 'bufferlen' in __init__ for buffer size) 

+

197 kwargs : optional kwargs  

+

198 To be passed to self.filter, if it is listed 

+

199 

+

200 Returns 

+

201 ------- 

+

202 np.recarray  

+

203 Datatype of record array is the dtype of the DataSourceSystem 

+

204 ''' 

+

205 if self.status.value <= 0: 

+

206 raise Exception('\n\nError starting datasource: %s\n\n' % self.name) 

+

207 

+

208 self.lock.acquire() 

+

209 i = (self.idx.value % self.max_len) * self.slice_size 

+

210 if all: 

+

211 if self.idx.value < self.max_len: 

+

212 data = self.data[:i] 

+

213 else: 

+

214 data = self.data[i:]+self.data[:i] 

+

215 else: 

+

216 mlen = min((self.idx.value - self.last_idx), self.max_len) 

+

217 last = ((self.idx.value - mlen) % self.max_len) * self.slice_size 

+

218 if last > i: 

+

219 data = self.data[last:] + self.data[:i] 

+

220 else: 

+

221 data = self.data[last:i] 

+

222 

+

223 self.last_idx = self.idx.value 

+

224 self.lock.release() 

+

225 try: 

+

226 data = np.fromstring(data, dtype=self.source.dtype) 

+

227 except: 

+

228 print("can't get fromstring...") 

+

229 

+

230 if self.filter is not None: 

+

231 return self.filter(data, **kwargs) 

+

232 return data 

+

233 

+

234 def read(self, n_pts=1, **kwargs): 

+

235 ''' 

+

236 Read the last n_pts out of the buffer? Not sure how this is different from .get, and it doesn't appear to be used in any existing code.... 

+

237 ''' 

+

238 if self.status.value <= 0: 

+

239 raise Exception('\n\nError starting datasource: %s\n\n' % self.name) 

+

240 

+

241 self.lock.acquire() 

+

242 idx = self.idx.value % self.max_len 

+

243 i = idx * self.slice_size 

+

244 

+

245 if n_pts > self.max_len: 

+

246 n_pts = self.max_len 

+

247 

+

248 if idx >= n_pts: # no wrap-around required 

+

249 data = self.data[(idx-n_pts)*self.slice_size:idx*self.slice_size] 

+

250 else: 

+

251 data = self.data[-(n_pts-idx)*self.slice_size:] + self.data[:idx*self.slice_size] 

+

252 

+

253 self.lock.release() 

+

254 try: 

+

255 data = np.fromstring(data, dtype=self.source.dtype) 

+

256 except: 

+

257 print("can't get fromstring...") 

+

258 

+

259 if self.filter is not None: 

+

260 return self.filter(data, **kwargs) 

+

261 return data 

+

262 

+

263 def pause(self): 

+

264 ''' 

+

265 Used to toggle the 'streaming' variable in the remote "run" process  

+

266 ''' 

+

267 self.stream.set() 

+

268 

+

269 def stop(self): 

+

270 ''' 

+

271 Set self.status.value to negative so that the while loop in self.run() terminates 

+

272 ''' 

+

273 self.status.value = -1 

+

274 

+

275 def __del__(self): 

+

276 ''' 

+

277 Make sure the remote process stops if the Source object is destroyed 

+

278 ''' 

+

279 self.stop() 

+

280 

+

281 def __getattr__(self, attr): 

+

282 ''' 

+

283 Try to retreive attributes from the remote DataSourceSystem if the are not found in the proximal Source object 

+

284 

+

285 Parameters 

+

286 ---------- 

+

287 attr : string  

+

288 Name of attribute to retreive 

+

289 

+

290 Returns 

+

291 ------- 

+

292 object 

+

293 The arbitrary value associated with the named attribute, if it exists. 

+

294 ''' 

+

295 methods = object.__getattribute__(self, "methods") # this is done instead of "self.methods" to avoid infinite recursion in Windows 

+

296 if attr in methods: 

+

297 # if the attribute requested is an instance method of the 'source', return a proxy to the remote source's method 

+

298 return FuncProxy(attr, self.pipe, self.cmd_event) 

+

299 elif not attr.startswith("__"): 

+

300 # try to look up the attribute remotely 

+

301 self.pipe.send(("getattr", (attr,), {})) 

+

302 self.cmd_event.set() 

+

303 return self.pipe.recv() 

+

304 raise AttributeError(attr) 

+

305 

+

306 

+

307class MultiChanDataSource(mp.Process): 

+

308 ''' 

+

309 Multi-channel version of 'DataSource' 

+

310 ''' 

+

311 def __init__(self, source, bufferlen=5, name=None, send_data_to_sink_manager=False, **kwargs): 

+

312 ''' 

+

313 Parameters 

+

314 ---------- 

+

315 source: class  

+

316 lower-level class for interacting directly with the incoming data (e.g., plexnet) 

+

317 bufferlen: int 

+

318 Constrains the maximum amount of data history stored by the source 

+

319 name: string, optional, default=None 

+

320 Name of the sink, i.e., HDF table. If one is not provided, it will be inferred based 

+

321 on the name of the source module 

+

322 send_data_to_sink_manager: boolean, optional, default=False 

+

323 Flag to indicate whether data should be saved to a sink (e.g., HDF file)  

+

324 kwargs: dict, optional, default = {} 

+

325 For the multi-channel data source, you MUST specify a 'channels' keyword argument 

+

326 Note that kwargs['channels'] does not need to a list of integers, 

+

327 it can also be a list of strings. 

+

328 ''' 

+

329 

+

330 super(MultiChanDataSource, self).__init__() 

+

331 if name is not None: 

+

332 self.name = name 

+

333 else: 

+

334 self.name = source.__module__.split('.')[-1] 

+

335 self.filter = None 

+

336 self.source = source 

+

337 self.source_kwargs = kwargs 

+

338 self.bufferlen = bufferlen 

+

339 self.max_len = int(bufferlen * self.source.update_freq) 

+

340 self.channels = kwargs['channels'] 

+

341 self.chan_to_row = dict() 

+

342 for row, chan in enumerate(self.channels): 

+

343 self.chan_to_row[chan] = row 

+

344 

+

345 self.n_chan = len(self.channels) 

+

346 dtype = self.source.dtype # e.g., np.dtype('float') for LFP 

+

347 self.slice_size = dtype.itemsize 

+

348 self.idxs = shm.RawArray('l', self.n_chan) 

+

349 self.last_read_idxs = np.zeros(self.n_chan) 

+

350 rawarray = shm.RawArray('c', self.n_chan * self.max_len * self.slice_size) 

+

351 

+

352 

+

353 self.data = np.frombuffer(rawarray, dtype).reshape((self.n_chan, self.max_len)) 

+

354 

+

355 

+

356 

+

357 #self.fo2 = open('/storage/rawdata/test_rda_get.txt','w') 

+

358 #self.fo3 = open('/storage/rawdata/test_rda_run.txt','w') 

+

359 

+

360 

+

361 self.lock = mp.Lock() 

+

362 self.pipe, self._pipe = mp.Pipe() 

+

363 self.cmd_event = mp.Event() 

+

364 self.status = mp.Value('b', 1) 

+

365 self.stream = mp.Event() 

+

366 self.data_has_arrived = mp.Value('b', 0) 

+

367 

+

368 self.methods = set(n for n in dir(source) if inspect.ismethod(getattr(source, n))) 

+

369 

+

370 self.send_data_to_sink_manager = send_data_to_sink_manager 

+

371 if self.send_data_to_sink_manager: 

+

372 self.send_to_sinks_dtype = np.dtype([('chan'+str(chan), dtype) for chan in kwargs['channels']]) 

+

373 self.next_send_idx = mp.Value('l', 0) 

+

374 self.wrap_flags = shm.RawArray('b', self.n_chan) # zeros/Falses by default 

+

375 self.supp_hdf_file = kwargs['supp_file'] 

+

376 

+

377 

+

378 

+

379 def register_supp_hdf(self): 

+

380 try: 

+

381 from ismore.brainamp import brainamp_hdf_writer 

+

382 except: 

+

383 from riglib.ismore import brainamp_hdf_writer 

+

384 self.supp_hdf = brainamp_hdf_writer.BrainampData(self.supp_hdf_file, self.channels, self.send_to_sinks_dtype) 

+

385 

+

386 

+

387 def verify_data_arrival(self): 

+

388 try: 

+

389 from ismore.brainamp.brainamp_features import verify_data_arrival 

+

390 except: 

+

391 from riglib.ismore.brainamp.brainamp_features import verify_data_arrival 

+

392 

+

393 

+

394 

+

395 

+

396 def start(self, *args, **kwargs): 

+

397 ''' 

+

398 From Python's docs on the multiprocessing module: 

+

399 Start the process's activity. 

+

400 This must be called at most once per process object. It arranges for the object's run() method to be invoked in a separate process. 

+

401 

+

402 Parameters 

+

403 ---------- 

+

404 None 

+

405 

+

406 Returns 

+

407 ------- 

+

408 None 

+

409 ''' 

+

410 self.sinks = sink.sinks 

+

411 super(MultiChanDataSource, self).start(*args, **kwargs) 

+

412 

+

413 def run(self): 

+

414 ''' 

+

415 Main function executed by the mp.Process object. This function runs in the *remote* process, not in the main process 

+

416 ''' 

+

417 print(("Starting datasource %r" % self.source)) 

+

418 if self.send_data_to_sink_manager: 

+

419 print(("Registering Supplementary HDF file for datasource %r" % self.source)) 

+

420 self.register_supp_hdf() 

+

421 

+

422 try: 

+

423 system = self.source(**self.source_kwargs) 

+

424 system.start() 

+

425 

+

426 except Exception as e: 

+

427 print(e) 

+

428 self.status.value = -1 

+

429 

+

430 streaming = True 

+

431 size = self.slice_size 

+

432 while self.status.value > 0: 

+

433 if self.cmd_event.is_set(): 

+

434 cmd, args, kwargs = self._pipe.recv() 

+

435 self.lock.acquire() 

+

436 try: 

+

437 if cmd == "getattr": 

+

438 ret = getattr(system, args[0]) 

+

439 else: 

+

440 ret = getattr(system, cmd)(*args, **kwargs) 

+

441 except Exception as e: 

+

442 traceback.print_exc() 

+

443 ret = e 

+

444 self.lock.release() 

+

445 self._pipe.send(ret) 

+

446 self.cmd_event.clear() 

+

447 

+

448 if self.stream.is_set(): 

+

449 self.stream.clear() 

+

450 streaming = not streaming 

+

451 if streaming: 

+

452 self.idx.value = 0 

+

453 system.start() 

+

454 else: 

+

455 system.stop() 

+

456 

+

457 if streaming: 

+

458 # system.get() must return a tuple (chan, data), where:  

+

459 # chan is the the channel number 

+

460 # data is a numpy array with a dtype (or subdtype) of 

+

461 # self.source.dtype 

+

462 #print 'before get' 

+

463 

+

464 

+

465 chan, data = system.get() 

+

466 #self.fo3.write(str(data[0][0]) + ' ' + str(time.time()) + ' \n') 

+

467 

+

468 #print 'after get' 

+

469 if data is not None: 

+

470 try: 

+

471 self.lock.acquire() 

+

472 

+

473 try: 

+

474 row = self.chan_to_row[chan] # row in ringbuffer corresponding to this channel 

+

475 except KeyError: 

+

476 # print 'data source was not configured to get data on channel', chan 

+

477 pass 

+

478 else: 

+

479 n_pts = len(data) 

+

480 max_len = self.max_len 

+

481 

+

482 if n_pts > max_len: 

+

483 data = data[-max_len:] 

+

484 n_pts = max_len 

+

485 

+

486 idx = self.idxs[row] # for this channel, idx in ringbuffer 

+

487 if idx + n_pts <= self.max_len: 

+

488 self.data[row, idx:idx+n_pts] = data 

+

489 idx = (idx + n_pts) 

+

490 if idx == self.max_len: 

+

491 idx = 0 

+

492 if self.send_data_to_sink_manager: 

+

493 self.wrap_flags[row] = True 

+

494 else: # need to write data at both end and start of buffer 

+

495 self.data[row, idx:] = data[:max_len-idx] 

+

496 self.data[row, :n_pts-(max_len-idx)] = data[max_len-idx:] 

+

497 idx = n_pts-(max_len-idx) 

+

498 if self.send_data_to_sink_manager: 

+

499 self.wrap_flags[row] = True 

+

500 self.idxs[row] = idx 

+

501 

+

502 self.lock.release() 

+

503 

+

504 # Set the flag indicating that data has arrived from the source 

+

505 self.data_has_arrived.value = 1 

+

506 except Exception as e: 

+

507 print(e) 

+

508 

+

509 if self.send_data_to_sink_manager: 

+

510 self.lock.acquire() 

+

511 

+

512 # check if there is at least one column of data that 

+

513 # has not yet been sent to the sink manager 

+

514 if all(self.next_send_idx.value < idx + int(flag)*self.max_len for (idx, flag) in zip(self.idxs, self.wrap_flags)): 

+

515 start_idx = self.next_send_idx.value 

+

516 if not all(self.wrap_flags): 

+

517 

+

518 # look at minimum value of self.idxs only  

+

519 # among channels which have not wrapped,  

+

520 # in order to determine end_idx 

+

521 end_idx = np.min([idx for (idx, flag) in zip(self.idxs, self.wrap_flags) if not flag]) 

+

522 idxs_to_send = list(range(start_idx, end_idx)) 

+

523 else: 

+

524 min_idx = np.min(self.idxs[:]) 

+

525 idxs_to_send = list(range(start_idx, self.max_len)) + list(range(0, min_idx)) 

+

526 

+

527 for row in range(self.n_chan): 

+

528 self.wrap_flags[row] = False 

+

529 

+

530 # Old way to send data to the sink manager, one column at a time 

+

531 # for idx in idxs_to_send: 

+

532 # data = np.array([tuple(self.data[:, idx])], dtype=self.send_to_sinks_dtype) 

+

533 # print "data shape" 

+

534 # print data.shape 

+

535 # self.sinks.send(self.name, data) 

+

536 

+

537 # # # New way to send data (in blocks) (update 1/12/2016): all columns at a time 

+

538 #ix_ = np.ix_(np.arange(self.data.shape[0]), idxs_to_send) 

+

539 #data = np.array(self.data[ix_], dtype=self.send_to_sinks_dtype) 

+

540 #self.sinks.send(self.name, data) 

+

541 

+

542 #Newest way to send data to the supp hdf file, all columns at a time (1/21/2016) 

+

543 data = np.array(list(map(tuple, self.data[:, idxs_to_send].T)), dtype = self.send_to_sinks_dtype) 

+

544 self.supp_hdf.add_data(data) 

+

545 

+

546 

+

547 self.next_send_idx.value = np.mod(idxs_to_send[-1] + 1, self.max_len) 

+

548 

+

549 self.lock.release() 

+

550 else: 

+

551 time.sleep(.001) 

+

552 

+

553 if hasattr(self, "supp_hdf"): 

+

554 self.supp_hdf.close_data() 

+

555 print('end of supp hdf') 

+

556 

+

557 system.stop() 

+

558 print(("ended datasource %r" % self.source)) 

+

559 

+

560 

+

561 

+

562 def get(self, n_pts, channels, **kwargs): 

+

563 ''' 

+

564 Return the most recent n_pts of data from the specified channels. 

+

565 

+

566 Parameters 

+

567 ---------- 

+

568 n_pts : int 

+

569 Number of data points to read 

+

570 channels : iterable 

+

571 Channels from which to read 

+

572 

+

573 Returns 

+

574 ------- 

+

575 list of np.recarray objects 

+

576 Datatype of each record array is the dtype of the DataSourceSystem 

+

577 ''' 

+

578 if self.status.value <= 0: 

+

579 raise Exception('\n\nError starting datasource: %s\n\n' % self.name) 

+

580 

+

581 self.lock.acquire() 

+

582 

+

583 # these channels must be a subset of the channels passed into __init__ 

+

584 n_chan = len(channels) 

+

585 data = np.zeros((n_chan, n_pts), dtype=self.source.dtype) 

+

586 

+

587 if n_pts > self.max_len: 

+

588 n_pts = self.max_len 

+

589 # print "channels", channels[-1] 

+

590 for chan_num, chan in enumerate(channels): 

+

591 try: 

+

592 row = self.chan_to_row[chan] 

+

593 except KeyError: 

+

594 print(('data source was not configured to get data on channel', chan)) 

+

595 else: # executed if try clause does not raise a KeyError 

+

596 idx = self.idxs[row] 

+

597 if idx >= n_pts: # no wrap-around required 

+

598 data[chan_num, :] = self.data[row, idx-n_pts:idx] 

+

599 else: 

+

600 data[chan_num, :n_pts-idx] = self.data[row, -(n_pts-idx):] 

+

601 data[chan_num, n_pts-idx:] = self.data[row, :idx] 

+

602 self.last_read_idxs[row] = idx 

+

603 #print (data['data']) 

+

604 #self.fo2.write(str(data[0,0]['data']) + ' ' + str(len(data[0,:]['data'])) + ' ' + str(time.time()) + ' \n') 

+

605 self.lock.release() 

+

606 

+

607 if self.filter is not None: 

+

608 return self.filter(data, **kwargs) 

+

609 return data 

+

610 

+

611 def get_new(self, channels, **kwargs): 

+

612 ''' 

+

613 Return the new (unread) data from the specified channels. 

+

614 

+

615 Parameters 

+

616 ---------- 

+

617 channels : iterable 

+

618 Channels from which to read  

+

619 kwargs : optional kwargs  

+

620 To be passed to self.filter, if it is listed 

+

621 

+

622 Returns 

+

623 ------- 

+

624 list of np.recarray objects 

+

625 Datatype of each record array is the dtype of the DataSourceSystem  

+

626 ''' 

+

627 if self.status.value <= 0: 

+

628 raise Exception('\n\nError starting datasource: %s\n\n' % self.name) 

+

629 

+

630 self.lock.acquire() 

+

631 

+

632 # these channels must be a subset of the channels passed into __init__ 

+

633 n_chan = len(channels) 

+

634 data = [] 

+

635 

+

636 for chan in channels: 

+

637 try: 

+

638 row = self.chan_to_row[chan] 

+

639 except KeyError: 

+

640 print(('data source was not configured to get data on channel', chan)) 

+

641 data.append(None) 

+

642 else: # executed if try clause does not raise a KeyError 

+

643 idx = self.idxs[row] 

+

644 last_read_idx = self.last_read_idxs[row] 

+

645 if last_read_idx <= idx: # no wrap-around required 

+

646 data.append(self.data[row, last_read_idx:idx]) 

+

647 else: 

+

648 data.append(np.hstack((self.data[row, last_read_idx:], self.data[row, :idx]))) 

+

649 self.last_read_idxs[row] = idx 

+

650 

+

651 self.lock.release() 

+

652 

+

653 if self.filter is not None: 

+

654 return self.filter(data, **kwargs) 

+

655 return data 

+

656 

+

657 def pause(self): 

+

658 ''' 

+

659 Used to toggle the 'streaming' variable in the remote "run" process  

+

660 ''' 

+

661 self.stream.set() 

+

662 

+

663 def check_if_data_has_arrived(self): 

+

664 ''' 

+

665 ''' 

+

666 return self.data_has_arrived.value 

+

667 

+

668 def stop(self): 

+

669 ''' 

+

670 Set self.status.value to negative so that the while loop in self.run() terminates 

+

671 ''' 

+

672 self.status.value = -1 

+

673 # self.fo2.close() 

+

674 # self.fo3.close() 

+

675 

+

676 def __del__(self): 

+

677 ''' 

+

678 Make sure the remote process stops if the Source object is destroyed 

+

679 ''' 

+

680 self.stop() 

+

681 

+

682 def __getattr__(self, attr): 

+

683 ''' 

+

684 Try to retreive attributes from the remote DataSourceSystem if the are not found in the proximal Source object 

+

685 

+

686 Parameters 

+

687 ---------- 

+

688 attr : string  

+

689 Name of attribute to retreive 

+

690 

+

691 Returns 

+

692 ------- 

+

693 object 

+

694 The arbitrary value associated with the named attribute, if it exists. 

+

695 ''' 

+

696 if attr in self.methods: 

+

697 return FuncProxy(attr, self.pipe, self.cmd_event) 

+

698 elif not attr.beginsWith("__"): 

+

699 print(("getting attribute %s" % attr)) 

+

700 self.pipe.send(("getattr", (attr,), {})) 

+

701 self.cmd_event.set() 

+

702 return self.pipe.recv() 

+

703 raise AttributeError(attr) 

+

704 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl___init___py.html new file mode 100644 index 00000000..03203d5e --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl___init___py.html @@ -0,0 +1,74 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\__init__.py: 67% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2__init__ script for stereo_opengl module 

+

3''' 

+

4 

+

5try: 

+

6 from .window import Window 

+

7 from .render import stereo 

+

8 from .textures import Texture 

+

9except: 

+

10 pass 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_environment_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_environment_py.html new file mode 100644 index 00000000..d6b71d67 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_environment_py.html @@ -0,0 +1,111 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\environment.py: 33% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Various graphical "environmental" or "world" add-ins to graphical tasks.  

+

3''' 

+

4from .models import Group 

+

5from .xfm import Quaternion 

+

6from riglib.stereo_opengl.primitives import Sphere, Cylinder 

+

7 

+

8 

+

9class Box(Group): 

+

10 ''' 

+

11 Construct a 3D wireframe box in the world to add some depth cue references 

+

12 ''' 

+

13 def __init__(self, **kwargs): 

+

14 ''' 

+

15 Constructor for Box  

+

16 

+

17 Parameters 

+

18 ---------- 

+

19 kwargs: optional keyword arguments 

+

20 All passed to parent constructor 

+

21 

+

22 Returns 

+

23 ------- 

+

24 Box instance 

+

25 ''' 

+

26 bcolor = (181/256., 116/256., 96/256., 1) 

+

27 sidelen = 16 

+

28 linerad=.1 

+

29 self.vert_box = Group([ 

+

30 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(-sidelen/2, -sidelen/2, -sidelen/2), 

+

31 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(sidelen/2, -sidelen/2, -sidelen/2), 

+

32 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(sidelen/2, sidelen/2, -sidelen/2), 

+

33 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(-sidelen/2, sidelen/2, -sidelen/2)]) 

+

34 self.hor_box = Group([ 

+

35 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(-sidelen/2, -sidelen/2, -sidelen/2), 

+

36 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(sidelen/2, -sidelen/2, -sidelen/2), 

+

37 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(sidelen/2, sidelen/2, -sidelen/2), 

+

38 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(-sidelen/2, sidelen/2, -sidelen/2)]) 

+

39 self.depth_box = Group([ 

+

40 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(-sidelen/2, -sidelen/2, -sidelen/2), 

+

41 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(sidelen/2, -sidelen/2, -sidelen/2), 

+

42 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(sidelen/2, sidelen/2, -sidelen/2), 

+

43 Cylinder(radius=linerad, height=sidelen, color=bcolor).translate(-sidelen/2, sidelen/2, -sidelen/2)]) 

+

44 self.hor_box.xfm.rotate = Quaternion.rotate_vecs((0,0,1), (1,0,0)) 

+

45 self.depth_box.xfm.rotate = Quaternion.rotate_vecs((0,0,1), (0,1,0)) 

+

46 self.box = Group([self.hor_box, self.depth_box, self.vert_box]) 

+

47 super(Box, self).__init__([self.box], **kwargs) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_ik_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_ik_py.html new file mode 100644 index 00000000..8102d80b --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_ik_py.html @@ -0,0 +1,746 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\ik.py: 45% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2This module implements IK functions for running inverse kinematics. 

+

3Current, only a two-joint system can be modelled. 

+

4''' 

+

5 

+

6import numpy as np 

+

7 

+

8from .xfm import Quaternion 

+

9from .models import Group 

+

10from .primitives import Cylinder, Sphere, Cone 

+

11from .textures import TexModel 

+

12from .utils import cloudy_tex 

+

13from collections import OrderedDict 

+

14from riglib.bmi import robot_arms 

+

15 

+

16pi = np.pi 

+

17 

+

18joint_angles_dtype = [('sh_pflex', np.float64), ('sh_pabd', np.float64), ('sh_prot', np.float64), ('el_pflex', np.float64), ('el_psup', np.float64)] 

+

19joint_vel_dtype = [('sh_vflex', np.float64), ('sh_vabd', np.float64), ('sh_vrot', np.float64), ('el_vflex', np.float64), ('el_vsup', np.float64)] 

+

20 

+

21def inv_kin_2D(pos, l_upperarm, l_forearm, vel=None): 

+

22 ''' 

+

23 Inverse kinematics for a 2D arm. This function returns all 5 angles required 

+

24 to specify the pose of the exoskeleton (see riglib.bmi.train for the  

+

25 definitions of these angles). This pose is constrained to the x-z plane 

+

26 by forcing shoulder flexion/extension, elbow rotation and supination/pronation 

+

27 to always be 0.  

+

28 ''' 

+

29 if np.ndim(pos) == 1: 

+

30 pos = pos.reshape(1,-1) 

+

31 

+

32 # require the y-coordinate to be 0, i.e. flat on the screen 

+

33 x, y, z = pos[:,0], pos[:,1], pos[:,2] 

+

34 assert np.all(y == 0) 

+

35 

+

36 if vel is not None: 

+

37 if np.ndim(vel) == 1: 

+

38 vel = vel.reshape(1,-1) 

+

39 assert pos.shape == vel.shape 

+

40 vx, vy, vz = vel[:,0], vel[:,1], vel[:,2] 

+

41 assert np.all(vy == 0) 

+

42 

+

43 L = np.sqrt(x**2 + z**2) 

+

44 cos_el_pflex = (L**2 - l_forearm**2 - l_upperarm**2) / (2*l_forearm*l_upperarm) 

+

45 

+

46 cos_el_pflex[ (cos_el_pflex > 1) & (cos_el_pflex < 1 + 1e-9)] = 1 

+

47 el_pflex = np.arccos(cos_el_pflex) 

+

48 

+

49 sh_pabd = np.arctan2(z, x) - np.arcsin(l_forearm * np.sin(np.pi - el_pflex) / L) 

+

50 angles = np.zeros(len(pos), dtype=joint_angles_dtype) 

+

51 angles['sh_pabd'] = sh_pabd 

+

52 angles['el_pflex'] = el_pflex 

+

53 if np.any(np.isnan(angles['el_pflex'])) or np.any(np.isnan(angles['sh_pabd'])): 

+

54 pass 

+

55 

+

56 if vel is not None: 

+

57 joint_vel = np.zeros(len(pos), dtype=joint_vel_dtype) 

+

58 

+

59 # Calculate the jacobian 

+

60 for k, angle in enumerate(angles): 

+

61 s1 = np.sin(angle['sh_pabd']) 

+

62 s2 = np.sin(angle['sh_pabd'] + angle['el_pflex']) 

+

63 c1 = np.cos(angle['sh_pabd']) 

+

64 c2 = np.cos(angle['sh_pabd'] + angle['el_pflex']) 

+

65 J = np.array([[-l_upperarm*s1-l_forearm*s2, -l_forearm*s2 ], 

+

66 [ l_upperarm*c1+l_forearm*c2, l_forearm*c2 ]]) 

+

67 J_inv = np.linalg.inv(J) 

+

68 

+

69 joint_vel_mat = np.dot(J_inv, vel[k, [0,2]]) 

+

70 joint_vel[k]['sh_vabd'], joint_vel[k]['el_vflex'] = joint_vel_mat.ravel() 

+

71 

+

72 return angles, joint_vel 

+

73 else: 

+

74 return angles 

+

75 

+

76def make_list(value, num_joints): 

+

77 ''' 

+

78 Helper function to allow joint/link properties of the chain to be specified 

+

79 as one value for all joints/links or as separate values for each 

+

80 ''' 

+

81 if isinstance(value, list) and len(value) == num_joints: 

+

82 return value 

+

83 else: 

+

84 return [value] * num_joints 

+

85 

+

86arm_color = (181/256., 116/256., 96/256., 1) 

+

87arm_radius = 0.6 

+

88 

+

89class Plant(object): 

+

90 def __init__(self, *args, **kwargs): 

+

91 super(Plant, self).__init__(*args, **kwargs) 

+

92 

+

93 def drive(self, decoder): 

+

94 self.set_intrinsic_coordinates(decoder['q']) 

+

95 intrinsic_coords = self.get_intrinsic_coordinates() 

+

96 if not np.any(np.isnan(intrinsic_coords)): 

+

97 decoder['q'] = self.get_intrinsic_coordinates() 

+

98 

+

99 

+

100class CursorPlant(Plant): 

+

101 def __init__(self, endpt_bounds=None, **kwargs): 

+

102 self.endpt_bounds = endpt_bounds 

+

103 self.position = np.array([0., 0., 0.]) #np.zeros(3) 

+

104 

+

105 def get_endpoint_pos(self): 

+

106 return self.position 

+

107 

+

108 def set_endpoint_pos(self, pt, **kwargs): 

+

109 self.position = pt 

+

110 

+

111 def get_intrinsic_coordinates(self): 

+

112 return self.position 

+

113 

+

114 def set_intrinsic_coordinates(self, pt): 

+

115 self.position = pt 

+

116 

+

117 def drive(self, decoder): 

+

118 pos = decoder['q'].copy() 

+

119 vel = decoder['qdot'].copy() 

+

120 

+

121 if self.endpt_bounds is not None: 

+

122 if pos[0] < self.endpt_bounds[0]: 

+

123 pos[0] = self.endpt_bounds[0] 

+

124 #vel[0] = 0 

+

125 if pos[0] > self.endpt_bounds[1]: 

+

126 pos[0] = self.endpt_bounds[1] 

+

127 #vel[0] = 0 

+

128 

+

129 if pos[1] < self.endpt_bounds[2]: 

+

130 pos[1] = self.endpt_bounds[2] 

+

131 #vel[1] = 0 

+

132 if pos[1] > self.endpt_bounds[3]: 

+

133 pos[1] = self.endpt_bounds[3] 

+

134 #vel[1] = 0 

+

135 

+

136 if pos[2] < self.endpt_bounds[4]: 

+

137 pos[2] = self.endpt_bounds[4] 

+

138 #vel[2] = 0 

+

139 if pos[2] > self.endpt_bounds[5]: 

+

140 pos[2] = self.endpt_bounds[5] 

+

141 #vel[2] = 0 

+

142 

+

143 decoder['q'] = pos 

+

144 decoder['qdot'] = vel 

+

145 super(CursorPlant, self).drive(decoder) 

+

146 

+

147 

+

148class RobotArmGen2D(Plant, Group): 

+

149 def __init__(self, link_radii=arm_radius, joint_radii=arm_radius, link_lengths=[15,15,5,5], joint_colors=arm_color, 

+

150 link_colors=arm_color, base_loc=np.array([2., 0., -15]), **kwargs): 

+

151 ''' 

+

152 Instantiate the graphics and the virtual arm for a planar kinematic chain 

+

153 ''' 

+

154 num_joints = len(link_lengths) 

+

155 self.num_joints = num_joints 

+

156 

+

157 self.link_radii = make_list(link_radii, num_joints) 

+

158 self.joint_radii = make_list(joint_radii, num_joints) 

+

159 self.link_lengths = make_list(link_lengths, num_joints) 

+

160 self.joint_colors = make_list(joint_colors, num_joints) 

+

161 self.link_colors = make_list(link_colors, num_joints) 

+

162 

+

163 self.curr_vecs = np.zeros([num_joints, 3]) #rows go from proximal to distal links 

+

164 

+

165 # set initial vecs to correct orientations (arm starts out vertical) 

+

166 self.curr_vecs[0,2] = self.link_lengths[0] 

+

167 self.curr_vecs[1:,0] = self.link_lengths[1:] 

+

168 

+

169 # Create links 

+

170 self.links = [] 

+

171 

+

172 for i in range(self.num_joints): 

+

173 joint = Sphere(radius=self.joint_radii[i], color=self.joint_colors[i]) 

+

174 

+

175 # The most distal link gets a tapered cylinder (for purely stylistic reasons) 

+

176 if i < self.num_joints - 1: 

+

177 link = Cylinder(radius=self.link_radii[i], height=self.link_lengths[i], color=self.link_colors[i]) 

+

178 else: 

+

179 link = Cone(radius1=self.link_radii[-1], radius2=self.link_radii[-1]/2, height=self.link_lengths[-1], color=self.link_colors[-1]) 

+

180 link_i = Group((link, joint)) 

+

181 self.links.append(link_i) 

+

182 

+

183 link_offsets = [0] + self.link_lengths[:-1] 

+

184 self.link_groups = [None]*self.num_joints 

+

185 for i in range(self.num_joints)[::-1]: 

+

186 if i == self.num_joints-1: 

+

187 self.link_groups[i] = self.links[i] 

+

188 else: 

+

189 self.link_groups[i] = Group([self.links[i], self.link_groups[i+1]]) 

+

190 

+

191 self.link_groups[i].translate(0, 0, link_offsets[i]) 

+

192 

+

193 # Call the parent constructor 

+

194 super(RobotArmGen2D, self).__init__([self.link_groups[0]], **kwargs) 

+

195 

+

196 # Instantiate the kinematic chain object 

+

197 if self.num_joints == 2: 

+

198 self.kin_chain = robot_arms.PlanarXZKinematicChain(link_lengths) 

+

199 self.kin_chain.joint_limits = [(-pi,pi), (-pi,0)] 

+

200 else: 

+

201 self.kin_chain = robot_arms.PlanarXZKinematicChain(link_lengths) 

+

202 

+

203 # TODO the code below is (obviously) specific to a 4-joint chain 

+

204 self.kin_chain.joint_limits = [(-pi,pi), (-pi,0), (-pi/2,pi/2), (-pi/2, 10*pi/180)] 

+

205 

+

206 self.base_loc = base_loc 

+

207 self.translate(*self.base_loc, reset=True) 

+

208 

+

209 def _update_links(self): 

+

210 for i in range(0, self.num_joints): 

+

211 # Rotate each joint to the vector specified by the corresponding row in self.curr_vecs 

+

212 # Annoyingly, the baseline orientation of the first group is always different from the  

+

213 # more distal attachments, so the rotations have to be found relative to the orientation  

+

214 # established at instantiation time. 

+

215 if i == 0: 

+

216 baseline_orientation = (0, 0, 1) 

+

217 else: 

+

218 baseline_orientation = (1, 0, 0) 

+

219 

+

220 # Find the normalized quaternion that represents the desired joint rotation 

+

221 self.link_groups[i].xfm.rotate = Quaternion.rotate_vecs(baseline_orientation, self.curr_vecs[i]).norm() 

+

222 

+

223 # Recompute any cached transformations after the change 

+

224 self.link_groups[i]._recache_xfm() 

+

225 

+

226 def get_endpoint_pos(self): 

+

227 ''' 

+

228 Returns the current position of the non-anchored end of the arm. 

+

229 ''' 

+

230 relangs = np.arctan2(self.curr_vecs[:,2], self.curr_vecs[:,0]) 

+

231 return self.perform_fk(relangs) + self.base_loc 

+

232 

+

233 def perform_fk(self, angs): 

+

234 absvecs = np.zeros(self.curr_vecs.shape) 

+

235 for i in range(self.num_joints): 

+

236 absvecs[i] = self.link_lengths[i]*np.array([np.cos(np.sum(angs[:i+1])), 0, np.sin(np.sum(angs[:i+1]))]) 

+

237 return np.sum(absvecs,axis=0) 

+

238 

+

239 def set_endpoint_pos(self, pos, **kwargs): 

+

240 ''' 

+

241 Positions the arm according to specified endpoint position.  

+

242 ''' 

+

243 if pos is not None: 

+

244 # Run the inverse kinematics 

+

245 angles = self.perform_ik(pos, **kwargs) 

+

246 

+

247 # Update the joint configuration  

+

248 self.set_intrinsic_coordinates(angles) 

+

249 

+

250 def perform_ik(self, pos, **kwargs): 

+

251 angles = self.kin_chain.inverse_kinematics(pos - self.base_loc, q_start=-self.get_intrinsic_coordinates(), verbose=False, eps=0.008, **kwargs) 

+

252 # print self.kin_chain.endpoint_pos(angles) 

+

253 

+

254 # Negate the angles. The convention in the robotics library is  

+

255 # inverted, i.e. in the robotics library, positive is clockwise  

+

256 # rotation whereas here CCW rotation is positive.  

+

257 angles = -angles 

+

258 return angles 

+

259 

+

260 def calc_joint_angles(self, vecs): 

+

261 return np.arctan2(vecs[:,2], vecs[:,0]) 

+

262 

+

263 def get_intrinsic_coordinates(self): 

+

264 ''' 

+

265 Returns the joint angles of the arm in radians 

+

266 ''' 

+

267 

+

268 return self.calc_joint_angles(self.curr_vecs) 

+

269 

+

270 def set_intrinsic_coordinates(self,theta): 

+

271 ''' 

+

272 Set the joint by specifying the angle in radians. Theta is a list of angles. If an element of theta = NaN, angle should remain the same. 

+

273 ''' 

+

274 for i in range(self.num_joints): 

+

275 if theta[i] is not None and ~np.isnan(theta[i]): 

+

276 self.curr_vecs[i] = self.link_lengths[i]*np.array([np.cos(theta[i]), 0, np.sin(theta[i])]) 

+

277 

+

278 self._update_links() 

+

279 

+

280 

+

281class RobotArmGen3D(Plant, Group): 

+

282 def __init__(self, link_radii=arm_radius, joint_radii=arm_radius, link_lengths=[15,15,5,5], joint_colors=arm_color, 

+

283 link_colors=arm_color, base_loc=np.array([2., 0., -15]), **kwargs): 

+

284 ''' 

+

285 Instantiate the graphics and the virtual arm for a kinematic chain 

+

286 ''' 

+

287 num_joints = 2 

+

288 self.num_joints = 2 

+

289 

+

290 self.link_radii = make_list(link_radii, num_joints) 

+

291 self.joint_radii = make_list(joint_radii, num_joints) 

+

292 self.link_lengths = make_list(link_lengths, num_joints) 

+

293 self.joint_colors = make_list(joint_colors, num_joints) 

+

294 self.link_colors = make_list(link_colors, num_joints) 

+

295 

+

296 self.curr_vecs = np.zeros([num_joints, 3]) #rows go from proximal to distal links 

+

297 

+

298 # set initial vecs to correct orientations (arm starts out vertical) 

+

299 self.curr_vecs[0,2] = self.link_lengths[0] 

+

300 self.curr_vecs[1:,0] = self.link_lengths[1:] 

+

301 

+

302 # Create links 

+

303 self.links = [] 

+

304 

+

305 for i in range(self.num_joints): 

+

306 joint = Sphere(radius=self.joint_radii[i], color=self.joint_colors[i]) 

+

307 

+

308 # The most distal link gets a tapered cylinder (for purely stylistic reasons) 

+

309 if i < self.num_joints - 1: 

+

310 link = Cylinder(radius=self.link_radii[i], height=self.link_lengths[i], color=self.link_colors[i]) 

+

311 else: 

+

312 link = Cone(radius1=self.link_radii[-1], radius2=self.link_radii[-1]/2, height=self.link_lengths[-1], color=self.link_colors[-1]) 

+

313 link_i = Group((link, joint)) 

+

314 self.links.append(link_i) 

+

315 

+

316 link_offsets = [0] + self.link_lengths[:-1] 

+

317 self.link_groups = [None]*self.num_joints 

+

318 for i in range(self.num_joints)[::-1]: 

+

319 if i == self.num_joints-1: 

+

320 self.link_groups[i] = self.links[i] 

+

321 else: 

+

322 self.link_groups[i] = Group([self.links[i], self.link_groups[i+1]]) 

+

323 

+

324 self.link_groups[i].translate(0, 0, link_offsets[i]) 

+

325 

+

326 # Call the parent constructor 

+

327 super(RobotArmGen3D, self).__init__([self.link_groups[0]], **kwargs) 

+

328 

+

329 # Instantiate the kinematic chain object 

+

330 if self.num_joints == 2: 

+

331 self.kin_chain = robot_arms.PlanarXZKinematicChain(link_lengths) 

+

332 self.kin_chain.joint_limits = [(-pi,pi), (-pi,0)] 

+

333 else: 

+

334 self.kin_chain = robot_arms.PlanarXZKinematicChain(link_lengths) 

+

335 

+

336 # TODO the code below is (obviously) specific to a 4-joint chain 

+

337 self.kin_chain.joint_limits = [(-pi,pi), (-pi,0), (-pi/2,pi/2), (-pi/2, 10*pi/180)] 

+

338 

+

339 self.base_loc = base_loc 

+

340 self.translate(*self.base_loc, reset=True) 

+

341 

+

342 def _update_links(self): 

+

343 for i in range(0, self.num_joints): 

+

344 # Rotate each joint to the vector specified by the corresponding row in self.curr_vecs 

+

345 # Annoyingly, the baseline orientation of the first group is always different from the  

+

346 # more distal attachments, so the rotations have to be found relative to the orientation  

+

347 # established at instantiation time. 

+

348 if i == 0: 

+

349 baseline_orientation = (0, 0, 1) 

+

350 else: 

+

351 baseline_orientation = (1, 0, 0) 

+

352 

+

353 # Find the normalized quaternion that represents the desired joint rotation 

+

354 self.link_groups[i].xfm.rotate = Quaternion.rotate_vecs(baseline_orientation, self.curr_vecs[i]).norm() 

+

355 

+

356 # Recompute any cached transformations after the change 

+

357 self.link_groups[i]._recache_xfm() 

+

358 

+

359 def get_endpoint_pos(self): 

+

360 ''' 

+

361 Returns the current position of the non-anchored end of the arm. 

+

362 ''' 

+

363 relangs_xz = np.arctan2(self.curr_vecs[:,2], self.curr_vecs[:,0]) 

+

364 relangs_xy = np.arctan2(self.curr_vecs[:,1], self.curr_vecs[:,0]) 

+

365 return self.perform_fk(relangs_xz, relangs_xy) + self.base_loc 

+

366 

+

367 def perform_fk(self, angs_xz, angs_xy): 

+

368 absvecs = np.zeros(self.curr_vecs.shape) 

+

369 for i in range(self.num_joints): 

+

370 absvecs[i] = self.link_lengths[i]*np.array([np.cos(np.sum(angs_xz[:i+1])), np.sin(np.sum(angs_xy[:i+1])), np.sin(np.sum(angs_xz[:i+1]))]) 

+

371 return np.sum(absvecs,axis=0) 

+

372 

+

373 def set_endpoint_pos(self, pos, **kwargs): 

+

374 ''' 

+

375 Positions the arm according to specified endpoint position.  

+

376 ''' 

+

377 if pos is not None: 

+

378 # Run the inverse kinematics 

+

379 angles = self.perform_ik(pos, **kwargs) 

+

380 

+

381 # Update the joint configuration  

+

382 self.set_intrinsic_coordinates(angles) 

+

383 

+

384 

+

385 def perform_ik(self, pos, **kwargs): 

+

386 angles = self.kin_chain.inverse_kinematics(pos - self.base_loc, q_start=-self.get_intrinsic_coordinates(), verbose=False, eps=0.008, **kwargs) 

+

387 # print self.kin_chain.endpoint_pos(angles) 

+

388 

+

389 # Negate the angles. The convention in the robotics library is  

+

390 # inverted, i.e. in the robotics library, positive is clockwise  

+

391 # rotation whereas here CCW rotation is positive.  

+

392 angles = -angles 

+

393 

+

394 

+

395 

+

396 

+

397 '''Sets the endpoint coordinate for the two-joint system''' 

+

398 #Make sure the target is actually achievable 

+

399 if np.linalg.norm(pos) > self.tlen: 

+

400 self.upperarm.xfm.rotate = Quaternion.rotate_vecs((0,0,1), target).norm() 

+

401 self.forearm.xfm.rotate = Quaternion() 

+

402 else: 

+

403 elbow = np.array(self._midpos(target)) 

+

404 

+

405 #rotate the upperarm to the elbow 

+

406 self.upperarm.xfm.rotate = Quaternion.rotate_vecs((0,0,1), elbow).norm() 

+

407 

+

408 #this broke my mind for 2 hours at least, so I cheated 

+

409 #Rotate first to (0,0,1), then rotate to the target-elbow 

+

410 self.forearm.xfm.rotate = (Quaternion.rotate_vecs(elbow, (0,0,1)) * 

+

411 Quaternion.rotate_vecs((0,0,1), target-elbow)).norm() 

+

412 

+

413 self.upperarm._recache_xfm() 

+

414 # print self.upperarm.xfm.rotate 

+

415 self.curr_vecs[0] = self.lengths[0]*self.upperarm.xfm.rotate.quat[1:] 

+

416 self.curr_vecs[1] = self.lengths[1]*self.forearm.xfm.rotate.quat[1:] 

+

417 print(self.forearm.xfm) 

+

418 print(self.upperarm.xfm) 

+

419 # raise NotImplementedError("update curr_vecs!") 

+

420 

+

421 

+

422 return angles 

+

423 

+

424 def calc_joint_angles(self, vecs): 

+

425 return np.arctan2(vecs[:,2], vecs[:,0]), np.arctan2(vecs[:,1], vecs[:,0]) 

+

426 

+

427 def get_intrinsic_coordinates(self): 

+

428 ''' 

+

429 Returns the joint angles of the arm in radians 

+

430 ''' 

+

431 

+

432 return self.calc_joint_angles(self.curr_vecs) 

+

433 

+

434 def set_intrinsic_coordinates(self,theta_xz, theta_xy): 

+

435 ''' 

+

436 Set the joint by specifying the angle in radians. Theta is a list of angles. If an element of theta = NaN, angle should remain the same. 

+

437 ''' 

+

438 for i in range(self.num_joints): 

+

439 if theta_xz[i] is not None and ~np.isnan(theta_xz[i]) and theta_xy[i] is not None and ~np.isnan(theta_xy[i]): 

+

440 self.curr_vecs[i] = self.link_lengths[i]*np.array([np.cos(theta_xz[i]), np.sin(theta_xy[i]), np.sin(theta_xz[i])]) 

+

441 

+

442 self._update_links() 

+

443 

+

444 def _midpos(self, target): 

+

445 m, n = self.link_lengths 

+

446 x, y, z = target 

+

447 

+

448 #this heinous equation brought to you by Wolfram Alpha 

+

449 #it is ONE of the solutions to this system of equations: 

+

450 # a^2 + b^2 + (z/2)^2 = m^2 

+

451 # (x-a)^2 + (y-b)^2 + (z/2)^2 = n^2 

+

452 if x > 0: 

+

453 a = (m**2*x**2+y*np.sqrt(-x**2*(m**4-2*m**2*n**2-2*m**2*x**2-2*m**2*y**2+n**4-2*n**2*x**2-2*n**2*y**2+x**4+2*x**2*y**2+x**2*z**2+y**4+y**2*z**2))-n**2*x**2+x**4+x**2*y**2)/(2*x*(x**2+y**2)) 

+

454 b = (m**2*y-np.sqrt(-x**2*(m**4-2*m**2*n**2-2*m**2*x**2-2*m**2*y**2+n**4-2*n**2*x**2-2*n**2*y**2+x**4+2*x**2*y**2+x**2*z**2+y**4+y**2*z**2))-n**2*y+x**2*y+y**3)/(2*(x**2+y**2)) 

+

455 else: 

+

456 a = (m**2*x**2-y*np.sqrt(-x**2*(m**4-2*m**2*n**2-2*m**2*x**2-2*m**2*y**2+n**4-2*n**2*x**2-2*n**2*y**2+x**4+2*x**2*y**2+x**2*z**2+y**4+y**2*z**2))-n**2*x**2+x**4+x**2*y**2)/(2*x*(x**2+y**2)) 

+

457 b = (m**2*y+np.sqrt(-x**2*(m**4-2*m**2*n**2-2*m**2*x**2-2*m**2*y**2+n**4-2*n**2*x**2-2*n**2*y**2+x**4+2*x**2*y**2+x**2*z**2+y**4+y**2*z**2))-n**2*y+x**2*y+y**3)/(2*(x**2+y**2)) 

+

458 return a, b, z/2 

+

459 

+

460 

+

461class RobotArm2J2D(RobotArmGen2D): 

+

462 def drive(self, decoder): 

+

463 raise NotImplementedError("deal with the state bounding stuff!") 

+

464 # elif self.decoder.ssm == train.joint_2D_state_space: 

+

465 # self.set_arm_joints(self.decoder['sh_pabd', 'el_pflex']) 

+

466 

+

467 # # Force the arm to a joint configuration where the cursor is on-screen 

+

468 # pos = self.get_arm_endpoint() 

+

469 # pos = self.apply_cursor_bounds(pos) 

+

470 # self.set_arm_endpoint(pos) 

+

471 

+

472 # # Reset the decoder state to match the joint configuration of the arm 

+

473 # joint_pos = self.get_arm_joints() 

+

474 # self.decoder['sh_pabd', 'el_pflex'] = joint_pos 

+

475 

+

476class TwoJoint(object): 

+

477 ''' 

+

478 Models a two-joint IK system (arm, leg, etc). Constrains the system by  

+

479 always having middle joint halfway between the origin and target 

+

480 ''' 

+

481 def __init__(self, origin_bone, target_bone, lengths=(20,20)): 

+

482 '''Takes two Model objects for the "upperarm" bone and the "lowerarm" bone. 

+

483 Assumes the models start at origin, with vector to (0,0,1) for bones''' 

+

484 self.upperarm = origin_bone 

+

485 self.forearm = target_bone 

+

486 self.lengths = lengths 

+

487 self.tlen = lengths[0] + lengths[1] 

+

488 self.curr_vecs = np.zeros([2,3]) 

+

489 self.curr_angles = np.zeros(2) 

+

490 

+

491 def _midpos(self, target): 

+

492 m, n = self.lengths 

+

493 x, y, z = target 

+

494 

+

495 #this heinous equation brought to you by Wolfram Alpha 

+

496 #it is ONE of the solutions to this system of equations: 

+

497 # a^2 + b^2 + (z/2)^2 = m^2 

+

498 # (x-a)^2 + (y-b)^2 + (z/2)^2 = n^2 

+

499 if x > 0: 

+

500 a = (m**2*x**2+y*np.sqrt(-x**2*(m**4-2*m**2*n**2-2*m**2*x**2-2*m**2*y**2+n**4-2*n**2*x**2-2*n**2*y**2+x**4+2*x**2*y**2+x**2*z**2+y**4+y**2*z**2))-n**2*x**2+x**4+x**2*y**2)/(2*x*(x**2+y**2)) 

+

501 b = (m**2*y-np.sqrt(-x**2*(m**4-2*m**2*n**2-2*m**2*x**2-2*m**2*y**2+n**4-2*n**2*x**2-2*n**2*y**2+x**4+2*x**2*y**2+x**2*z**2+y**4+y**2*z**2))-n**2*y+x**2*y+y**3)/(2*(x**2+y**2)) 

+

502 else: 

+

503 a = (m**2*x**2-y*np.sqrt(-x**2*(m**4-2*m**2*n**2-2*m**2*x**2-2*m**2*y**2+n**4-2*n**2*x**2-2*n**2*y**2+x**4+2*x**2*y**2+x**2*z**2+y**4+y**2*z**2))-n**2*x**2+x**4+x**2*y**2)/(2*x*(x**2+y**2)) 

+

504 b = (m**2*y+np.sqrt(-x**2*(m**4-2*m**2*n**2-2*m**2*x**2-2*m**2*y**2+n**4-2*n**2*x**2-2*n**2*y**2+x**4+2*x**2*y**2+x**2*z**2+y**4+y**2*z**2))-n**2*y+x**2*y+y**3)/(2*(x**2+y**2)) 

+

505 return a, b, z/2 

+

506 

+

507 def set_endpoint_3D(self, target): 

+

508 '''Sets the endpoint coordinate for the two-joint system''' 

+

509 #Make sure the target is actually achievable 

+

510 if np.linalg.norm(target) > self.tlen: 

+

511 self.upperarm.xfm.rotate = Quaternion.rotate_vecs((0,0,1), target).norm() 

+

512 self.forearm.xfm.rotate = Quaternion() 

+

513 else: 

+

514 elbow = np.array(self._midpos(target)) 

+

515 

+

516 #rotate the upperarm to the elbow 

+

517 self.upperarm.xfm.rotate = Quaternion.rotate_vecs((0,0,1), elbow).norm() 

+

518 

+

519 #this broke my mind for 2 hours at least, so I cheated 

+

520 #Rotate first to (0,0,1), then rotate to the target-elbow 

+

521 self.forearm.xfm.rotate = (Quaternion.rotate_vecs(elbow, (0,0,1)) * 

+

522 Quaternion.rotate_vecs((0,0,1), target-elbow)).norm() 

+

523 

+

524 self.upperarm._recache_xfm() 

+

525 # print self.upperarm.xfm.rotate 

+

526 upperarm_affine_xform = self.upperarm.xfm.rotate.to_mat() 

+

527 forearm_affine_xform = (self.upperarm.xfm * self.forearm.xfm).rotate.to_mat() 

+

528 # print np.dot(upperarm_affine_xform, np.array([0., 0, self.lengths[0], 1])) 

+

529 self.curr_vecs[0] = np.dot(upperarm_affine_xform, np.array([0., 0, self.lengths[0], 1]))[:-1]#self.lengths[0]*self.upperarm.xfm.rotate.quat[1:] 

+

530 self.curr_vecs[1] = np.dot(forearm_affine_xform, np.array([0, 0, self.lengths[1], 1]))[:-1]#self.lengths[1]*self.forearm.xfm.rotate.quat[1:] 

+

531 # print self.forearm.xfm 

+

532 # print self.upperarm.xfm 

+

533 # raise NotImplementedError("update curr_vecs!") 

+

534 

+

535 def set_endpoint_2D(self, target): 

+

536 ''' Given an endpoint coordinate in the x-z plane, solves for joint positions in that plane via inverse kinematics''' 

+

537 pass 

+

538 

+

539 def set_joints_2D(self, shoulder_angle, elbow_angle): 

+

540 ''' Given angles for shoulder and elbow in a plane, set joint positions. Shoulder angle is in fixed 

+

541 frame of reference where 0 is horizontal pointing to the right (left if viewing on screen without mirror), pi/2 

+

542 is vertical pointing up, pi is horizontal pointing to the left. Elbow angle is relative to upper arm vector, where 

+

543 0 is fully extended, pi/2 is a right angle to upper arm pointing left, and pi is fully overlapping with upper 

+

544 arm.''' 

+

545 

+

546 elbow_angle_mod = elbow_angle + np.pi/2 

+

547 

+

548 #if shoulder_angle>np.pi: shoulder_angle = np.pi 

+

549 #if shoulder_angle<0.0: shoulder_angle = 0.0 

+

550 #if elbow_angle>np.pi: elbow_angle = np.pi 

+

551 #if elbow_angle<0: elbow_angle = 0 

+

552 

+

553 # Find upper arm vector 

+

554 xs = self.lengths[0]*np.cos(shoulder_angle) 

+

555 ys = 0.0 

+

556 zs = self.lengths[0]*np.sin(shoulder_angle) 

+

557 self.curr_vecs[0,:] = np.array([xs, ys, zs]) 

+

558 self.upperarm.xfm.rotate = Quaternion.rotate_vecs((0,0,1), (xs,0,zs)).norm() 

+

559 

+

560 # Find forearm vector (relative to upper arm) 

+

561 xe = self.lengths[1]*np.cos(elbow_angle_mod) 

+

562 ye = 0.0 

+

563 ze = self.lengths[1]*np.sin(elbow_angle_mod) 

+

564 # Find absolute vector 

+

565 xe2 = self.lengths[1]*np.cos(shoulder_angle+elbow_angle) 

+

566 ye2 = 0.0 

+

567 ze2 = self.lengths[1]*np.sin(shoulder_angle+elbow_angle) 

+

568 self.curr_vecs[1,:] = np.array([xe2, ye2, ze2]) 

+

569 self.forearm.xfm.rotate = Quaternion.rotate_vecs((0,0,1), (xe,0,ze)).norm() 

+

570 

+

571 self.curr_angles[0] = shoulder_angle 

+

572 self.curr_angles[1] = elbow_angle_mod 

+

573 

+

574 self.upperarm._recache_xfm() 

+

575 

+

576 

+

577 # cursor_color = (.5,0,.5,1) 

+

578 # cursor_radius = 0.4 

+

579 # self.endpt_cursor = Sphere(radius=cursor_radius, color=cursor_color) 

+

580 # self.endpt_cursor.translate(0, 0, lengths[1]) 

+

581 # self.forearm = Group([ 

+

582 # Cylinder(radius=link_radii[1], height=lengths[1], color=link_colors[1]),  

+

583 # self.endpt_cursor]) 

+

584 # self.forearm.translate(0,0,lengths[0]) 

+

585 

+

586 # self.upperarm = Group([ 

+

587 # Cylinder(radius=link_radii[0], height=lengths[0],color=link_colors[0]),  

+

588 # Sphere(radius=ball_radii[0],color=ball_colors[0]).translate(0, 0, lengths[0]), 

+

589 # self.forearm]) 

+

590 # self.system = TwoJoint(self.upperarm, self.forearm, lengths = (self.lengths)) 

+

591 

+

592class RobotArm(Plant, Group): 

+

593 def __init__(self, link_radii=(.2, .2), ball_radii=(.5,.5),lengths=(20, 20), ball_colors = ((1,1,1,1),(1,1,1,1)),\ 

+

594 link_colors = ((1,1,1,1), (1,1,1,1)), base_loc=np.array([2., 0., -10.]), **kwargs): 

+

595 self.link_radii = link_radii 

+

596 self.ball_radii = ball_radii 

+

597 self.lengths = lengths 

+

598 

+

599 self.endpt_cursor = Sphere(radius=ball_radii[1], color=(1, 0, 1, 1)) #ball_colors[1]) 

+

600 self.forearm = Group([ 

+

601 Cylinder(radius=link_radii[1], height=lengths[1], color=link_colors[1]), 

+

602 self.endpt_cursor.translate(0, 0, lengths[1])]).translate(0,0,lengths[0]) 

+

603 self.upperarm = Group([ 

+

604 Cylinder(radius=link_radii[0], height=lengths[0],color=link_colors[0]), 

+

605 Sphere(radius=ball_radii[0],color=ball_colors[0]).translate(0, 0, lengths[0]), 

+

606 self.forearm]) 

+

607 self.system = TwoJoint(self.upperarm, self.forearm, lengths = (self.lengths)) 

+

608 super(RobotArm, self).__init__([self.upperarm], **kwargs) 

+

609 

+

610 self.num_links = len(link_radii) 

+

611 self.num_joints = 3 # abstract joints. this system is fully characterized by the endpoint position since the elbow angle is determined by IK 

+

612 

+

613 self.base_loc = base_loc 

+

614 

+

615 self.translate(*self.base_loc, reset=True) 

+

616 

+

617 def get_endpoint_pos(self): 

+

618 # print 'curr_vecs', self.system.curr_vecs 

+

619 # print 

+

620 return np.sum(self.system.curr_vecs, axis=0) + self.base_loc 

+

621 

+

622 def set_endpoint_pos(self, pos, **kwargs): 

+

623 self.system.set_endpoint_3D(pos - self.base_loc) 

+

624 

+

625 def get_intrinsic_coordinates(self): 

+

626 return self.get_endpoint_pos() 

+

627 

+

628 def set_intrinsic_coordinates(self, pos): 

+

629 self.set_endpoint_pos(pos) 

+

630 

+

631 # def drive(self, decoder): 

+

632 # self.set_intrinsic_coordinates(decoder['q']) 

+

633 # intrinsic_coords = self.get_intrinsic_coordinates() 

+

634 # if not np.any(np.isnan(intrinsic_coords)): 

+

635 # decoder['q'] = self.get_intrinsic_coordinates() 

+

636 # print 'arm pos', self.get_endpoint_pos() 

+

637 

+

638 # def set_endpoint_2D(self, target): 

+

639 # self.system.set_endpoint_2D(target) 

+

640 

+

641 # def set_joints_2D(self, shoulder_angle, elbow_angle):  

+

642 # self.system.set_joints_2D(shoulder_angle, elbow_angle) 

+

643 

+

644 # def get_hand_location(self, shoulder_anchor): 

+

645 # ''' returns position of ball at end of forearm (hand)''' 

+

646 # return shoulder_anchor + self.system.curr_vecs[0] +self.system.curr_vecs[1] 

+

647 

+

648 # def get_joint_angles_2D(self): 

+

649 # return self.system.curr_angles[0], self.system.curr_angles[1] - np.pi/2 

+

650 

+

651 

+

652 

+

653 

+

654 

+

655# cursor_bounds = traits.Tuple((-25, 25, 0, 0, -14, 14), "Boundaries for where the cursor can travel on the screen") 

+

656 

+

657chain_kwargs = dict(link_radii=.6, joint_radii=0.6, joint_colors=(181/256., 116/256., 96/256., 1), link_colors=(181/256., 116/256., 96/256., 1)) 

+

658 

+

659shoulder_anchor = np.array([2., 0., -15]) 

+

660 

+

661chain_15_15_5_5 = RobotArmGen2D(link_lengths=[15, 15, 5, 5], base_loc=shoulder_anchor, **chain_kwargs) 

+

662init_joint_pos = np.array([ 0.47515737, 1.1369006 , 1.57079633, 0.29316668]) ## center pos coordinates: 0.63017, 1.38427, 1.69177, 0.42104  

+

663chain_15_15_5_5.set_intrinsic_coordinates(init_joint_pos) 

+

664 

+

665chain_20_20 = RobotArm2J2D(link_lengths=[20, 20], base_loc=shoulder_anchor, **chain_kwargs) 

+

666starting_pos = np.array([5., 0., 5]) 

+

667chain_20_20.set_endpoint_pos(starting_pos - shoulder_anchor, n_iter=10, n_particles=500) 

+

668chain_20_20.set_endpoint_pos(starting_pos, n_iter=10, n_particles=500) 

+

669 

+

670cursor = CursorPlant(endpt_bounds=(-14, 14, 0., 0., -14, 14)) 

+

671#cursor = CursorPlant(endpt_bounds=(-10, 10, 0., 0., -10, 10)) 

+

672#cursor = CursorPlant(endpt_bounds=(-9.5, 9.5, 0., 0., -7.5, 11.5)) 

+

673#cursor = CursorPlant(endpt_bounds=(-11, 11., 0., 0., -11., 11.)) 

+

674#cursor = CursorPlant(endpt_bounds=(-10, 10., 0., 0., -10., 10.)) 

+

675#cursor = CursorPlant(endpt_bounds=(-24., 24., 0., 0., -14., 14.)) 

+

676 

+

677arm_3d = RobotArm() 

+

678 

+

679plants = dict(RobotArmGen2D=chain_15_15_5_5, 

+

680 RobotArm2J2D=chain_20_20, 

+

681 CursorPlant=cursor, 

+

682 Arm3D=arm_3d) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_models_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_models_py.html new file mode 100644 index 00000000..1a25f91a --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_models_py.html @@ -0,0 +1,367 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\models.py: 38% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Collections of shapes to render on screen 

+

3''' 

+

4 

+

5import numpy as np 

+

6from OpenGL.GL import * 

+

7from OpenGL import GLUT as glut 

+

8 

+

9from .xfm import Transform 

+

10 

+

11class Model(object): 

+

12 def __init__(self, shader="default", color=(0.5, 0.5, 0.5, 1), 

+

13 shininess=10, specular_color=(1.,1.,1.,1.)): 

+

14 ''' 

+

15 Docstring Constructor for Model 

+

16 

+

17 Parameters 

+

18 ---------- 

+

19 shader: string 

+

20 OpenGL shading method? 

+

21 color: tuple of length 4 

+

22 (r, g, b, translucence) 

+

23 shininess: float 

+

24 ??? 

+

25 specular_color: tuple of length 4 

+

26 color of the light source??? 

+

27 ''' 

+

28 self.shader = shader 

+

29 self.parent = None 

+

30 

+

31 # This is different from self.xfm = Transform() for children of Model implemented with multiple inheritance 

+

32 super(Model, self).__setattr__("xfm", Transform()) 

+

33 self.color = color 

+

34 self.shininess = shininess 

+

35 self.spec_color = specular_color 

+

36 

+

37 # The orientation of the object, in the world frame 

+

38 self._xfm = self.xfm 

+

39 self.allocated = False 

+

40 

+

41 def __setattr__(self, attr, xfm): 

+

42 '''Checks if the xfm was changed, and recaches the _xfm which is sent to the shader''' 

+

43 val = super(Model, self).__setattr__(attr, xfm) 

+

44 if attr == "xfm": 

+

45 self._recache_xfm() 

+

46 

+

47 return val 

+

48 

+

49 def _recache_xfm(self): 

+

50 ''' 

+

51 For models with a parent, the transform of the current model must be cascaded with the parent model's transform. 

+

52 NOTE: this only goes one level up the graphics tree, so the transform is  

+

53 always with respect to the parent's frame, not with respect to the world frame! 

+

54 ''' 

+

55 if self.parent is not None: 

+

56 self._xfm = self.parent._xfm * self.xfm 

+

57 else: 

+

58 self._xfm = self.xfm 

+

59 

+

60 def init(self): 

+

61 allocated = self.allocated 

+

62 self.allocated = True 

+

63 return allocated 

+

64 

+

65 def rotate_x(self, deg, reset=False): 

+

66 self.xfm.rotate_x(np.radians(deg), reset=reset) 

+

67 self._recache_xfm() 

+

68 return self 

+

69 

+

70 def rotate_y(self, deg, reset=False): 

+

71 self.xfm.rotate_y(np.radians(deg), reset=reset) 

+

72 self._recache_xfm() 

+

73 return self 

+

74 

+

75 def rotate_z(self, deg, reset=False): 

+

76 self.xfm.rotate_z(np.radians(deg), reset=reset) 

+

77 self._recache_xfm() 

+

78 return self 

+

79 

+

80 def translate(self, x, y, z, reset=False): 

+

81 self.xfm.translate(x,y,z, reset=reset) 

+

82 self._recache_xfm() 

+

83 return self 

+

84 

+

85 def render_queue(self, shader=None): 

+

86 '''Yields the shader, texture, and the partial drawfunc for queueing''' 

+

87 if shader is not None: 

+

88 yield shader, self.draw, None 

+

89 else: 

+

90 yield self.shader, self.draw, None 

+

91 

+

92 def draw(self, ctx, **kwargs): 

+

93 ''' 

+

94 Parameters 

+

95 ---------- 

+

96 ctx: ?????? 

+

97 kwargs: optional keyword arguments 

+

98 Can specify 'color', 'specular_color', or 'shininess' of the object to draw (overriding the model's attributes) 

+

99 

+

100 Returns: None 

+

101 ''' 

+

102 

+

103 glUniformMatrix4fv(ctx.uniforms.xfm, 1, GL_TRUE, self._xfm.to_mat().astype(np.float32)) 

+

104 glUniform4f(ctx.uniforms.basecolor, *kwargs.pop('color', self.color)) 

+

105 glUniform4f(ctx.uniforms.spec_color, *kwargs.pop('specular_color', self.spec_color)) 

+

106 glUniform1f(ctx.uniforms.shininess, kwargs.pop('shininess', self.shininess)) 

+

107 

+

108 # glUniform4f(ctx.uniforms.basecolor, *(self.color if "color" not in kwargs else kwargs['color'])) 

+

109 # glUniform4f(ctx.uniforms.spec_color, *(self.spec_color if "specular_color" not in kwargs else kwargs['spec_color'])) 

+

110 # glUniform1f(ctx.uniforms.shininess, self.shininess if "shininess" not in kwargs else kwargs['shininess'])  

+

111 

+

112 def attach(self): 

+

113 assert self.parent is not None 

+

114 while self not in self.parent.models: 

+

115 self.parent.models.append(self) 

+

116 

+

117 def detach(self): 

+

118 assert self.parent is not None 

+

119 while self in self.parent.models: 

+

120 self.parent.models.remove(self) 

+

121 

+

122 

+

123class Group(Model): 

+

124 def __init__(self, models=()): 

+

125 super(Group, self).__init__() 

+

126 self.models = [] 

+

127 for model in models: 

+

128 self.add(model) 

+

129 

+

130 def add(self, model): 

+

131 self.models.append(model) 

+

132 model.parent = self 

+

133 model._recache_xfm() 

+

134 

+

135 def init(self): 

+

136 for model in self.models: 

+

137 model.init() 

+

138 

+

139 def render_queue(self, xfm=np.eye(4), **kwargs): 

+

140 for model in self.models: 

+

141 for out in model.render_queue(**kwargs): 

+

142 yield out 

+

143 

+

144 def draw(self, ctx, **kwargs): 

+

145 for model in self.models: 

+

146 model.draw(ctx, **kwargs) 

+

147 

+

148 def __getitem__(self, idx): 

+

149 return self.models[idx] 

+

150 

+

151 def _recache_xfm(self): 

+

152 super(Group, self)._recache_xfm() 

+

153 for model in self.models: 

+

154 model._recache_xfm() 

+

155 

+

156 

+

157builtins = dict([ (n[9:].lower(), getattr(glut, n)) 

+

158 for n in dir(glut) 

+

159 if "glutSolid" in n]) 

+

160class Builtins(Model): 

+

161 def __init__(self, model, shader="fixedfunc", *args): 

+

162 super(Builtins, self).__init__(xfm) 

+

163 assert model in builtins 

+

164 self.model = builtins['model'] 

+

165 self.args = args 

+

166 

+

167 def draw(self, ctx): 

+

168 glPushMatrix() 

+

169 glLoadMatrixf(np.dot(xfm, self.xfm).ravel()) 

+

170 self.model(*self.args) 

+

171 glPopMatrix() 

+

172 

+

173class TriMesh(Model): 

+

174 '''Basic triangle mesh model. Houses the GL functions for making buffers and displaying triangles''' 

+

175 def __init__(self, verts, polys, normals=None, tcoords=None, **kwargs): 

+

176 super(TriMesh, self).__init__(**kwargs) 

+

177 if verts.shape[1] == 3: 

+

178 verts = np.hstack([verts, np.ones((len(verts),1))]) 

+

179 if normals.shape[1] == 3: 

+

180 normals = np.hstack([normals, np.ones((len(normals),1))]) 

+

181 

+

182 self.verts = verts 

+

183 self.polys = polys 

+

184 self.tcoords = tcoords 

+

185 self.normals = normals 

+

186 

+

187 def init(self): 

+

188 allocated = super(TriMesh, self).init() 

+

189 if not allocated: 

+

190 self.vbuf = glGenBuffers(1) 

+

191 self.ebuf = glGenBuffers(1) 

+

192 glBindBuffer(GL_ARRAY_BUFFER, self.vbuf) 

+

193 glBufferData(GL_ARRAY_BUFFER, 

+

194 self.verts.astype(np.float32).ravel(), GL_STATIC_DRAW) 

+

195 glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, self.ebuf) 

+

196 glBufferData(GL_ELEMENT_ARRAY_BUFFER, 

+

197 self.polys.astype(np.uint16).ravel(), GL_STATIC_DRAW) 

+

198 

+

199 if self.tcoords is not None: 

+

200 self.tbuf = glGenBuffers(1) 

+

201 glBindBuffer(GL_ARRAY_BUFFER, self.tbuf) 

+

202 glBufferData(GL_ARRAY_BUFFER, 

+

203 self.tcoords.astype(np.float32).ravel(), GL_STATIC_DRAW) 

+

204 

+

205 if self.normals is not None: 

+

206 self.nbuf = glGenBuffers(1) 

+

207 glBindBuffer(GL_ARRAY_BUFFER, self.nbuf) 

+

208 glBufferData(GL_ARRAY_BUFFER, 

+

209 self.normals.astype(np.float32).ravel(), GL_STATIC_DRAW) 

+

210 return allocated 

+

211 

+

212 def draw(self, ctx): 

+

213 super(TriMesh, self).draw(ctx) 

+

214 glEnableVertexAttribArray(ctx.attributes['position']) 

+

215 glBindBuffer(GL_ARRAY_BUFFER, self.vbuf) 

+

216 glVertexAttribPointer( ctx.attributes['position'], 

+

217 4, GL_FLOAT, GL_FALSE, 4*4, GLvoidp(0)) 

+

218 

+

219 if self.tcoords is not None and ctx.attributes['texcoord'] != -1: 

+

220 glEnableVertexAttribArray(ctx.attributes['texcoord']) 

+

221 glBindBuffer(GL_ARRAY_BUFFER, self.tbuf) 

+

222 glVertexAttribPointer( 

+

223 ctx.attributes['texcoord'], 2, 

+

224 GL_FLOAT, GL_FALSE, 4*2, GLvoidp(0)) 

+

225 

+

226 if self.normals is not None and ctx.attributes['normal'] != -1: 

+

227 glEnableVertexAttribArray(ctx.attributes['normal']) 

+

228 glBindBuffer(GL_ARRAY_BUFFER, self.nbuf) 

+

229 glVertexAttribPointer( 

+

230 ctx.attributes['normal'], 4, 

+

231 GL_FLOAT, GL_FALSE, 4*4, GLvoidp(0)) 

+

232 

+

233 glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, self.ebuf); 

+

234 glDrawElements( 

+

235 GL_TRIANGLES, # mode 

+

236 len(self.polys)*3, # count 

+

237 GL_UNSIGNED_SHORT, # type 

+

238 GLvoidp(0) # element array buffer offset 

+

239 ) 

+

240 glDisableVertexAttribArray(ctx.attributes['position']) 

+

241 if self.tcoords is not None and ctx.attributes['texcoord'] != -1: 

+

242 glDisableVertexAttribArray(ctx.attributes['texcoord']) 

+

243 if self.normals is not None and ctx.attributes['normal'] != -1: 

+

244 glDisableVertexAttribArray(ctx.attributes['normal']) 

+

245 

+

246class FlatMesh(TriMesh): 

+

247 '''Takes smoothed or no-normal meshes and gives them a flat shading''' 

+

248 def __init__(self, verts, polys, normals=None, **kwargs): 

+

249 checked = dict() 

+

250 normals = [] 

+

251 nverts = [] 

+

252 npolys = [] 

+

253 

+

254 for i, poly in enumerate(polys): 

+

255 v1 = verts[poly[1]] - verts[poly[0]] 

+

256 v2 = verts[poly[2]] - verts[poly[0]] 

+

257 nvec = tuple(np.cross(v1, v2)) 

+

258 

+

259 npoly = [] 

+

260 for v in poly: 

+

261 vert = tuple(verts[v]) 

+

262 if (vert, nvec) not in checked: 

+

263 checked[(vert, nvec)] = len(nverts) 

+

264 npoly.append(len(nverts)) 

+

265 nverts.append(vert) 

+

266 normals.append(nvec) 

+

267 else: 

+

268 npoly.append(checked[(vert, nvec)]) 

+

269 

+

270 npolys.append(npoly) 

+

271 

+

272 

+

273 super(FlatMesh, self).__init__(np.array(nverts), np.array(npolys), 

+

274 normals=np.array(normals), **kwargs) 

+

275 

+

276class PolyMesh(TriMesh): 

+

277 def __init__(self, verts, polys, **kwargs): 

+

278 tripoly = [] 

+

279 for poly in polys: 

+

280 for p in zip(poly[1:-1], poly[2:]): 

+

281 tripoly.append((poly[1],)+p) 

+

282 super(PolyMesh, self).__init__(verts, tripoly, **kwargs) 

+

283 

+

284 

+

285def obj_load(filename): 

+

286 

+

287 facesplit = lambda x: x.split('/') 

+

288 verts, polys, normals, tcoords = [], [], [], [] 

+

289 objfile = open(filename) 

+

290 for line in objfile: 

+

291 el = line.split() 

+

292 if el[0] == "#": 

+

293 pass 

+

294 elif el[0] == "v": 

+

295 verts.append(list(map(float, el[1:]))) 

+

296 elif el[0] == "vt": 

+

297 tcoords.append(list(map(float, el[1:]))) 

+

298 elif el[0] == "vn": 

+

299 normals.append(list(map(float, el[1:]))) 

+

300 elif el[0] == "f": 

+

301 for v in el[1:]: 

+

302 pass 

+

303 list(map(facesplit, el[1:])) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_primitives_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_primitives_py.html new file mode 100644 index 00000000..714cd789 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_primitives_py.html @@ -0,0 +1,381 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\primitives.py: 68% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Basic OpenGL shapes constructed out of triangular meshes 

+

3''' 

+

4 

+

5import numpy as np 

+

6from numpy import pi 

+

7try: 

+

8 import os 

+

9 os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 

+

10 import pygame 

+

11except: 

+

12 import warnings 

+

13 warnings.warn('riglib/stereo_opengl_primitives.py: not importing name pygame') 

+

14 

+

15from .models import TriMesh 

+

16 

+

17class Plane(TriMesh): 

+

18 def __init__(self, width=1, height=1, **kwargs): 

+

19 pts = np.array([[0,0,0], 

+

20 [width,0,0], 

+

21 [width,height,0], 

+

22 [0,height,0]]) 

+

23 polys = [(0,1,3),(1,2,3)] 

+

24 tcoords = np.array([[0,0],[1,0],[1,1],[0,1]]) 

+

25 normals = [(0,0,1)]*4 

+

26 super(Plane, self).__init__(pts, np.array(polys), 

+

27 tcoords=tcoords, normals=np.array(normals), **kwargs) 

+

28 

+

29class Cube(TriMesh): 

+

30 def __init__(self, side_len=1 , segments=36, **kwargs): 

+

31 self.side_len = side_len 

+

32 side_len_half = side_len/2. 

+

33 side = np.linspace(-1, 1, segments/4, endpoint=True) 

+

34 

+

35 unit1 = np.hstack(( side[:,np.newaxis], np.ones((len(side),1)), np.ones((len(side),1)) )) 

+

36 unit2 = np.hstack(( np.ones((len(side),1)), side[::-1,np.newaxis], np.ones((len(side),1)) )) 

+

37 unit3 = np.hstack(( side[::-1,np.newaxis], -1*np.ones((len(side),1)), np.ones((len(side),1)) )) 

+

38 unit4 = np.hstack(( -1*np.ones((len(side),1)), side[:,np.newaxis], np.ones((len(side),1)) )) 

+

39 

+

40 unit = np.vstack((unit1, unit2, unit3, unit4)) 

+

41 

+

42 pts = np.vstack([unit*[side_len_half, side_len_half, 0], unit*[side_len_half,side_len_half,side_len]]) 

+

43 normals = np.vstack([unit*[1,1,0], unit*[1,1,0]]) 

+

44 

+

45 polys = [] 

+

46 for i in range(segments-1): 

+

47 polys.append((i, i+1, i+segments)) 

+

48 polys.append((i+segments, i+1, i+1+segments)) 

+

49 polys.append((segments-1, 0, segments*2-1)) 

+

50 polys.append((segments*2-1, 0, segments)) 

+

51 

+

52 tcoord = np.array([np.arange(segments), np.ones(segments)]).T 

+

53 n = 1./segments 

+

54 tcoord = np.vstack([tcoord*[n,1], tcoord*[n,0]]) 

+

55 

+

56 super(Cube, self).__init__(pts, np.array(polys), 

+

57 tcoords=tcoord, normals=normals, **kwargs) 

+

58 

+

59class Cylinder(TriMesh): 

+

60 def __init__(self, height=1, radius=1, segments=36, **kwargs): 

+

61 self.height = height 

+

62 self.radius = radius 

+

63 theta = np.linspace(0, 2*np.pi, segments, endpoint=False) 

+

64 unit = np.array([np.cos(theta), np.sin(theta), np.ones(segments)]).T 

+

65 

+

66 pts = np.vstack([unit*[radius, radius, 0], unit*[radius,radius,height]]) 

+

67 normals = np.vstack([unit*[1,1,0], unit*[1,1,0]]) 

+

68 

+

69 polys = [] 

+

70 for i in range(segments-1): 

+

71 polys.append((i, i+1, i+segments)) 

+

72 polys.append((i+segments, i+1, i+1+segments)) 

+

73 polys.append((segments-1, 0, segments*2-1)) 

+

74 polys.append((segments*2-1, 0, segments)) 

+

75 

+

76 tcoord = np.array([np.arange(segments), np.ones(segments)]).T 

+

77 n = 1./segments 

+

78 tcoord = np.vstack([tcoord*[n,1], tcoord*[n,0]]) 

+

79 

+

80 super(Cylinder, self).__init__(pts, np.array(polys), 

+

81 tcoords=tcoord, normals=normals, **kwargs) 

+

82 

+

83class Sphere(TriMesh): 

+

84 def __init__(self, radius=1, segments=36, **kwargs): 

+

85 self.radius = radius 

+

86 zvals = radius * np.cos(np.linspace(0, np.pi, num=segments)) 

+

87 circlevals = np.linspace(0, 2*pi, num=segments, endpoint=False) 

+

88 

+

89 vertices = np.zeros(((len(zvals)-2) * len(circlevals), 3)) 

+

90 

+

91 for i, z in enumerate(zvals[1:-1]): 

+

92 circlepoints = np.zeros((segments, 3)) 

+

93 circlepoints[:,2] = z 

+

94 r = np.sqrt(radius**2 - z**2) 

+

95 circlepoints[:,0] = r*np.sin(circlevals) 

+

96 circlepoints[:,1] = r*np.cos(circlevals) 

+

97 vertices[segments*i:segments*(i+1),:] = circlepoints 

+

98 

+

99 vertices = np.vstack([vertices,(0,0,radius),(0,0,-radius)]) 

+

100 allpointinds = np.arange(len(vertices)) 

+

101 

+

102 triangles = np.zeros((segments,3)) 

+

103 firstcirc = allpointinds[0:segments] 

+

104 triangles[0,:] = (allpointinds[-2],firstcirc[0], firstcirc[-1]) 

+

105 for i in range(segments-1): 

+

106 triangles[i+1,:] = (allpointinds[-2], firstcirc[i+1], firstcirc[i]) 

+

107 

+

108 triangles = list(triangles) 

+

109 for i in range(segments-3): 

+

110 points1 = allpointinds[i*segments:(i+1)*segments] 

+

111 points2 = allpointinds[(i+1)*segments:(i+2)*segments] 

+

112 for ind, p in enumerate(points1[:-1]): 

+

113 t1 = (p, points1[ind+1], points2[ind+1]) 

+

114 t2 = (p, points2[ind+1], points2[ind]) 

+

115 triangles += [t1, t2] 

+

116 triangles += [(points1[-1], points1[0], points2[0]), (points1[-1], points2[0], points2[-1])] 

+

117 

+

118 bottom = np.zeros((segments,3)) 

+

119 lastcirc = allpointinds[-segments-2:-2] 

+

120 bottom[0,:] = (allpointinds[-1], lastcirc[-1], lastcirc[0]) 

+

121 for i in range(segments-1): 

+

122 bottom[i+1,:] = (allpointinds[-1], lastcirc[i], lastcirc[i+1]) 

+

123 triangles = np.vstack([triangles, bottom]) 

+

124 

+

125 normals = vertices/radius 

+

126 hcoord = np.arctan2(normals[:,1], normals[:,0]) 

+

127 vcoord = np.arctan2(normals[:,2], np.sqrt(vertices[:,0]**2 + vertices[:,1]**2)) 

+

128 tcoord = np.array([(hcoord+pi) / (2*pi), (vcoord+pi/2) / pi]).T 

+

129 

+

130 super(Sphere, self).__init__(vertices, np.array(triangles), 

+

131 tcoords=tcoord, normals=normals, **kwargs) 

+

132 

+

133 

+

134class Cone(TriMesh): 

+

135 def __init__(self, height=1, radius1=1, radius2=1, segments=36, **kwargs): 

+

136 self.height = height 

+

137 self.radius1 = radius1 

+

138 self.radius2 = radius2 

+

139 self.radius = radius1 # for pretending it's a cylinder.. 

+

140 theta = np.linspace(0, 2*np.pi, segments, endpoint=False) 

+

141 unit = np.array([np.cos(theta), np.sin(theta), np.ones(segments)]).T 

+

142 

+

143 pts = np.vstack([unit*[radius1, radius1, 0], unit*[radius2,radius2,height]]) 

+

144 normals = np.vstack([unit*[1,1,0], unit*[1,1,0]]) 

+

145 

+

146 polys = [] 

+

147 for i in range(segments-1): 

+

148 polys.append((i, i+1, i+segments)) 

+

149 polys.append((i+segments, i+1, i+1+segments)) 

+

150 polys.append((segments-1, 0, segments*2-1)) 

+

151 polys.append((segments*2-1, 0, segments)) 

+

152 

+

153 tcoord = np.array([np.arange(segments), np.ones(segments)]).T 

+

154 n = 1./segments 

+

155 tcoord = np.vstack([tcoord*[n,1], tcoord*[n,0]]) 

+

156 

+

157 super(Cone, self).__init__(pts, np.array(polys), 

+

158 tcoords=tcoord, normals=normals, **kwargs) 

+

159 

+

160 

+

161class Chain(object): 

+

162 ''' 

+

163 An open chain of cylinders and cones, e.g. to simulate a stick-figure arm/robot 

+

164 ''' 

+

165 def __init__(self, link_radii, joint_radii, link_lengths, joint_colors, link_colors): 

+

166 from .models import Group 

+

167 from .xfm import Quaternion 

+

168 self.num_joints = num_joints = len(link_lengths) 

+

169 

+

170 self.link_radii = self.make_list(link_radii, num_joints) 

+

171 self.joint_radii = self.make_list(joint_radii, num_joints) 

+

172 self.link_lengths = self.make_list(link_lengths, num_joints) 

+

173 self.joint_colors = self.make_list(joint_colors, num_joints) 

+

174 self.link_colors = self.make_list(link_colors, num_joints) 

+

175 

+

176 self.links = [] 

+

177 

+

178 # Create the link graphics 

+

179 for i in range(self.num_joints): 

+

180 joint = Sphere(radius=self.joint_radii[i], color=self.joint_colors[i]) 

+

181 

+

182 # The most distal link gets a tapered cylinder (for purely stylistic reasons) 

+

183 if i < self.num_joints - 1: 

+

184 link = Cylinder(radius=self.link_radii[i], height=self.link_lengths[i], color=self.link_colors[i]) 

+

185 else: 

+

186 link = Cone(radius1=self.link_radii[-1], radius2=self.link_radii[-1]/2, height=self.link_lengths[-1], color=self.link_colors[-1]) 

+

187 link_i = Group((link, joint)) 

+

188 self.links.append(link_i) 

+

189 

+

190 link_offsets = [0] + self.link_lengths[:-1] 

+

191 self.link_groups = [None]*self.num_joints 

+

192 for i in range(self.num_joints)[::-1]: 

+

193 if i == self.num_joints-1: 

+

194 self.link_groups[i] = self.links[i] 

+

195 else: 

+

196 self.link_groups[i] = Group([self.links[i], self.link_groups[i+1]]) 

+

197 

+

198 self.link_groups[i].translate(0, 0, link_offsets[i]) 

+

199 

+

200 def _update_link_graphics(self, curr_vecs): 

+

201 from .models import Group 

+

202 from .xfm import Quaternion 

+

203 

+

204 for i in range(self.num_joints): 

+

205 # Rotate each joint to the vector specified by the corresponding row in self.curr_vecs 

+

206 # Annoyingly, the baseline orientation of the first group is always different from the  

+

207 # more distal attachments, so the rotations have to be found relative to the orientation  

+

208 # established at instantiation time. 

+

209 if i == 0: 

+

210 baseline_orientation = (0, 0, 1) 

+

211 else: 

+

212 baseline_orientation = (1, 0, 0) 

+

213 

+

214 # Find the normalized quaternion that represents the desired joint rotation 

+

215 self.link_groups[i].xfm.rotate = Quaternion.rotate_vecs(baseline_orientation, curr_vecs[i]).norm() 

+

216 

+

217 # Recompute any cached transformations after the change 

+

218 self.link_groups[i]._recache_xfm() 

+

219 

+

220 def translate(self, *args, **kwargs): 

+

221 self.link_groups[0].translate(*args, **kwargs) 

+

222 

+

223 @staticmethod 

+

224 def make_list(value, num_joints): 

+

225 ''' 

+

226 Helper function to allow joint/link properties of the chain to be specified 

+

227 as one value for all joints/links or as separate values for each 

+

228 ''' 

+

229 if isinstance(value, list) and len(value) == num_joints: 

+

230 return value 

+

231 else: 

+

232 return [value] * num_joints 

+

233 

+

234##### 2-D primitives ##### 

+

235 

+

236class Shape2D(object): 

+

237 '''Abstract base class for shapes that live in the 2-dimension xz-plane 

+

238 and are intended only for use with the WindowDispl2D class (not Window). 

+

239 ''' 

+

240 

+

241 def __init__(self, color, visible=True): 

+

242 self.color = color 

+

243 self.visible = visible 

+

244 

+

245 def draw(self, surface, pos2pix_fn): 

+

246 '''Draw itself on the given pygame.Surface object using the given 

+

247 position-to-pixel_position function.''' 

+

248 

+

249 raise NotImplementedError # implement in subclasses 

+

250 

+

251 def _recache_xfm(self): 

+

252 pass 

+

253 

+

254 

+

255class Circle(Shape2D): 

+

256 def __init__(self, center_pos, radius, *args, **kwargs): 

+

257 super(Circle, self).__init__(*args, **kwargs) 

+

258 self.center_pos = center_pos 

+

259 self.radius = radius 

+

260 

+

261 def draw(self, surface, pos2pix_fn): 

+

262 if self.visible: 

+

263 color = tuple([int(255*x) for x in self.color[0:3]]) 

+

264 

+

265 pix_pos = pos2pix_fn(self.center_pos) 

+

266 pix_radius = pos2pix_fn([self.radius, 0])[0] - pos2pix_fn([0, 0])[0] 

+

267 pygame.draw.circle(surface, color, pix_pos, pix_radius) 

+

268 

+

269 return self.visible # return True if object was drawn 

+

270 

+

271 

+

272class Sector(Shape2D): 

+

273 def __init__(self, center_pos, radius, ang_range, *args, **kwargs): 

+

274 super(Sector, self).__init__(*args, **kwargs) 

+

275 self.center_pos = center_pos 

+

276 self.radius = radius 

+

277 self.ang_range = ang_range 

+

278 

+

279 def draw(self, surface, pos2pix_fn): 

+

280 if self.visible: 

+

281 color = tuple([int(255*x) for x in self.color[0:3]]) 

+

282 

+

283 arc_angles = np.linspace(self.ang_range[0], self.ang_range[1], 5) 

+

284 pts = list(self.center_pos + self.radius*np.c_[np.cos(arc_angles), np.sin(arc_angles)]) 

+

285 pts.append(self.center_pos) 

+

286 

+

287 point_list = list(map(pos2pix_fn, pts)) 

+

288 pygame.draw.polygon(surface, color, point_list) 

+

289 

+

290 return self.visible # return True if object was drawn 

+

291 

+

292 

+

293class Line(Shape2D): 

+

294 def __init__(self, start_pos, length, width, angle, *args, **kwargs): 

+

295 super(Line, self).__init__(*args, **kwargs) 

+

296 self.start_pos = start_pos 

+

297 self.length = length 

+

298 self.width = width # draw a line as thin rectangle 

+

299 self.angle = angle 

+

300 

+

301 def draw(self, surface, pos2pix_fn): 

+

302 if self.visible: 

+

303 color = tuple([int(255*x) for x in self.color[0:3]]) 

+

304 

+

305 # create points and then rotate to correct orientation 

+

306 pts = np.array([[ 0, self.width/2], 

+

307 [ 0, -self.width/2], 

+

308 [self.length, -self.width/2], 

+

309 [self.length, self.width/2]]) 

+

310 rot_mat = np.array([[np.cos(self.angle), -np.sin(self.angle)], 

+

311 [np.sin(self.angle), np.cos(self.angle)]]) 

+

312 pts = np.dot(rot_mat, pts.T).T + self.start_pos 

+

313 

+

314 point_list = list(map(pos2pix_fn, pts)) 

+

315 pygame.draw.polygon(surface, color, point_list) 

+

316 

+

317 return self.visible # return True if object was drawn 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render___init___py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render___init___py.html new file mode 100644 index 00000000..b969ca22 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render___init___py.html @@ -0,0 +1,70 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\__init__.py: 100% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Collection of graphical renderers 

+

3''' 

+

4 

+

5from .render import Renderer 

+

6from . import stereo 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_fbo_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_fbo_py.html new file mode 100644 index 00000000..3aea0640 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_fbo_py.html @@ -0,0 +1,173 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\fbo.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1'''Needs docs''' 

+

2 

+

3 

+

4import numpy as np 

+

5from OpenGL.GL import * 

+

6 

+

7from .render import Renderer 

+

8from ..textures import Texture 

+

9 

+

10fbotypes = dict( 

+

11 depth=(GL_DEPTH_COMPONENT, GL_DEPTH_COMPONENT, GL_FLOAT, GL_DEPTH_ATTACHMENT), 

+

12 stencil=(GL_DEPTH_COMPONENT, GL_DEPTH_COMPONENT, GL_UNSIGNED_BYTE, GL_STENCIL_ATTACHMENT), 

+

13 colors=(GL_RGBA8, GL_RGBA, GL_UNSIGNED_BYTE, GL_COLOR_ATTACHMENT0) 

+

14) 

+

15 

+

16class FBO(object): 

+

17 def __init__(self, attachments, size=None, ncolors=1, **kwargs): 

+

18 maxcolors = range(glGetInteger(GL_MAX_COLOR_ATTACHMENTS)) 

+

19 self.names = dict(("color%d"%i, GL_COLOR_ATTACHMENT0+i) for i in maxcolors) 

+

20 self.names["depth"] = GL_DEPTH_ATTACHMENT 

+

21 self.names["stencil"] = GL_STENCIL_ATTACHMENT 

+

22 

+

23 self.textures = dict() 

+

24 

+

25 self.fbo = glGenFramebuffers(1) 

+

26 glBindFramebuffer(GL_FRAMEBUFFER, self.fbo) 

+

27 

+

28 for attach in attachments: 

+

29 if isinstance(attach, str): 

+

30 if attach.startswith("color"): 

+

31 idx = int(attach[5:]) 

+

32 attach = "colors" 

+

33 iform, exform, dtype, attachment = fbotypes[attach] 

+

34 texture = Texture(None, size=size, iformat=iform, exformat=exform, dtype=dtype) 

+

35 texture.init() 

+

36 if attach == "colors": 

+

37 attachment += idx 

+

38 glFramebufferTexture2D(GL_FRAMEBUFFER, attachment, GL_TEXTURE_2D, texture.tex, 0) 

+

39 self.textures[attachment] = texture 

+

40 else: 

+

41 attachment, texture = attach 

+

42 if attachment in self.names: 

+

43 attachment = self.names[attachment] 

+

44 if texture is None and attachment == GL_DEPTH_ATTACHMENT: 

+

45 rb = glGenRenderbuffers(1) 

+

46 glBindRenderbuffer(GL_RENDERBUFFER, rb) 

+

47 glRenderbufferStorage(GL_RENDERBUFFER, GL_DEPTH_COMPONENT, size[0], size[1]) 

+

48 glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, GL_RENDERBUFFER, rb) 

+

49 else: 

+

50 if texture.tex is None: 

+

51 texture.init() 

+

52 glFramebufferTexture2D(GL_FRAMEBUFFER, attachment, GL_TEXTURE_2D, texture.tex, 0) 

+

53 self.textures[attachment] = texture 

+

54 

+

55 types = [t for t in list(self.textures.keys()) if 

+

56 t !=GL_DEPTH_ATTACHMENT and 

+

57 t != GL_STENCIL_ATTACHMENT and 

+

58 t != GL_DEPTH_STENCIL_ATTACHMENT] 

+

59 if len(types) > 0: 

+

60 glDrawBuffers(types) 

+

61 else: 

+

62 glDrawBuffers(GL_NONE) 

+

63 

+

64 assert glCheckFramebufferStatus(GL_FRAMEBUFFER) == GL_FRAMEBUFFER_COMPLETE 

+

65 glBindFramebuffer(GL_FRAMEBUFFER, 0) 

+

66 

+

67 def __getitem__(self, idx): 

+

68 if isinstance(idx, str): 

+

69 return self.textures[self.names[idx]] 

+

70 

+

71 return self.textures[idx] 

+

72 

+

73 def clear(self): 

+

74 glBindFramebuffer(GL_FRAMEBUFFER, self.fbo) 

+

75 glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT | GL_STENCIL_BUFFER_BIT) 

+

76 glBindFramebuffer(GL_FRAMEBUFFER, 0) 

+

77 

+

78class FBOrender(Renderer): 

+

79 def draw_fsquad(self, shader, **kwargs): 

+

80 ctx = self.programs[shader] 

+

81 glUseProgram(ctx.program) 

+

82 for name, v in list(kwargs.items()): 

+

83 if isinstance(v, Texture): 

+

84 ctx.uniforms[name] = self.get_texunit(v) 

+

85 else: 

+

86 ctx.uniforms[name] = v 

+

87 

+

88 glEnableVertexAttribArray(ctx.attributes['position']) 

+

89 glBindBuffer(GL_ARRAY_BUFFER, self.fsquad_buf[0]) 

+

90 glVertexAttribPointer(ctx.attributes['position'], 

+

91 4, GL_FLOAT, GL_FALSE, 4*4, GLvoidp(0)) 

+

92 

+

93 glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, self.fsquad_buf[1]); 

+

94 glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_SHORT, GLvoidp(0)) 

+

95 glDisableVertexAttribArray(ctx.attributes['position']) 

+

96 

+

97 def draw_fsquad_to_fbo(self, fbo, shader, **kwargs): 

+

98 glBindFramebuffer(GL_FRAMEBUFFER, fbo.fbo) 

+

99 glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT | GL_STENCIL_BUFFER_BIT) 

+

100 self.draw_fsquad(shader, **kwargs) 

+

101 glBindFramebuffer(GL_FRAMEBUFFER, 0) 

+

102 

+

103 def draw_to_fbo(self, fbo, root, **kwargs): 

+

104 glBindFramebuffer(GL_FRAMEBUFFER, fbo.fbo) 

+

105 #Erase old buffer info 

+

106 glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT | GL_STENCIL_BUFFER_BIT) 

+

107 super(FBOrender, self).draw(root, **kwargs) 

+

108 glBindFramebuffer(GL_FRAMEBUFFER, 0) 

+

109 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_render_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_render_py.html new file mode 100644 index 00000000..6647aa77 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_render_py.html @@ -0,0 +1,227 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\render.py: 18% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1'''Needs docs''' 

+

2 

+

3 

+

4import os 

+

5import operator 

+

6import numpy as np 

+

7from OpenGL.GL import * 

+

8 

+

9from ..utils import perspective 

+

10from .shader import ShaderProgram 

+

11 

+

12cwd = os.path.join(os.path.abspath(os.path.split(__file__)[0]), "..") 

+

13 

+

14class _textrack(object): 

+

15 pass 

+

16 

+

17class Renderer(object): 

+

18 def __init__(self, window_size, fov, near, far, shaders=None, programs=None): 

+

19 self.render_queue = None 

+

20 self.size = window_size 

+

21 self.drawpos = 0,0 

+

22 w, h = window_size 

+

23 self.projection = perspective(fov, w / h, near, far) 

+

24 

+

25 #Add the default shaders 

+

26 if shaders is None: 

+

27 shaders = dict() 

+

28 if programs is None: 

+

29 programs = dict() 

+

30 shaders['passthru'] = GL_VERTEX_SHADER, "passthrough.v.glsl" 

+

31 shaders['default'] = GL_FRAGMENT_SHADER, "default.f.glsl", "phong.f.glsl" 

+

32 programs['default'] = "passthru", "default" 

+

33 

+

34 #compile the given shaders and the programs 

+

35 self.shaders = dict() 

+

36 for k, v in list(shaders.items()): 

+

37 print("Compiling shader %s..."%k) 

+

38 self.add_shader(k, *v) 

+

39 

+

40 self.programs = dict() 

+

41 for name, shaders in list(programs.items()): 

+

42 self.add_program(name, shaders) 

+

43 

+

44 #Set up the texture units 

+

45 self.reset_texunits() 

+

46 

+

47 #Generate the default fullscreen quad 

+

48 verts = np.array([(-1,-1,0,1), (1,-1,0,1), (1,1,0,1), (-1,1,0,1)]).astype(np.float32) 

+

49 polys = np.array([(0,1,2),(0,2,3)]).astype(np.uint16) 

+

50 vbuf = glGenBuffers(1) 

+

51 ebuf = glGenBuffers(1) 

+

52 glBindBuffer(GL_ARRAY_BUFFER, vbuf) 

+

53 glBufferData(GL_ARRAY_BUFFER, verts, GL_STATIC_DRAW) 

+

54 glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, ebuf) 

+

55 glBufferData(GL_ELEMENT_ARRAY_BUFFER, polys, GL_STATIC_DRAW) 

+

56 self.fsquad_buf = vbuf, ebuf 

+

57 

+

58 def _queue_render(self, root, shader=None): 

+

59 ''' 

+

60 

+

61 Parameters 

+

62 ---------- 

+

63 root: stereo_opengl.models.Model instance 

+

64 The root model from which to start drawing 

+

65 shader: ???????, default=None 

+

66 ???????? 

+

67 ''' 

+

68 queue = dict((k, dict()) for k in list(self.programs.keys())) 

+

69 

+

70 for pname, drawfunc, tex in root.render_queue(shader=shader): 

+

71 if tex not in queue[pname]: 

+

72 queue[pname][tex] = [] 

+

73 queue[pname][tex].append(drawfunc) 

+

74 

+

75 for pname in list(self.programs.keys()): 

+

76 #assert len(self.texavail) > len(queue[pname]) 

+

77 for tex in list(queue[pname].keys()): 

+

78 if tex is not None: 

+

79 self.get_texunit(tex) 

+

80 

+

81 self.render_queue = queue 

+

82 

+

83 def get_texunit(self, tex): 

+

84 '''Input a Texture object, output a tuple (index, TexUnit)''' 

+

85 if tex not in self.texunits: 

+

86 unit = self.texavail.pop() 

+

87 glActiveTexture(unit[1]) 

+

88 if tex == "None": 

+

89 glBindTexture(GL_TEXTURE_2D, 0) 

+

90 else: 

+

91 glBindTexture(GL_TEXTURE_2D, tex.tex) 

+

92 #print "Binding %r to %d"%(tex, unit[0]) 

+

93 self.texunits[tex] = unit[0] 

+

94 

+

95 return self.texunits[tex] 

+

96 

+

97 def reset_texunits(self): 

+

98 maxtex = glGetIntegerv(GL_MAX_TEXTURE_COORDS) 

+

99 #Use first texture unit as the "blank" texture 

+

100 self.texavail = set((i, globals()['GL_TEXTURE%d'%i]) for i in range(maxtex)) 

+

101 self.texunits = dict() 

+

102 

+

103 def add_shader(self, name, stype, filename, *includes): 

+

104 src = [] 

+

105 main = open(os.path.join(cwd, "shaders", filename)) 

+

106 version = main.readline().strip() 

+

107 for inc in includes: 

+

108 incfile = open(os.path.join(cwd, "shaders", inc)) 

+

109 ver = incfile.readline().strip() 

+

110 assert ver == version, "Version: %s, %s"%(ver, version) 

+

111 src.append(incfile.read()) 

+

112 incfile.close() 

+

113 src.append(main.read()) 

+

114 main.close() 

+

115 

+

116 shader = glCreateShader(stype) 

+

117 glShaderSource(shader, "\n".join(src)) 

+

118 glCompileShader(shader) 

+

119 

+

120 if not glGetShaderiv(shader, GL_COMPILE_STATUS): 

+

121 err = glGetShaderInfoLog(shader) 

+

122 glDeleteShader(shader) 

+

123 raise Exception(err) 

+

124 

+

125 self.shaders[name] = shader 

+

126 

+

127 def add_program(self, name, shaders): 

+

128 shaders = [self.shaders[i] for i in shaders] 

+

129 sp = ShaderProgram(shaders) 

+

130 self.programs[name] = sp 

+

131 

+

132 def draw(self, root, shader=None, requeue=False, **kwargs): 

+

133 if self.render_queue is None or requeue: 

+

134 self._queue_render(root) 

+

135 

+

136 if "p_matrix" not in kwargs: 

+

137 kwargs['p_matrix'] = self.projection 

+

138 if "modelview" not in kwargs: 

+

139 kwargs['modelview'] = root._xfm.to_mat() 

+

140 

+

141 if shader is not None: 

+

142 for items in list(self.render_queue.values()): 

+

143 self.programs[shader].draw(self, items, **kwargs) 

+

144 else: 

+

145 for name, program in list(self.programs.items()): 

+

146 program.draw(self, self.render_queue[name], **kwargs) 

+

147 

+

148 def draw_done(self): 

+

149 self.reset_texunits() 

+

150 

+

151def test(): 

+

152 import pygame 

+

153 pygame.init() 

+

154 pygame.display.set_mode((100,100), pygame.OPENGL | pygame.DOUBLEBUF) 

+

155 

+

156 return Renderer( 

+

157 shaders=dict( 

+

158 passthru=(GL_VERTEX_SHADER, "passthrough.v.glsl"), 

+

159 phong=(GL_FRAGMENT_SHADER, "default.f.glsl", "phong.f.glsl")), 

+

160 programs=dict( 

+

161 default=("passthru", "phong"), 

+

162 ) 

+

163 ) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_shader_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_shader_py.html new file mode 100644 index 00000000..2bbf7268 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_shader_py.html @@ -0,0 +1,181 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\shader.py: 19% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1'''Needs docs''' 

+

2 

+

3 

+

4import numpy as np 

+

5from OpenGL.GL import * 

+

6 

+

7from ..textures import Texture 

+

8 

+

9_mattypes = { 

+

10 (4,4):"4", (3,3):"3", (2,2):"2", 

+

11 (2,3):"2x3",(3,2):"3x2", 

+

12 (2,4):"2x4",(4,2):"4x2", 

+

13 (3,4):"3x4",(4,3):"4x3" 

+

14} 

+

15_typename = { int:"i", float:"f" } 

+

16class _getter(object): 

+

17 '''Wrapper object which allows direct getting and setting of shader values''' 

+

18 def __init__(self, type, prog): 

+

19 setter = super(_getter, self).__setattr__ 

+

20 

+

21 setter("prog", prog) 

+

22 setter("cache", dict()) 

+

23 setter("func", globals()['glGet{type}Location'.format(type=type)]) 

+

24 

+

25 if type == "Uniform": 

+

26 setter("info", dict()) 

+

27 for i in range(glGetProgramiv(self.prog, GL_ACTIVE_UNIFORMS)): 

+

28 name, size, t = glGetActiveUniform(self.prog, i) 

+

29 self.info[name] = t 

+

30 

+

31 def __getattr__(self, attr): 

+

32 if attr not in self.cache: 

+

33 self.cache[attr] = self.func(self.prog, attr) 

+

34 return self.cache[attr] 

+

35 

+

36 def __getitem__(self, attr): 

+

37 if attr not in self.cache: 

+

38 self.cache[attr] = self.func(self.prog, attr) 

+

39 return self.cache[attr] 

+

40 

+

41 def __setitem__(self, attr, val): 

+

42 self._set(attr, val) 

+

43 

+

44 def __setattr__(self, attr, val): 

+

45 self._set(attr, val) 

+

46 

+

47 def __contains__(self, attr): 

+

48 if attr not in self.cache: 

+

49 self.cache[attr] = self.func(self.prog, attr) 

+

50 return self.cache[attr] != -1 

+

51 

+

52 def _set(self, attr, val): 

+

53 '''This heinously complicated function has to guess the function to use because 

+

54 there are no strong types in python, hence we just have to guess''' 

+

55 if attr not in self.cache: 

+

56 self.cache[attr] = self.func(self.prog, attr) 

+

57 

+

58 if isinstance(val, np.ndarray) and len(val.shape) > 1: 

+

59 assert len(val.shape) <= 3 

+

60 if val.shape[-2:] in _mattypes: 

+

61 nmats = val.shape[0] if len(val.shape) == 3 else 1 

+

62 fname = _mattypes[val.shape[-2:]] 

+

63 func = globals()['glUniformMatrix%sfv'%fname] 

+

64 #We need to transpose all numpy matrices since numpy is row-major 

+

65 #and opengl is column-major 

+

66 func(self.cache[attr], nmats, GL_TRUE, val.astype(np.float32).ravel()) 

+

67 elif isinstance(val, (list, tuple, np.ndarray)): 

+

68 #glUniform\d[if]v 

+

69 if isinstance(val[0], (tuple, list, np.ndarray)): 

+

70 assert len(val[0]) <= 4 

+

71 t = _typename[type(val[0][0])] 

+

72 func = globals()['glUniform%d%sv'%(len(val[0]), t)] 

+

73 func(self.cache[attr], len(val), np.array(val).astype(np.float32).ravel()) 

+

74 else: 

+

75 t = _typename[type(val[0])] 

+

76 func = globals()['glUniform%d%s'%(len(val), t)] 

+

77 func(self.cache[attr], *val) 

+

78 elif isinstance(val, (int, float)): 

+

79 #single value, push with glUni2form1 

+

80 globals()['glUniform1%s'%_typename[type(val)]](self.cache[attr], val) 

+

81 

+

82class ShaderProgram(object): 

+

83 def __init__(self, shaders): 

+

84 self.shaders = shaders 

+

85 self.program = glCreateProgram() 

+

86 for shader in shaders: 

+

87 glAttachShader(self.program, shader) 

+

88 glLinkProgram(self.program) 

+

89 

+

90 if not glGetProgramiv(self.program, GL_LINK_STATUS): 

+

91 err = glGetProgramInfoLog(self.program) 

+

92 glDeleteProgram(self.program) 

+

93 raise Exception(err) 

+

94 

+

95 self.attributes = _getter("Attrib", self.program) 

+

96 self.uniforms = _getter("Uniform", self.program) 

+

97 

+

98 def draw(self, ctx, models, **kwargs): 

+

99 glUseProgram(self.program) 

+

100 for name, v in list(kwargs.items()): 

+

101 if isinstance(v, Texture): 

+

102 self.uniforms[name] = ctx.get_texunit(v) 

+

103 elif name in self.uniforms: 

+

104 self.uniforms[name] = v 

+

105 elif name in self.attributes: 

+

106 self.attributes[name] = v 

+

107 elif hasattr(v, "__call__"): 

+

108 v(self) 

+

109 

+

110 for tex, funcs in list(models.items()): 

+

111 if tex is None: 

+

112 self.uniforms.texture = ctx.get_texunit("None") 

+

113 else: 

+

114 self.uniforms.texture = ctx.get_texunit(tex) 

+

115 

+

116 for drawfunc in funcs: 

+

117 drawfunc(self) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_ssao_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_ssao_py.html new file mode 100644 index 00000000..9f71e309 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_ssao_py.html @@ -0,0 +1,139 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\ssao.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1'''Needs docs''' 

+

2 

+

3import numpy as np 

+

4from OpenGL.GL import * 

+

5 

+

6from .render import Renderer 

+

7from .fbo import FBOrender, FBO 

+

8from ..textures import Texture 

+

9 

+

10class SSAO(FBOrender): 

+

11 def __init__(self, *args, **kwargs): 

+

12 super(SSAO, self).__init__(*args, **kwargs) 

+

13 self.sf = 3 

+

14 w, h = self.size[0] / self.sf, self.size[1] / self.sf 

+

15 

+

16 self.normdepth = FBO(["color0", "depth"], size=(w,h)) 

+

17 self.ping = FBO(['color0'], size=(w,h)) 

+

18 self.pong = FBO(["color0"], size=(w,h)) 

+

19 

+

20 self.add_shader("fsquad", GL_VERTEX_SHADER, "fsquad.v.glsl") 

+

21 self.add_shader("ssao_pass1", GL_FRAGMENT_SHADER, "ssao_pass1.f.glsl") 

+

22 self.add_shader("ssao_pass2", GL_FRAGMENT_SHADER, "ssao_pass2.f.glsl") 

+

23 self.add_shader("ssao_pass3", GL_FRAGMENT_SHADER, "ssao_pass3.f.glsl", "phong.f.glsl") 

+

24 self.add_shader("hblur", GL_FRAGMENT_SHADER, "hblur.f.glsl") 

+

25 self.add_shader("vblur", GL_FRAGMENT_SHADER, "vblur.f.glsl") 

+

26 

+

27 #override the default shader with this passthru + ssao_pass1 to store depth 

+

28 self.add_program("ssao_pass1", ("passthru", "ssao_pass1")) 

+

29 self.add_program("ssao_pass2", ("fsquad", "ssao_pass2")) 

+

30 self.add_program("hblur", ("fsquad", "hblur")) 

+

31 self.add_program("vblur", ("fsquad", "vblur")) 

+

32 self.add_program("ssao_pass3", ("passthru", "ssao_pass3")) 

+

33 

+

34 randtex = np.random.rand(3, w, h) 

+

35 randtex /= randtex.sum(0) 

+

36 self.rnm = Texture(randtex.T, wrap_x=GL_REPEAT, wrap_y=GL_REPEAT, 

+

37 magfilter=GL_NEAREST, minfilter=GL_NEAREST) 

+

38 self.rnm.init() 

+

39 

+

40 self.clips = args[2], args[3] 

+

41 

+

42 def draw(self, root, **kwargs): 

+

43 #First, draw the whole damned scene, but only read the normals and depth into ssao 

+

44 glPushAttrib(GL_VIEWPORT_BIT) 

+

45 glViewport( 0,0, self.size[0]/self.sf, self.size[1]/self.sf) 

+

46 self.draw_to_fbo(self.normdepth, root, shader="ssao_pass1", **kwargs) 

+

47 

+

48 #Now, do the actual ssao calculations, and draw it into ping 

+

49 self.draw_fsquad_to_fbo(self.pong, "ssao_pass2", rnm=self.rnm, 

+

50 normalMap=self.normdepth['color0'], depthMap=self.normdepth['depth'], 

+

51 nearclip=self.clips[0], farclip=self.clips[1] ) 

+

52 

+

53 #Blur the textures 

+

54 self.draw_fsquad_to_fbo(self.ping, "hblur", tex=self.pong['color0'], blur=1./(self.size[0]/self.sf)) 

+

55 self.draw_fsquad_to_fbo(self.pong, "vblur", tex=self.ping['color0'], blur=1./(self.size[0]/self.sf)) 

+

56 

+

57 glPopAttrib() 

+

58 #Actually draw the final image to the screen 

+

59 win = glGetIntegerv(GL_VIEWPORT) 

+

60 #Why is this call necessary at all?! 

+

61 glViewport(*win) 

+

62 

+

63 super(SSAO, self).draw(root, shader="ssao_pass3", shadow=self.pong['color0'], 

+

64 window=[float(i) for i in win], **kwargs) 

+

65 

+

66 #self.draw_done() 

+

67 

+

68 def clear(self): 

+

69 self.normdepth.clear() 

+

70 self.ping.clear() 

+

71 self.pong.clear() 

+

72 

+

73 def draw_done(self): 

+

74 super(SSAO, self).draw_done() 

+

75 self.clear() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_stereo_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_stereo_py.html new file mode 100644 index 00000000..461de5ae --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_render_stereo_py.html @@ -0,0 +1,182 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\stereo.py: 27% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Extensions of the render.Renderer class for stereo displays 

+

3''' 

+

4 

+

5import numpy as np 

+

6from OpenGL.GL import * 

+

7 

+

8from .render import Renderer 

+

9from ..utils import offaxis_frusta 

+

10 

+

11class LeftRight(Renderer): 

+

12 def __init__(self, window_size, fov, near, far, focal_dist, iod, **kwargs): 

+

13 w, h = window_size 

+

14 super(LeftRight, self).__init__((w/2,h), fov, near, far, **kwargs) 

+

15 self.projections = offaxis_frusta((w/2, h), fov, near, far, focal_dist, iod) 

+

16 

+

17 def draw(self, root, **kwargs): 

+

18 w, h = self.size 

+

19 glViewport(0, 0, w, h) 

+

20 super(LeftRight, self).draw(root, p_matrix=self.projections[0], **kwargs) 

+

21 glViewport(w, 0, w, h) 

+

22 super(LeftRight, self).draw(root, p_matrix=self.projections[1], **kwargs) 

+

23 

+

24class RightLeft(LeftRight): 

+

25 def draw(self, root, **kwargs): 

+

26 w, h = self.size 

+

27 glViewport(w, 0, w, h) 

+

28 super(LeftRight, self).draw(root, p_matrix=self.projections[0], **kwargs) 

+

29 glViewport(0, 0, w, h) 

+

30 super(LeftRight, self).draw(root, p_matrix=self.projections[1], **kwargs) 

+

31 

+

32class MirrorDisplay(Renderer): 

+

33 '''The mirror display requires a left-right flip, otherwise the sides are messed up''' 

+

34 def __init__(self, window_size, fov, near, far, focal_dist, iod, **kwargs): 

+

35 w, h = window_size 

+

36 super(MirrorDisplay, self).__init__((w/2,h), fov, near, far, **kwargs) 

+

37 flip = kwargs.pop('flip', True) 

+

38 self.projections = offaxis_frusta((w/2, h), fov, near, far, focal_dist, iod, flip=flip) 

+

39 

+

40 def draw(self, root, **kwargs): 

+

41 ''' 

+

42 Draw the 'root' model. 

+

43 

+

44 Parameters 

+

45 ---------- 

+

46 root: stereo_opengl.models.Model instance 

+

47 Root "world" model. Draws all submodels of this model. 

+

48 kwargs: optional keyword-arguments 

+

49 Optional shaders and stuff to pass to the lower-level drawing functions 

+

50 ''' 

+

51 w, h = self.size 

+

52 w = int(w) 

+

53 h = int(h) 

+

54 

+

55 # draw the portion of the screen with lower-left corner (0, 0), width 'w' and height 'h' 

+

56 glViewport(0, 0, w, h) 

+

57 super(MirrorDisplay, self).draw(root, p_matrix=self.projections[0], **kwargs) 

+

58 

+

59 # draw the portion of the screen with lower-left corner (w, 0), width 'w' and height 'h' 

+

60 glViewport(w, 0, w, h) 

+

61 super(MirrorDisplay, self).draw(root, p_matrix=self.projections[1], **kwargs) 

+

62 

+

63class DualMultisizeDisplay(Renderer): 

+

64 def __init__(self, main_window_size, mini_window_size, fov, near, far, focal_dist, iod, **kwargs): 

+

65 w, h = main_window_size 

+

66 w2, h2 = mini_window_size 

+

67 self.main_window_size = main_window_size 

+

68 self.mini_window_size = mini_window_size 

+

69 

+

70 flip_main_z = kwargs.pop('flip_main_z', False) 

+

71 flip = kwargs.pop('flip', True) 

+

72 

+

73 

+

74 super(DualMultisizeDisplay, self).__init__((w+w2, h), fov, near, far, **kwargs) 

+

75 main_projections = offaxis_frusta(main_window_size, fov, near, far, focal_dist, iod, flip=flip, flip_z=flip_main_z) 

+

76 mini_projections = offaxis_frusta(mini_window_size, fov, near, far, focal_dist, iod, flip=flip) 

+

77 self.projections = (mini_projections[0], main_projections[0]) 

+

78 

+

79 def draw(self, root, **kwargs): 

+

80 ''' 

+

81 Draw the 'root' model. 

+

82 

+

83 Parameters 

+

84 ---------- 

+

85 root: stereo_opengl.models.Model instance 

+

86 Root "world" model. Draws all submodels of this model. 

+

87 kwargs: optional keyword-arguments 

+

88 Optional shaders and stuff to pass to the lower-level drawing functions 

+

89 ''' 

+

90 

+

91 # draw the portion of the screen with lower-left corner (0, 0), width 'w' and height 'h' 

+

92 w2, h2 = self.mini_window_size 

+

93 glViewport(0, 0, w2, h2) 

+

94 super(DualMultisizeDisplay, self).draw(root, p_matrix=self.projections[0], **kwargs) 

+

95 

+

96 # draw the portion of the screen with lower-left corner (w, 0), width 'w' and height 'h' 

+

97 w, h = self.main_window_size 

+

98 glViewport(w2, 0, w, h) 

+

99 super(DualMultisizeDisplay, self).draw(root, p_matrix=self.projections[1], **kwargs) 

+

100 

+

101class Anaglyph(Renderer): 

+

102 ''' 

+

103 From wikipedia: Anaglyph 3D is the name given to the stereoscopic 3D effect achieved by  

+

104 means of encoding each eye's image using filters of different (usually chromatically  

+

105 opposite) colors, typically red and cyan. 

+

106 ''' 

+

107 def __init__(self, window_size, fov, near, far, focal_dist, iod, **kwargs): 

+

108 super(Anaglyph, self).__init__(window_size, fov, near, far, **kwargs) 

+

109 self.projections = offaxis_frusta(self.size, fov, near, far, focal_dist, iod) 

+

110 

+

111 def draw(self, root, **kwargs): 

+

112 glViewport(0,0,self.size[0], self.size[1]) 

+

113 glColorMask(GL_TRUE, GL_FALSE, GL_FALSE, GL_TRUE) 

+

114 super(Anaglyph, self).draw(root, p_matrix=self.projections[0], **kwargs) 

+

115 glClear(GL_DEPTH_BUFFER_BIT) 

+

116 glColorMask(GL_FALSE, GL_TRUE, GL_TRUE, GL_TRUE) 

+

117 super(Anaglyph, self).draw(root, p_matrix=self.projections[1], **kwargs) 

+

118 glColorMask(GL_TRUE, GL_TRUE, GL_TRUE, GL_TRUE) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_textures_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_textures_py.html new file mode 100644 index 00000000..1852573d --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_textures_py.html @@ -0,0 +1,161 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\textures.py: 23% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1'''Needs docs''' 

+

2 

+

3 

+

4import numpy as np 

+

5from OpenGL.GL import * 

+

6 

+

7from .models import Model 

+

8 

+

9textypes = {GL_UNSIGNED_BYTE:np.uint8, GL_FLOAT:np.float32} 

+

10class Texture(object): 

+

11 def __init__(self, tex, size=None, 

+

12 magfilter=GL_LINEAR, minfilter=GL_LINEAR, 

+

13 wrap_x=GL_CLAMP_TO_EDGE, wrap_y=GL_CLAMP_TO_EDGE, 

+

14 iformat=GL_RGBA8, exformat=GL_RGBA, dtype=GL_UNSIGNED_BYTE): 

+

15 

+

16 self.opts = dict( 

+

17 magfilter=magfilter, minfilter=minfilter, 

+

18 wrap_x=wrap_x, wrap_y=wrap_y, 

+

19 iformat=iformat, exformat=exformat, dtype=dtype) 

+

20 

+

21 if isinstance(tex, np.ndarray): 

+

22 if tex.max() <= 1: 

+

23 tex *= 255 

+

24 if len(tex.shape) < 3: 

+

25 tex = np.tile(tex, [3, 1, 1]).T 

+

26 if tex.shape[-1] == 3: 

+

27 tex = np.dstack([tex, np.ones(tex.shape[:-1])]) 

+

28 size = tex.shape[:2] 

+

29 tex = tex.astype(np.uint8).tostring() 

+

30 elif isinstance(tex, str): 

+

31 im = pygame.image.load(tex) 

+

32 size = tex.get_size() 

+

33 tex = pygame.image.tostring(im, 'RGBA') 

+

34 

+

35 self.texstr = tex 

+

36 self.size = size 

+

37 self.tex = None 

+

38 

+

39 def init(self): 

+

40 gltex = glGenTextures(1) 

+

41 glBindTexture(GL_TEXTURE_2D, gltex) 

+

42 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, self.opts['minfilter']) 

+

43 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, self.opts['magfilter']) 

+

44 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, self.opts['wrap_x']) 

+

45 glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, self.opts['wrap_y']) 

+

46 glTexImage2D( 

+

47 GL_TEXTURE_2D, 0, #target, level 

+

48 self.opts['iformat'], #internal format 

+

49 self.size[0], self.size[1], 0, #width, height, border 

+

50 self.opts['exformat'], self.opts['dtype'], #external format, type 

+

51 self.texstr if self.texstr is not None else 0 #pixels 

+

52 ) 

+

53 

+

54 self.tex = gltex 

+

55 

+

56 def set(self, idx): 

+

57 glActiveTexture(GL_TEXTURE0+idx) 

+

58 glBindTexture(GL_TEXTURE_2D, self.tex) 

+

59 

+

60 def get(self, filename=None): 

+

61 current = glGetInteger(GL_TEXTURE_BINDING_2D) 

+

62 glBindTexture(GL_TEXTURE_2D, self.tex) 

+

63 texstr = glGetTexImage(GL_TEXTURE_2D, 0, self.opts['exformat'], self.opts['dtype']) 

+

64 glBindTexture(GL_TEXTURE_2D, current) 

+

65 im = np.fromstring(texstr, dtype=textypes[self.opts['dtype']]) 

+

66 im.shape = (self.size[1], self.size[0], -1) 

+

67 if filename is not None: 

+

68 np.save(filename, im) 

+

69 return im 

+

70 

+

71 

+

72class MultiTex(object): 

+

73 '''This is not ready yet!''' 

+

74 def __init__(self, textures, weights): 

+

75 raise NotImplementedError 

+

76 assert len(textures) < max_multitex 

+

77 self.texs = textures 

+

78 self.weights = weights 

+

79 

+

80class TexModel(Model): 

+

81 def __init__(self, tex=None, **kwargs): 

+

82 if tex is not None: 

+

83 kwargs['color'] = (0,0,0,1) 

+

84 super(TexModel, self).__init__(**kwargs) 

+

85 

+

86 self.tex = tex 

+

87 

+

88 def init(self): 

+

89 super(TexModel, self).init() 

+

90 if self.tex.tex is None: 

+

91 self.tex.init() 

+

92 

+

93 def render_queue(self, shader=None, **kwargs): 

+

94 if shader is not None: 

+

95 yield shader, self.draw, self.tex 

+

96 else: 

+

97 yield self.shader, self.draw, self.tex 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_utils_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_utils_py.html new file mode 100644 index 00000000..89a121c3 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_utils_py.html @@ -0,0 +1,150 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\utils.py: 21% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1 

+

2'''Needs docs''' 

+

3 

+

4 

+

5import numpy as np 

+

6from .textures import Texture 

+

7from OpenGL.GL import glBindTexture, glGetTexImage, GL_TEXTURE_2D, GL_RGBA, GL_UNSIGNED_BYTE 

+

8 

+

9def frustum(l, r, t, b, n, f): 

+

10 ''' 

+

11 This function emulates glFrustum: https://www.opengl.org/sdk/docs/man2/xhtml/glFrustum.xml 

+

12 A frustum is a solid cut by planes, e.g., the planes representing the viewable area of a screen. 

+

13 

+

14 Parameters 

+

15 ---------- 

+

16 l: float 

+

17 Distance to the left plane of the screen 

+

18 r: float 

+

19 Distance to the right plane of the screen 

+

20 t: float 

+

21 Distance to the top plane of the screen 

+

22 b: float 

+

23 Distance to the bottom plane of the screen 

+

24 n: float 

+

25 Distance to the near plane of the screen 

+

26 f: float 

+

27 Distance to the far plane of the screen 

+

28 

+

29 Returns 

+

30 ------- 

+

31 Projection matrix to apply to solid to truncate 

+

32 

+

33 ''' 

+

34 rl, nrl = r + l, r - l 

+

35 tb, ntb = t + b, t - b 

+

36 fn, nfn = f + n, f - n 

+

37 return np.array([[2*n / nrl, 0, rl / nrl, 0], 

+

38 [ 0, 2*n / ntb, tb / ntb, 0], 

+

39 [ 0, 0, -fn / nfn, -2*f*n / nfn], 

+

40 [ 0, 0, -1, 0]]) 

+

41 

+

42def perspective(angle, aspect, near, far): 

+

43 '''Generates a perspective transform matrix''' 

+

44 f = 1./ np.tan(np.radians(angle) / 2) 

+

45 fn = far + near 

+

46 nfn = far - near 

+

47 return np.array([[f/aspect, 0, 0, 0], 

+

48 [0, f, 0, 0], 

+

49 [0, 0, -fn/nfn, -2*far*near/nfn], 

+

50 [0, 0, -1, 0]]) 

+

51 

+

52def offaxis_frusta(winsize, fov, near, far, focal_dist, iod, flip=False, flip_z=False): 

+

53 aspect = winsize[0] / winsize[1] 

+

54 top = near * np.tan(np.radians(fov) / 2) 

+

55 right = aspect*top 

+

56 fshift = (iod/2) * near / focal_dist 

+

57 

+

58 # calculate the perspective matrix for the left eye and for the right eye 

+

59 left = frustum(-right+fshift, right+fshift, top, -top, near, far) 

+

60 right = frustum(-right-fshift, right-fshift, top, -top, near, far) 

+

61 

+

62 # multiply in the iod (intraocular distance) modelview transform 

+

63 lxfm, rxfm = np.eye(4), np.eye(4) 

+

64 lxfm[:3,-1] = [0.5*iod, 0, 0] 

+

65 rxfm[:3,-1] = [-0.5*iod, 0, 0] 

+

66 flip_mat = np.eye(4) 

+

67 

+

68 

+

69 if flip: 

+

70 flip_mat[0,0] = -1 

+

71 if flip_z: 

+

72 flip_mat[1,1] = -1 

+

73 

+

74 return np.dot(flip_mat, np.dot(left, lxfm)), np.dot(flip_mat, np.dot(right, rxfm)) 

+

75 

+

76 #return np.dot(left, lxfm), np.dot(right, rxfm) 

+

77 

+

78def cloudy_tex(size=(512,512)): 

+

79 '''Generates 1/f distributed noise and puts it into a texture. Looks like clouds''' 

+

80 im = np.random.randn(*size) 

+

81 grid = np.mgrid[-1:1:size[0]*1j, -1:1:size[1]*1j] 

+

82 mask = 1/(grid**2).sum(0) 

+

83 fim = np.fft.fftshift(np.fft.fft2(im)) 

+

84 im = np.abs(np.fft.ifft2(np.fft.fftshift(mask * fim))) 

+

85 im -= im.min() 

+

86 return Texture(im / im.max()) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_window_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_window_py.html new file mode 100644 index 00000000..e5fc02ba --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_window_py.html @@ -0,0 +1,497 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\window.py: 27% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Graphical display classes. Experimental tasks involving graphical displays  

+

3inherit from these classes. 

+

4''' 

+

5import os 

+

6 

+

7import numpy as np 

+

8from OpenGL.GL import * 

+

9 

+

10from riglib.experiment import LogExperiment 

+

11from riglib.experiment import traits 

+

12 

+

13from .render import stereo 

+

14from .models import Group 

+

15from .xfm import Quaternion 

+

16from riglib.stereo_opengl.primitives import Sphere, Cube, Chain 

+

17from riglib.stereo_opengl.environment import Box 

+

18import time 

+

19from config import config 

+

20from .primitives import Cylinder, Sphere, Cone 

+

21import socket 

+

22 

+

23try: 

+

24 os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 

+

25 import pygame 

+

26except: 

+

27 import warnings 

+

28 warnings.warn('riglib/stereo_opengl/window.py: not importing name pygame') 

+

29 

+

30# for WindowDispl2D only 

+

31from riglib.stereo_opengl.primitives import Shape2D 

+

32 

+

33 

+

34class Window(LogExperiment): 

+

35 ''' 

+

36 Generic stereo window  

+

37 ''' 

+

38 status = dict(draw=dict(stop=None)) 

+

39 state = "draw" 

+

40 stop = False 

+

41 

+

42 #window_size = traits.Tuple((1920*2, 1080), descr='window size, in pixels') 

+

43 #XPS computer 

+

44 window_size = traits.Tuple((1280, 1080), descr='window size, in pixels') 

+

45 # window_size = (1920*2, 1080) 

+

46 background = (0,0,0,1) 

+

47 

+

48 #Screen parameters, all in centimeters -- adjust for monkey 

+

49 fov = np.degrees(np.arctan(14.65/(44.5+3)))*2 

+

50 screen_dist = 44.5+3 

+

51 iod = 2.5 # intraocular distance 

+

52 

+

53 show_environment = traits.Int(0) 

+

54 

+

55 def __init__(self, *args, **kwargs): 

+

56 super(Window, self).__init__(*args, **kwargs) 

+

57 

+

58 self.models = [] 

+

59 self.world = None 

+

60 self.event = None 

+

61 

+

62 # os.popen('sudo vbetool dpms on') 

+

63 

+

64 if self.show_environment: 

+

65 self.add_model(Box()) 

+

66 

+

67 def set_os_params(self): 

+

68 os.environ['SDL_VIDEO_WINDOW_POS'] = config.display_start_pos 

+

69 #print(os.environ['SDL_VIDEO_WINDOW_POS']) 

+

70 os.environ['SDL_VIDEO_X11_WMCLASS'] = "monkey_experiment" 

+

71 

+

72 def screen_init(self): 

+

73 self.set_os_params() 

+

74 pygame.init() 

+

75 

+

76 pygame.display.gl_set_attribute(pygame.GL_DEPTH_SIZE, 24) 

+

77 flags = pygame.DOUBLEBUF | pygame.HWSURFACE | pygame.OPENGL | pygame.NOFRAME 

+

78 try: 

+

79 pygame.display.gl_set_attribute(pygame.GL_MULTISAMPLEBUFFERS,1) 

+

80 self.surf = pygame.display.set_mode(self.window_size, flags) 

+

81 except: 

+

82 pygame.display.gl_set_attribute(pygame.GL_MULTISAMPLEBUFFERS,0) 

+

83 self.surf = pygame.display.set_mode(self.window_size, flags) 

+

84 

+

85 glEnable(GL_BLEND) 

+

86 glDepthFunc(GL_LESS) 

+

87 glEnable(GL_DEPTH_TEST) 

+

88 glEnable(GL_TEXTURE_2D) 

+

89 glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA) 

+

90 glClearColor(*self.background) 

+

91 glClearDepth(1.0) 

+

92 glDepthMask(GL_TRUE) 

+

93 

+

94 self.renderer = self._get_renderer() 

+

95 

+

96 #this effectively determines the modelview matrix 

+

97 self.world = Group(self.models) 

+

98 self.world.init() 

+

99 

+

100 #up vector is always (0,0,1), why would I ever need to roll the camera?! 

+

101 self.set_eye((0, -self.screen_dist, 0), (0,0)) 

+

102 

+

103 def _get_renderer(self): 

+

104 near = 1 

+

105 far = 1024 

+

106 return stereo.MirrorDisplay(self.window_size, self.fov, near, far, self.screen_dist, self.iod) 

+

107 

+

108 def set_eye(self, pos, vec, reset=True): 

+

109 '''Set the eye's position and direction. Camera starts at (0,0,0), pointing towards positive y''' 

+

110 self.world.translate(pos[0], pos[2], pos[1], reset=True).rotate_x(-90) 

+

111 self.world.rotate_y(vec[0]).rotate_x(vec[1]) 

+

112 

+

113 def add_model(self, model): 

+

114 if self.world is None: 

+

115 #world doesn't exist yet, add the model to cache 

+

116 self.models.append(model) 

+

117 else: 

+

118 #We're already running, initialize the model and add it to the world 

+

119 model.init() 

+

120 self.world.add(model) 

+

121 

+

122 def show_object(self, obj, show=False): 

+

123 ''' 

+

124 Show or hide an object. This function is an abstraction so that tasks don't need to know about attach/detach 

+

125 ''' 

+

126 if show: 

+

127 obj.attach() 

+

128 else: 

+

129 obj.detach() 

+

130 

+

131 def draw_world(self): 

+

132 glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT) 

+

133 self.renderer.draw(self.world) 

+

134 pygame.display.flip() 

+

135 self.renderer.draw_done() 

+

136 

+

137 def _get_event(self): 

+

138 for e in pygame.event.get(pygame.KEYDOWN): 

+

139 return (e.key, e.type) 

+

140 

+

141 def _start_None(self): 

+

142 pygame.display.quit() 

+

143 

+

144 def _test_stop(self, ts): 

+

145 ''' 

+

146 Stop the task if the escape key is pressed, or if the super _test_stop instructs a stop 

+

147 ''' 

+

148 super_stop = super(Window, self)._test_stop(ts) 

+

149 from pygame import K_ESCAPE 

+

150 return super_stop or self.event is not None and self.event[0] == K_ESCAPE 

+

151 

+

152 def requeue(self): 

+

153 self.renderer._queue_render(self.world) 

+

154 

+

155 def _cycle(self): 

+

156 self.requeue() 

+

157 self.draw_world() 

+

158 super(Window, self)._cycle() 

+

159 self.event = self._get_event() 

+

160 

+

161 

+

162class WindowWithExperimenterDisplay(Window): 

+

163 hostname = socket.gethostname() 

+

164 # This class has a hard-coded window size 

+

165 if hostname == 'lynx': 

+

166 _stereo_window_flip = True 

+

167 _stereo_main_flip_z = True 

+

168 window_size = (1280+480, 1024) 

+

169 window_size1 = (1280, 1024) 

+

170 window_size2 = (480, 270) 

+

171 else: 

+

172 _stereo_window_flip = False 

+

173 _stereo_main_flip_z = False 

+

174 window_size = (1920 + 480, 1080) 

+

175 window_size1 = (1920, 1080) 

+

176 window_size2 = (480, 270) 

+

177 

+

178 def __init__(self, *args, **kwargs): 

+

179 super(WindowWithExperimenterDisplay, self).__init__(*args, **kwargs) 

+

180 # This class has a hard-coded window size 

+

181 # self.window_size = (1920 + 480, 1080) 

+

182 

+

183 def set_os_params(self): 

+

184 # NOTE: in Ubuntu Unity, setting the SDL_VIDEO_WINDOW_POS seems to be largely ignored. 

+

185 # You can set which screen the window appears on if you have a dual display, but you cannot set the exact position 

+

186 # Instead, you have to hard-code a render start location in the compiz-config settings manager 

+

187 # http://askubuntu.com/questions/452995/how-to-adjust-window-placement-in-unity-ubuntu-14-04-based-on-overlapping-top-b 

+

188 os.environ['SDL_VIDEO_WINDOW_POS'] = config.display_start_pos 

+

189 os.environ['SDL_VIDEO_X11_WMCLASS'] = "monkey_experiment_with_mini" 

+

190 

+

191 def _get_renderer(self): 

+

192 near = 1 

+

193 far = 1024 

+

194 return stereo.DualMultisizeDisplay(self.window_size1, self.window_size2, self.fov, near, far, self.screen_dist, self.iod, flip=self._stereo_window_flip, 

+

195 flip_main_z = self._stereo_main_flip_z) 

+

196 

+

197 

+

198class WindowDispl2D(Window): 

+

199 background = (1,1,1,1) 

+

200 def __init__(self, *args, **kwargs): 

+

201 self.models = [] 

+

202 self.world = None 

+

203 self.event = None 

+

204 super(WindowDispl2D, self).__init__(*args, **kwargs) 

+

205 

+

206 def _set_workspace_size(self): 

+

207 ''' 

+

208 By default, the workspace is 50x28 cm, centered around the origin (0,0) 

+

209 ''' 

+

210 self.workspace_bottom_left = (-25., -14.) 

+

211 self.workspace_top_right = (25., 14.) 

+

212 

+

213 def screen_init(self): 

+

214 os.environ['SDL_VIDEO_WINDOW_POS'] = config.display_start_pos 

+

215 os.environ['SDL_VIDEO_X11_WMCLASS'] = "monkey_experiment" 

+

216 pygame.init() 

+

217 self.clock = pygame.time.Clock() 

+

218 

+

219 flags = pygame.NOFRAME 

+

220 self._set_workspace_size() 

+

221 

+

222 self.workspace_x_len = self.workspace_top_right[0] - self.workspace_bottom_left[0] 

+

223 self.workspace_y_len = self.workspace_top_right[1] - self.workspace_bottom_left[1] 

+

224 

+

225 self.display_border = 10 

+

226 

+

227 self.size = np.array(self.window_size, dtype=np.float64) 

+

228 self.screen = pygame.display.set_mode(self.window_size, flags) 

+

229 self.screen_background = pygame.Surface(self.screen.get_size()).convert() 

+

230 self.screen_background.fill(self.background) 

+

231 

+

232 x1, y1 = self.workspace_top_right 

+

233 x0, y0 = self.workspace_bottom_left 

+

234 self.normalize = np.array(np.diag([1./(x1-x0), 1./(y1-y0), 1])) 

+

235 self.center_xform = np.array([[1., 0, -x0], 

+

236 [0, 1., -y0], 

+

237 [0, 0, 1]]) 

+

238 self.norm_to_screen = np.array(np.diag(np.hstack([self.size, 1]))) 

+

239 

+

240 # the y-coordinate in pixel space has to be swapped for some graphics convention reason 

+

241 self.flip_y_coord = np.array([[1, 0, 0], 

+

242 [0, -1, self.size[1]], 

+

243 [0, 0, 1]]) 

+

244 

+

245 self.pos_space_to_pixel_space = np.dot(self.flip_y_coord, np.dot(self.norm_to_screen, np.dot(self.normalize, self.center_xform))) 

+

246 

+

247 self.world = Group(self.models) 

+

248 # Dont 'init' self.world in this Window. Just allocates a bunch of OpenGL stuff which is not necessary (and may not work in some cases) 

+

249 # self.world.init() 

+

250 

+

251 #initialize surfaces for translucent markers 

+

252 TRANSPARENT = (255,0,255) 

+

253 self.surf={} 

+

254 self.surf['0'] = pygame.Surface(self.screen.get_size()) 

+

255 self.surf['0'].fill(TRANSPARENT) 

+

256 self.surf['0'].set_colorkey(TRANSPARENT) 

+

257 

+

258 self.surf['1'] = pygame.Surface(self.screen.get_size()) 

+

259 self.surf['1'].fill(TRANSPARENT) 

+

260 self.surf['1'].set_colorkey(TRANSPARENT) 

+

261 

+

262 #values of alpha: higher = less translucent 

+

263 self.surf['0'].set_alpha(170) #Cursor 

+

264 self.surf['1'].set_alpha(130) #Targets 

+

265 

+

266 self.surf_background = pygame.Surface(self.surf['0'].get_size()).convert() 

+

267 self.surf_background.fill(TRANSPARENT) 

+

268 

+

269 self.i = 0 

+

270 

+

271 def pos2pix(self, kfpos): 

+

272 # re-specify the point in homogenous coordinates 

+

273 pt = np.hstack([kfpos, 1]).reshape(-1, 1) 

+

274 

+

275 # perform the homogenous transformation 

+

276 pix_coords = np.dot(self.pos_space_to_pixel_space, pt) 

+

277 

+

278 pix_pos = np.array(pix_coords[:2,0], dtype=int) 

+

279 return pix_pos 

+

280 

+

281 def get_surf(self): 

+

282 return self.surf[str(np.min([self.i,1]))] 

+

283 

+

284 def draw_model(self, model): 

+

285 ''' 

+

286 Draw a single Model on the current surface, or recurse if the model is a composite model (i.e., a Group) 

+

287 ''' 

+

288 color = tuple([int(255*x) for x in model.color[0:3]]) 

+

289 if isinstance(model, Sphere): 

+

290 pos = model._xfm.move[[0,2]] 

+

291 pix_pos = self.pos2pix(pos) 

+

292 

+

293 rad = model.radius 

+

294 pix_radius = self.pos2pix(np.array([model.radius, 0]))[0] - self.pos2pix([0,0])[0] 

+

295 

+

296 #Draws cursor and targets on transparent surfaces 

+

297 pygame.draw.circle(self.get_surf(), color, pix_pos, pix_radius) 

+

298 

+

299 elif isinstance(model, Shape2D): 

+

300 # model.draw() returns True if the object was drawn 

+

301 # (which happens if the object's .visible attr is True) 

+

302 if model.draw(self.get_surf(), self.pos2pix): 

+

303 pass 

+

304 elif isinstance(model, Cube): 

+

305 pos = model.xfm.move[[0,2]] 

+

306 side_len = model.side_len 

+

307 

+

308 left = pos[0] - side_len/2 

+

309 right = pos[0] + side_len/2 

+

310 top = pos[1] + side_len/2 

+

311 bottom = pos[1] - side_len/2 

+

312 

+

313 top_left = np.array([left, top]) 

+

314 bottom_right = np.array([right, bottom]) 

+

315 top_left_pix_pos = self.pos2pix(top_left) 

+

316 bottom_right_pix_pos = self.pos2pix(bottom_right) 

+

317 

+

318 rect = pygame.Rect(top_left_pix_pos, bottom_right_pix_pos - top_left_pix_pos) 

+

319 color = tuple([int(255*x) for x in model.color[0:3]]) 

+

320 

+

321 pygame.draw.rect(self.get_surf(), color, rect) 

+

322 

+

323 elif isinstance(model, (Cylinder, Cone)): 

+

324 vec_st = np.array([0., 0, 0, 1]) 

+

325 vec_end = np.array([0., 0, model.height, 1]) 

+

326 

+

327 cyl_xform = model._xfm.to_mat() 

+

328 cyl_start = np.dot(cyl_xform, vec_st) 

+

329 cyl_end = np.dot(cyl_xform, vec_end) 

+

330 pix_radius = self.pos2pix(np.array([model.radius, 0]))[0] - self.pos2pix([0,0])[0] 

+

331 

+

332 pygame.draw.line(self.get_surf(), color, self.pos2pix(cyl_start[[0,2]]), self.pos2pix(cyl_end[[0,2]]), pix_radius) 

+

333 

+

334 # print cyl_start, cyl_end 

+

335 

+

336 elif isinstance(model, Group): 

+

337 for mdl in model: 

+

338 self.draw_model(mdl) 

+

339 

+

340 def draw_world(self): 

+

341 #Refreshes the screen with original background 

+

342 self.screen.blit(self.screen_background, (0, 0)) 

+

343 self.surf['0'].blit(self.surf_background,(0,0)) 

+

344 self.surf['1'].blit(self.surf_background,(0,0)) 

+

345 

+

346 # surface index 

+

347 self.i = 0 

+

348 

+

349 for model in self.world.models: 

+

350 self.draw_model(model) 

+

351 self.i += 1 

+

352 

+

353 #Renders the new surfaces 

+

354 self.screen.blit(self.surf['0'], (0,0)) 

+

355 self.screen.blit(self.surf['1'], (0,0)) 

+

356 pygame.display.update() 

+

357 

+

358 def requeue(self): 

+

359 ''' 

+

360 Simulation 'requeue' does nothing because the simulation is lazy and 

+

361 inefficient and chooses to redraw the entire screen every loop 

+

362 ''' 

+

363 pass 

+

364 

+

365 

+

366 

+

367class FakeWindow(Window): 

+

368 ''' 

+

369 A dummy class to secretly avoid rendering graphics without  

+

370 the graphics-based tasks knowing about it. Used e.g. for simulation  

+

371 purposes where the graphics only slow down the simulation. 

+

372 ''' 

+

373 background = (1,1,1,1) 

+

374 def __init__(self, *args, **kwargs): 

+

375 self.models = [] 

+

376 self.world = None 

+

377 self.event = None 

+

378 super(FakeWindow, self).__init__(*args, **kwargs) 

+

379 

+

380 def screen_init(self): 

+

381 self.world = Group(self.models) 

+

382 # self.world.init() 

+

383 

+

384 def draw_world(self): 

+

385 pass 

+

386 

+

387 def requeue(self): 

+

388 pass 

+

389 

+

390 def _start_reward(self, *args, **kwargs): 

+

391 n_rewards = self.calc_state_occurrences('reward') 

+

392 if n_rewards % 10 == 0: 

+

393 print(n_rewards) 

+

394 super(FakeWindow, self)._start_reward(*args, **kwargs) 

+

395 

+

396 def show_object(self, *args, **kwargs): 

+

397 pass 

+

398 

+

399 def _get_event(self): 

+

400 pass 

+

401 

+

402 

+

403class FPScontrol(Window): 

+

404 '''A mixin that adds a WASD + Mouse controller to the window.  

+

405 Use WASD to move in XY plane, q to go down, e to go up''' 

+

406 

+

407 def init(self): 

+

408 super(FPScontrol, self).init() 

+

409 pygame.event.set_grab(True) 

+

410 pygame.mouse.set_visible(False) 

+

411 self.eyepos = [0,-self.screen_dist, 0] 

+

412 self.eyevec = [0,0] 

+

413 self.wasd = [False, False, False, False, False, False] 

+

414 

+

415 def _get_event(self): 

+

416 retme = None 

+

417 for e in pygame.event.get([pygame.MOUSEMOTION, pygame.KEYDOWN, pygame.KEYUP, pygame.QUIT]): 

+

418 moved = any(self.wasd) 

+

419 if e.type == pygame.MOUSEMOTION: 

+

420 self.world.xfm.rotate *= Quaternion.from_axisangle((1,0,0), np.radians(e.rel[1]*.1)) 

+

421 self.world.xfm.rotate *= Quaternion.from_axisangle((0,0,1), np.radians(e.rel[0]*.1)) 

+

422 self.world._recache_xfm() 

+

423 elif e.type == pygame.KEYDOWN: 

+

424 kn = pygame.key.name(e.key) 

+

425 if kn in ["escape", "q"]: 

+

426 self.stop = True 

+

427 retme = (e.key, e.type) 

+

428 elif e.type == pygame.QUIT: 

+

429 self.stop = True 

+

430 

+

431 if moved: 

+

432 self.set_eye(self.eyepos, self.eyevec, reset=True) 

+

433 return retme 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_xfm_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_xfm_py.html new file mode 100644 index 00000000..d202b026 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stereo_opengl_xfm_py.html @@ -0,0 +1,310 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stereo_opengl\xfm.py: 55% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Quaternions and generic 3D transformations 

+

3''' 

+

4 

+

5 

+

6 

+

7import numpy as np 

+

8 

+

9class Quaternion(object): 

+

10 def __init__(self, w=1, x=0, y=0, z=0): 

+

11 if isinstance(w, (list, tuple, np.ndarray)) and not isinstance(x, np.ndarray): 

+

12 #Allows the use of a quaternion as a vector 

+

13 self.quat = np.array([0, w[0], w[1], w[2]]) 

+

14 else: 

+

15 self.quat = np.array([w, x, y, z]) 

+

16 

+

17 def __repr__(self): 

+

18 if self.quat.ndim > 1: 

+

19 return "<Quaternion set for %d rotations>"%self.quat.shape[1] 

+

20 return "%g+%gi+%gj+%gk"%tuple(self.quat) 

+

21 

+

22 def norm(self): 

+

23 self.quat = self.quat / np.sqrt((self.quat**2).sum()) 

+

24 return self 

+

25 

+

26 def conj(self): 

+

27 return Quaternion(self.w, *(-self.vec)) 

+

28 

+

29 @property 

+

30 def H(self): 

+

31 return self.conj() 

+

32 

+

33 def __getattr__(self, attr): 

+

34 if attr in ["w", "scalar"]: 

+

35 return self.quat[0] 

+

36 elif attr in ["x", "i"]: 

+

37 return self.quat[1] 

+

38 elif attr in ["y", "j"]: 

+

39 return self.quat[2] 

+

40 elif attr in ["z", "k"]: 

+

41 return self.quat[3] 

+

42 elif attr in ["v", "vec", "vector"]: 

+

43 return self.quat[1:] 

+

44 else: 

+

45 super(Quaternion, self).__getattr__(self, attr) 

+

46 

+

47 def __mul__(self, other): 

+

48 if isinstance(other, Quaternion): 

+

49 w = self.w*other.w - (self.vec*other.vec).sum(0) 

+

50 v = self.w*other.vec + other.w*self.vec + np.cross(self.vec.T, other.vec.T).T 

+

51 return Quaternion(w, *v).norm() 

+

52 elif isinstance(other, (np.ndarray, list, tuple)): 

+

53 if isinstance(other, (list, tuple)): 

+

54 other = np.array(other) 

+

55 #rotate a vector, will need to be implemented in GLSL eventually 

+

56 cross = np.cross(self.vec.T, other) + self.w*other 

+

57 return (other + np.cross(2*self.vec.T, cross)).squeeze() 

+

58 ''' 

+

59 conj = self.H 

+

60 w = -np.dot(other, conj.vec) 

+

61 vec = np.outer(conj.w, other) + np.cross(other, conj.vec.T) 

+

62 if self.quat.ndim > 1: 

+

63 return self.w*vec.T + np. 

+

64 return self.w*vec + np.outer(w, self.vec).squeeze() + np.cross(self.vec, vec) 

+

65 ''' 

+

66 else: 

+

67 raise ValueError 

+

68 

+

69 def to_mat(self): 

+

70 ''' 

+

71 Convert to an augmented rotation matrix if the quaternion is of unit norm 

+

72 ??? Does this function provide a sensible result for non-unit quaternions? 

+

73 

+

74 Parameters 

+

75 ---------- 

+

76 None 

+

77 

+

78 Returns 

+

79 ------- 

+

80 np.ndarray of shape (4, 4) 

+

81 Affine transformation matrix 

+

82 ''' 

+

83 a, b, c, d = self.quat 

+

84 return np.array([ 

+

85 [a**2+b**2-c**2-d**2, 2*b*c-2*a*d, 2*b*d+2*a*c, 0], 

+

86 [ 2*b*c+2*a*d, a**2-b**2+c**2-d**2, 2*c*d-2*a*b, 0], 

+

87 [ 2*b*d-2*a*c, 2*c*d+2*a*b, a**2-b**2-c**2+d**2, 0], 

+

88 [ 0, 0, 0, 1]]) 

+

89 

+

90 @classmethod 

+

91 def from_mat(cls, M): 

+

92 qw = np.sqrt(1 + M[0,0] + M[1,1] + M[2,2]) / 2 

+

93 qx = (M[2,1] - M[1,2])/(4*qw) 

+

94 qy = (M[0,2] - M[2,0])/(4*qw) 

+

95 qz = (M[1,0] - M[0,1])/(4*qw) 

+

96 return Quaternion(w=qw, x=qx, y=qy, z=qz) 

+

97 

+

98 def rotate_to(self, vec): 

+

99 svec = self.vec / np.sqrt((self.vec**2).sum()) 

+

100 nvec = nvec = vec2 / np.sqrt((vec2**2).sum()) 

+

101 rad = np.arccos(np.dot(svec, nvec)) 

+

102 axis = np.cross(svec, nvec) 

+

103 self = self.from_axisangle(axis, rad)*self 

+

104 

+

105 @classmethod 

+

106 def rotate_vecs(cls, vec1, vec2): 

+

107 ''' 

+

108 Get the quaternion which rotates vec1 onto vec2 

+

109 

+

110 Parameters 

+

111 ---------- 

+

112 vec1: np.ndarray of shape (3,) 

+

113 Starting vector 

+

114 vec2: np.ndarray of shape (3,) 

+

115 Vector which defines the orientation that you want to rotate the first vector to 

+

116 

+

117 Returns 

+

118 ------- 

+

119 Quaternion representing the rotation 

+

120 ''' 

+

121 vec1, vec2 = np.array(vec1), np.array(vec2) 

+

122 svec = vec1 / np.sqrt((vec1**2).sum()) 

+

123 nvec = vec2 / np.sqrt((vec2**2).sum()) 

+

124 if nvec.ndim > 1: 

+

125 if svec.ndim > 1: 

+

126 rad = (svec * nvec).sum(1) 

+

127 else: 

+

128 rad = np.arccos(np.dot(svec, nvec.T)) 

+

129 else: 

+

130 rad = np.arccos(np.dot(svec, nvec)) 

+

131 axis = np.cross(svec, nvec) 

+

132 return cls.from_axisangle(axis, rad) 

+

133 

+

134 @classmethod 

+

135 def from_axisangle(cls, axis, rad): 

+

136 ''' 

+

137 Convert from the Axis-angle representation of rotations to the quaternion representation 

+

138 

+

139 Parameters 

+

140 ---------- 

+

141 axis: np.ndarray of shape (3,) or ????? 

+

142 Rotation axis 

+

143 rad: float 

+

144 Angle to rotate around the specified axis in radians 

+

145 

+

146 Returns 

+

147 ------- 

+

148 Quaternion representing the rotation 

+

149 ''' 

+

150 #normalize the axis first 

+

151 axis = np.array(axis) 

+

152 if axis.ndim > 1: 

+

153 axis = axis.T / np.sqrt((axis**2).sum(1)) 

+

154 else: 

+

155 if not np.all(axis == 0): 

+

156 axis = axis / np.sqrt((axis**2).sum()) 

+

157 w = np.cos(rad*0.5) 

+

158 v = axis * np.sin(rad*0.5) 

+

159 return cls(w, *v) 

+

160 

+

161class Transform(object): 

+

162 ''' 

+

163 Homogenous transformations ??? 

+

164 ''' 

+

165 def __init__(self, move=(0,0,0), scale=1, rotate=None): 

+

166 self.move = np.array(move, dtype=np.float) 

+

167 self.scale = scale 

+

168 self.rotate = rotate if rotate is not None else Quaternion() 

+

169 

+

170 def __repr__(self): 

+

171 return "Rotate %s, then scale %s, then translate %s"%(self.rotate, self.scale, self.move) 

+

172 

+

173 def __mul__(self, other): 

+

174 if isinstance(other, Transform): 

+

175 #Pre-multiply the other transform, then apply self 

+

176 move = self.move + self.rotate*other.move 

+

177 scale = self.scale * other.scale 

+

178 rot = self.rotate * other.rotate 

+

179 return Transform(move, scale, rot) 

+

180 

+

181 elif isinstance(other, Quaternion): 

+

182 #Apply the quaternion directly to current rotation 

+

183 return Transform(self.move, self.scale, other.rotate * self.rotate) 

+

184 

+

185 def __call__(self, vecs): 

+

186 return self.scale * (self.rotate * vecs) + self.move 

+

187 

+

188 def translate(self, x, y, z, reset=False): 

+

189 ''' 

+

190 Set the translation point of the transformation 

+

191 

+

192 Parameters 

+

193 ---------- 

+

194 x, y, z: float 

+

195 Coordinates representing how much to move 

+

196 reset: bool, optional, default=False  

+

197 If true, the new coordinates replace the old ones. If false, they are added on 

+

198 ''' 

+

199 if reset: 

+

200 self.move[:] = x,y,z 

+

201 else: 

+

202 self.move += x,y,z 

+

203 return self 

+

204 

+

205 def rotate_x(self, rad, reset=False): 

+

206 rotate = Quaternion.from_axisangle((1,0,0), rad) 

+

207 if reset: 

+

208 self.rotate = rotate 

+

209 else: 

+

210 self.rotate = (rotate * self.rotate).norm() 

+

211 return self 

+

212 

+

213 def rotate_y(self, rad, reset=False): 

+

214 rotate = Quaternion.from_axisangle((0,1,0), rad) 

+

215 if reset: 

+

216 self.rotate = rotate 

+

217 else: 

+

218 self.rotate = (rotate * self.rotate).norm() 

+

219 return self 

+

220 

+

221 def rotate_z(self, rad, reset=False): 

+

222 rotate = Quaternion.from_axisangle((0,0,1), rad) 

+

223 if reset: 

+

224 self.rotate = rotate 

+

225 else: 

+

226 self.rotate = (rotate * self.rotate).norm() 

+

227 return self 

+

228 

+

229 def to_mat(self): 

+

230 scale = np.eye(4) 

+

231 scale[(0,1,2), (0,1,2)] = self.scale 

+

232 move = np.eye(4) 

+

233 move[:3, -1] = self.move 

+

234 

+

235 return np.dot(move, np.dot(scale, self.rotate.to_mat())) 

+

236 

+

237def test(): 

+

238 world = Transform().rotate_x(np.radians(-90)) 

+

239 eye = Transform().translate(0,35,0) 

+

240 obj = Transform().translate(0,10,5) 

+

241 assert np.allclose((world*eye*obj)((0,0,1)), [0,6,-45]) 

+

242 obj.rotate_y(np.radians(-90)) 

+

243 assert np.allclose((world*eye*obj)((0,0,1)), [-1, 5, -45]) 

+

244 obj.rotate_z(np.radians(-90)) 

+

245 assert np.allclose((world*eye*obj)((0,0,1)), [0,5,-46]) 

+

246 assert np.allclose(np.dot((world*eye*obj).to_mat(), [0,0,1,1]), [0,5,-46, 1]) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stimulus_pulse_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stimulus_pulse_py.html new file mode 100644 index 00000000..687c04cd --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_stimulus_pulse_py.html @@ -0,0 +1,93 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\stimulus_pulse.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1import os 

+

2import time 

+

3import comedi 

+

4 

+

5 

+

6class stimulus_pulse(object): 

+

7 com = comedi.comedi_open('/dev/comedi0') 

+

8 

+

9 def __init__(self, *args, **kwargs): 

+

10 #self.com = comedi.comedi_open('/dev/comedi0') 

+

11 super(stimulus_pulse, self).__init__(*args, **kwargs) 

+

12 subdevice = 0 

+

13 write_mask = 0x800000 

+

14 val = 0x000000 

+

15 base_channel = 0 

+

16 comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) 

+

17 

+

18 def pulse(self,ts): 

+

19 #super(stimulus_pulse, self).pulse() 

+

20 subdevice = 0 

+

21 write_mask = 0x800000 

+

22 val = 0x000000 

+

23 base_channel = 0 

+

24 while ts < 0.4: 

+

25 val = 0x800000 

+

26 comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) 

+

27 else: 

+

28 val = 0x000000 

+

29 comedi.comedi_dio_bitfield2(self.com, subdevice, write_mask, val, base_channel) 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_touch_data_py.html b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_touch_data_py.html new file mode 100644 index 00000000..d6b4308b --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/_GitHub_brain-python-interface_riglib_touch_data_py.html @@ -0,0 +1,163 @@ + + + + + + Coverage for C:\GitHub\brain-python-interface\riglib\touch_data.py: 0% + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ r + m + x + p   toggle line displays +

+

+ j + k   next/prev highlighted chunk +

+

+ 0   (zero) top of page +

+

+ 1   (one) first highlighted chunk +

+
+
+
+

1''' 

+

2Code for getting data from kinarm 

+

3''' 

+

4 

+

5import time 

+

6import numpy as np 

+

7from riglib.source import DataSourceSystem 

+

8import serial 

+

9import struct 

+

10import time 

+

11 

+

12class TouchData(DataSourceSystem): 

+

13 ''' 

+

14 Client for data streamed from kinarm, compatible with riglib.source.DataSource 

+

15 ''' 

+

16 update_freq = 1000. 

+

17 

+

18 # dtype is the numpy data type of items that will go  

+

19 # into the (multi-channel, in this case) datasource's ringbuffer 

+

20 dtype = np.dtype((np.float, (5,5))) 

+

21 

+

22 def __init__(self): 

+

23 ''' 

+

24 def __init__(self, addr=("192.168.0.8", 9090)): 

+

25  

+

26 Constructor for Kinarmdata and connect to server 

+

27 

+

28 Parameters 

+

29 ---------- 

+

30 addr : tuple of length 2 

+

31 (client (self) IP address, client UDP port) 

+

32 

+

33  

+

34 self.conn = kinarmsocket.KinarmSocket(addr) 

+

35 self.conn.connect() 

+

36 ''' 

+

37 self.conn = TouchDataInterface() 

+

38 self.conn.connect() 

+

39 

+

40 def start(self): 

+

41 ''' 

+

42 Start receiving data 

+

43 ''' 

+

44 self.data = self.get_iterator() 

+

45 

+

46 

+

47 def get_iterator(self): 

+

48 assert self.conn.connected, "Socket is not connected, cannot get data" 

+

49 while self.conn.connected: 

+

50 self.conn.touch_port.write('t') 

+

51 tmp = self.conn.touch_port.read() 

+

52 tmp2 = np.zeros((5, 5))-1 

+

53 tmp2[0, 0] = float(tmp) 

+

54 yield tmp2 

+

55 

+

56 def stop(self): 

+

57 ''' 

+

58 Disconnect from kinarmdata socket 

+

59 ''' 

+

60 self.conn.touch_port.close() 

+

61 

+

62 def get(self): 

+

63 ''' 

+

64 Get a new kinarm sample 

+

65 ''' 

+

66 return next(self.data) 

+

67 

+

68def make(cls=DataSourceSystem, *args, **kwargs): 

+

69 ''' 

+

70 Docstring 

+

71 This ridiculous function dynamically creates a class with a new init function 

+

72 

+

73 Parameters 

+

74 ---------- 

+

75 

+

76 Returns 

+

77 ------- 

+

78 ''' 

+

79 def init(self, *args, **kwargs): 

+

80 super(self.__class__, self).__init__(*args, **kwargs) 

+

81 

+

82 dtype = np.dtype((np.float, (1, ))) 

+

83 return type(cls.__name__, (cls,), dict(dtype=dtype, __init__=init)) 

+

84 

+

85class TouchDataInterface(object): 

+

86 def __init__(self): 

+

87 self.touch_port = serial.Serial('/dev/ttyACM2', baudrate=115200) 

+

88 

+

89 def _test_connected(self): 

+

90 try: 

+

91 self.touch_port.write('t') 

+

92 tmp = int(self.touch_port.read()) 

+

93 self.connected = True 

+

94 except: 

+

95 print('Error in interfacing w/ touch sensor') 

+

96 self.connected = False 

+

97 

+

98 def connect(self): 

+

99 self._test_connected() 

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/coverage_html.js b/tests/unit_tests (deprecated)/htmlcov/coverage_html.js new file mode 100644 index 00000000..3bf04bf9 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/coverage_html.js @@ -0,0 +1,589 @@ +// Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0 +// For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt + +// Coverage.py HTML report browser code. +/*jslint browser: true, sloppy: true, vars: true, plusplus: true, maxerr: 50, indent: 4 */ +/*global coverage: true, document, window, $ */ + +coverage = {}; + +// Find all the elements with shortkey_* class, and use them to assign a shortcut key. +coverage.assign_shortkeys = function () { + $("*[class*='shortkey_']").each(function (i, e) { + $.each($(e).attr("class").split(" "), function (i, c) { + if (/^shortkey_/.test(c)) { + $(document).bind('keydown', c.substr(9), function () { + $(e).click(); + }); + } + }); + }); +}; + +// Create the events for the help panel. +coverage.wire_up_help_panel = function () { + $("#keyboard_icon").click(function () { + // Show the help panel, and position it so the keyboard icon in the + // panel is in the same place as the keyboard icon in the header. + $(".help_panel").show(); + var koff = $("#keyboard_icon").offset(); + var poff = $("#panel_icon").position(); + $(".help_panel").offset({ + top: koff.top-poff.top, + left: koff.left-poff.left + }); + }); + $("#panel_icon").click(function () { + $(".help_panel").hide(); + }); +}; + +// Create the events for the filter box. +coverage.wire_up_filter = function () { + // Cache elements. + var table = $("table.index"); + var table_rows = table.find("tbody tr"); + var table_row_names = table_rows.find("td.name a"); + var no_rows = $("#no_rows"); + + // Create a duplicate table footer that we can modify with dynamic summed values. + var table_footer = $("table.index tfoot tr"); + var table_dynamic_footer = table_footer.clone(); + table_dynamic_footer.attr('class', 'total_dynamic hidden'); + table_footer.after(table_dynamic_footer); + + // Observe filter keyevents. + $("#filter").on("keyup change", $.debounce(150, function (event) { + var filter_value = $(this).val(); + + if (filter_value === "") { + // Filter box is empty, remove all filtering. + table_rows.removeClass("hidden"); + + // Show standard footer, hide dynamic footer. + table_footer.removeClass("hidden"); + table_dynamic_footer.addClass("hidden"); + + // Hide placeholder, show table. + if (no_rows.length > 0) { + no_rows.hide(); + } + table.show(); + + } + else { + // Filter table items by value. + var hidden = 0; + var shown = 0; + + // Hide / show elements. + $.each(table_row_names, function () { + var element = $(this).parents("tr"); + + if ($(this).text().indexOf(filter_value) === -1) { + // hide + element.addClass("hidden"); + hidden++; + } + else { + // show + element.removeClass("hidden"); + shown++; + } + }); + + // Show placeholder if no rows will be displayed. + if (no_rows.length > 0) { + if (shown === 0) { + // Show placeholder, hide table. + no_rows.show(); + table.hide(); + } + else { + // Hide placeholder, show table. + no_rows.hide(); + table.show(); + } + } + + // Manage dynamic header: + if (hidden > 0) { + // Calculate new dynamic sum values based on visible rows. + for (var column = 2; column < 20; column++) { + // Calculate summed value. + var cells = table_rows.find('td:nth-child(' + column + ')'); + if (!cells.length) { + // No more columns...! + break; + } + + var sum = 0, numer = 0, denom = 0; + $.each(cells.filter(':visible'), function () { + var ratio = $(this).data("ratio"); + if (ratio) { + var splitted = ratio.split(" "); + numer += parseInt(splitted[0], 10); + denom += parseInt(splitted[1], 10); + } + else { + sum += parseInt(this.innerHTML, 10); + } + }); + + // Get footer cell element. + var footer_cell = table_dynamic_footer.find('td:nth-child(' + column + ')'); + + // Set value into dynamic footer cell element. + if (cells[0].innerHTML.indexOf('%') > -1) { + // Percentage columns use the numerator and denominator, + // and adapt to the number of decimal places. + var match = /\.([0-9]+)/.exec(cells[0].innerHTML); + var places = 0; + if (match) { + places = match[1].length; + } + var pct = numer * 100 / denom; + footer_cell.text(pct.toFixed(places) + '%'); + } + else { + footer_cell.text(sum); + } + } + + // Hide standard footer, show dynamic footer. + table_footer.addClass("hidden"); + table_dynamic_footer.removeClass("hidden"); + } + else { + // Show standard footer, hide dynamic footer. + table_footer.removeClass("hidden"); + table_dynamic_footer.addClass("hidden"); + } + } + })); + + // Trigger change event on setup, to force filter on page refresh + // (filter value may still be present). + $("#filter").trigger("change"); +}; + +// Loaded on index.html +coverage.index_ready = function ($) { + // Look for a localStorage item containing previous sort settings: + var sort_list = []; + var storage_name = "COVERAGE_INDEX_SORT"; + var stored_list = undefined; + try { + stored_list = localStorage.getItem(storage_name); + } catch(err) {} + + if (stored_list) { + sort_list = JSON.parse('[[' + stored_list + ']]'); + } + + // Create a new widget which exists only to save and restore + // the sort order: + $.tablesorter.addWidget({ + id: "persistentSort", + + // Format is called by the widget before displaying: + format: function (table) { + if (table.config.sortList.length === 0 && sort_list.length > 0) { + // This table hasn't been sorted before - we'll use + // our stored settings: + $(table).trigger('sorton', [sort_list]); + } + else { + // This is not the first load - something has + // already defined sorting so we'll just update + // our stored value to match: + sort_list = table.config.sortList; + } + } + }); + + // Configure our tablesorter to handle the variable number of + // columns produced depending on report options: + var headers = []; + var col_count = $("table.index > thead > tr > th").length; + + headers[0] = { sorter: 'text' }; + for (i = 1; i < col_count-1; i++) { + headers[i] = { sorter: 'digit' }; + } + headers[col_count-1] = { sorter: 'percent' }; + + // Enable the table sorter: + $("table.index").tablesorter({ + widgets: ['persistentSort'], + headers: headers + }); + + coverage.assign_shortkeys(); + coverage.wire_up_help_panel(); + coverage.wire_up_filter(); + + // Watch for page unload events so we can save the final sort settings: + $(window).unload(function () { + try { + localStorage.setItem(storage_name, sort_list.toString()) + } catch(err) {} + }); +}; + +// -- pyfile stuff -- + +coverage.pyfile_ready = function ($) { + // If we're directed to a particular line number, highlight the line. + var frag = location.hash; + if (frag.length > 2 && frag[1] === 't') { + $(frag).addClass('highlight'); + coverage.set_sel(parseInt(frag.substr(2), 10)); + } + else { + coverage.set_sel(0); + } + + $(document) + .bind('keydown', 'j', coverage.to_next_chunk_nicely) + .bind('keydown', 'k', coverage.to_prev_chunk_nicely) + .bind('keydown', '0', coverage.to_top) + .bind('keydown', '1', coverage.to_first_chunk) + ; + + $(".button_toggle_run").click(function (evt) {coverage.toggle_lines(evt.target, "run");}); + $(".button_toggle_exc").click(function (evt) {coverage.toggle_lines(evt.target, "exc");}); + $(".button_toggle_mis").click(function (evt) {coverage.toggle_lines(evt.target, "mis");}); + $(".button_toggle_par").click(function (evt) {coverage.toggle_lines(evt.target, "par");}); + + coverage.assign_shortkeys(); + coverage.wire_up_help_panel(); + + coverage.init_scroll_markers(); + + // Rebuild scroll markers when the window height changes. + $(window).resize(coverage.build_scroll_markers); +}; + +coverage.toggle_lines = function (btn, cls) { + btn = $(btn); + var show = "show_"+cls; + if (btn.hasClass(show)) { + $("#source ." + cls).removeClass(show); + btn.removeClass(show); + } + else { + $("#source ." + cls).addClass(show); + btn.addClass(show); + } + coverage.build_scroll_markers(); +}; + +// Return the nth line div. +coverage.line_elt = function (n) { + return $("#t" + n); +}; + +// Return the nth line number div. +coverage.num_elt = function (n) { + return $("#n" + n); +}; + +// Set the selection. b and e are line numbers. +coverage.set_sel = function (b, e) { + // The first line selected. + coverage.sel_begin = b; + // The next line not selected. + coverage.sel_end = (e === undefined) ? b+1 : e; +}; + +coverage.to_top = function () { + coverage.set_sel(0, 1); + coverage.scroll_window(0); +}; + +coverage.to_first_chunk = function () { + coverage.set_sel(0, 1); + coverage.to_next_chunk(); +}; + +// Return a string indicating what kind of chunk this line belongs to, +// or null if not a chunk. +coverage.chunk_indicator = function (line_elt) { + var klass = line_elt.attr('class'); + if (klass) { + var m = klass.match(/\bshow_\w+\b/); + if (m) { + return m[0]; + } + } + return null; +}; + +coverage.to_next_chunk = function () { + var c = coverage; + + // Find the start of the next colored chunk. + var probe = c.sel_end; + var chunk_indicator, probe_line; + while (true) { + probe_line = c.line_elt(probe); + if (probe_line.length === 0) { + return; + } + chunk_indicator = c.chunk_indicator(probe_line); + if (chunk_indicator) { + break; + } + probe++; + } + + // There's a next chunk, `probe` points to it. + var begin = probe; + + // Find the end of this chunk. + var next_indicator = chunk_indicator; + while (next_indicator === chunk_indicator) { + probe++; + probe_line = c.line_elt(probe); + next_indicator = c.chunk_indicator(probe_line); + } + c.set_sel(begin, probe); + c.show_selection(); +}; + +coverage.to_prev_chunk = function () { + var c = coverage; + + // Find the end of the prev colored chunk. + var probe = c.sel_begin-1; + var probe_line = c.line_elt(probe); + if (probe_line.length === 0) { + return; + } + var chunk_indicator = c.chunk_indicator(probe_line); + while (probe > 0 && !chunk_indicator) { + probe--; + probe_line = c.line_elt(probe); + if (probe_line.length === 0) { + return; + } + chunk_indicator = c.chunk_indicator(probe_line); + } + + // There's a prev chunk, `probe` points to its last line. + var end = probe+1; + + // Find the beginning of this chunk. + var prev_indicator = chunk_indicator; + while (prev_indicator === chunk_indicator) { + probe--; + probe_line = c.line_elt(probe); + prev_indicator = c.chunk_indicator(probe_line); + } + c.set_sel(probe+1, end); + c.show_selection(); +}; + +// Return the line number of the line nearest pixel position pos +coverage.line_at_pos = function (pos) { + var l1 = coverage.line_elt(1), + l2 = coverage.line_elt(2), + result; + if (l1.length && l2.length) { + var l1_top = l1.offset().top, + line_height = l2.offset().top - l1_top, + nlines = (pos - l1_top) / line_height; + if (nlines < 1) { + result = 1; + } + else { + result = Math.ceil(nlines); + } + } + else { + result = 1; + } + return result; +}; + +// Returns 0, 1, or 2: how many of the two ends of the selection are on +// the screen right now? +coverage.selection_ends_on_screen = function () { + if (coverage.sel_begin === 0) { + return 0; + } + + var top = coverage.line_elt(coverage.sel_begin); + var next = coverage.line_elt(coverage.sel_end-1); + + return ( + (top.isOnScreen() ? 1 : 0) + + (next.isOnScreen() ? 1 : 0) + ); +}; + +coverage.to_next_chunk_nicely = function () { + coverage.finish_scrolling(); + if (coverage.selection_ends_on_screen() === 0) { + // The selection is entirely off the screen: select the top line on + // the screen. + var win = $(window); + coverage.select_line_or_chunk(coverage.line_at_pos(win.scrollTop())); + } + coverage.to_next_chunk(); +}; + +coverage.to_prev_chunk_nicely = function () { + coverage.finish_scrolling(); + if (coverage.selection_ends_on_screen() === 0) { + var win = $(window); + coverage.select_line_or_chunk(coverage.line_at_pos(win.scrollTop() + win.height())); + } + coverage.to_prev_chunk(); +}; + +// Select line number lineno, or if it is in a colored chunk, select the +// entire chunk +coverage.select_line_or_chunk = function (lineno) { + var c = coverage; + var probe_line = c.line_elt(lineno); + if (probe_line.length === 0) { + return; + } + var the_indicator = c.chunk_indicator(probe_line); + if (the_indicator) { + // The line is in a highlighted chunk. + // Search backward for the first line. + var probe = lineno; + var indicator = the_indicator; + while (probe > 0 && indicator === the_indicator) { + probe--; + probe_line = c.line_elt(probe); + if (probe_line.length === 0) { + break; + } + indicator = c.chunk_indicator(probe_line); + } + var begin = probe + 1; + + // Search forward for the last line. + probe = lineno; + indicator = the_indicator; + while (indicator === the_indicator) { + probe++; + probe_line = c.line_elt(probe); + indicator = c.chunk_indicator(probe_line); + } + + coverage.set_sel(begin, probe); + } + else { + coverage.set_sel(lineno); + } +}; + +coverage.show_selection = function () { + var c = coverage; + + // Highlight the lines in the chunk + $(".linenos .highlight").removeClass("highlight"); + for (var probe = c.sel_begin; probe > 0 && probe < c.sel_end; probe++) { + c.num_elt(probe).addClass("highlight"); + } + + c.scroll_to_selection(); +}; + +coverage.scroll_to_selection = function () { + // Scroll the page if the chunk isn't fully visible. + if (coverage.selection_ends_on_screen() < 2) { + // Need to move the page. The html,body trick makes it scroll in all + // browsers, got it from http://stackoverflow.com/questions/3042651 + var top = coverage.line_elt(coverage.sel_begin); + var top_pos = parseInt(top.offset().top, 10); + coverage.scroll_window(top_pos - 30); + } +}; + +coverage.scroll_window = function (to_pos) { + $("html,body").animate({scrollTop: to_pos}, 200); +}; + +coverage.finish_scrolling = function () { + $("html,body").stop(true, true); +}; + +coverage.init_scroll_markers = function () { + var c = coverage; + // Init some variables + c.lines_len = $('#source p').length; + c.body_h = $('body').height(); + c.header_h = $('div#header').height(); + + // Build html + c.build_scroll_markers(); +}; + +coverage.build_scroll_markers = function () { + var c = coverage, + min_line_height = 3, + max_line_height = 10, + visible_window_h = $(window).height(); + + c.lines_to_mark = $('#source').find('p.show_run, p.show_mis, p.show_exc, p.show_exc, p.show_par'); + $('#scroll_marker').remove(); + // Don't build markers if the window has no scroll bar. + if (c.body_h <= visible_window_h) { + return; + } + + $("body").append("
 
"); + var scroll_marker = $('#scroll_marker'), + marker_scale = scroll_marker.height() / c.body_h, + line_height = scroll_marker.height() / c.lines_len; + + // Line height must be between the extremes. + if (line_height > min_line_height) { + if (line_height > max_line_height) { + line_height = max_line_height; + } + } + else { + line_height = min_line_height; + } + + var previous_line = -99, + last_mark, + last_top, + offsets = {}; + + // Calculate line offsets outside loop to prevent relayouts + c.lines_to_mark.each(function() { + offsets[this.id] = $(this).offset().top; + }); + c.lines_to_mark.each(function () { + var id_name = $(this).attr('id'), + line_top = Math.round(offsets[id_name] * marker_scale), + line_number = parseInt(id_name.substring(1, id_name.length)); + + if (line_number === previous_line + 1) { + // If this solid missed block just make previous mark higher. + last_mark.css({ + 'height': line_top + line_height - last_top + }); + } + else { + // Add colored line in scroll_marker block. + scroll_marker.append('
'); + last_mark = $('#m' + line_number); + last_mark.css({ + 'height': line_height, + 'top': line_top + }); + last_top = line_top; + } + + previous_line = line_number; + }); +}; diff --git a/tests/unit_tests (deprecated)/htmlcov/index.html b/tests/unit_tests (deprecated)/htmlcov/index.html new file mode 100644 index 00000000..1cbd6a41 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/index.html @@ -0,0 +1,728 @@ + + + + + Coverage report + + + + + + + + + + +
+ Hide keyboard shortcuts +

Hot-keys on this page

+
+

+ n + s + m + x + c   change column sorting +

+
+
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Modulestatementsmissingexcludedcoverage
Total120209833018%
C:\GitHub\brain-python-interface\riglib\__init__.py000100%
C:\GitHub\brain-python-interface\riglib\arduino_imu.py494900%
C:\GitHub\brain-python-interface\riglib\arduino_joystick.py494900%
C:\GitHub\brain-python-interface\riglib\blackrock\__init__.py3316052%
C:\GitHub\brain-python-interface\riglib\blackrock\brMiscFxns.py343400%
C:\GitHub\brain-python-interface\riglib\blackrock\brpylib.py64464400%
C:\GitHub\brain-python-interface\riglib\blackrock\cerelink.py10180021%
C:\GitHub\brain-python-interface\riglib\bmi\__init__.py100100%
C:\GitHub\brain-python-interface\riglib\bmi\accumulator.py3524031%
C:\GitHub\brain-python-interface\riglib\bmi\assist.py4730036%
C:\GitHub\brain-python-interface\riglib\bmi\bmi.py654526020%
C:\GitHub\brain-python-interface\riglib\bmi\clda.py527442016%
C:\GitHub\brain-python-interface\riglib\bmi\extractor.py648563013%
C:\GitHub\brain-python-interface\riglib\bmi\feedback_controllers.py12491027%
C:\GitHub\brain-python-interface\riglib\bmi\goal_calculators.py154126018%
C:\GitHub\brain-python-interface\riglib\bmi\kfdecoder.py429314027%
C:\GitHub\brain-python-interface\riglib\bmi\kfdecoder_fcns.py27427400%
C:\GitHub\brain-python-interface\riglib\bmi\lindecoder.py3728024%
C:\GitHub\brain-python-interface\riglib\bmi\onedim_lfp_decoder.py15115100%
C:\GitHub\brain-python-interface\riglib\bmi\ppfdecoder.py233197015%
C:\GitHub\brain-python-interface\riglib\bmi\rat_bmi_decoder.py35935900%
C:\GitHub\brain-python-interface\riglib\bmi\robot_arms.py291199032%
C:\GitHub\brain-python-interface\riglib\bmi\sim_neurons.py426374012%
C:\GitHub\brain-python-interface\riglib\bmi\sskfdecoder.py616100%
C:\GitHub\brain-python-interface\riglib\bmi\state_space_models.py228159030%
C:\GitHub\brain-python-interface\riglib\bmi\train.py70465507%
C:\GitHub\brain-python-interface\riglib\button.py434300%
C:\GitHub\brain-python-interface\riglib\calibrations.py10481022%
C:\GitHub\brain-python-interface\riglib\dio\__init__.py000100%
C:\GitHub\brain-python-interface\riglib\dio\nidaq\__init__.py393900%
C:\GitHub\brain-python-interface\riglib\dio\parse.py7459020%
C:\GitHub\brain-python-interface\riglib\experiment\Pygame.py6132048%
C:\GitHub\brain-python-interface\riglib\experiment\__init__.py4417061%
C:\GitHub\brain-python-interface\riglib\experiment\experiment.py282199029%
C:\GitHub\brain-python-interface\riglib\experiment\generate.py7358021%
C:\GitHub\brain-python-interface\riglib\experiment\mocks.py11538067%
C:\GitHub\brain-python-interface\riglib\experiment\report.py4438014%
C:\GitHub\brain-python-interface\riglib\eyetracker.py737300%
C:\GitHub\brain-python-interface\riglib\filter.py141400%
C:\GitHub\brain-python-interface\riglib\fsm\__init__.py100100%
C:\GitHub\brain-python-interface\riglib\fsm\fsm\__init__.py100100%
C:\GitHub\brain-python-interface\riglib\fsm\fsm\fsm.py178108039%
C:\GitHub\brain-python-interface\riglib\fsm\setup.py4400%
C:\GitHub\brain-python-interface\riglib\hdfwriter\__init__.py200100%
C:\GitHub\brain-python-interface\riglib\hdfwriter\hdfwriter\__init__.py200100%
C:\GitHub\brain-python-interface\riglib\hdfwriter\hdfwriter\hdfwriter.py4027032%
C:\GitHub\brain-python-interface\riglib\hdfwriter\setup.py4400%
C:\GitHub\brain-python-interface\riglib\kinarmdata.py212100%
C:\GitHub\brain-python-interface\riglib\kinarmsocket.py373700%
C:\GitHub\brain-python-interface\riglib\master8stimulation.py616100%
C:\GitHub\brain-python-interface\riglib\motiontracker.py10410400%
C:\GitHub\brain-python-interface\riglib\mp_calc.py10285017%
C:\GitHub\brain-python-interface\riglib\mp_proxy.py1900100%
C:\GitHub\brain-python-interface\riglib\optitrack_client\NatNetClient.py33733700%
C:\GitHub\brain-python-interface\riglib\optitrack_client\PythonSample.py9900%
C:\GitHub\brain-python-interface\riglib\optitrack_client\__init__.py000100%
C:\GitHub\brain-python-interface\riglib\optitrack_client\optitrack_direct_pack.py535300%
C:\GitHub\brain-python-interface\riglib\optitrack_client\optitrack_interface.py686800%
C:\GitHub\brain-python-interface\riglib\optitrack_client\test_NatNetClient_perframe.py131300%
C:\GitHub\brain-python-interface\riglib\optitrack_client\test_control.py9900%
C:\GitHub\brain-python-interface\riglib\optitrack_client\test_optitrack.py7700%
C:\GitHub\brain-python-interface\riglib\phidgets.py454500%
C:\GitHub\brain-python-interface\riglib\plants.py319163049%
C:\GitHub\brain-python-interface\riglib\plexon\__init__.py929200%
C:\GitHub\brain-python-interface\riglib\plexon\checkbin.py111100%
C:\GitHub\brain-python-interface\riglib\plexon\plexnet.py25325300%
C:\GitHub\brain-python-interface\riglib\plexon\plexnet_softserver_oldfiles.py22022000%
C:\GitHub\brain-python-interface\riglib\plexon\source.py222200%
C:\GitHub\brain-python-interface\riglib\plexon\test_plexfile.py151500%
C:\GitHub\brain-python-interface\riglib\positioner\__init__.py32432400%
C:\GitHub\brain-python-interface\riglib\positioner\calib.py3300%
C:\GitHub\brain-python-interface\riglib\reward.py177119033%
C:\GitHub\brain-python-interface\riglib\serial_dio.py4227036%
C:\GitHub\brain-python-interface\riglib\setup.py8800%
C:\GitHub\brain-python-interface\riglib\sink.py9969030%
C:\GitHub\brain-python-interface\riglib\source.py348267023%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\__init__.py62067%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\environment.py1510033%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\ik.py330183045%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\models.py193119038%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\primitives.py20966068%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\__init__.py200100%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\fbo.py767600%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\render.py10788018%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\shader.py8367019%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\ssao.py474700%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\render\stereo.py6749027%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\textures.py6550023%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\utils.py3830021%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\window.py272198027%
C:\GitHub\brain-python-interface\riglib\stereo_opengl\xfm.py13561055%
C:\GitHub\brain-python-interface\riglib\stimulus_pulse.py222200%
C:\GitHub\brain-python-interface\riglib\touch_data.py444400%
+

+ No items found using the specified filter. +

+
+ + + diff --git a/tests/unit_tests (deprecated)/htmlcov/jquery.ba-throttle-debounce.min.js b/tests/unit_tests (deprecated)/htmlcov/jquery.ba-throttle-debounce.min.js new file mode 100644 index 00000000..648fe5d3 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/jquery.ba-throttle-debounce.min.js @@ -0,0 +1,9 @@ +/* + * jQuery throttle / debounce - v1.1 - 3/7/2010 + * http://benalman.com/projects/jquery-throttle-debounce-plugin/ + * + * Copyright (c) 2010 "Cowboy" Ben Alman + * Dual licensed under the MIT and GPL licenses. + * http://benalman.com/about/license/ + */ +(function(b,c){var $=b.jQuery||b.Cowboy||(b.Cowboy={}),a;$.throttle=a=function(e,f,j,i){var h,d=0;if(typeof f!=="boolean"){i=j;j=f;f=c}function g(){var o=this,m=+new Date()-d,n=arguments;function l(){d=+new Date();j.apply(o,n)}function k(){h=c}if(i&&!h){l()}h&&clearTimeout(h);if(i===c&&m>e){l()}else{if(f!==true){h=setTimeout(i?k:l,i===c?e-m:e)}}}if($.guid){g.guid=j.guid=j.guid||$.guid++}return g};$.debounce=function(d,e,f){return f===c?a(d,e,false):a(d,f,e!==false)}})(this); diff --git a/tests/unit_tests (deprecated)/htmlcov/jquery.hotkeys.js b/tests/unit_tests (deprecated)/htmlcov/jquery.hotkeys.js new file mode 100644 index 00000000..09b21e03 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/jquery.hotkeys.js @@ -0,0 +1,99 @@ +/* + * jQuery Hotkeys Plugin + * Copyright 2010, John Resig + * Dual licensed under the MIT or GPL Version 2 licenses. + * + * Based upon the plugin by Tzury Bar Yochay: + * http://github.com/tzuryby/hotkeys + * + * Original idea by: + * Binny V A, http://www.openjs.com/scripts/events/keyboard_shortcuts/ +*/ + +(function(jQuery){ + + jQuery.hotkeys = { + version: "0.8", + + specialKeys: { + 8: "backspace", 9: "tab", 13: "return", 16: "shift", 17: "ctrl", 18: "alt", 19: "pause", + 20: "capslock", 27: "esc", 32: "space", 33: "pageup", 34: "pagedown", 35: "end", 36: "home", + 37: "left", 38: "up", 39: "right", 40: "down", 45: "insert", 46: "del", + 96: "0", 97: "1", 98: "2", 99: "3", 100: "4", 101: "5", 102: "6", 103: "7", + 104: "8", 105: "9", 106: "*", 107: "+", 109: "-", 110: ".", 111 : "/", + 112: "f1", 113: "f2", 114: "f3", 115: "f4", 116: "f5", 117: "f6", 118: "f7", 119: "f8", + 120: "f9", 121: "f10", 122: "f11", 123: "f12", 144: "numlock", 145: "scroll", 191: "/", 224: "meta" + }, + + shiftNums: { + "`": "~", "1": "!", "2": "@", "3": "#", "4": "$", "5": "%", "6": "^", "7": "&", + "8": "*", "9": "(", "0": ")", "-": "_", "=": "+", ";": ": ", "'": "\"", ",": "<", + ".": ">", "/": "?", "\\": "|" + } + }; + + function keyHandler( handleObj ) { + // Only care when a possible input has been specified + if ( typeof handleObj.data !== "string" ) { + return; + } + + var origHandler = handleObj.handler, + keys = handleObj.data.toLowerCase().split(" "); + + handleObj.handler = function( event ) { + // Don't fire in text-accepting inputs that we didn't directly bind to + if ( this !== event.target && (/textarea|select/i.test( event.target.nodeName ) || + event.target.type === "text") ) { + return; + } + + // Keypress represents characters, not special keys + var special = event.type !== "keypress" && jQuery.hotkeys.specialKeys[ event.which ], + character = String.fromCharCode( event.which ).toLowerCase(), + key, modif = "", possible = {}; + + // check combinations (alt|ctrl|shift+anything) + if ( event.altKey && special !== "alt" ) { + modif += "alt+"; + } + + if ( event.ctrlKey && special !== "ctrl" ) { + modif += "ctrl+"; + } + + // TODO: Need to make sure this works consistently across platforms + if ( event.metaKey && !event.ctrlKey && special !== "meta" ) { + modif += "meta+"; + } + + if ( event.shiftKey && special !== "shift" ) { + modif += "shift+"; + } + + if ( special ) { + possible[ modif + special ] = true; + + } else { + possible[ modif + character ] = true; + possible[ modif + jQuery.hotkeys.shiftNums[ character ] ] = true; + + // "$" can be triggered as "Shift+4" or "Shift+$" or just "$" + if ( modif === "shift+" ) { + possible[ jQuery.hotkeys.shiftNums[ character ] ] = true; + } + } + + for ( var i = 0, l = keys.length; i < l; i++ ) { + if ( possible[ keys[i] ] ) { + return origHandler.apply( this, arguments ); + } + } + }; + } + + jQuery.each([ "keydown", "keyup", "keypress" ], function() { + jQuery.event.special[ this ] = { add: keyHandler }; + }); + +})( jQuery ); diff --git a/tests/unit_tests (deprecated)/htmlcov/jquery.isonscreen.js b/tests/unit_tests (deprecated)/htmlcov/jquery.isonscreen.js new file mode 100644 index 00000000..0182ebd2 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/jquery.isonscreen.js @@ -0,0 +1,53 @@ +/* Copyright (c) 2010 + * @author Laurence Wheway + * Dual licensed under the MIT (http://www.opensource.org/licenses/mit-license.php) + * and GPL (http://www.opensource.org/licenses/gpl-license.php) licenses. + * + * @version 1.2.0 + */ +(function($) { + jQuery.extend({ + isOnScreen: function(box, container) { + //ensure numbers come in as intgers (not strings) and remove 'px' is it's there + for(var i in box){box[i] = parseFloat(box[i])}; + for(var i in container){container[i] = parseFloat(container[i])}; + + if(!container){ + container = { + left: $(window).scrollLeft(), + top: $(window).scrollTop(), + width: $(window).width(), + height: $(window).height() + } + } + + if( box.left+box.width-container.left > 0 && + box.left < container.width+container.left && + box.top+box.height-container.top > 0 && + box.top < container.height+container.top + ) return true; + return false; + } + }) + + + jQuery.fn.isOnScreen = function (container) { + for(var i in container){container[i] = parseFloat(container[i])}; + + if(!container){ + container = { + left: $(window).scrollLeft(), + top: $(window).scrollTop(), + width: $(window).width(), + height: $(window).height() + } + } + + if( $(this).offset().left+$(this).width()-container.left > 0 && + $(this).offset().left < container.width+container.left && + $(this).offset().top+$(this).height()-container.top > 0 && + $(this).offset().top < container.height+container.top + ) return true; + return false; + } +})(jQuery); diff --git a/tests/unit_tests (deprecated)/htmlcov/jquery.min.js b/tests/unit_tests (deprecated)/htmlcov/jquery.min.js new file mode 100644 index 00000000..d1608e37 --- /dev/null +++ b/tests/unit_tests (deprecated)/htmlcov/jquery.min.js @@ -0,0 +1,4 @@ +/*! jQuery v1.11.1 | (c) 2005, 2014 jQuery Foundation, Inc. | jquery.org/license */ +!function(a,b){"object"==typeof module&&"object"==typeof module.exports?module.exports=a.document?b(a,!0):function(a){if(!a.document)throw new Error("jQuery requires a window with a document");return b(a)}:b(a)}("undefined"!=typeof window?window:this,function(a,b){var c=[],d=c.slice,e=c.concat,f=c.push,g=c.indexOf,h={},i=h.toString,j=h.hasOwnProperty,k={},l="1.11.1",m=function(a,b){return new m.fn.init(a,b)},n=/^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g,o=/^-ms-/,p=/-([\da-z])/gi,q=function(a,b){return b.toUpperCase()};m.fn=m.prototype={jquery:l,constructor:m,selector:"",length:0,toArray:function(){return d.call(this)},get:function(a){return null!=a?0>a?this[a+this.length]:this[a]:d.call(this)},pushStack:function(a){var b=m.merge(this.constructor(),a);return b.prevObject=this,b.context=this.context,b},each:function(a,b){return m.each(this,a,b)},map:function(a){return this.pushStack(m.map(this,function(b,c){return a.call(b,c,b)}))},slice:function(){return this.pushStack(d.apply(this,arguments))},first:function(){return this.eq(0)},last:function(){return this.eq(-1)},eq:function(a){var b=this.length,c=+a+(0>a?b:0);return this.pushStack(c>=0&&b>c?[this[c]]:[])},end:function(){return this.prevObject||this.constructor(null)},push:f,sort:c.sort,splice:c.splice},m.extend=m.fn.extend=function(){var a,b,c,d,e,f,g=arguments[0]||{},h=1,i=arguments.length,j=!1;for("boolean"==typeof g&&(j=g,g=arguments[h]||{},h++),"object"==typeof g||m.isFunction(g)||(g={}),h===i&&(g=this,h--);i>h;h++)if(null!=(e=arguments[h]))for(d in e)a=g[d],c=e[d],g!==c&&(j&&c&&(m.isPlainObject(c)||(b=m.isArray(c)))?(b?(b=!1,f=a&&m.isArray(a)?a:[]):f=a&&m.isPlainObject(a)?a:{},g[d]=m.extend(j,f,c)):void 0!==c&&(g[d]=c));return g},m.extend({expando:"jQuery"+(l+Math.random()).replace(/\D/g,""),isReady:!0,error:function(a){throw new Error(a)},noop:function(){},isFunction:function(a){return"function"===m.type(a)},isArray:Array.isArray||function(a){return"array"===m.type(a)},isWindow:function(a){return null!=a&&a==a.window},isNumeric:function(a){return!m.isArray(a)&&a-parseFloat(a)>=0},isEmptyObject:function(a){var b;for(b in a)return!1;return!0},isPlainObject:function(a){var b;if(!a||"object"!==m.type(a)||a.nodeType||m.isWindow(a))return!1;try{if(a.constructor&&!j.call(a,"constructor")&&!j.call(a.constructor.prototype,"isPrototypeOf"))return!1}catch(c){return!1}if(k.ownLast)for(b in a)return j.call(a,b);for(b in a);return void 0===b||j.call(a,b)},type:function(a){return null==a?a+"":"object"==typeof a||"function"==typeof a?h[i.call(a)]||"object":typeof a},globalEval:function(b){b&&m.trim(b)&&(a.execScript||function(b){a.eval.call(a,b)})(b)},camelCase:function(a){return a.replace(o,"ms-").replace(p,q)},nodeName:function(a,b){return a.nodeName&&a.nodeName.toLowerCase()===b.toLowerCase()},each:function(a,b,c){var d,e=0,f=a.length,g=r(a);if(c){if(g){for(;f>e;e++)if(d=b.apply(a[e],c),d===!1)break}else for(e in a)if(d=b.apply(a[e],c),d===!1)break}else if(g){for(;f>e;e++)if(d=b.call(a[e],e,a[e]),d===!1)break}else for(e in a)if(d=b.call(a[e],e,a[e]),d===!1)break;return a},trim:function(a){return null==a?"":(a+"").replace(n,"")},makeArray:function(a,b){var c=b||[];return null!=a&&(r(Object(a))?m.merge(c,"string"==typeof a?[a]:a):f.call(c,a)),c},inArray:function(a,b,c){var d;if(b){if(g)return g.call(b,a,c);for(d=b.length,c=c?0>c?Math.max(0,d+c):c:0;d>c;c++)if(c in b&&b[c]===a)return c}return-1},merge:function(a,b){var c=+b.length,d=0,e=a.length;while(c>d)a[e++]=b[d++];if(c!==c)while(void 0!==b[d])a[e++]=b[d++];return a.length=e,a},grep:function(a,b,c){for(var d,e=[],f=0,g=a.length,h=!c;g>f;f++)d=!b(a[f],f),d!==h&&e.push(a[f]);return e},map:function(a,b,c){var d,f=0,g=a.length,h=r(a),i=[];if(h)for(;g>f;f++)d=b(a[f],f,c),null!=d&&i.push(d);else for(f in a)d=b(a[f],f,c),null!=d&&i.push(d);return e.apply([],i)},guid:1,proxy:function(a,b){var c,e,f;return"string"==typeof b&&(f=a[b],b=a,a=f),m.isFunction(a)?(c=d.call(arguments,2),e=function(){return a.apply(b||this,c.concat(d.call(arguments)))},e.guid=a.guid=a.guid||m.guid++,e):void 0},now:function(){return+new Date},support:k}),m.each("Boolean Number String Function Array Date RegExp Object Error".split(" "),function(a,b){h["[object "+b+"]"]=b.toLowerCase()});function r(a){var b=a.length,c=m.type(a);return"function"===c||m.isWindow(a)?!1:1===a.nodeType&&b?!0:"array"===c||0===b||"number"==typeof b&&b>0&&b-1 in a}var s=function(a){var b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u="sizzle"+-new Date,v=a.document,w=0,x=0,y=gb(),z=gb(),A=gb(),B=function(a,b){return a===b&&(l=!0),0},C="undefined",D=1<<31,E={}.hasOwnProperty,F=[],G=F.pop,H=F.push,I=F.push,J=F.slice,K=F.indexOf||function(a){for(var b=0,c=this.length;c>b;b++)if(this[b]===a)return b;return-1},L="checked|selected|async|autofocus|autoplay|controls|defer|disabled|hidden|ismap|loop|multiple|open|readonly|required|scoped",M="[\\x20\\t\\r\\n\\f]",N="(?:\\\\.|[\\w-]|[^\\x00-\\xa0])+",O=N.replace("w","w#"),P="\\["+M+"*("+N+")(?:"+M+"*([*^$|!~]?=)"+M+"*(?:'((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\"|("+O+"))|)"+M+"*\\]",Q=":("+N+")(?:\\((('((?:\\\\.|[^\\\\'])*)'|\"((?:\\\\.|[^\\\\\"])*)\")|((?:\\\\.|[^\\\\()[\\]]|"+P+")*)|.*)\\)|)",R=new RegExp("^"+M+"+|((?:^|[^\\\\])(?:\\\\.)*)"+M+"+$","g"),S=new RegExp("^"+M+"*,"+M+"*"),T=new RegExp("^"+M+"*([>+~]|"+M+")"+M+"*"),U=new RegExp("="+M+"*([^\\]'\"]*?)"+M+"*\\]","g"),V=new RegExp(Q),W=new RegExp("^"+O+"$"),X={ID:new RegExp("^#("+N+")"),CLASS:new RegExp("^\\.("+N+")"),TAG:new RegExp("^("+N.replace("w","w*")+")"),ATTR:new RegExp("^"+P),PSEUDO:new RegExp("^"+Q),CHILD:new RegExp("^:(only|first|last|nth|nth-last)-(child|of-type)(?:\\("+M+"*(even|odd|(([+-]|)(\\d*)n|)"+M+"*(?:([+-]|)"+M+"*(\\d+)|))"+M+"*\\)|)","i"),bool:new RegExp("^(?:"+L+")$","i"),needsContext:new RegExp("^"+M+"*[>+~]|:(even|odd|eq|gt|lt|nth|first|last)(?:\\("+M+"*((?:-\\d)?\\d*)"+M+"*\\)|)(?=[^-]|$)","i")},Y=/^(?:input|select|textarea|button)$/i,Z=/^h\d$/i,$=/^[^{]+\{\s*\[native \w/,_=/^(?:#([\w-]+)|(\w+)|\.([\w-]+))$/,ab=/[+~]/,bb=/'|\\/g,cb=new RegExp("\\\\([\\da-f]{1,6}"+M+"?|("+M+")|.)","ig"),db=function(a,b,c){var d="0x"+b-65536;return d!==d||c?b:0>d?String.fromCharCode(d+65536):String.fromCharCode(d>>10|55296,1023&d|56320)};try{I.apply(F=J.call(v.childNodes),v.childNodes),F[v.childNodes.length].nodeType}catch(eb){I={apply:F.length?function(a,b){H.apply(a,J.call(b))}:function(a,b){var c=a.length,d=0;while(a[c++]=b[d++]);a.length=c-1}}}function fb(a,b,d,e){var f,h,j,k,l,o,r,s,w,x;if((b?b.ownerDocument||b:v)!==n&&m(b),b=b||n,d=d||[],!a||"string"!=typeof a)return d;if(1!==(k=b.nodeType)&&9!==k)return[];if(p&&!e){if(f=_.exec(a))if(j=f[1]){if(9===k){if(h=b.getElementById(j),!h||!h.parentNode)return d;if(h.id===j)return d.push(h),d}else if(b.ownerDocument&&(h=b.ownerDocument.getElementById(j))&&t(b,h)&&h.id===j)return d.push(h),d}else{if(f[2])return I.apply(d,b.getElementsByTagName(a)),d;if((j=f[3])&&c.getElementsByClassName&&b.getElementsByClassName)return I.apply(d,b.getElementsByClassName(j)),d}if(c.qsa&&(!q||!q.test(a))){if(s=r=u,w=b,x=9===k&&a,1===k&&"object"!==b.nodeName.toLowerCase()){o=g(a),(r=b.getAttribute("id"))?s=r.replace(bb,"\\$&"):b.setAttribute("id",s),s="[id='"+s+"'] ",l=o.length;while(l--)o[l]=s+qb(o[l]);w=ab.test(a)&&ob(b.parentNode)||b,x=o.join(",")}if(x)try{return I.apply(d,w.querySelectorAll(x)),d}catch(y){}finally{r||b.removeAttribute("id")}}}return i(a.replace(R,"$1"),b,d,e)}function gb(){var a=[];function b(c,e){return a.push(c+" ")>d.cacheLength&&delete b[a.shift()],b[c+" "]=e}return b}function hb(a){return a[u]=!0,a}function ib(a){var b=n.createElement("div");try{return!!a(b)}catch(c){return!1}finally{b.parentNode&&b.parentNode.removeChild(b),b=null}}function jb(a,b){var c=a.split("|"),e=a.length;while(e--)d.attrHandle[c[e]]=b}function kb(a,b){var c=b&&a,d=c&&1===a.nodeType&&1===b.nodeType&&(~b.sourceIndex||D)-(~a.sourceIndex||D);if(d)return d;if(c)while(c=c.nextSibling)if(c===b)return-1;return a?1:-1}function lb(a){return function(b){var c=b.nodeName.toLowerCase();return"input"===c&&b.type===a}}function mb(a){return function(b){var c=b.nodeName.toLowerCase();return("input"===c||"button"===c)&&b.type===a}}function nb(a){return hb(function(b){return b=+b,hb(function(c,d){var e,f=a([],c.length,b),g=f.length;while(g--)c[e=f[g]]&&(c[e]=!(d[e]=c[e]))})})}function ob(a){return a&&typeof a.getElementsByTagName!==C&&a}c=fb.support={},f=fb.isXML=function(a){var b=a&&(a.ownerDocument||a).documentElement;return b?"HTML"!==b.nodeName:!1},m=fb.setDocument=function(a){var b,e=a?a.ownerDocument||a:v,g=e.defaultView;return e!==n&&9===e.nodeType&&e.documentElement?(n=e,o=e.documentElement,p=!f(e),g&&g!==g.top&&(g.addEventListener?g.addEventListener("unload",function(){m()},!1):g.attachEvent&&g.attachEvent("onunload",function(){m()})),c.attributes=ib(function(a){return a.className="i",!a.getAttribute("className")}),c.getElementsByTagName=ib(function(a){return a.appendChild(e.createComment("")),!a.getElementsByTagName("*").length}),c.getElementsByClassName=$.test(e.getElementsByClassName)&&ib(function(a){return a.innerHTML="
",a.firstChild.className="i",2===a.getElementsByClassName("i").length}),c.getById=ib(function(a){return o.appendChild(a).id=u,!e.getElementsByName||!e.getElementsByName(u).length}),c.getById?(d.find.ID=function(a,b){if(typeof b.getElementById!==C&&p){var c=b.getElementById(a);return c&&c.parentNode?[c]:[]}},d.filter.ID=function(a){var b=a.replace(cb,db);return function(a){return a.getAttribute("id")===b}}):(delete d.find.ID,d.filter.ID=function(a){var b=a.replace(cb,db);return function(a){var c=typeof a.getAttributeNode!==C&&a.getAttributeNode("id");return c&&c.value===b}}),d.find.TAG=c.getElementsByTagName?function(a,b){return typeof b.getElementsByTagName!==C?b.getElementsByTagName(a):void 0}:function(a,b){var c,d=[],e=0,f=b.getElementsByTagName(a);if("*"===a){while(c=f[e++])1===c.nodeType&&d.push(c);return d}return f},d.find.CLASS=c.getElementsByClassName&&function(a,b){return typeof b.getElementsByClassName!==C&&p?b.getElementsByClassName(a):void 0},r=[],q=[],(c.qsa=$.test(e.querySelectorAll))&&(ib(function(a){a.innerHTML="",a.querySelectorAll("[msallowclip^='']").length&&q.push("[*^$]="+M+"*(?:''|\"\")"),a.querySelectorAll("[selected]").length||q.push("\\["+M+"*(?:value|"+L+")"),a.querySelectorAll(":checked").length||q.push(":checked")}),ib(function(a){var b=e.createElement("input");b.setAttribute("type","hidden"),a.appendChild(b).setAttribute("name","D"),a.querySelectorAll("[name=d]").length&&q.push("name"+M+"*[*^$|!~]?="),a.querySelectorAll(":enabled").length||q.push(":enabled",":disabled"),a.querySelectorAll("*,:x"),q.push(",.*:")})),(c.matchesSelector=$.test(s=o.matches||o.webkitMatchesSelector||o.mozMatchesSelector||o.oMatchesSelector||o.msMatchesSelector))&&ib(function(a){c.disconnectedMatch=s.call(a,"div"),s.call(a,"[s!='']:x"),r.push("!=",Q)}),q=q.length&&new RegExp(q.join("|")),r=r.length&&new RegExp(r.join("|")),b=$.test(o.compareDocumentPosition),t=b||$.test(o.contains)?function(a,b){var c=9===a.nodeType?a.documentElement:a,d=b&&b.parentNode;return a===d||!(!d||1!==d.nodeType||!(c.contains?c.contains(d):a.compareDocumentPosition&&16&a.compareDocumentPosition(d)))}:function(a,b){if(b)while(b=b.parentNode)if(b===a)return!0;return!1},B=b?function(a,b){if(a===b)return l=!0,0;var d=!a.compareDocumentPosition-!b.compareDocumentPosition;return d?d:(d=(a.ownerDocument||a)===(b.ownerDocument||b)?a.compareDocumentPosition(b):1,1&d||!c.sortDetached&&b.compareDocumentPosition(a)===d?a===e||a.ownerDocument===v&&t(v,a)?-1:b===e||b.ownerDocument===v&&t(v,b)?1:k?K.call(k,a)-K.call(k,b):0:4&d?-1:1)}:function(a,b){if(a===b)return l=!0,0;var c,d=0,f=a.parentNode,g=b.parentNode,h=[a],i=[b];if(!f||!g)return a===e?-1:b===e?1:f?-1:g?1:k?K.call(k,a)-K.call(k,b):0;if(f===g)return kb(a,b);c=a;while(c=c.parentNode)h.unshift(c);c=b;while(c=c.parentNode)i.unshift(c);while(h[d]===i[d])d++;return d?kb(h[d],i[d]):h[d]===v?-1:i[d]===v?1:0},e):n},fb.matches=function(a,b){return fb(a,null,null,b)},fb.matchesSelector=function(a,b){if((a.ownerDocument||a)!==n&&m(a),b=b.replace(U,"='$1']"),!(!c.matchesSelector||!p||r&&r.test(b)||q&&q.test(b)))try{var d=s.call(a,b);if(d||c.disconnectedMatch||a.document&&11!==a.document.nodeType)return d}catch(e){}return fb(b,n,null,[a]).length>0},fb.contains=function(a,b){return(a.ownerDocument||a)!==n&&m(a),t(a,b)},fb.attr=function(a,b){(a.ownerDocument||a)!==n&&m(a);var e=d.attrHandle[b.toLowerCase()],f=e&&E.call(d.attrHandle,b.toLowerCase())?e(a,b,!p):void 0;return void 0!==f?f:c.attributes||!p?a.getAttribute(b):(f=a.getAttributeNode(b))&&f.specified?f.value:null},fb.error=function(a){throw new Error("Syntax error, unrecognized expression: "+a)},fb.uniqueSort=function(a){var b,d=[],e=0,f=0;if(l=!c.detectDuplicates,k=!c.sortStable&&a.slice(0),a.sort(B),l){while(b=a[f++])b===a[f]&&(e=d.push(f));while(e--)a.splice(d[e],1)}return k=null,a},e=fb.getText=function(a){var b,c="",d=0,f=a.nodeType;if(f){if(1===f||9===f||11===f){if("string"==typeof a.textContent)return a.textContent;for(a=a.firstChild;a;a=a.nextSibling)c+=e(a)}else if(3===f||4===f)return a.nodeValue}else while(b=a[d++])c+=e(b);return c},d=fb.selectors={cacheLength:50,createPseudo:hb,match:X,attrHandle:{},find:{},relative:{">":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(a){return a[1]=a[1].replace(cb,db),a[3]=(a[3]||a[4]||a[5]||"").replace(cb,db),"~="===a[2]&&(a[3]=" "+a[3]+" "),a.slice(0,4)},CHILD:function(a){return a[1]=a[1].toLowerCase(),"nth"===a[1].slice(0,3)?(a[3]||fb.error(a[0]),a[4]=+(a[4]?a[5]+(a[6]||1):2*("even"===a[3]||"odd"===a[3])),a[5]=+(a[7]+a[8]||"odd"===a[3])):a[3]&&fb.error(a[0]),a},PSEUDO:function(a){var b,c=!a[6]&&a[2];return X.CHILD.test(a[0])?null:(a[3]?a[2]=a[4]||a[5]||"":c&&V.test(c)&&(b=g(c,!0))&&(b=c.indexOf(")",c.length-b)-c.length)&&(a[0]=a[0].slice(0,b),a[2]=c.slice(0,b)),a.slice(0,3))}},filter:{TAG:function(a){var b=a.replace(cb,db).toLowerCase();return"*"===a?function(){return!0}:function(a){return a.nodeName&&a.nodeName.toLowerCase()===b}},CLASS:function(a){var b=y[a+" "];return b||(b=new RegExp("(^|"+M+")"+a+"("+M+"|$)"))&&y(a,function(a){return b.test("string"==typeof a.className&&a.className||typeof a.getAttribute!==C&&a.getAttribute("class")||"")})},ATTR:function(a,b,c){return function(d){var e=fb.attr(d,a);return null==e?"!="===b:b?(e+="","="===b?e===c:"!="===b?e!==c:"^="===b?c&&0===e.indexOf(c):"*="===b?c&&e.indexOf(c)>-1:"$="===b?c&&e.slice(-c.length)===c:"~="===b?(" "+e+" ").indexOf(c)>-1:"|="===b?e===c||e.slice(0,c.length+1)===c+"-":!1):!0}},CHILD:function(a,b,c,d,e){var f="nth"!==a.slice(0,3),g="last"!==a.slice(-4),h="of-type"===b;return 1===d&&0===e?function(a){return!!a.parentNode}:function(b,c,i){var j,k,l,m,n,o,p=f!==g?"nextSibling":"previousSibling",q=b.parentNode,r=h&&b.nodeName.toLowerCase(),s=!i&&!h;if(q){if(f){while(p){l=b;while(l=l[p])if(h?l.nodeName.toLowerCase()===r:1===l.nodeType)return!1;o=p="only"===a&&!o&&"nextSibling"}return!0}if(o=[g?q.firstChild:q.lastChild],g&&s){k=q[u]||(q[u]={}),j=k[a]||[],n=j[0]===w&&j[1],m=j[0]===w&&j[2],l=n&&q.childNodes[n];while(l=++n&&l&&l[p]||(m=n=0)||o.pop())if(1===l.nodeType&&++m&&l===b){k[a]=[w,n,m];break}}else if(s&&(j=(b[u]||(b[u]={}))[a])&&j[0]===w)m=j[1];else while(l=++n&&l&&l[p]||(m=n=0)||o.pop())if((h?l.nodeName.toLowerCase()===r:1===l.nodeType)&&++m&&(s&&((l[u]||(l[u]={}))[a]=[w,m]),l===b))break;return m-=e,m===d||m%d===0&&m/d>=0}}},PSEUDO:function(a,b){var c,e=d.pseudos[a]||d.setFilters[a.toLowerCase()]||fb.error("unsupported pseudo: "+a);return e[u]?e(b):e.length>1?(c=[a,a,"",b],d.setFilters.hasOwnProperty(a.toLowerCase())?hb(function(a,c){var d,f=e(a,b),g=f.length;while(g--)d=K.call(a,f[g]),a[d]=!(c[d]=f[g])}):function(a){return e(a,0,c)}):e}},pseudos:{not:hb(function(a){var b=[],c=[],d=h(a.replace(R,"$1"));return d[u]?hb(function(a,b,c,e){var f,g=d(a,null,e,[]),h=a.length;while(h--)(f=g[h])&&(a[h]=!(b[h]=f))}):function(a,e,f){return b[0]=a,d(b,null,f,c),!c.pop()}}),has:hb(function(a){return function(b){return fb(a,b).length>0}}),contains:hb(function(a){return function(b){return(b.textContent||b.innerText||e(b)).indexOf(a)>-1}}),lang:hb(function(a){return W.test(a||"")||fb.error("unsupported lang: "+a),a=a.replace(cb,db).toLowerCase(),function(b){var c;do if(c=p?b.lang:b.getAttribute("xml:lang")||b.getAttribute("lang"))return c=c.toLowerCase(),c===a||0===c.indexOf(a+"-");while((b=b.parentNode)&&1===b.nodeType);return!1}}),target:function(b){var c=a.location&&a.location.hash;return c&&c.slice(1)===b.id},root:function(a){return a===o},focus:function(a){return a===n.activeElement&&(!n.hasFocus||n.hasFocus())&&!!(a.type||a.href||~a.tabIndex)},enabled:function(a){return a.disabled===!1},disabled:function(a){return a.disabled===!0},checked:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&!!a.checked||"option"===b&&!!a.selected},selected:function(a){return a.parentNode&&a.parentNode.selectedIndex,a.selected===!0},empty:function(a){for(a=a.firstChild;a;a=a.nextSibling)if(a.nodeType<6)return!1;return!0},parent:function(a){return!d.pseudos.empty(a)},header:function(a){return Z.test(a.nodeName)},input:function(a){return Y.test(a.nodeName)},button:function(a){var b=a.nodeName.toLowerCase();return"input"===b&&"button"===a.type||"button"===b},text:function(a){var b;return"input"===a.nodeName.toLowerCase()&&"text"===a.type&&(null==(b=a.getAttribute("type"))||"text"===b.toLowerCase())},first:nb(function(){return[0]}),last:nb(function(a,b){return[b-1]}),eq:nb(function(a,b,c){return[0>c?c+b:c]}),even:nb(function(a,b){for(var c=0;b>c;c+=2)a.push(c);return a}),odd:nb(function(a,b){for(var c=1;b>c;c+=2)a.push(c);return a}),lt:nb(function(a,b,c){for(var d=0>c?c+b:c;--d>=0;)a.push(d);return a}),gt:nb(function(a,b,c){for(var d=0>c?c+b:c;++db;b++)d+=a[b].value;return d}function rb(a,b,c){var d=b.dir,e=c&&"parentNode"===d,f=x++;return b.first?function(b,c,f){while(b=b[d])if(1===b.nodeType||e)return a(b,c,f)}:function(b,c,g){var h,i,j=[w,f];if(g){while(b=b[d])if((1===b.nodeType||e)&&a(b,c,g))return!0}else while(b=b[d])if(1===b.nodeType||e){if(i=b[u]||(b[u]={}),(h=i[d])&&h[0]===w&&h[1]===f)return j[2]=h[2];if(i[d]=j,j[2]=a(b,c,g))return!0}}}function sb(a){return a.length>1?function(b,c,d){var e=a.length;while(e--)if(!a[e](b,c,d))return!1;return!0}:a[0]}function tb(a,b,c){for(var d=0,e=b.length;e>d;d++)fb(a,b[d],c);return c}function ub(a,b,c,d,e){for(var f,g=[],h=0,i=a.length,j=null!=b;i>h;h++)(f=a[h])&&(!c||c(f,d,e))&&(g.push(f),j&&b.push(h));return g}function vb(a,b,c,d,e,f){return d&&!d[u]&&(d=vb(d)),e&&!e[u]&&(e=vb(e,f)),hb(function(f,g,h,i){var j,k,l,m=[],n=[],o=g.length,p=f||tb(b||"*",h.nodeType?[h]:h,[]),q=!a||!f&&b?p:ub(p,m,a,h,i),r=c?e||(f?a:o||d)?[]:g:q;if(c&&c(q,r,h,i),d){j=ub(r,n),d(j,[],h,i),k=j.length;while(k--)(l=j[k])&&(r[n[k]]=!(q[n[k]]=l))}if(f){if(e||a){if(e){j=[],k=r.length;while(k--)(l=r[k])&&j.push(q[k]=l);e(null,r=[],j,i)}k=r.length;while(k--)(l=r[k])&&(j=e?K.call(f,l):m[k])>-1&&(f[j]=!(g[j]=l))}}else r=ub(r===g?r.splice(o,r.length):r),e?e(null,g,r,i):I.apply(g,r)})}function wb(a){for(var b,c,e,f=a.length,g=d.relative[a[0].type],h=g||d.relative[" "],i=g?1:0,k=rb(function(a){return a===b},h,!0),l=rb(function(a){return K.call(b,a)>-1},h,!0),m=[function(a,c,d){return!g&&(d||c!==j)||((b=c).nodeType?k(a,c,d):l(a,c,d))}];f>i;i++)if(c=d.relative[a[i].type])m=[rb(sb(m),c)];else{if(c=d.filter[a[i].type].apply(null,a[i].matches),c[u]){for(e=++i;f>e;e++)if(d.relative[a[e].type])break;return vb(i>1&&sb(m),i>1&&qb(a.slice(0,i-1).concat({value:" "===a[i-2].type?"*":""})).replace(R,"$1"),c,e>i&&wb(a.slice(i,e)),f>e&&wb(a=a.slice(e)),f>e&&qb(a))}m.push(c)}return sb(m)}function xb(a,b){var c=b.length>0,e=a.length>0,f=function(f,g,h,i,k){var l,m,o,p=0,q="0",r=f&&[],s=[],t=j,u=f||e&&d.find.TAG("*",k),v=w+=null==t?1:Math.random()||.1,x=u.length;for(k&&(j=g!==n&&g);q!==x&&null!=(l=u[q]);q++){if(e&&l){m=0;while(o=a[m++])if(o(l,g,h)){i.push(l);break}k&&(w=v)}c&&((l=!o&&l)&&p--,f&&r.push(l))}if(p+=q,c&&q!==p){m=0;while(o=b[m++])o(r,s,g,h);if(f){if(p>0)while(q--)r[q]||s[q]||(s[q]=G.call(i));s=ub(s)}I.apply(i,s),k&&!f&&s.length>0&&p+b.length>1&&fb.uniqueSort(i)}return k&&(w=v,j=t),r};return c?hb(f):f}return h=fb.compile=function(a,b){var c,d=[],e=[],f=A[a+" "];if(!f){b||(b=g(a)),c=b.length;while(c--)f=wb(b[c]),f[u]?d.push(f):e.push(f);f=A(a,xb(e,d)),f.selector=a}return f},i=fb.select=function(a,b,e,f){var i,j,k,l,m,n="function"==typeof a&&a,o=!f&&g(a=n.selector||a);if(e=e||[],1===o.length){if(j=o[0]=o[0].slice(0),j.length>2&&"ID"===(k=j[0]).type&&c.getById&&9===b.nodeType&&p&&d.relative[j[1].type]){if(b=(d.find.ID(k.matches[0].replace(cb,db),b)||[])[0],!b)return e;n&&(b=b.parentNode),a=a.slice(j.shift().value.length)}i=X.needsContext.test(a)?0:j.length;while(i--){if(k=j[i],d.relative[l=k.type])break;if((m=d.find[l])&&(f=m(k.matches[0].replace(cb,db),ab.test(j[0].type)&&ob(b.parentNode)||b))){if(j.splice(i,1),a=f.length&&qb(j),!a)return I.apply(e,f),e;break}}}return(n||h(a,o))(f,b,!p,e,ab.test(a)&&ob(b.parentNode)||b),e},c.sortStable=u.split("").sort(B).join("")===u,c.detectDuplicates=!!l,m(),c.sortDetached=ib(function(a){return 1&a.compareDocumentPosition(n.createElement("div"))}),ib(function(a){return a.innerHTML="","#"===a.firstChild.getAttribute("href")})||jb("type|href|height|width",function(a,b,c){return c?void 0:a.getAttribute(b,"type"===b.toLowerCase()?1:2)}),c.attributes&&ib(function(a){return a.innerHTML="",a.firstChild.setAttribute("value",""),""===a.firstChild.getAttribute("value")})||jb("value",function(a,b,c){return c||"input"!==a.nodeName.toLowerCase()?void 0:a.defaultValue}),ib(function(a){return null==a.getAttribute("disabled")})||jb(L,function(a,b,c){var d;return c?void 0:a[b]===!0?b.toLowerCase():(d=a.getAttributeNode(b))&&d.specified?d.value:null}),fb}(a);m.find=s,m.expr=s.selectors,m.expr[":"]=m.expr.pseudos,m.unique=s.uniqueSort,m.text=s.getText,m.isXMLDoc=s.isXML,m.contains=s.contains;var t=m.expr.match.needsContext,u=/^<(\w+)\s*\/?>(?:<\/\1>|)$/,v=/^.[^:#\[\.,]*$/;function w(a,b,c){if(m.isFunction(b))return m.grep(a,function(a,d){return!!b.call(a,d,a)!==c});if(b.nodeType)return m.grep(a,function(a){return a===b!==c});if("string"==typeof b){if(v.test(b))return m.filter(b,a,c);b=m.filter(b,a)}return m.grep(a,function(a){return m.inArray(a,b)>=0!==c})}m.filter=function(a,b,c){var d=b[0];return c&&(a=":not("+a+")"),1===b.length&&1===d.nodeType?m.find.matchesSelector(d,a)?[d]:[]:m.find.matches(a,m.grep(b,function(a){return 1===a.nodeType}))},m.fn.extend({find:function(a){var b,c=[],d=this,e=d.length;if("string"!=typeof a)return this.pushStack(m(a).filter(function(){for(b=0;e>b;b++)if(m.contains(d[b],this))return!0}));for(b=0;e>b;b++)m.find(a,d[b],c);return c=this.pushStack(e>1?m.unique(c):c),c.selector=this.selector?this.selector+" "+a:a,c},filter:function(a){return this.pushStack(w(this,a||[],!1))},not:function(a){return this.pushStack(w(this,a||[],!0))},is:function(a){return!!w(this,"string"==typeof a&&t.test(a)?m(a):a||[],!1).length}});var x,y=a.document,z=/^(?:\s*(<[\w\W]+>)[^>]*|#([\w-]*))$/,A=m.fn.init=function(a,b){var c,d;if(!a)return this;if("string"==typeof a){if(c="<"===a.charAt(0)&&">"===a.charAt(a.length-1)&&a.length>=3?[null,a,null]:z.exec(a),!c||!c[1]&&b)return!b||b.jquery?(b||x).find(a):this.constructor(b).find(a);if(c[1]){if(b=b instanceof m?b[0]:b,m.merge(this,m.parseHTML(c[1],b&&b.nodeType?b.ownerDocument||b:y,!0)),u.test(c[1])&&m.isPlainObject(b))for(c in b)m.isFunction(this[c])?this[c](b[c]):this.attr(c,b[c]);return this}if(d=y.getElementById(c[2]),d&&d.parentNode){if(d.id!==c[2])return x.find(a);this.length=1,this[0]=d}return this.context=y,this.selector=a,this}return a.nodeType?(this.context=this[0]=a,this.length=1,this):m.isFunction(a)?"undefined"!=typeof x.ready?x.ready(a):a(m):(void 0!==a.selector&&(this.selector=a.selector,this.context=a.context),m.makeArray(a,this))};A.prototype=m.fn,x=m(y);var B=/^(?:parents|prev(?:Until|All))/,C={children:!0,contents:!0,next:!0,prev:!0};m.extend({dir:function(a,b,c){var d=[],e=a[b];while(e&&9!==e.nodeType&&(void 0===c||1!==e.nodeType||!m(e).is(c)))1===e.nodeType&&d.push(e),e=e[b];return d},sibling:function(a,b){for(var c=[];a;a=a.nextSibling)1===a.nodeType&&a!==b&&c.push(a);return c}}),m.fn.extend({has:function(a){var b,c=m(a,this),d=c.length;return this.filter(function(){for(b=0;d>b;b++)if(m.contains(this,c[b]))return!0})},closest:function(a,b){for(var c,d=0,e=this.length,f=[],g=t.test(a)||"string"!=typeof a?m(a,b||this.context):0;e>d;d++)for(c=this[d];c&&c!==b;c=c.parentNode)if(c.nodeType<11&&(g?g.index(c)>-1:1===c.nodeType&&m.find.matchesSelector(c,a))){f.push(c);break}return this.pushStack(f.length>1?m.unique(f):f)},index:function(a){return a?"string"==typeof a?m.inArray(this[0],m(a)):m.inArray(a.jquery?a[0]:a,this):this[0]&&this[0].parentNode?this.first().prevAll().length:-1},add:function(a,b){return this.pushStack(m.unique(m.merge(this.get(),m(a,b))))},addBack:function(a){return this.add(null==a?this.prevObject:this.prevObject.filter(a))}});function D(a,b){do a=a[b];while(a&&1!==a.nodeType);return a}m.each({parent:function(a){var b=a.parentNode;return b&&11!==b.nodeType?b:null},parents:function(a){return m.dir(a,"parentNode")},parentsUntil:function(a,b,c){return m.dir(a,"parentNode",c)},next:function(a){return D(a,"nextSibling")},prev:function(a){return D(a,"previousSibling")},nextAll:function(a){return m.dir(a,"nextSibling")},prevAll:function(a){return m.dir(a,"previousSibling")},nextUntil:function(a,b,c){return m.dir(a,"nextSibling",c)},prevUntil:function(a,b,c){return m.dir(a,"previousSibling",c)},siblings:function(a){return m.sibling((a.parentNode||{}).firstChild,a)},children:function(a){return m.sibling(a.firstChild)},contents:function(a){return m.nodeName(a,"iframe")?a.contentDocument||a.contentWindow.document:m.merge([],a.childNodes)}},function(a,b){m.fn[a]=function(c,d){var e=m.map(this,b,c);return"Until"!==a.slice(-5)&&(d=c),d&&"string"==typeof d&&(e=m.filter(d,e)),this.length>1&&(C[a]||(e=m.unique(e)),B.test(a)&&(e=e.reverse())),this.pushStack(e)}});var E=/\S+/g,F={};function G(a){var b=F[a]={};return m.each(a.match(E)||[],function(a,c){b[c]=!0}),b}m.Callbacks=function(a){a="string"==typeof a?F[a]||G(a):m.extend({},a);var b,c,d,e,f,g,h=[],i=!a.once&&[],j=function(l){for(c=a.memory&&l,d=!0,f=g||0,g=0,e=h.length,b=!0;h&&e>f;f++)if(h[f].apply(l[0],l[1])===!1&&a.stopOnFalse){c=!1;break}b=!1,h&&(i?i.length&&j(i.shift()):c?h=[]:k.disable())},k={add:function(){if(h){var d=h.length;!function f(b){m.each(b,function(b,c){var d=m.type(c);"function"===d?a.unique&&k.has(c)||h.push(c):c&&c.length&&"string"!==d&&f(c)})}(arguments),b?e=h.length:c&&(g=d,j(c))}return this},remove:function(){return h&&m.each(arguments,function(a,c){var d;while((d=m.inArray(c,h,d))>-1)h.splice(d,1),b&&(e>=d&&e--,f>=d&&f--)}),this},has:function(a){return a?m.inArray(a,h)>-1:!(!h||!h.length)},empty:function(){return h=[],e=0,this},disable:function(){return h=i=c=void 0,this},disabled:function(){return!h},lock:function(){return i=void 0,c||k.disable(),this},locked:function(){return!i},fireWith:function(a,c){return!h||d&&!i||(c=c||[],c=[a,c.slice?c.slice():c],b?i.push(c):j(c)),this},fire:function(){return k.fireWith(this,arguments),this},fired:function(){return!!d}};return k},m.extend({Deferred:function(a){var b=[["resolve","done",m.Callbacks("once memory"),"resolved"],["reject","fail",m.Callbacks("once memory"),"rejected"],["notify","progress",m.Callbacks("memory")]],c="pending",d={state:function(){return c},always:function(){return e.done(arguments).fail(arguments),this},then:function(){var a=arguments;return m.Deferred(function(c){m.each(b,function(b,f){var g=m.isFunction(a[b])&&a[b];e[f[1]](function(){var a=g&&g.apply(this,arguments);a&&m.isFunction(a.promise)?a.promise().done(c.resolve).fail(c.reject).progress(c.notify):c[f[0]+"With"](this===d?c.promise():this,g?[a]:arguments)})}),a=null}).promise()},promise:function(a){return null!=a?m.extend(a,d):d}},e={};return d.pipe=d.then,m.each(b,function(a,f){var g=f[2],h=f[3];d[f[1]]=g.add,h&&g.add(function(){c=h},b[1^a][2].disable,b[2][2].lock),e[f[0]]=function(){return e[f[0]+"With"](this===e?d:this,arguments),this},e[f[0]+"With"]=g.fireWith}),d.promise(e),a&&a.call(e,e),e},when:function(a){var b=0,c=d.call(arguments),e=c.length,f=1!==e||a&&m.isFunction(a.promise)?e:0,g=1===f?a:m.Deferred(),h=function(a,b,c){return function(e){b[a]=this,c[a]=arguments.length>1?d.call(arguments):e,c===i?g.notifyWith(b,c):--f||g.resolveWith(b,c)}},i,j,k;if(e>1)for(i=new Array(e),j=new Array(e),k=new Array(e);e>b;b++)c[b]&&m.isFunction(c[b].promise)?c[b].promise().done(h(b,k,c)).fail(g.reject).progress(h(b,j,i)):--f;return f||g.resolveWith(k,c),g.promise()}});var H;m.fn.ready=function(a){return m.ready.promise().done(a),this},m.extend({isReady:!1,readyWait:1,holdReady:function(a){a?m.readyWait++:m.ready(!0)},ready:function(a){if(a===!0?!--m.readyWait:!m.isReady){if(!y.body)return setTimeout(m.ready);m.isReady=!0,a!==!0&&--m.readyWait>0||(H.resolveWith(y,[m]),m.fn.triggerHandler&&(m(y).triggerHandler("ready"),m(y).off("ready")))}}});function I(){y.addEventListener?(y.removeEventListener("DOMContentLoaded",J,!1),a.removeEventListener("load",J,!1)):(y.detachEvent("onreadystatechange",J),a.detachEvent("onload",J))}function J(){(y.addEventListener||"load"===event.type||"complete"===y.readyState)&&(I(),m.ready())}m.ready.promise=function(b){if(!H)if(H=m.Deferred(),"complete"===y.readyState)setTimeout(m.ready);else if(y.addEventListener)y.addEventListener("DOMContentLoaded",J,!1),a.addEventListener("load",J,!1);else{y.attachEvent("onreadystatechange",J),a.attachEvent("onload",J);var c=!1;try{c=null==a.frameElement&&y.documentElement}catch(d){}c&&c.doScroll&&!function e(){if(!m.isReady){try{c.doScroll("left")}catch(a){return setTimeout(e,50)}I(),m.ready()}}()}return H.promise(b)};var K="undefined",L;for(L in m(k))break;k.ownLast="0"!==L,k.inlineBlockNeedsLayout=!1,m(function(){var a,b,c,d;c=y.getElementsByTagName("body")[0],c&&c.style&&(b=y.createElement("div"),d=y.createElement("div"),d.style.cssText="position:absolute;border:0;width:0;height:0;top:0;left:-9999px",c.appendChild(d).appendChild(b),typeof b.style.zoom!==K&&(b.style.cssText="display:inline;margin:0;border:0;padding:1px;width:1px;zoom:1",k.inlineBlockNeedsLayout=a=3===b.offsetWidth,a&&(c.style.zoom=1)),c.removeChild(d))}),function(){var a=y.createElement("div");if(null==k.deleteExpando){k.deleteExpando=!0;try{delete a.test}catch(b){k.deleteExpando=!1}}a=null}(),m.acceptData=function(a){var b=m.noData[(a.nodeName+" ").toLowerCase()],c=+a.nodeType||1;return 1!==c&&9!==c?!1:!b||b!==!0&&a.getAttribute("classid")===b};var M=/^(?:\{[\w\W]*\}|\[[\w\W]*\])$/,N=/([A-Z])/g;function O(a,b,c){if(void 0===c&&1===a.nodeType){var d="data-"+b.replace(N,"-$1").toLowerCase();if(c=a.getAttribute(d),"string"==typeof c){try{c="true"===c?!0:"false"===c?!1:"null"===c?null:+c+""===c?+c:M.test(c)?m.parseJSON(c):c}catch(e){}m.data(a,b,c)}else c=void 0}return c}function P(a){var b;for(b in a)if(("data"!==b||!m.isEmptyObject(a[b]))&&"toJSON"!==b)return!1;return!0}function Q(a,b,d,e){if(m.acceptData(a)){var f,g,h=m.expando,i=a.nodeType,j=i?m.cache:a,k=i?a[h]:a[h]&&h; +if(k&&j[k]&&(e||j[k].data)||void 0!==d||"string"!=typeof b)return k||(k=i?a[h]=c.pop()||m.guid++:h),j[k]||(j[k]=i?{}:{toJSON:m.noop}),("object"==typeof b||"function"==typeof b)&&(e?j[k]=m.extend(j[k],b):j[k].data=m.extend(j[k].data,b)),g=j[k],e||(g.data||(g.data={}),g=g.data),void 0!==d&&(g[m.camelCase(b)]=d),"string"==typeof b?(f=g[b],null==f&&(f=g[m.camelCase(b)])):f=g,f}}function R(a,b,c){if(m.acceptData(a)){var d,e,f=a.nodeType,g=f?m.cache:a,h=f?a[m.expando]:m.expando;if(g[h]){if(b&&(d=c?g[h]:g[h].data)){m.isArray(b)?b=b.concat(m.map(b,m.camelCase)):b in d?b=[b]:(b=m.camelCase(b),b=b in d?[b]:b.split(" ")),e=b.length;while(e--)delete d[b[e]];if(c?!P(d):!m.isEmptyObject(d))return}(c||(delete g[h].data,P(g[h])))&&(f?m.cleanData([a],!0):k.deleteExpando||g!=g.window?delete g[h]:g[h]=null)}}}m.extend({cache:{},noData:{"applet ":!0,"embed ":!0,"object ":"clsid:D27CDB6E-AE6D-11cf-96B8-444553540000"},hasData:function(a){return a=a.nodeType?m.cache[a[m.expando]]:a[m.expando],!!a&&!P(a)},data:function(a,b,c){return Q(a,b,c)},removeData:function(a,b){return R(a,b)},_data:function(a,b,c){return Q(a,b,c,!0)},_removeData:function(a,b){return R(a,b,!0)}}),m.fn.extend({data:function(a,b){var c,d,e,f=this[0],g=f&&f.attributes;if(void 0===a){if(this.length&&(e=m.data(f),1===f.nodeType&&!m._data(f,"parsedAttrs"))){c=g.length;while(c--)g[c]&&(d=g[c].name,0===d.indexOf("data-")&&(d=m.camelCase(d.slice(5)),O(f,d,e[d])));m._data(f,"parsedAttrs",!0)}return e}return"object"==typeof a?this.each(function(){m.data(this,a)}):arguments.length>1?this.each(function(){m.data(this,a,b)}):f?O(f,a,m.data(f,a)):void 0},removeData:function(a){return this.each(function(){m.removeData(this,a)})}}),m.extend({queue:function(a,b,c){var d;return a?(b=(b||"fx")+"queue",d=m._data(a,b),c&&(!d||m.isArray(c)?d=m._data(a,b,m.makeArray(c)):d.push(c)),d||[]):void 0},dequeue:function(a,b){b=b||"fx";var c=m.queue(a,b),d=c.length,e=c.shift(),f=m._queueHooks(a,b),g=function(){m.dequeue(a,b)};"inprogress"===e&&(e=c.shift(),d--),e&&("fx"===b&&c.unshift("inprogress"),delete f.stop,e.call(a,g,f)),!d&&f&&f.empty.fire()},_queueHooks:function(a,b){var c=b+"queueHooks";return m._data(a,c)||m._data(a,c,{empty:m.Callbacks("once memory").add(function(){m._removeData(a,b+"queue"),m._removeData(a,c)})})}}),m.fn.extend({queue:function(a,b){var c=2;return"string"!=typeof a&&(b=a,a="fx",c--),arguments.lengthh;h++)b(a[h],c,g?d:d.call(a[h],h,b(a[h],c)));return e?a:j?b.call(a):i?b(a[0],c):f},W=/^(?:checkbox|radio)$/i;!function(){var a=y.createElement("input"),b=y.createElement("div"),c=y.createDocumentFragment();if(b.innerHTML="
a",k.leadingWhitespace=3===b.firstChild.nodeType,k.tbody=!b.getElementsByTagName("tbody").length,k.htmlSerialize=!!b.getElementsByTagName("link").length,k.html5Clone="<:nav>"!==y.createElement("nav").cloneNode(!0).outerHTML,a.type="checkbox",a.checked=!0,c.appendChild(a),k.appendChecked=a.checked,b.innerHTML="",k.noCloneChecked=!!b.cloneNode(!0).lastChild.defaultValue,c.appendChild(b),b.innerHTML="",k.checkClone=b.cloneNode(!0).cloneNode(!0).lastChild.checked,k.noCloneEvent=!0,b.attachEvent&&(b.attachEvent("onclick",function(){k.noCloneEvent=!1}),b.cloneNode(!0).click()),null==k.deleteExpando){k.deleteExpando=!0;try{delete b.test}catch(d){k.deleteExpando=!1}}}(),function(){var b,c,d=y.createElement("div");for(b in{submit:!0,change:!0,focusin:!0})c="on"+b,(k[b+"Bubbles"]=c in a)||(d.setAttribute(c,"t"),k[b+"Bubbles"]=d.attributes[c].expando===!1);d=null}();var X=/^(?:input|select|textarea)$/i,Y=/^key/,Z=/^(?:mouse|pointer|contextmenu)|click/,$=/^(?:focusinfocus|focusoutblur)$/,_=/^([^.]*)(?:\.(.+)|)$/;function ab(){return!0}function bb(){return!1}function cb(){try{return y.activeElement}catch(a){}}m.event={global:{},add:function(a,b,c,d,e){var f,g,h,i,j,k,l,n,o,p,q,r=m._data(a);if(r){c.handler&&(i=c,c=i.handler,e=i.selector),c.guid||(c.guid=m.guid++),(g=r.events)||(g=r.events={}),(k=r.handle)||(k=r.handle=function(a){return typeof m===K||a&&m.event.triggered===a.type?void 0:m.event.dispatch.apply(k.elem,arguments)},k.elem=a),b=(b||"").match(E)||[""],h=b.length;while(h--)f=_.exec(b[h])||[],o=q=f[1],p=(f[2]||"").split(".").sort(),o&&(j=m.event.special[o]||{},o=(e?j.delegateType:j.bindType)||o,j=m.event.special[o]||{},l=m.extend({type:o,origType:q,data:d,handler:c,guid:c.guid,selector:e,needsContext:e&&m.expr.match.needsContext.test(e),namespace:p.join(".")},i),(n=g[o])||(n=g[o]=[],n.delegateCount=0,j.setup&&j.setup.call(a,d,p,k)!==!1||(a.addEventListener?a.addEventListener(o,k,!1):a.attachEvent&&a.attachEvent("on"+o,k))),j.add&&(j.add.call(a,l),l.handler.guid||(l.handler.guid=c.guid)),e?n.splice(n.delegateCount++,0,l):n.push(l),m.event.global[o]=!0);a=null}},remove:function(a,b,c,d,e){var f,g,h,i,j,k,l,n,o,p,q,r=m.hasData(a)&&m._data(a);if(r&&(k=r.events)){b=(b||"").match(E)||[""],j=b.length;while(j--)if(h=_.exec(b[j])||[],o=q=h[1],p=(h[2]||"").split(".").sort(),o){l=m.event.special[o]||{},o=(d?l.delegateType:l.bindType)||o,n=k[o]||[],h=h[2]&&new RegExp("(^|\\.)"+p.join("\\.(?:.*\\.|)")+"(\\.|$)"),i=f=n.length;while(f--)g=n[f],!e&&q!==g.origType||c&&c.guid!==g.guid||h&&!h.test(g.namespace)||d&&d!==g.selector&&("**"!==d||!g.selector)||(n.splice(f,1),g.selector&&n.delegateCount--,l.remove&&l.remove.call(a,g));i&&!n.length&&(l.teardown&&l.teardown.call(a,p,r.handle)!==!1||m.removeEvent(a,o,r.handle),delete k[o])}else for(o in k)m.event.remove(a,o+b[j],c,d,!0);m.isEmptyObject(k)&&(delete r.handle,m._removeData(a,"events"))}},trigger:function(b,c,d,e){var f,g,h,i,k,l,n,o=[d||y],p=j.call(b,"type")?b.type:b,q=j.call(b,"namespace")?b.namespace.split("."):[];if(h=l=d=d||y,3!==d.nodeType&&8!==d.nodeType&&!$.test(p+m.event.triggered)&&(p.indexOf(".")>=0&&(q=p.split("."),p=q.shift(),q.sort()),g=p.indexOf(":")<0&&"on"+p,b=b[m.expando]?b:new m.Event(p,"object"==typeof b&&b),b.isTrigger=e?2:3,b.namespace=q.join("."),b.namespace_re=b.namespace?new RegExp("(^|\\.)"+q.join("\\.(?:.*\\.|)")+"(\\.|$)"):null,b.result=void 0,b.target||(b.target=d),c=null==c?[b]:m.makeArray(c,[b]),k=m.event.special[p]||{},e||!k.trigger||k.trigger.apply(d,c)!==!1)){if(!e&&!k.noBubble&&!m.isWindow(d)){for(i=k.delegateType||p,$.test(i+p)||(h=h.parentNode);h;h=h.parentNode)o.push(h),l=h;l===(d.ownerDocument||y)&&o.push(l.defaultView||l.parentWindow||a)}n=0;while((h=o[n++])&&!b.isPropagationStopped())b.type=n>1?i:k.bindType||p,f=(m._data(h,"events")||{})[b.type]&&m._data(h,"handle"),f&&f.apply(h,c),f=g&&h[g],f&&f.apply&&m.acceptData(h)&&(b.result=f.apply(h,c),b.result===!1&&b.preventDefault());if(b.type=p,!e&&!b.isDefaultPrevented()&&(!k._default||k._default.apply(o.pop(),c)===!1)&&m.acceptData(d)&&g&&d[p]&&!m.isWindow(d)){l=d[g],l&&(d[g]=null),m.event.triggered=p;try{d[p]()}catch(r){}m.event.triggered=void 0,l&&(d[g]=l)}return b.result}},dispatch:function(a){a=m.event.fix(a);var b,c,e,f,g,h=[],i=d.call(arguments),j=(m._data(this,"events")||{})[a.type]||[],k=m.event.special[a.type]||{};if(i[0]=a,a.delegateTarget=this,!k.preDispatch||k.preDispatch.call(this,a)!==!1){h=m.event.handlers.call(this,a,j),b=0;while((f=h[b++])&&!a.isPropagationStopped()){a.currentTarget=f.elem,g=0;while((e=f.handlers[g++])&&!a.isImmediatePropagationStopped())(!a.namespace_re||a.namespace_re.test(e.namespace))&&(a.handleObj=e,a.data=e.data,c=((m.event.special[e.origType]||{}).handle||e.handler).apply(f.elem,i),void 0!==c&&(a.result=c)===!1&&(a.preventDefault(),a.stopPropagation()))}return k.postDispatch&&k.postDispatch.call(this,a),a.result}},handlers:function(a,b){var c,d,e,f,g=[],h=b.delegateCount,i=a.target;if(h&&i.nodeType&&(!a.button||"click"!==a.type))for(;i!=this;i=i.parentNode||this)if(1===i.nodeType&&(i.disabled!==!0||"click"!==a.type)){for(e=[],f=0;h>f;f++)d=b[f],c=d.selector+" ",void 0===e[c]&&(e[c]=d.needsContext?m(c,this).index(i)>=0:m.find(c,this,null,[i]).length),e[c]&&e.push(d);e.length&&g.push({elem:i,handlers:e})}return h]","i"),hb=/^\s+/,ib=/<(?!area|br|col|embed|hr|img|input|link|meta|param)(([\w:]+)[^>]*)\/>/gi,jb=/<([\w:]+)/,kb=/\s*$/g,rb={option:[1,""],legend:[1,"
","
"],area:[1,"",""],param:[1,"",""],thead:[1,"","
"],tr:[2,"","
"],col:[2,"","
"],td:[3,"","
"],_default:k.htmlSerialize?[0,"",""]:[1,"X
","
"]},sb=db(y),tb=sb.appendChild(y.createElement("div"));rb.optgroup=rb.option,rb.tbody=rb.tfoot=rb.colgroup=rb.caption=rb.thead,rb.th=rb.td;function ub(a,b){var c,d,e=0,f=typeof a.getElementsByTagName!==K?a.getElementsByTagName(b||"*"):typeof a.querySelectorAll!==K?a.querySelectorAll(b||"*"):void 0;if(!f)for(f=[],c=a.childNodes||a;null!=(d=c[e]);e++)!b||m.nodeName(d,b)?f.push(d):m.merge(f,ub(d,b));return void 0===b||b&&m.nodeName(a,b)?m.merge([a],f):f}function vb(a){W.test(a.type)&&(a.defaultChecked=a.checked)}function wb(a,b){return m.nodeName(a,"table")&&m.nodeName(11!==b.nodeType?b:b.firstChild,"tr")?a.getElementsByTagName("tbody")[0]||a.appendChild(a.ownerDocument.createElement("tbody")):a}function xb(a){return a.type=(null!==m.find.attr(a,"type"))+"/"+a.type,a}function yb(a){var b=pb.exec(a.type);return b?a.type=b[1]:a.removeAttribute("type"),a}function zb(a,b){for(var c,d=0;null!=(c=a[d]);d++)m._data(c,"globalEval",!b||m._data(b[d],"globalEval"))}function Ab(a,b){if(1===b.nodeType&&m.hasData(a)){var c,d,e,f=m._data(a),g=m._data(b,f),h=f.events;if(h){delete g.handle,g.events={};for(c in h)for(d=0,e=h[c].length;e>d;d++)m.event.add(b,c,h[c][d])}g.data&&(g.data=m.extend({},g.data))}}function Bb(a,b){var c,d,e;if(1===b.nodeType){if(c=b.nodeName.toLowerCase(),!k.noCloneEvent&&b[m.expando]){e=m._data(b);for(d in e.events)m.removeEvent(b,d,e.handle);b.removeAttribute(m.expando)}"script"===c&&b.text!==a.text?(xb(b).text=a.text,yb(b)):"object"===c?(b.parentNode&&(b.outerHTML=a.outerHTML),k.html5Clone&&a.innerHTML&&!m.trim(b.innerHTML)&&(b.innerHTML=a.innerHTML)):"input"===c&&W.test(a.type)?(b.defaultChecked=b.checked=a.checked,b.value!==a.value&&(b.value=a.value)):"option"===c?b.defaultSelected=b.selected=a.defaultSelected:("input"===c||"textarea"===c)&&(b.defaultValue=a.defaultValue)}}m.extend({clone:function(a,b,c){var d,e,f,g,h,i=m.contains(a.ownerDocument,a);if(k.html5Clone||m.isXMLDoc(a)||!gb.test("<"+a.nodeName+">")?f=a.cloneNode(!0):(tb.innerHTML=a.outerHTML,tb.removeChild(f=tb.firstChild)),!(k.noCloneEvent&&k.noCloneChecked||1!==a.nodeType&&11!==a.nodeType||m.isXMLDoc(a)))for(d=ub(f),h=ub(a),g=0;null!=(e=h[g]);++g)d[g]&&Bb(e,d[g]);if(b)if(c)for(h=h||ub(a),d=d||ub(f),g=0;null!=(e=h[g]);g++)Ab(e,d[g]);else Ab(a,f);return d=ub(f,"script"),d.length>0&&zb(d,!i&&ub(a,"script")),d=h=e=null,f},buildFragment:function(a,b,c,d){for(var e,f,g,h,i,j,l,n=a.length,o=db(b),p=[],q=0;n>q;q++)if(f=a[q],f||0===f)if("object"===m.type(f))m.merge(p,f.nodeType?[f]:f);else if(lb.test(f)){h=h||o.appendChild(b.createElement("div")),i=(jb.exec(f)||["",""])[1].toLowerCase(),l=rb[i]||rb._default,h.innerHTML=l[1]+f.replace(ib,"<$1>")+l[2],e=l[0];while(e--)h=h.lastChild;if(!k.leadingWhitespace&&hb.test(f)&&p.push(b.createTextNode(hb.exec(f)[0])),!k.tbody){f="table"!==i||kb.test(f)?""!==l[1]||kb.test(f)?0:h:h.firstChild,e=f&&f.childNodes.length;while(e--)m.nodeName(j=f.childNodes[e],"tbody")&&!j.childNodes.length&&f.removeChild(j)}m.merge(p,h.childNodes),h.textContent="";while(h.firstChild)h.removeChild(h.firstChild);h=o.lastChild}else p.push(b.createTextNode(f));h&&o.removeChild(h),k.appendChecked||m.grep(ub(p,"input"),vb),q=0;while(f=p[q++])if((!d||-1===m.inArray(f,d))&&(g=m.contains(f.ownerDocument,f),h=ub(o.appendChild(f),"script"),g&&zb(h),c)){e=0;while(f=h[e++])ob.test(f.type||"")&&c.push(f)}return h=null,o},cleanData:function(a,b){for(var d,e,f,g,h=0,i=m.expando,j=m.cache,l=k.deleteExpando,n=m.event.special;null!=(d=a[h]);h++)if((b||m.acceptData(d))&&(f=d[i],g=f&&j[f])){if(g.events)for(e in g.events)n[e]?m.event.remove(d,e):m.removeEvent(d,e,g.handle);j[f]&&(delete j[f],l?delete d[i]:typeof d.removeAttribute!==K?d.removeAttribute(i):d[i]=null,c.push(f))}}}),m.fn.extend({text:function(a){return V(this,function(a){return void 0===a?m.text(this):this.empty().append((this[0]&&this[0].ownerDocument||y).createTextNode(a))},null,a,arguments.length)},append:function(){return this.domManip(arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=wb(this,a);b.appendChild(a)}})},prepend:function(){return this.domManip(arguments,function(a){if(1===this.nodeType||11===this.nodeType||9===this.nodeType){var b=wb(this,a);b.insertBefore(a,b.firstChild)}})},before:function(){return this.domManip(arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this)})},after:function(){return this.domManip(arguments,function(a){this.parentNode&&this.parentNode.insertBefore(a,this.nextSibling)})},remove:function(a,b){for(var c,d=a?m.filter(a,this):this,e=0;null!=(c=d[e]);e++)b||1!==c.nodeType||m.cleanData(ub(c)),c.parentNode&&(b&&m.contains(c.ownerDocument,c)&&zb(ub(c,"script")),c.parentNode.removeChild(c));return this},empty:function(){for(var a,b=0;null!=(a=this[b]);b++){1===a.nodeType&&m.cleanData(ub(a,!1));while(a.firstChild)a.removeChild(a.firstChild);a.options&&m.nodeName(a,"select")&&(a.options.length=0)}return this},clone:function(a,b){return a=null==a?!1:a,b=null==b?a:b,this.map(function(){return m.clone(this,a,b)})},html:function(a){return V(this,function(a){var b=this[0]||{},c=0,d=this.length;if(void 0===a)return 1===b.nodeType?b.innerHTML.replace(fb,""):void 0;if(!("string"!=typeof a||mb.test(a)||!k.htmlSerialize&&gb.test(a)||!k.leadingWhitespace&&hb.test(a)||rb[(jb.exec(a)||["",""])[1].toLowerCase()])){a=a.replace(ib,"<$1>");try{for(;d>c;c++)b=this[c]||{},1===b.nodeType&&(m.cleanData(ub(b,!1)),b.innerHTML=a);b=0}catch(e){}}b&&this.empty().append(a)},null,a,arguments.length)},replaceWith:function(){var a=arguments[0];return this.domManip(arguments,function(b){a=this.parentNode,m.cleanData(ub(this)),a&&a.replaceChild(b,this)}),a&&(a.length||a.nodeType)?this:this.remove()},detach:function(a){return this.remove(a,!0)},domManip:function(a,b){a=e.apply([],a);var c,d,f,g,h,i,j=0,l=this.length,n=this,o=l-1,p=a[0],q=m.isFunction(p);if(q||l>1&&"string"==typeof p&&!k.checkClone&&nb.test(p))return this.each(function(c){var d=n.eq(c);q&&(a[0]=p.call(this,c,d.html())),d.domManip(a,b)});if(l&&(i=m.buildFragment(a,this[0].ownerDocument,!1,this),c=i.firstChild,1===i.childNodes.length&&(i=c),c)){for(g=m.map(ub(i,"script"),xb),f=g.length;l>j;j++)d=i,j!==o&&(d=m.clone(d,!0,!0),f&&m.merge(g,ub(d,"script"))),b.call(this[j],d,j);if(f)for(h=g[g.length-1].ownerDocument,m.map(g,yb),j=0;f>j;j++)d=g[j],ob.test(d.type||"")&&!m._data(d,"globalEval")&&m.contains(h,d)&&(d.src?m._evalUrl&&m._evalUrl(d.src):m.globalEval((d.text||d.textContent||d.innerHTML||"").replace(qb,"")));i=c=null}return this}}),m.each({appendTo:"append",prependTo:"prepend",insertBefore:"before",insertAfter:"after",replaceAll:"replaceWith"},function(a,b){m.fn[a]=function(a){for(var c,d=0,e=[],g=m(a),h=g.length-1;h>=d;d++)c=d===h?this:this.clone(!0),m(g[d])[b](c),f.apply(e,c.get());return this.pushStack(e)}});var Cb,Db={};function Eb(b,c){var d,e=m(c.createElement(b)).appendTo(c.body),f=a.getDefaultComputedStyle&&(d=a.getDefaultComputedStyle(e[0]))?d.display:m.css(e[0],"display");return e.detach(),f}function Fb(a){var b=y,c=Db[a];return c||(c=Eb(a,b),"none"!==c&&c||(Cb=(Cb||m("