diff --git a/.gitignore b/.gitignore index 8d9f2708..796ab249 100644 --- a/.gitignore +++ b/.gitignore @@ -20,7 +20,6 @@ db/*.sql *.sql.* *.jpeg log/* -*.ipynb *.doctree tasks/debugging.txt riglib/plexon/plexfile.py @@ -46,5 +45,8 @@ riglib/fsm/fsm.egg-info/* env/* tests/*.mat tests/*.hdf +tests/unit_tests/htmlcov *.h5 riglib/plexon/Plexfile_utilities.egg-info/* +*.dot +riglib/dio/nidaq/pcidio.py diff --git a/README.md b/README.md index 4c7da146..e0f58367 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,12 @@ Getting started ```bash sudo xargs apt-get -y install < requirements.system ``` +## Linux/OS X +(none at this time) +16.04 64bit + +#python version +3.7.8 ## Windows Visual C++ Build tools (for the 'traits' package) diff --git a/utils/fa_decomp.py b/analysis/fa_decomp.py similarity index 100% rename from utils/fa_decomp.py rename to analysis/fa_decomp.py diff --git a/analysis/performance_metrics.py b/analysis/performance_metrics.py index 20b7b619..8498d4aa 100644 --- a/analysis/performance_metrics.py +++ b/analysis/performance_metrics.py @@ -6,7 +6,6 @@ import os import tables -from itertools import izip from riglib.bmi import robot_arms, train diff --git a/analysis/target_capture_task_analysis.py b/analysis/target_capture_task_analysis.py index 4965ef09..7c103811 100644 --- a/analysis/target_capture_task_analysis.py +++ b/analysis/target_capture_task_analysis.py @@ -8,10 +8,8 @@ from collections import OrderedDict, defaultdict import os import tables -from itertools import izip from riglib.bmi import robot_arms, train, kfdecoder, ppfdecoder -from db.tracker import models from db import dbfunctions from db import dbfunctions as dbfn @@ -52,43 +50,40 @@ def __init__(self, *args, **kwargs): super(ManualControlMultiTaskEntry, self).__init__(*args, **kwargs) try: - task_msgs = self.hdf.root.task_msgs[:] - # Ignore the last message if it's the "None" transition used to stop the task - if task_msgs[-1]['msg'] == 'None': - task_msgs = task_msgs[:-1] - - # ignore "update bmi" messages. These have been removed in later datasets - task_msgs = task_msgs[task_msgs['msg'] != 'update_bmi'] - - target_index = self.hdf.root.task[:]['target_index'].ravel() - task_msg_dtype = np.dtype([('msg', '|S256'), ('time', ' target_radius: + assist_cursor_pos = cursor_pos + speed*dir_to_target + else: + assist_cursor_pos = cursor_pos + speed*diff_vec/2 + + return assist_cursor_pos.ravel() + class SimpleEndpointAssisterLFC(feedback_controllers.MultiModalLFC): ''' Docstring @@ -212,19 +220,9 @@ class BMIControlMulti(BMILoop, LinearlyDecreasingAssist, ScreenTargetCapture): is_bmi_seed = False - cursor_color_adjust = traits.OptionsList(*list(target_colors.keys()), bmi3d_input_options=list(target_colors.keys())) - def __init__(self, *args, **kwargs): super(BMIControlMulti, self).__init__(*args, **kwargs) - - def init(self, *args, **kwargs): - sph = self.plant.graphics_models[0] - sph.color = target_colors[self.cursor_color_adjust] - sph.radius = self.cursor_radius - self.plant.cursor_radius = self.cursor_radius - self.plant.cursor.radius = self.cursor_radius - super(BMIControlMulti, self).init(*args, **kwargs) - + def create_assister(self): # Create the appropriate type of assister object start_level, end_level = self.assist_level diff --git a/built_in_tasks/manualcontrolmultitasks.py b/built_in_tasks/manualcontrolmultitasks.py index 3e13a214..000f2184 100644 --- a/built_in_tasks/manualcontrolmultitasks.py +++ b/built_in_tasks/manualcontrolmultitasks.py @@ -12,46 +12,69 @@ from riglib.experiment import traits from .target_graphics import * -from .target_capture_task import ScreenTargetCapture - - -class JoystickMulti(ScreenTargetCapture): +from .target_capture_task import ScreenTargetCapture, ScreenReachAngle +from riglib.stereo_opengl.window import WindowDispl2D + + +rotations = dict( + yzx = np.array( + [[0, 1, 0, 0], + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 0, 0, 1]] + ), + zyx = np.array( + [[0, 0, 1, 0], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 1]] + ), + xzy = np.array( + [[1, 0, 0, 0], + [0, 0, 1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]] + ), + xyz = np.identity(4), +) + +class ManualControlMixin(traits.HasTraits): '''Target capture task where the subject operates a joystick to control a cursor. Targets are captured by having the cursor dwell in the screen target for the allotted time''' # Settable Traits - joystick_method = traits.Float(1,desc="1: Normal velocity, 0: Position control") - random_rewards = traits.Float(0,desc="Add randomness to reward, 1: yes, 0: no") - joystick_speed = traits.Float(20, desc="Radius of cursor") - + wait_time = traits.Float(2., desc="Time between successful trials") + velocity_control = traits.Bool(False, desc="Position or velocity control") + random_rewards = traits.Bool(False, desc="Add randomness to reward") + rotation = traits.OptionsList(*rotations, desc="Control rotation matrix", bmi3d_input_options=list(rotations.keys())) + scale = traits.Float(1.0, desc="Control scale factor") + offset = traits.Array(value=[0,0,0], desc="Control offset") is_bmi_seed = True def __init__(self, *args, **kwargs): - super(JoystickMulti, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.current_pt=np.zeros([3]) #keep track of current pt self.last_pt=np.zeros([3]) #keep track of last pt to calc. velocity + self._quality_window_size = 500 # how many cycles to accumulate quality statistics + self.reportstats['Input quality'] = "100 %" + if self.random_rewards: + self.reward_time_base = self.reward_time - def update_report_stats(self): - super(JoystickMulti, self).update_report_stats() - start_time = self.state_log[0][1] - rewardtimes=np.array([state[1] for state in self.state_log if state[0]=='reward']) - if len(rewardtimes): - rt = rewardtimes[-1]-start_time - else: - rt= np.float64("0.0") + def init(self): + self.add_dtype('manual_input', 'f8', (3,)) + super().init() + self.no_data_counter = np.zeros((self._quality_window_size,), dtype='?') - sec = str(np.int(np.mod(rt,60))) - if len(sec) < 2: - sec = '0'+sec - self.reportstats['Time Of Last Reward'] = str(np.int(np.floor(rt/60))) + ':' + sec + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause def _test_trial_complete(self, ts): if self.target_index==self.chain_length-1 : if self.random_rewards: if not self.rand_reward_set_flag: #reward time has not been set for this iteration self.reward_time = np.max([2*(np.random.rand()-0.5) + self.reward_time_base, self.reward_time_base/2]) #set randomly with min of base / 2 - self.rand_reward_set_flag =1; + self.rand_reward_set_flag =1 #print self.reward_time, self.rand_reward_set_flag return self.target_index==self.chain_length-1 @@ -59,74 +82,106 @@ def _test_reward_end(self, ts): #When finished reward, reset flag. if self.random_rewards: if ts > self.reward_time: - self.rand_reward_set_flag = 0; + self.rand_reward_set_flag = 0 #print self.reward_time, self.rand_reward_set_flag, ts return ts > self.reward_time - def move_effector(self): - ''' Returns the 3D coordinates of the cursor. For manual control, uses - motiontracker data. If no motiontracker data available, returns None''' - - #get data from phidget + def _transform_coords(self, coords): + ''' + Returns transformed coordinates based on rotation, offset, and scale traits + ''' + offset = np.array( + [[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [self.offset[0], self.offset[1], self.offset[2], 1]] + ) + scale = np.array( + [[self.scale, 0, 0, 0], + [0, self.scale, 0, 0], + [0, 0, self.scale, 0], + [0, 0, 0, 1]] + ) + old = np.concatenate((np.reshape(coords, -1), [1])) + new = np.linalg.multi_dot((old, offset, scale, rotations[self.rotation])) + return new[0:3] + + def _get_manual_position(self): + ''' + Fetches joystick position + ''' + if not hasattr(self, 'joystick'): + return pt = self.joystick.get() - #print pt - - if len(pt) > 0: + if len(pt) == 0: + return - pt = pt[-1][0] - x = pt[1] - y = 1-pt[0] + pt = pt[-1] # Use only the latest coordinate + if len(pt) == 2: + pt = np.concatenate((np.reshape(pt, -1), [0])) - pt[0]=1-pt[0]; #Switch L / R axes - calib = [0.5,0.5] #Sometimes zero point is subject to drift this is the value of the incoming joystick when at 'rest' - # calib = [ 0.487, 0. ] + return [pt] - #if self.joystick_method==0: - if self.joystick_method==0: - pos = np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]]) - pos[0] = pos[0]*36 - pos[2] = pos[2]*24 - self.current_pt = pos - - elif self.joystick_method==1: - #vel=np.array([(pt[0]-calib[0]), 0, calib[1]-pt[1]]) - vel = np.array([x-calib[0], 0., y-calib[1]]) - epsilon = 2*(10**-2) #Define epsilon to stabilize cursor movement - if sum((vel)**2) > epsilon: - self.current_pt=self.last_pt+20*vel*(1/60) #60 Hz update rate, dt = 1/60 - else: - self.current_pt = self.last_pt + def move_effector(self): + ''' + Sets the 3D coordinates of the cursor. For manual control, uses + motiontracker / joystick / mouse data. If no data available, returns None + ''' + + # Get raw input and save it as task data + raw_coords = self._get_manual_position() # array of [3x1] arrays + if raw_coords is None or len(raw_coords) < 1: + self.no_data_counter[self.cycle_count % self._quality_window_size] = 1 + self.update_report_stats() + self.task_data['manual_input'] = np.empty((3,)) + return + + self.task_data['manual_input'] = raw_coords.copy() + self.no_data_counter[self.cycle_count % self._quality_window_size] = 0 + + # Transform coordinates + coords = self._transform_coords(raw_coords) + if self.limit2d: + coords[1] = 0 + + # Set cursor position + if not self.velocity_control: + self.current_pt = coords + else: + epsilon = 2*(10**-2) # Define epsilon to stabilize cursor movement + if sum((coords)**2) > epsilon: - #self.current_pt = self.current_pt + (np.array([np.random.rand()-0.5, 0., np.random.rand()-0.5])*self.joystick_speed) + # Add the velocity (units/s) to the position (units) + self.current_pt = coords / self.fps + self.last_pt + else: + self.current_pt = self.last_pt - if self.current_pt[0] < -25: self.current_pt[0] = -25 - if self.current_pt[0] > 25: self.current_pt[0] = 25 - if self.current_pt[-1] < -14: self.current_pt[-1] = -14 - if self.current_pt[-1] > 14: self.current_pt[-1] = 14 + self.plant.set_endpoint_pos(self.current_pt) + self.last_pt = self.plant.get_endpoint_pos() - self.plant.set_endpoint_pos(self.current_pt) - self.last_pt = self.current_pt.copy() + def update_report_stats(self): + super().update_report_stats() + window_size = min(max(1, self.cycle_count), self._quality_window_size) + num_missing = np.sum(self.no_data_counter[:window_size]) + quality = 1 - num_missing / window_size + self.reportstats['Input quality'] = "{} %".format(int(100*quality)) @classmethod - def get_desc(cls, params, report): - duration = report[-1][-1] - report[0][-1] - reward_count = 0 - for item in report: - if item[0] == "reward": - reward_count += 1 - return "{} rewarded trials in {} min".format(reward_count, duration) - - -class JoystickMulti2DWindow(JoystickMulti, WindowDispl2D): - fps = 20. - def __init__(self,*args, **kwargs): - super(JoystickMulti2DWindow, self).__init__(*args, **kwargs) - - def _start_wait(self): - self.wait_time = 0. - super(JoystickMulti2DWindow, self)._start_wait() - - def _test_start_trial(self, ts): - return ts > self.wait_time and not self.pause - + def get_desc(cls, params, log_summary): + duration = round(log_summary['runtime'] / 60, 1) + return "{}/{} succesful trials in {} min".format( + log_summary['n_success_trials'], log_summary['n_trials'], duration) + + +class ManualControl(ManualControlMixin, ScreenTargetCapture): + ''' + Slightly refactored original manual control task + ''' + pass + +class ManualControlDirectionConstraint(ManualControlMixin, ScreenReachAngle): + ''' + Adds an additional constraint that the direction of travel must be within a certain angle + ''' + pass \ No newline at end of file diff --git a/built_in_tasks/othertasks.py b/built_in_tasks/othertasks.py new file mode 100644 index 00000000..df0ecef8 --- /dev/null +++ b/built_in_tasks/othertasks.py @@ -0,0 +1,197 @@ +''' Tasks which don't include any visuals or bmi, such as laser-only or camera-only tasks''' + +from riglib.experiment import LogExperiment, Sequence +from features.laser_features import DigitalWave +from riglib.experiment import traits +import itertools +import numpy as np + +MAX_EDGES = 1000 + +class Conditions(Sequence): + + status = dict( + wait = dict(start_trial="trial"), + trial = dict(end_trial="wait", stoppable=False, end_state=True), + ) + + wait_time = traits.Float(5.0, desc="Inter-trial interval (s)") + trial_time = traits.Float(1.0, desc="Trial duration (s)") + sequence_generators = ['null_sequence'] + + def init(self): + self.trial_dtype = np.dtype([('trial', 'u4'), ('index', 'u4')]) + super().init() + + def _parse_next_trial(self): + self.trial_index = self.next_trial + + # Send record of trial to sinks + self.trial_record['trial'] = self.calc_trial_num() + self.trial_record['index'] = self.trial_index + self.sinks.send("trials", self.trial_record) + + def _test_start_trial(self, ts): + return ts > self.wait_time and not self.pause + + def _test_end_trial(self, ts): + return ts > self.trial_time + + def _start_trial(self): + self.sync_event('TRIAL_START', self.trial_index) + + def _end_trial(self): + self.sync_event('TRIAL_END') + + @staticmethod + def gen_random_conditions(nreps, *args, replace=False): + ''' Generate random sequence of all combinations of the given arguments''' + unique = list(itertools.product(*args)) + conds = np.random.choice(nreps*len(unique), nreps*len(unique), replace=replace) + seq = [[i % len(unique)] + list(unique[i % len(unique)]) for i in conds] # list of [index, arg1, arg2, ..., argn] + return tuple(zip(*seq)) + + @staticmethod + def gen_conditions(nreps, *args, ascend=True): + ''' Generate a sequential sequence of all combinations of the given arguments''' + unique = list(itertools.product(*args)) + conds = np.tile(range(len(unique)), nreps) + if not ascend: # descending + conds = np.flipud(conds) + seq = [[i % len(unique)] + list(unique[i % len(unique)]) for i in conds] # list of [index, arg1, arg2, ..., argn] + return tuple(zip(*seq)) + + @staticmethod + def null_sequence(ntrials=100): + return [0 for _ in range(ntrials)] + +class LaserConditions(Conditions): + + sequence_generators = ['single_laser_pulse', 'single_laser_square_wave'] + exclude_parent_traits = ['trial_time'] + + def __init__(self, *args, **kwargs): + self.laser_threads = [] + super().__init__(*args, **kwargs) + + def init(self): + self.trial_dtype = np.dtype([ + ('trial', 'u4'), + ('index', 'u4'), + ('laser', 'S32'), + ('power', 'f8'), + ('edges', 'V', MAX_EDGES) + ]) + super(Conditions, self).init() + if not (hasattr(self, 'lasers') and len(self.lasers) > 0): + raise AttributeError("No laser feature enabled, cannot init LaserConditions") + + def _parse_next_trial(self): + self.trial_index, self.laser_powers, self.laser_edges = self.next_trial + if len(self.laser_powers) < len(self.lasers) or len(self.laser_edges) < len(self.lasers): + raise AttributeError("Not enough laser sequences for the number of lasers enabled") + + # Send record of trial to sinks + self.trial_record['trial'] = self.calc_trial_num() + self.trial_record['index'] = self.trial_index + for idx in range(len(self.lasers)): + self.trial_record['laser'] = self.lasers[idx].name + self.trial_record['power'] = self.laser_powers[idx] + self.trial_record['edges'] = np.array(self.laser_edges[idx]).tobytes() + self.sinks.send("trials", self.trial_record) + + def _start_trial(self): + super()._start_trial() + for idx in range(len(self.lasers)): + laser = self.lasers[idx] + edges = self.laser_edges[idx] + # TODO set laser power + power = self.laser_powers[idx] + # Trigger digital wave + wave = DigitalWave(laser, mask=1<>laser.port) + wave.set_edges([0], False) + wave.start() + + def _test_end_trial(self, ts): + return all([not t.is_alive() for t in self.laser_threads]) + + @staticmethod + def single_laser_pulse(nreps=100, duration=[0.005], power=[1], uniformsampling=True, ascending=False): + ''' + Generates a sequence of laser pulse trains. + + Parameters + ---------- + nreps : int + The number of repetitions of each unique condition. + duration: list of floats + The duration of each pulse. Can be a list, randomly sampled + power : list of floats + Power for each pulse. Can be a list, randomly sampled + + Returns + ------- + seq : (nreps*len(duration)*len(power) x 3) tuple of trial indices, laser powers, and edge sequences + + ''' + duration = make_list_of_float(duration) + power = make_list_of_float(power) + if uniformsampling: + idx, dur_seq, pow_seq = Conditions.gen_random_conditions(nreps, duration, power) + else: + idx, dur_seq, pow_seq = Conditions.gen_conditions(nreps, duration, power, ascend=ascending) + edge_seq = map(lambda dur: [0, dur], dur_seq) + return list(zip(idx, [[p] for p in pow_seq], [[e] for e in edge_seq])) + + @staticmethod + def single_laser_square_wave(nreps=100, freq=[20], duration=[0.005], power=[1], uniformsampling=True, ascending=False): + ''' + Generates a sequence of laser square waves. + + Parameters + ---------- + nreps : int + The number of repetitions of each unique condition. + freq : list of floats + The frequency for each square wave. Can be a list, randomly sampled + duration: list of floats + The duration of each square wave. Can be a list, randomly sampled + power : list of floats + Power for each square wave. Can be a list, randomly sampled + + Returns + ------- + seq : (nreps*len(duration)*len(power)*len(freq) x 3) tuple of trial indices, laser powers, and edge sequences + + ''' + freq = make_list_of_float(freq) + duration = make_list_of_float(duration) + power = make_list_of_float(power) + if uniformsampling: + idx, freq_seq, dur_seq, pow_seq = Conditions.gen_random_conditions(nreps, freq, duration, power) + else: + idx, freq_seq, dur_seq, pow_seq = Conditions.gen_conditions(nreps, freq, duration, power, ascend=ascending) + edge_seq = map(lambda freq, dur: DigitalWave.square_wave(freq, dur), freq_seq, dur_seq) + return list(zip(idx, [[p] for p in pow_seq], [[e] for e in edge_seq])) + + +#################### +# Helper functions # +#################### +def make_list_of_float(maybe_a_float): + try: + _ = iter(maybe_a_float) + except TypeError: + return [maybe_a_float] + else: + return maybe_a_float \ No newline at end of file diff --git a/built_in_tasks/passivetasks.py b/built_in_tasks/passivetasks.py index b7dd7c7e..a3f6d361 100644 --- a/built_in_tasks/passivetasks.py +++ b/built_in_tasks/passivetasks.py @@ -2,26 +2,20 @@ Tasks which control a plant under pure machine control. Used typically for initializing BMI decoder parameters. ''' import numpy as np -import time import os -import pdb -import multiprocessing as mp -import pickle import tables -import re -import tempfile, traceback, datetime - -import riglib.bmi -from riglib.stereo_opengl import ik -from riglib.experiment import traits, experiment -from riglib.bmi import clda, assist, extractor, train, goal_calculators, ppfdecoder -from riglib.bmi.bmi import Decoder, BMISystem, GaussianStateHMM, BMILoop, GaussianState, MachineOnlyFilter -from riglib.bmi.extractor import DummyExtractor -from riglib.stereo_opengl.window import WindowDispl2D, FakeWindow +import time + +from riglib.experiment import traits from riglib.bmi.state_space_models import StateSpaceEndptVel2D +from riglib.bmi.bmi import Decoder, BMILoop, MachineOnlyFilter +from riglib.bmi.extractor import DummyExtractor +from riglib.stereo_opengl.window import Window, WindowDispl2D -from .bmimultitasks import BMIControlMulti +from built_in_tasks.manualcontrolmultitasks import ScreenTargetCapture +from built_in_tasks.bmimultitasks import BMIControlMulti +from .target_graphics import * bmi_ssm_options = ['Endpt2D', 'Tentacle', 'Joint2L'] @@ -77,3 +71,68 @@ def get_desc(cls, params, report): return "{} rewarded trials in {} min".format(reward_count, int(np.ceil(duration / 60))) else: return "No trials" + +from .target_graphics import target_colors + +class TargetCaptureReplay(ScreenTargetCapture): + ''' + Reads the frame-by-frame cursor and trial-by-trial target positions from a saved + HDF file to display an exact copy of a previous experiment. + Doesn't really work, do not recommend using this. + ''' + + hdf_filepath = traits.String("", desc="Filepath of hdf file to replay") + + exclude_parent_traits = list(set(ScreenTargetCapture.class_traits().keys()) - \ + set(['window_size', 'fullscreen'])) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.t0 = time.perf_counter() + with tables.open_file(self.hdf_filepath, 'r') as f: + task = f.root.task.read() + state = f.root.task_msgs.read() + trial = f.root.trials.read() + params = f.root.task.attrs._f_list("user") + self.task_meta = {k : getattr(f.root.task.attrs, k) for k in params} + self.replay_state = state + self.replay_task = task + self.replay_trial = trial + for k, v in self.task_meta.items(): + if k in self.exclude_parent_traits: + print("setting {} to {}".format(k, v)) + setattr(self, k, v) + + # Have to additionally reset the targets since they are created in super().__init__() + target1 = VirtualCircularTarget(target_radius=self.target_radius, target_color=target_colors[self.target_color]) + target2 = VirtualCircularTarget(target_radius=self.target_radius, target_color=target_colors[self.target_color]) + self.targets = [target1, target2] + + def _test_start_trial(self, time_in_state): + '''Wait for the state change in the HDF file in case there is autostart enabled''' + trials = self.replay_state[self.replay_state['msg'] == b'target'] + upcoming_trials = [t['time']-1 for t in trials if self.replay_task[t['time']]['trial'] >= self.calc_trial_num()] + return (np.array(upcoming_trials) <= self.cycle_count).any() + + def _parse_next_trial(self): + '''Ignore the generator''' + self.targs = [] + self.gen_indices = [] + trial_num = self.calc_trial_num() + for trial in self.replay_trial: + if trial['trial'] == trial_num: + self.targs.append(trial['target']) + self.gen_indices.append(trial['index']) + + def _cycle(self): + '''Have to fudge the cycle_count a bit in case the fps isn't exactly the same''' + super()._cycle() + t1 = time.perf_counter() - self.t0 + self.cycle_count = int(t1*self.fps) + + def move_effector(self): + current_pt = self.replay_task['cursor'][self.cycle_count] + self.plant.set_endpoint_pos(current_pt) + + def _test_stop(self, ts): + return super()._test_stop(ts) or self.cycle_count == len(self.replay_task) diff --git a/built_in_tasks/target_capture_task.py b/built_in_tasks/target_capture_task.py index 32fff651..30af26b5 100644 --- a/built_in_tasks/target_capture_task.py +++ b/built_in_tasks/target_capture_task.py @@ -12,57 +12,39 @@ from riglib.stereo_opengl import ik from riglib import plants +from riglib.stereo_opengl.window import Window from .target_graphics import * - -####### CONSTANTS -sec_per_min = 60.0 -RED = (1,0,0,.5) -GREEN = (0,1,0,0.5) -GOLD = (1., 0.843, 0., 0.5) -mm_per_cm = 1./10 - - ## Plants # List of possible "plants" that a subject could control either during manual or brain control -cursor_14x14 = plants.CursorPlant(endpt_bounds=(-14, 14, 0., 0., -14, 14)) - +cursor = plants.CursorPlant() shoulder_anchor = np.array([2., 0., -15]) chain_kwargs = dict(link_radii=.6, joint_radii=0.6, joint_colors=(181/256., 116/256., 96/256., 1), link_colors=(181/256., 116/256., 96/256., 1)) - chain_20_20_endpt = plants.EndptControlled2LArm(link_lengths=[20, 20], base_loc=shoulder_anchor, **chain_kwargs) -init_pos = np.array([0, 0, 0], np.float64) -chain_20_20_endpt.set_intrinsic_coordinates(init_pos) - chain_20_20 = plants.RobotArmGen2D(link_lengths=[20, 20], base_loc=shoulder_anchor, **chain_kwargs) -init_pos = np.array([ 0.38118002, 2.08145271]) -chain_20_20.set_intrinsic_coordinates(init_pos) plantlist = dict( - cursor_14x14=cursor_14x14, + cursor=cursor, chain_20_20=chain_20_20, chain_20_20_endpt=chain_20_20_endpt) - - class TargetCapture(Sequence): ''' This is a generic cued target capture skeleton, to form as a common ancestor to the most common type of motor control task. ''' - status = FSMTable( - wait = StateTransitions(start_trial="target"), - target = StateTransitions(enter_target="hold", timeout="timeout_penalty"), - hold = StateTransitions(leave_early="hold_penalty", hold_complete="targ_transition"), - targ_transition = StateTransitions(trial_complete="reward", trial_abort="wait", trial_incomplete="target"), - timeout_penalty = StateTransitions(timeout_penalty_end="targ_transition", end_state=True), - hold_penalty = StateTransitions(hold_penalty_end="targ_transition", end_state=True), - reward = StateTransitions(reward_end="wait", stoppable=False, end_state=True) + status = dict( + wait = dict(start_trial="target"), + target = dict(enter_target="hold", timeout="timeout_penalty"), + hold = dict(leave_target="hold_penalty", hold_complete="delay"), + delay = dict(leave_target="delay_penalty", delay_complete="targ_transition"), + targ_transition = dict(trial_complete="reward", trial_abort="wait", trial_incomplete="target"), + timeout_penalty = dict(timeout_penalty_end="targ_transition", end_state=True), + hold_penalty = dict(hold_penalty_end="targ_transition", end_state=True), + delay_penalty = dict(delay_penalty_end="targ_transition", end_state=True), + reward = dict(reward_end="wait", stoppable=False, end_state=True) ) - trial_end_states = ['reward', 'timeout_penalty', 'hold_penalty'] - - # initial state state = "wait" @@ -72,13 +54,19 @@ class TargetCapture(Sequence): sequence_generators = [] reward_time = traits.Float(.5, desc="Length of reward dispensation") - hold_time = traits.Float(.2, desc="Length of hold required at targets") + hold_time = traits.Float(.2, desc="Length of hold required at targets before next target appears") hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error") + delay_time = traits.Float(0, desc="Length of time after a hold while the next target is on before the go cue") + delay_penalty_time = traits.Float(1, desc="Length of penalty time for delay error") timeout_time = traits.Float(10, desc="Time allowed to go between targets") timeout_penalty_time = traits.Float(1, desc="Length of penalty time for timeout error") max_attempts = traits.Int(10, desc='The number of attempts at a target before\ skipping to the next one') + def init(self): + self.trial_dtype = np.dtype([('trial', 'u4'), ('index', 'u4'), ('target', 'f8', (3,))]) + super().init() + def _start_wait(self): # Call parent method to draw the next target capture sequence from the generator super()._start_wait() @@ -94,13 +82,18 @@ def _start_wait(self): def _parse_next_trial(self): '''Check that the generator has the required data''' - self.targs = self.next_trial - + self.gen_indices, self.targs = self.next_trial # TODO error checking + + # Update the data sinks with trial information + self.trial_record['trial'] = self.calc_trial_num() + for i in range(len(self.gen_indices)): + self.trial_record['index'] = self.gen_indices[i] + self.trial_record['target'] = self.targs[i] + self.sinks.send("trials", self.trial_record) def _start_target(self): self.target_index += 1 - self.target_location = self.targs[self.target_index] def _end_target(self): '''Nothing generic to do.''' @@ -118,6 +111,18 @@ def _end_hold(self): '''Nothing generic to do.''' pass + def _start_delay(self): + '''Nothing generic to do.''' + pass + + def _while_delay(self): + '''Nothing generic to do.''' + pass + + def _end_delay(self): + '''Nothing generic to do.''' + pass + def _start_targ_transition(self): '''Nothing generic to do. Child class might show/hide targets''' pass @@ -130,10 +135,20 @@ def _end_targ_transition(self): '''Nothing generic to do.''' pass - def _start_timeout_penalty(self): + def _increment_tries(self): self.tries += 1 self.target_index = -1 + if self.tries < self.max_attempts: + self.trial_record['trial'] += 1 + for i in range(len(self.gen_indices)): + self.trial_record['index'] = self.gen_indices[i] + self.trial_record['target'] = self.targs[i] + self.sinks.send("trials", self.trial_record) + + def _start_timeout_penalty(self): + self._increment_tries() + def _while_timeout_penalty(self): '''Nothing generic to do.''' pass @@ -143,8 +158,7 @@ def _end_timeout_penalty(self): pass def _start_hold_penalty(self): - self.tries += 1 - self.target_index = -1 + self._increment_tries() def _while_hold_penalty(self): '''Nothing generic to do.''' @@ -154,6 +168,17 @@ def _end_hold_penalty(self): '''Nothing generic to do.''' pass + def _start_delay_penalty(self): + self._increment_tries() + + def _while_delay_penalty(self): + '''Nothing generic to do.''' + pass + + def _end_delay_penalty(self): + '''Nothing generic to do.''' + pass + def _start_reward(self): '''Nothing generic to do.''' pass @@ -175,7 +200,7 @@ def _test_start_trial(self, time_in_state): return True def _test_timeout(self, time_in_state): - return time_in_state > self.timeout_time + return time_in_state > self.timeout_time or self.pause def _test_hold_complete(self, time_in_state): ''' @@ -190,6 +215,14 @@ def _test_hold_complete(self, time_in_state): ''' return time_in_state > self.hold_time + def _test_delay_complete(self, time_in_state): + ''' + Test whether the delay period, when the cursor must stay in place + while another target is being presented, is over. There should be + no delay on the last target in a chain. + ''' + return self.target_index + 1 == self.chain_length or time_in_state > self.delay_time + def _test_trial_complete(self, time_in_state): '''Test whether all targets in sequence have been acquired''' return self.target_index == self.chain_length-1 @@ -200,7 +233,7 @@ def _test_trial_abort(self, time_in_state): def _test_trial_incomplete(self, time_in_state): '''Test whether the target capture sequence needs to be restarted''' - return (not self._test_trial_complete(time_in_state)) and (self.tries self.timeout_penalty_time @@ -208,6 +241,9 @@ def _test_timeout_penalty_end(self, time_in_state): def _test_hold_penalty_end(self, time_in_state): return time_in_state > self.hold_penalty_time + def _test_delay_penalty_end(self, time_in_state): + return time_in_state > self.delay_penalty_time + def _test_reward_end(self, time_in_state): return time_in_state > self.reward_time @@ -215,9 +251,9 @@ def _test_enter_target(self, time_in_state): '''This function is task-specific and not much can be done generically''' return False - def _test_leave_early(self, time_in_state): + def _test_leave_target(self, time_in_state): '''This function is task-specific and not much can be done generically''' - return False + return self.pause def update_report_stats(self): ''' @@ -227,66 +263,42 @@ def update_report_stats(self): self.reportstats['Trial #'] = self.calc_trial_num() self.reportstats['Reward/min'] = np.round(self.calc_events_per_min('reward', 120.), decimals=2) - @classmethod - def get_desc(cls, params, report): - '''Used by the database infrasturcture to generate summary stats on this task''' - duration = report[-1][-1] - report[0][-1] - reward_count = 0 - for item in report: - if item[0] == "reward": - reward_count += 1 - return "{} rewarded trials in {} min".format(reward_count, duration) - - class ScreenTargetCapture(TargetCapture, Window): """Concrete implementation of TargetCapture task where targets are acquired by "holding" a cursor in an on-screen target""" - background = (0,0,0,1) - cursor_color = (.5,0,.5,1) - - plant_type = traits.OptionsList(*plantlist, desc='', bmi3d_input_options=list(plantlist.keys())) - - starting_pos = (5, 0, 5) - - target_color = (1,0,0,.5) - cursor_visible = False # Determines when to hide the cursor. - no_data_count = 0 # Counter for number of missing data frames in a row - scale_factor = 3.0 #scale factor for converting hand movement to screen movement (1cm hand movement = 3.5cm cursor movement) + limit2d = traits.Bool(True, desc="Limit cursor movement to 2D") - limit2d = 1 + sequence_generators = [ + 'out_2D', 'centerout_2D', 'centeroutback_2D', 'rand_target_chain_2D', 'rand_target_chain_3D', + ] - sequence_generators = ['centerout_2D_discrete'] + hidden_traits = ['cursor_color', 'target_color', 'cursor_bounds', 'cursor_radius', 'plant_hide_rate', 'starting_pos'] is_bmi_seed = True - _target_color = RED - # Runtime settable traits - reward_time = traits.Float(.5, desc="Length of juice reward") target_radius = traits.Float(2, desc="Radius of targets in cm") - - hold_time = traits.Float(.2, desc="Length of hold required at targets") - hold_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error") - timeout_time = traits.Float(10, desc="Time allowed to go between targets") - timeout_penalty_time = traits.Float(1, desc="Length of penalty time for timeout error") - max_attempts = traits.Int(10, desc='The number of attempts at a target before\ - skipping to the next one') - + target_color = traits.OptionsList("yellow", *target_colors, desc="Color of the target", bmi3d_input_options=list(target_colors.keys())) plant_hide_rate = traits.Float(0.0, desc='If the plant is visible, specifies a percentage of trials where it will be hidden') - plant_type_options = list(plantlist.keys()) plant_type = traits.OptionsList(*plantlist, bmi3d_input_options=list(plantlist.keys())) plant_visible = traits.Bool(True, desc='Specifies whether entire plant is displayed or just endpoint') - cursor_radius = traits.Float(.5, desc="Radius of cursor") + cursor_radius = traits.Float(.5, desc='Radius of cursor in cm') + cursor_color = traits.OptionsList("pink", *target_colors, desc='Color of cursor endpoint', bmi3d_input_options=list(target_colors.keys())) + cursor_bounds = traits.Tuple((-10., 10., 0., 0., -10., 10.), desc='(x min, x max, y min, y max, z min, z max)') + starting_pos = traits.Tuple((5., 0., 5.), desc='Where to initialize the cursor') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.cursor_visible = True # Initialize the plant if not hasattr(self, 'plant'): self.plant = plantlist[self.plant_type] + self.plant.set_bounds(np.array(self.cursor_bounds)) + self.plant.set_color(target_colors[self.cursor_color]) + self.plant.set_cursor_radius(self.cursor_radius) self.plant_vis_prev = True + self.cursor_vis_prev = True # Add graphics models for the plant and targets to the window if hasattr(self.plant, 'graphics_models'): @@ -296,47 +308,41 @@ def __init__(self, *args, **kwargs): # Instantiate the targets instantiate_targets = kwargs.pop('instantiate_targets', True) if instantiate_targets: - target1 = VirtualCircularTarget(target_radius=self.target_radius, target_color=self._target_color) - target2 = VirtualCircularTarget(target_radius=self.target_radius, target_color=self._target_color) - self.targets = [target1, target2] - for target in self.targets: - for model in target.graphics_models: - self.add_model(model) + # Need two targets to have the ability for delayed holds + target1 = VirtualCircularTarget(target_radius=self.target_radius, target_color=target_colors[self.target_color]) + target2 = VirtualCircularTarget(target_radius=self.target_radius, target_color=target_colors[self.target_color]) - # Initialize target location variable - self.target_location = np.array([0, 0, 0]) + self.targets = [target1, target2] # Declare any plant attributes which must be saved to the HDF file at the _cycle rate for attr in self.plant.hdf_attrs: self.add_dtype(*attr) def init(self): - self.add_dtype('target', 'f8', (3,)) - self.add_dtype('target_index', 'i', (1,)) + self.add_dtype('trial', 'u4', (1,)) + self.add_dtype('plant_visible', '?', (1,)) super().init() + self.plant.set_endpoint_pos(np.array(self.starting_pos)) def _cycle(self): ''' - Calls any update functions necessary and redraws screen. Runs 60x per second. + Calls any update functions necessary and redraws screen ''' - self.task_data['target'] = self.target_location.copy() - self.task_data['target_index'] = self.target_index + self.move_effector() ## Run graphics commands to show/hide the plant if the visibility has changed - if self.plant_type != 'CursorPlant': - if self.plant_visible != self.plant_vis_prev: - self.plant_vis_prev = self.plant_visible - self.plant.set_visibility(self.plant_visible) - # self.show_object(self.plant, show=self.plant_visible) - - self.move_effector() + self.update_plant_visibility() + self.task_data['plant_visible'] = self.plant_visible ## Save plant status to HDF file plant_data = self.plant.get_data_to_save() for key in plant_data: self.task_data[key] = plant_data[key] + # Update the trial index + self.task_data['trial'] = self.calc_trial_num() + super()._cycle() def move_effector(self): @@ -349,24 +355,19 @@ def run(self): ''' # Fire up the plant. For virtual/simulation plants, this does little/nothing. self.plant.start() + + # Include some cleanup in case the parent class has errors try: super().run() finally: self.plant.stop() ##### HELPER AND UPDATE FUNCTIONS #### - def update_cursor_visibility(self): - ''' Update cursor visible flag to hide cursor if there has been no good data for more than 3 frames in a row''' - prev = self.cursor_visible - if self.no_data_count < 3: - self.cursor_visible = True - if prev != self.cursor_visible: - self.show_object(self.cursor, show=True) - else: - self.cursor_visible = False - if prev != self.cursor_visible: - self.show_object(self.cursor, show=False) - + def update_plant_visibility(self): + ''' Update plant visibility''' + if self.plant_visible != self.plant_vis_prev: + self.plant_vis_prev = self.plant_visible + self.plant.set_visibility(self.plant_visible) #### TEST FUNCTIONS #### def _test_enter_target(self, ts): @@ -374,105 +375,379 @@ def _test_enter_target(self, ts): return true if the distance between center of cursor and target is smaller than the cursor radius ''' cursor_pos = self.plant.get_endpoint_pos() - d = np.linalg.norm(cursor_pos - self.target_location) - return d <= (self.target_radius - self.cursor_radius) + d = np.linalg.norm(cursor_pos - self.targs[self.target_index]) + return d <= (self.target_radius - self.cursor_radius) or super()._test_enter_target(ts) - def _test_leave_early(self, ts): + def _test_leave_target(self, ts): ''' return true if cursor moves outside the exit radius ''' cursor_pos = self.plant.get_endpoint_pos() - d = np.linalg.norm(cursor_pos - self.target_location) + d = np.linalg.norm(cursor_pos - self.targs[self.target_index]) rad = self.target_radius - self.cursor_radius - return d > rad + return d > rad or super()._test_leave_target(ts) #### STATE FUNCTIONS #### def _start_wait(self): super()._start_wait() - # hide targets - for target in self.targets: - target.hide() + + if self.calc_trial_num() == 0: + + # Instantiate the targets here so they don't show up in any states that might come before "wait" + for target in self.targets: + for model in target.graphics_models: + self.add_model(model) + target.hide() def _start_target(self): super()._start_target() - # move one of the two targets to the new target location + # Show target if it is hidden (this is the first target, or previous state was a penalty) target = self.targets[self.target_index % 2] - target.move_to_position(self.target_location) - target.cue_trial_start() + if self.target_index == 0: + target.move_to_position(self.targs[self.target_index]) + target.show() + self.sync_event('TARGET_ON', self.gen_indices[self.target_index]) def _start_hold(self): - #make next target visible unless this is the final target in the trial - idx = (self.target_index + 1) - if idx < self.chain_length: - target = self.targets[idx % 2] - target.move_to_position(self.targs[idx]) + super()._start_hold() + self.sync_event('CURSOR_ENTER_TARGET', self.gen_indices[self.target_index]) + + def _start_delay(self): + super()._start_delay() + + # Make next target visible unless this is the final target in the trial + next_idx = (self.target_index + 1) + if next_idx < self.chain_length: + target = self.targets[next_idx % 2] + target.move_to_position(self.targs[next_idx]) + target.show() + self.sync_event('TARGET_ON', self.gen_indices[next_idx]) + else: + # This delay state should only last 1 cycle, don't sync anything + pass - def _end_hold(self): - # change current target color to green - self.targets[self.target_index % 2].cue_trial_end_success() + def _start_targ_transition(self): + super()._start_targ_transition() + if self.target_index == -1: + + # Came from a penalty state + pass + elif self.target_index + 1 < self.chain_length: + + # Hide the current target if there are more + self.targets[self.target_index % 2].hide() + self.sync_event('TARGET_OFF', self.gen_indices[self.target_index]) def _start_hold_penalty(self): + self.sync_event('HOLD_PENALTY') super()._start_hold_penalty() - # hide targets + # Hide targets for target in self.targets: target.hide() + target.reset() - def _start_timeout_penalty(self): - super()._start_timeout_penalty() - # hide targets + def _end_hold_penalty(self): + super()._end_hold_penalty() + self.sync_event('TRIAL_END') + + def _start_delay_penalty(self): + self.sync_event('DELAY_PENALTY') + super()._start_delay_penalty() + # Hide targets for target in self.targets: target.hide() + target.reset() - def _start_targ_transition(self): - #hide targets + def _end_delay_penalty(self): + super()._end_delay_penalty() + self.sync_event('TRIAL_END') + + def _start_timeout_penalty(self): + self.sync_event('TIMEOUT_PENALTY') + super()._start_timeout_penalty() + # Hide targets for target in self.targets: target.hide() + target.reset() + + def _end_timeout_penalty(self): + super()._end_timeout_penalty() + self.sync_event('TRIAL_END') def _start_reward(self): - self.targets[self.target_index % 2].show() + self.targets[self.target_index % 2].cue_trial_end_success() + self.sync_event('REWARD') + + def _end_reward(self): + super()._end_reward() + self.sync_event('TRIAL_END') + + # Hide targets + for target in self.targets: + target.hide() + target.reset() #### Generator functions #### + ''' + Note to self: because of the way these get into the database, the parameters don't + have human-readable descriptions like the other traits. So it is useful to define + the descriptions elsewhere, in models.py under Generator.to_json(). + + Ideally someone should take the time to reimplement generators as their own classes + rather than static methods that belong to a task. + ''' + @staticmethod + def static(pos=(0,0,0), ntrials=0): + '''Single location, finite (ntrials!=0) or infinite (ntrials==0)''' + if ntrials == 0: + while True: + yield [0], np.array(pos) + else: + for _ in range(ntrials): + yield [0], np.array(pos) + + @staticmethod + def out_2D(nblocks=100, ntargets=8, distance=10, origin=(0,0,0)): + ''' + Generates a sequence of 2D (x and z) targets at a given distance from the origin + + Parameters + ---------- + nblocks : int + The number of ntarget pairs in the sequence. + ntargets : int + The number of equally spaced targets + distance : float + The distance in cm between the center and peripheral targets. + origin : 3-tuple + Location of the central targets around which the peripheral targets span + + Returns + ------- + [nblocks*ntargets x 1] array of tuples containing trial indices and [1 x 3] target coordinates + + ''' + rng = np.random.default_rng() + for _ in range(nblocks): + order = np.arange(ntargets) + 1 # target indices, starting from 1 + rng.shuffle(order) + for t in range(ntargets): + idx = order[t] + theta = 2*np.pi*idx/ntargets + pos = np.array([ + distance*np.cos(theta), + 0, + distance*np.sin(theta) + ]).T + yield [idx], [pos + origin] + + @staticmethod + def centerout_2D(nblocks=100, ntargets=8, distance=10, origin=(0,0,0)): + ''' + Pairs of central targets at the origin and peripheral targets centered around the origin + + Returns + ------- + [nblocks*ntargets x 1] array of tuples containing trial indices and [2 x 3] target coordinates + ''' + gen = ScreenTargetCapture.out_2D(nblocks, ntargets, distance, origin) + for _ in range(nblocks*ntargets): + idx, pos = next(gen) + targs = np.zeros([2, 3]) + origin + targs[1,:] = pos[0] + indices = np.zeros([2,1]) + indices[1] = idx + yield indices, targs + + @staticmethod + def centeroutback_2D(nblocks=100, ntargets=8, distance=10, origin=(0,0,0)): + ''' + Triplets of central targets, peripheral targets, and central targets + + Returns + ------- + [nblocks*ntargets x 1] array of tuples containing trial indices and [3 x 3] target coordinates + ''' + gen = ScreenTargetCapture.out_2D(nblocks, ntargets, distance, origin) + for _ in range(nblocks*ntargets): + idx, pos = next(gen) + targs = np.zeros([3, 3]) + origin + targs[1,:] = pos[0] + indices = np.zeros([3,1]) + indices[1] = idx + yield indices, targs + @staticmethod - def centerout_2D_discrete(nblocks=100, ntargets=8, boundaries=(-18,18,-12,12), - distance=10): + def rand_target_chain_2D(ntrials=100, chain_length=1, boundaries=(-12,12,-12,12)): ''' + Generates a sequence of 2D (x and z) target pairs. - Generates a sequence of 2D (x and z) target pairs with the first target - always at the origin. + Parameters + ---------- + ntrials : int + The number of target chains in the sequence. + chain_length : int + The number of targets in each chain + boundaries: 4 element Tuple + The limits of the allowed target locations (-x, x, -z, z) + Returns + ------- + [ntrials x chain_length x 3] array of target coordinates + ''' + rng = np.random.default_rng() + idx = 0 + for t in range(ntrials): + + # Choose a random sequence of points within the boundaries + pts = rng.uniform(size=(chain_length, 3))*((boundaries[1]-boundaries[0]), + 0, (boundaries[3]-boundaries[2])) + pts = pts+(boundaries[0], 0, boundaries[2]) + yield idx+np.arange(chain_length), pts + idx += chain_length + + @staticmethod + def rand_target_chain_3D(ntrials=100, chain_length=1, boundaries=(-12,12,-10,10,-12,12)): + ''' + Generates a sequence of 3D target pairs. Parameters ---------- - length : int - The number of target pairs in the sequence. + ntrials : int + The number of target chains in the sequence. + chain_length : int + The number of targets in each chain boundaries: 6 element Tuple - The limits of the allowed target locations (-x, x, -z, z) - distance : float - The distance in cm between the targets in a pair. + The limits of the allowed target locations (-x, x, -y, y, -z, z) Returns ------- - pairs : [nblocks*ntargets x 2 x 3] array of pairs of target locations + [ntrials x chain_length x 3] array of target coordinates + ''' + rng = np.random.default_rng() + idx = 0 + for t in range(ntrials): + + # Choose a random sequence of points within the boundaries + pts = rng.uniform(size=(chain_length, 3))*((boundaries[1]-boundaries[0]), + (boundaries[3]-boundaries[2]), (boundaries[5]-boundaries[4])) + pts = pts+(boundaries[0], boundaries[2], boundaries[4]) + yield idx+np.arange(chain_length), pts + idx += chain_length + +class ScreenReachAngle(ScreenTargetCapture): + ''' + A modified task that requires the cursor to move in the right direction towards the target, + without actually needing to arrive at the target. If the maximum angle is exceeded, a reach + penalty is applied. No hold or delay period. + + Only works for sequences with 1 target in a chain. + ''' + + status = dict( + wait = dict(start_trial="target"), + target = dict(reach_success="targ_transition", timeout="timeout_penalty", leave_bounds="reach_penalty"), + targ_transition = dict(trial_complete="reward", trial_abort="wait", trial_incomplete="target"), + timeout_penalty = dict(timeout_penalty_end="targ_transition", end_state=True), + reach_penalty = dict(reach_penalty_end="targ_transition", end_state=True), + reward = dict(reward_end="wait", stoppable=False, end_state=True) + ) + + sequence_generators = [ + 'out_2D', 'rand_target_chain_2D', 'rand_target_chain_3D', 'discrete_targets_2D', + ] + + max_reach_angle = traits.Float(90., desc="Angle defining the boundaries between the starting position of the cursor and the target") + reach_penalty_time = traits.Float(1, desc="Length of penalty time for target hold error") + reach_fraction = traits.Float(0.5, desc="Fraction of the distance between the reach start and the target before a reward") + start_radius = 1. # buffer around reach start allowed in bounds + exclude_parent_traits = ['hold_time', 'hold_penalty_time', 'delay_time', 'delay_penalty_time'] + def _start_target(self): + super()._start_target() + + # Define a reach start and reach target position whenever the target appears + self.reach_start = self.plant.get_endpoint_pos().copy() + self.reach_target = self.targs[self.target_index] + + def _test_leave_bounds(self, ts): + ''' + Check whether the cursor is in the boundary defined by reach_start, target_pos, + and max_reach_angle. ''' - # Choose a random sequence of points on the edge of a circle of radius - # "distance" + # Calculate the angle between the vectors from the start pos to the current cursor and target + a = self.plant.get_endpoint_pos() - self.reach_start + b = self.reach_target - self.reach_start + cursor_target_angle = np.arccos(np.dot(a, b)/np.linalg.norm(a)/np.linalg.norm(b)) + + # If that angle is more than half the maximum, we are outside the bounds + out_of_bounds = np.degrees(cursor_target_angle) > self.max_reach_angle / 2 + + # But also allow a target radius around the reach_start + away_from_start = np.linalg.norm(self.plant.get_endpoint_pos() - self.reach_start) > self.start_radius + + return away_from_start and out_of_bounds + + def _test_reach_success(self, ts): + dist_traveled = np.linalg.norm(self.plant.get_endpoint_pos() - self.reach_start) + dist_total = np.linalg.norm(self.reach_target - self.reach_start) + dist_total -= (self.target_radius - self.cursor_radius) + return dist_traveled/dist_total > self.reach_fraction + + def _start_reach_penalty(self): + self.sync_event('OTHER_PENALTY') + self._increment_tries() + + # Hide targets + for target in self.targets: + target.hide() + target.reset() - theta = [] - for i in range(nblocks): - temp = np.arange(0, 2*np.pi, 2*np.pi/ntargets) - np.random.shuffle(temp) - theta = theta + [temp] - theta = np.hstack(theta) + def _end_reach_penalty(self): + self.sync_event('TRIAL_END') + def _test_reach_penalty_end(self, ts): + return ts > self.reach_penalty_time - x = distance*np.cos(theta) - y = np.zeros(len(theta)) - z = distance*np.sin(theta) + @staticmethod + def discrete_targets_2D(nblocks=100, ntargets=3, boundaries=(-6,6,-3,3)): + ''' + Generates a sequence of 2D (x and z) target pairs that don't overlap - pairs = np.zeros([len(theta), 2, 3]) - pairs[:,1,:] = np.vstack([x, y, z]).T + Parameters + ---------- + nblocks : int + The number of ntarget pairs in the sequence. + ntargets : int + The number of unique targets (up to 9 maximum) + boundaries: 4 element Tuple + The limits of the allowed target locations (-x, x, -z, z) - return pairs + Returns + ------- + [ntrials x ntargets x 3] array of target coordinates + ''' + targets = np.array([ + [0, 0.5], + [1, 0.5], + [1, 0], + [0, 0], + [0.25, 0.25], + [0.75, 0.25], + [0.25, 0.75], + [0.75, 0.75], + [0.5, 1], + ]) + rng = np.random.default_rng() + for _ in range(nblocks): + order = np.arange(ntargets) # target indices + rng.shuffle(order) + for t in range(ntargets): + idx = order[t] + pts = targets[idx]*((boundaries[1]-boundaries[0]), + (boundaries[3]-boundaries[2])) + pts = pts+(boundaries[0], boundaries[2]) + pos = np.array([pts[0], 0, pts[1]]) + yield [idx], [pos] \ No newline at end of file diff --git a/built_in_tasks/target_graphics.py b/built_in_tasks/target_graphics.py index c2c5c149..5df30cd1 100644 --- a/built_in_tasks/target_graphics.py +++ b/built_in_tasks/target_graphics.py @@ -9,7 +9,6 @@ import traceback from riglib.stereo_opengl.primitives import Sphere, Cube -from riglib.stereo_opengl.window import Window, FPScontrol, WindowDispl2D from riglib.stereo_opengl.primitives import Cylinder, Plane, Sphere, Cube from riglib.stereo_opengl.models import FlatMesh, Group from riglib.stereo_opengl.textures import Texture, TexModel @@ -24,10 +23,24 @@ GOLD = (1., 0.843, 0., 0.5) mm_per_cm = 1./10 -class CircularTarget(object): +target_colors = { + "red": (1,0,0,0.75), + "yellow": (1,1,0,0.75), + "green":(0., 1., 0., 0.75), + "blue":(0.,0.,1.,0.75), + "magenta": (1,0,1,0.75), + "pink": (1,0.5,1,0.75), + "purple":(0.608,0.188,1,0.75), + "teal":(0,0.502,0.502,0.75), + "olive":(0.420,0.557,0.137,.75), + "orange": (1,0.502,0.,0.75), + "hotpink":(1,0.0,0.606,.75), + "elephant":(0.5,0.5,0.5,0.5), +} + +class CircularTarget(object): def __init__(self, target_radius=2, target_color=(1, 0, 0, .5), starting_pos=np.zeros(3)): self.target_color = target_color - self.default_target_color = tuple(self.target_color) self.target_radius = target_radius self.target_color = target_color self.position = starting_pos @@ -58,7 +71,7 @@ def show(self): self.sphere.attach() def cue_trial_start(self): - self.sphere.color = RED + self.sphere.color = self.target_color self.show() def cue_trial_end_success(self): @@ -66,14 +79,6 @@ def cue_trial_end_success(self): def cue_trial_end_failure(self): self.sphere.color = RED - self.hide() - # self.sphere.color = GREEN - def turn_yellow(self): - self.sphere.color = GOLD - - def idle(self): - self.sphere.color = RED - self.hide() def pt_inside(self, pt): ''' @@ -83,7 +88,7 @@ def pt_inside(self, pt): return (np.abs(pt[0] - pos[0]) < self.target_radius) and (np.abs(pt[2] - pos[2]) < self.target_radius) def reset(self): - self.sphere.color = self.default_target_color + self.sphere.color = self.target_color def get_position(self): return self.sphere.xfm.move @@ -93,7 +98,6 @@ def __init__(self, target_width=4, target_height=4, target_color=(1, 0, 0, .5), self.target_width = target_width self.target_height = target_height self.target_color = target_color - self.default_target_color = tuple(self.target_color) self.position = starting_pos self.int_position = starting_pos self._pickle_init() @@ -101,9 +105,10 @@ def __init__(self, target_width=4, target_height=4, target_color=(1, 0, 0, .5), def _pickle_init(self): self.cube = Cube(side_len=self.target_width, color=self.target_color) self.graphics_models = [self.cube] - self.cube.translate(*self.position) #self.center_offset = np.array([self.target_width, 0, self.target_width], dtype=np.float64) / 2 self.center_offset = np.array([0, 0, self.target_width], dtype=np.float64) / 2 + corner_pos = self.position - self.center_offset + self.cube.translate(*corner_pos) def move_to_position(self, new_pos): self.int_position = new_pos self.drive_to_new_pos() @@ -149,7 +154,7 @@ def pt_inside(self, pt): return (np.abs(pt[0] - pos[0]) < self.target_width/2) and (np.abs(pt[2] - pos[2]) < self.target_height/2) def reset(self): - self.cube.color = self.default_target_color + self.cube.color = self.target_color def get_position(self): return self.cube.xfm.move diff --git a/config/bmiconfig.py b/config/bmiconfig.py index fce46b72..a402ac33 100644 --- a/config/bmiconfig.py +++ b/config/bmiconfig.py @@ -8,6 +8,7 @@ KFDecoder=bmi.train.train_KFDecoder, PPFDecoder=bmi.train.train_PPFDecoder, OneDimLFPDecoder=bmi.train.create_onedimLFP, + LinearDecoder=bmi.train.create_lindecoder, ) bmi_training_pos_vars = [ @@ -43,6 +44,7 @@ extractors = dict( spikecounts = extractor.BinnedSpikeCountsExtractor, LFPpowerMTM = extractor.LFPMTMPowerExtractor, + direct = extractor.DirectObsExtractor, ) kin_extractors = dict( diff --git a/db/__init__.py b/db/__init__.py index 18ab444a..ab439dad 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -1,8 +1,2 @@ # from . import websocket from .tracker import tasktrack - -# This will make sure the app is always imported when -# Django starts so that shared_task will use this app. -from .celery_base import app as celery_app - -__all__ = ('celery_app',) \ No newline at end of file diff --git a/db/datasource.log b/db/datasource.log new file mode 100644 index 00000000..f0bb3571 --- /dev/null +++ b/db/datasource.log @@ -0,0 +1 @@ +RPCProcess.run diff --git a/db/dbfunctions.py b/db/dbfunctions.py index 7f77b7f2..71eb6a1a 100644 --- a/db/dbfunctions.py +++ b/db/dbfunctions.py @@ -4,6 +4,7 @@ # django initialization import os os.environ['DJANGO_SETTINGS_MODULE'] = 'db.settings' +os.environ["DJANGO_ALLOW_ASYNC_UNSAFE"] = "true" import django django.setup() @@ -18,7 +19,7 @@ from collections import defaultdict, OrderedDict import db -from tracker import models +from .tracker import models # default DB, change this variable from python session to switch to other database db_name = 'default' @@ -30,7 +31,7 @@ class TaskEntry(object): can be defined for TaskEntry blocks (e.g., for analysis methods for a particular experiment) without needing to modfiy the database model. ''' - def __init__(self, task_entry_id, dbname='default', **kwargs): + def __init__(self, task_entry_id, dbname=db_name, **kwargs): ''' Constructor for TaskEntry @@ -103,55 +104,41 @@ def __init__(self, task_entry_id, dbname='default', **kwargs): try: task_msgs = self.hdf.root.task_msgs[:] # Ignore the last message if it's the "None" transition used to stop the task - if task_msgs[-1]['msg'] == 'None': + if task_msgs[-1]['msg'] == b'None': task_msgs = task_msgs[:-1] # ignore "update bmi" messages. These have been removed in later datasets - task_msgs = task_msgs[task_msgs['msg'] != 'update_bmi'] + task_msgs = task_msgs[task_msgs['msg'] != b'update_bmi'] # Try to add the target index.. these are not present in every task type try: + target = self.hdf.root.task[:]['target'].ravel() target_index = self.hdf.root.task[:]['target_index'].ravel() - task_msg_dtype = np.dtype([('msg', '|S256'), ('time', 'tr:hover { background:#DDD; @@ -160,14 +191,6 @@ th { text-align:left;} /*z-index:1;*/ } -input[type="text"] { - width: 50px; -} - -input[type="number"] { - width: 50px; -} - #report td:first-child { font-weight:bold; } @@ -197,12 +220,15 @@ fieldset { /*options class is used for div containing Features/Sequence/Parameters fieldsets*/ .options { - float:left; - width:300px; + width:350px; +} + +div#seqparams { + width: 100%; } -#parameters table,#sequence table { - /*width:100%;*/ +#parameters table,#sequence table, #seqparams table, #controls table, #metadata table { + width:100%; border-collapse:collapse; } td.param_label { @@ -213,14 +239,23 @@ td.param_label+td { } td.param_label+td select { width:100%; } input[type=text], input[type=number] { + width:50px; border:1px solid; border-color:#888; padding:2px; + margin-right:2px; border-radius:2px; transition:all 0.25s ease-in-out; -webkit-transition:all 0.25s ease-in-out; -moz-transition:all 0.25s ease-in-out; } +input.string { + width:120px; +} + +input#seqlist { + width: auto; +} /*rightside contains Report, Notes, Linked Files*/ div.rightside { @@ -241,15 +276,21 @@ input:invalid { box-shadow:0px 0px 10px red; } - - +#paramadd { + margin-top: 5px; + margin-left: 290px; +} +#seqadd { + margin-top: 5px; + margin-left: 255px; +} #report table.option { font-size:10pt; } -fieldset#report { - max-height:500px; +fieldset:not(#parameters, #bmi) { + max-height:400px; /*width:100%;*/ /*width : 460px;*/ overflow:auto; @@ -259,33 +300,28 @@ fieldset#report { /*width : 200px;*/ } +#controls button { + padding: 5px 10px; + margin-bottom: 5px; + font-size: 16px; + border: none; + color: white; +} +#controls button:enabled { + background-color: black; +} +#controls button:hover:enabled { + box-shadow: 0px 0px 5px black; +} +#controls button:disabled { + background-color: gray; +} + .report_table { font-size:10pt; width : 450px; } - - - - - - - - - - - - -input#toggle_table -{ - /* Double-sized Checkboxes */ - -ms-transform: scale(2); /* IE */ - -moz-transform: scale(2); /* FF */ - -webkit-transform: scale(2); /* Safari and Chrome */ - -o-transform: scale(2); /* Opera */ - padding: 10px; -} - .checkboxtext { /* Checkbox text */ diff --git a/db/html/static/resources/js/bmi.js b/db/html/static/resources/js/bmi.js index c0eeff70..565a2a3c 100644 --- a/db/html/static/resources/js/bmi.js +++ b/db/html/static/resources/js/bmi.js @@ -262,7 +262,7 @@ BMI.prototype._bindui = function() { }.bind(this)); $("#cellnames").blur(function(e) { - console.log($("#cellnames").val()) + debug($("#cellnames").val()) this.parse($("#cellnames").val()); }.bind(this)); diff --git a/db/html/static/resources/js/collections.js b/db/html/static/resources/js/collections.js index 80951dd1..8ac241a7 100644 --- a/db/html/static/resources/js/collections.js +++ b/db/html/static/resources/js/collections.js @@ -12,7 +12,7 @@ Collections.prototype.show = function() { Collections.prototype.select_collections = function(collections) { // clear - console.log(collections); + debug(collections); $("#collections input[type=checkbox]").each( function() { this.checked = false; diff --git a/db/html/static/resources/js/list.js b/db/html/static/resources/js/list.js index b74d9c65..cad7b052 100644 --- a/db/html/static/resources/js/list.js +++ b/db/html/static/resources/js/list.js @@ -1,5 +1,4 @@ - -var log_mode = 5 +var log_mode = 2 function log(msg, level) { if (level <= log_mode) { @@ -11,28 +10,29 @@ function debug(msg) { log(msg, 5); } - - - -var report_activation = null; +function remove_entries(start, end) { // for debugging + for (i=start; i', - { - text: taskinfo.annotations[i], - id: "annotation_btn_" + i.toString(), - click: create_annotation_callback(taskinfo.annotations[i]), - type: "button" - } - ); - - var new_break = $("
"); - - $("#annot_div").append(new_button); - $("#annot_div").append(new_break); - this.annotation_buttons.push(new_button); - this.annotation_buttons.push(new_break); - } - } - } - - this.update_from_server = function(taskid, sel_feats) { - $.getJSON("ajax/task_info/"+taskid+"/", sel_feats, - function(taskinfo) { - this.update(taskinfo); - }.bind(this) - ); - } - - this.destroy_annotation_buttons = function() { - for (var i = 0; i < this.annotation_buttons.length; i += 1) { - this.annotation_buttons[i].remove() - } - } - - this.destroy = function() { - this.destroy_annotation_buttons(); - } - - this.hide = function() { - $("#annotations").hide(); - } - - this.show = function() { - $("#annotations").show(); - } -} - - function Files() { this.neural_data_found = false; $("#file_modal_server_resp").html(""); @@ -301,29 +248,24 @@ Files.prototype.update_filelist = function(datafiles, task_entry_id) { this.filelist = document.createElement("ul"); for (var sys in datafiles) { - if (sys == "sequence") { - // Do nothing. No point in showing the sequence.. - } else { - // info.datafiles[sys] is an array of files for that system - for (var i = 0; i < datafiles[sys].length; i++) { - // Create a list element to hold the file name - var file = document.createElement("li"); - file.textContent = datafiles[sys][i]; - this.filelist.appendChild(file); - numfiles++; - } + for (var i = 0; i < datafiles[sys].length; i++) { + // Create a list element to hold the file name + var file = document.createElement("li"); + file.textContent = datafiles[sys][i]; + this.filelist.appendChild(file); + numfiles++; } } if (numfiles > 0) { // Append the files onto the #files field $("#file_list").append(this.filelist); - - for (var sys in datafiles) - if ((sys == "plexon") || (sys == "blackrock") || (sys == "tdt")) { + for (var sys in datafiles) { + if (sys == "plexon" || sys == "blackrock" || sys == "ecube") { this.neural_data_found = true; break; } + } } } @@ -336,30 +278,49 @@ function TaskEntry(idx, info) { * idx: string of format row\d\d\d where \d\d\d represents the string numbers of the database ID of the block */ - // hide short descriptions - $('.colShortDesc').hide() + // resize the window to fit the TE pane correctly + $(window).resize() // hide the old content $("#content").hide(); + // Reset HTML fields + $("#file_list").empty(); + $("#content").removeClass("error running testing") + $("#files").hide(); + $('#newentry').hide() + $('#te_table_header').unbind("click"); + $('#te_table_header').click( + function() { + if (te) te.destroy(); + te = new TaskEntry(null); + } + ) + $("#tasks").unbind("change"); + + // Make new widgets this.sequence = new Sequence(); this.params = new Parameters(); + this.metadata = new Metadata(); this.report = new Report(task_interface.trigger.bind(this)); - this.annotations = new Annotations(); this.files = new Files(); + this.controls = new Controls(); + this.controls.hide() $("#parameters").append(this.params.obj); $("#plots").empty() - console.log("JS constructing task entry", idx) + debug("JS constructing task entry", idx) if (idx) { // If the task entry which was clicked has an id (stored in the database) // No 'info' is provided--the ID is pulled from the HTML + this.status = 'completed' // parse the actual integer database ID out of the HTML object name if (typeof(idx) == "number") { this.idx = idx; + idx = "row" + idx; } else { this.idx = parseInt(idx.match(/row(\d+)/)[1]); } @@ -369,32 +330,17 @@ function TaskEntry(idx, info) { // Create a jQuery object to represent the table row this.tr = $("#"+idx); this.__date = $("#"+idx + " .colDate"); - console.log(this.__date); - - this.status = this.tr.hasClass("running") ? "running" : "completed"; - if (this.status == 'running'){ - this.report.activate(); - } else { - this.tr.addClass("rowactive active"); - } - - if (this.status == "completed") { - this.annotations.hide(); - this.report.set_mode("completed"); - this.files.show(); - } + debug(this.__date); // Show the wait wheel before sending the request for exp_info. It will be hidden once data is successfully returned and processed (see below) $('#wait_wheel').show(); - $('#tr_seqlist').hide(); $.getJSON("ajax/exp_info/"+this.idx+"/", // URL to query for data on this task entry {}, // POST data to send to the server function (expinfo) { // function to run on successful response this.notes = new Notes(this.idx); - console.log(this) + debug(this) this.update(expinfo); - this.disable(); $("#content").show("slide", "fast"); $('#wait_wheel').hide() @@ -402,8 +348,9 @@ function TaskEntry(idx, info) { // If the server responds with data, disable reacting to clicks on the current row so that things don't get reset this.tr.unbind("click"); - // console.log('setting ') + // debug('setting ') this.tr.addClass("rowactive active"); + $("#newentry").hide(); // enable editing of the notes field for a previously saved entry $("#notes textarea").removeAttr("disabled"); @@ -420,56 +367,36 @@ function TaskEntry(idx, info) { // this code block executes when you click the header of the left table (date, time, etc.) this.idx = null; $("#entry_name").val(""); - - // show the bar at the top left with drop-downs for subject and task - this.tr = $("#newentry"); - this.tr.show(); // declared in list.html this.status = "stopped"; - - // - $('#tr_seqlist').show(); + this.tr = $("#newentry"); + this.tr.show(); // make sure the task entry row is visible + + feats.clear(); + this.report.hide(); + this.files.hide(); + // task_interface.trigger.bind(this)({state:''}); + + // query the server for information about the task (which generators can be used, which parameters can be set, etc.) + this._task_query( + function() { + this.enable(); + $("#content").show("slide", "fast"); + }.bind(this), true, true + ); // Set 'change' bindings to re-run the _task_query function if the selected task or the features change $("#tasks").change(this._task_query.bind(this)); feats.bind_change_callback(this._task_query.bind(this)) - if (info) { // if info is present and the id is null, then this block is being copied from a previous block - console.log('creating a new JS TaskEntry by copy') - this.update(info); - - // update the annotation buttons - var taskid = $("#tasks").attr("value"); - var sel_feats = feats.get_checked_features(); - this.annotations.update_from_server(taskid, sel_feats); - this.enable(); - $("#content").show("slide", "fast"); - - this.files.hide(); - } else { // no id and no info suggests that the table header was clicked to create a new block - console.log('creating a brand-new JS TaskEntry') - feats.clear(); - this.annotations.hide(); - this.report.hide(); - this.files.hide(); - task_interface.trigger.bind(this)({state:''}); - - // query the server for information about the task (which generators can be used, which parameters can be set, etc.) - this._task_query( - function() { - this.enable(); - $("#content").show("slide", "fast"); - }.bind(this) - ); - } // make the notes blank and editable $("#notes textarea").val("").removeAttr("disabled"); // Disable reacting to clicks on the current row so that the interface doesn't get reset this.tr.unbind("click"); $('te_table_header').unbind("click"); - } - this.being_copied = false; + task_interface.trigger.bind(this)({status: this.status}); + } } /* Populate the 'exp_content' template with data from the 'info' object */ @@ -478,37 +405,35 @@ TaskEntry.prototype.update = function(info) { // populate the list of generators if (Object.keys(info.generators).length > 0) { - console.log('limiting generators') + debug('limiting generators') this.sequence.update_available_generators(info.generators); } else { - console.log('not limiting generators!') + debug('not limiting generators!') } + this.status = info.state; + // Update all the sub-parts of the exp_content template separately this.sequence.update(info.sequence); this.params.update(info.params); - this.report.update(info.report); + this.metadata.update(info.metadata); if (this.notes) this.notes.update(info.notes); else $("#notes").attr("value", info.notes); - + feats.unbind_change_callback(); + this.report.update(info.report); + // set the checkboxes for the "visible" and "flagged for backup" $('#hidebtn').attr('checked', info.visible); $('#backupbtn').attr('checked', info.flagged_for_backup); - - this.expinfo = info; + $('#templatebtn').attr('checked', info.template); // set the 'tasks' drop-down menu to match the 'info' $("#tasks option").each(function() { if (this.value == info.task) this.selected = true; }) - // set the 'subjects' drop-down menu to match the 'info' - $("#subjects option").each(function() { - if (this.value == info.subject) - this.selected = true; - }); feats.select_features(info.feats); @@ -530,18 +455,25 @@ TaskEntry.prototype.update = function(info) { $("#sequence").hide() } - console.log("TaskEntry.prototype.update done!"); + if (this.status != "stopped") this.disable(); + task_interface.trigger.bind(this)({status: this.status}); + + debug("TaskEntry.prototype.update done!"); } TaskEntry.prototype.reload = function() { this.files.clear(); + if (this.idx == null) return; + $.getJSON("ajax/exp_info/"+this.idx+"/", // URL to query for data on this task entry {}, // POST data to send to the server function (expinfo) { // function to run on successful response + this.notes.destroy(); this.notes = new Notes(this.idx); - console.log(this) + this.sequence.destroy(); + this.sequence = new Sequence(); + debug(this) this.update(expinfo); - this.disable(); $("#content").show("slide", "fast"); $('#wait_wheel').hide() @@ -549,7 +481,7 @@ TaskEntry.prototype.reload = function() { // If the server responds with data, disable reacting to clicks on the current row so that things don't get reset this.tr.unbind("click"); - // console.log('setting ') + // debug('setting ') this.tr.addClass("rowactive active"); // enable editing of the notes field for a previously saved entry @@ -567,79 +499,112 @@ TaskEntry.prototype.reload = function() { TaskEntry.prototype.toggle_visible = function() { debug("TaskEntry.prototype.toggle_visible") var btn = $('#hidebtn'); - if (btn.attr('checked') == 'checked') { - // uncheck the box - btn.attr('checked', false); - - // send the data - $.get("/ajax/hide_entry/"+this.idx, + if (btn.is(':checked')) { // is hidden, and we want to show + $.get("/exp_log/ajax/show_entry/"+this.idx, {}, function() { - console.log("Hiding task entry " + te.idx); - $("#row" + te.idx).css('background-color', 'gray'); + debug("Showing task entry " + te.idx); + $("#row" + te.idx).css({'background-color': 'white'}); } ); - } else { // is hidden, and we want to show - // uncheck the box - $('#hidebtn').attr('checked', true); - - // send the data - $.get("/ajax/show_entry/"+this.idx, + } else { // want to hide + $.get("/exp_log/ajax/hide_entry/"+this.idx, {}, function() { - console.log("Showing task entry " + te.idx); - $("#row" + te.idx).css('background-color', 'white'); + debug("Hiding task entry " + te.idx); + $("#row" + te.idx).css({"background-color": "gray"}); + te.destroy(); } ); } } TaskEntry.prototype.save_name = function() { - $.post("save_entry_name", {"id": this.idx, "entry_name": $("#entry_name").val()}); + var idx = this.idx; + var name = $("#entry_name").val() + var tr = this.tr; + if (this.idx) $.post("save_entry_name", {"id": this.idx, "entry_name": name}, function() { + if (name) tr.find("td.colID").html(name+" ("+idx+")"); + else tr.find("td.colID").html(idx); + }); } TaskEntry.prototype.toggle_backup = function() { debug("TaskEntry.prototype.toggle_backup") var btn = $('#backupbtn'); - if (btn.attr('checked') == 'checked') { // is flagged for backup and we want to unflag - // uncheck the box - btn.attr('checked', false); - - // send the data - $.get("/ajax/unbackup_entry/"+this.idx, + if (btn.is(':checked')) { // is hidden, and we want to show + $.get("/exp_log/ajax/backup_entry/"+te.idx, {}, function() { - console.log("Unflagging task entry for backup" + te.idx); + debug("Flagging task entry for backup" + te.idx); + }); + } else { + $.get("/exp_log/ajax/unbackup_entry/"+this.idx, + {}, + function() { + debug("Unflagging task entry for backup" + te.idx); } ); - } else { // is hidden, and we want to show - // uncheck the box - btn.attr('checked', true); + } +} - // send the data - $.get("/ajax/backup_entry/"+te.idx, +TaskEntry.prototype.toggle_template = function() { + debug("TaskEntry.prototype.toggle_template") + var btn = $('#templatebtn'); + if (btn.is(':checked')) { // is hidden, and we want to show + $.get("/exp_log/ajax/template_entry/"+this.idx, {}, function() { - console.log("Flagging task entry for backup" + te.idx); - }); + debug("Flagging task entry as a template: " + te.idx); + }); + } else { + $.get("/exp_log/ajax/untemplate_entry/"+this.idx, + {}, + function() { + debug("Unflagging task entry as a template: " + te.idx); + } + ); } } -/* callback for 'Copy Parameters' button. Note this is not a prototype function +/* callback for 'Copy Parameters' button. */ -TaskEntry.copy = function() { +TaskEntry.prototype.copy = function() { debug("TaskEntry.copy") - // start with the info saved in the current TaskEntry object - var info = te.expinfo; - info.report = {}; // clear the report data - info.datafiles = {}; // clear the datafile data - info.notes = ""; // clear the notes + // reset the task entry row + var idx = this.idx + this.tr.click( + function() { + if (te) te.destroy(); + te = new TaskEntry(idx); + } + ); + this.tr.removeClass("rowactive active error"); + this.tr = $("#newentry"); + this.tr.show(); + $('te_table_header').unbind("click"); + this.tr.unbind("click"); + + // bind callbacks to the tasks and features fieldsets + $("#tasks").change(this._task_query.bind(this)); + feats.bind_change_callback(this._task_query.bind(this)) + + // reset the task info + this.idx = null; // reset the id + this.status = "stopped"; // set the status + this.report.destroy(); // clear the report data + this.files.clear(); // clear the datafile data + this.files.hide(); + this.notes.destroy(); // clear the notes + this.report.hide(); // turn off the report pane + + // update the task info, but leave the parameters alone + this._task_query(function(){}, false, true); - te.being_copied = true; - te = new TaskEntry(null, info); - $('#report').hide(); // creating a TaskEntry with "null" goes into the "stopped" state + // go into the "stopped" state + task_interface.trigger.bind(this)({status: this.status}); } /* @@ -649,100 +614,66 @@ TaskEntry.prototype.destroy = function() { debug("TaskEntry.prototype.destroy") $("#content").hide(); - // Destruct the Report object for this TaskEntry + // Destruct objects this.report.destroy(); - - // Destruct the Sequence object for this TaskEntry - if (this.being_copied) { - // don't destroy when copying because two objects try to manipulate the - // Sequence at the same time - this.sequence.destroy_parameters(); - } else { - this.sequence.destroy(); - } - - this.annotations.destroy(); - - // Free the parameters - if (this.params) { - $(this.params.obj).remove(); - delete this.params; - } - - // Clear out list of files - $("#file_list").html("") + this.sequence.destroy(); + if (this.params) $(this.params.obj).remove(); + if (this.notes) this.notes.destroy(); + if (this.bmi) this.bmi.destroy(); + $(this.filelist).remove(); // Remove any designations that this TaskEntry is active/running/errored/etc. this.tr.removeClass("rowactive active error"); - $("#content").removeClass("error running testing") - - // Hide the 'files' field - $("#files").hide(); - $(this.filelist).remove(); - - if (this.idx != null) { - var idx = "row"+this.idx; - - // re-bind a callback to when the row is clicked - this.tr.click( - function() { - te = new TaskEntry(idx); - } - ) - - // clear the notes field - this.notes.destroy(); - // clear the BMI - if (this.bmi !== undefined) { - this.bmi.destroy(); - delete this.bmi; + // Re-bind a callback to when the row is clicked + var idx = this.idx + this.tr.unbind("click"); + this.tr.click( + function() { + if (te) te.destroy(); + te = new TaskEntry(idx); } + ); - } else { - //Remove the newentry row - $('#newentry').hide() - - //Rebind the click action to create a blank TaskEntry form - this.tr.click(function() { - te = new TaskEntry(null); - }) + this.destroyed = true; +} - $('#te_table_header').click( - function() { - te = new TaskEntry(null); - } - ) - //Clean up event bindings - feats.unbind_change_callback(); - $("#tasks").unbind("change"); - } +TaskEntry.prototype.remove = function(callback) { + debug('TaskEntry.prototype.remove') + $.getJSON("ajax/remove_entry/"+this.idx,function() { + location.reload(); + }); } -TaskEntry.prototype._task_query = function(callback) { +TaskEntry.prototype._task_query = function(callback, update_params=true, update_metadata=false) { debug('TaskEntry.prototype._task_query') var taskid = $("#tasks").attr("value"); var sel_feats = feats.get_checked_features(); $.getJSON("ajax/task_info/"+taskid+"/", sel_feats, function(taskinfo) { - console.log("Information about task received from the server"); - console.log(taskinfo); + debug("Information about task received from the server"); + debug(taskinfo); + + if (update_params) this.params.update(taskinfo.params); + if (typeof(callback) == "function") + callback(); - this.params.update(taskinfo.params); + if (taskinfo.generators) { + this.sequence.update_available_generators(taskinfo.generators); + } if (taskinfo.sequence) { $("#sequence").show() this.sequence.update(taskinfo.sequence); } else $("#sequence").hide() - if (typeof(callback) == "function") - callback(); - - this.annotations.update(taskinfo); + if (update_metadata) this.metadata.update(taskinfo.metadata); - if (taskinfo.generators) { - this.sequence.update_available_generators(taskinfo.generators); + if (taskinfo.controls) { + this.controls.update(taskinfo.controls); + } else { + this.controls.update([]); } }.bind(this) ); @@ -761,29 +692,38 @@ TaskEntry.prototype.stop = function() { $.post("stop/", csrf, task_interface.trigger.bind(this)); } -/* Callback for the 'Test' button - */ +// Callback for the 'Test' button TaskEntry.prototype.test = function() { debug("TaskEntry.prototype.test") - this.disable(); return this.run(false, true); } -/* Callback for the 'Start experiment' button - */ +// Callback for the 'Start experiment' button TaskEntry.prototype.start = function() { debug("TaskEntry.prototype.start") - this.disable(); return this.run(true, true); } +// Callback for 'Save Record' button TaskEntry.prototype.saverec = function() { - this.disable(); return this.run(true, false); } TaskEntry.prototype.run = function(save, exec) { debug("TaskEntry.run") + // make sure we're stopped + task_interface.trigger.bind(this)({status: "stopped"}); + + // check that inputs have been filled out + let valid = true; + $('[required]').each(function() { + if ($(this).is(':invalid') || !$(this).val()) valid = false; + }) + if (!valid) { + $("#experiment").trigger("submit"); // this will pop up a message to fill out the missing fields + return; + } + // activate the report; start listening to the websocket and update the 'report' field when new data is received if (this.report){ this.report.destroy(); @@ -791,9 +731,8 @@ TaskEntry.prototype.run = function(save, exec) { this.report = new Report(task_interface.trigger.bind(this)); this.report.activate(); this.report.set_mode("running"); - - this.annotations.show(); this.files.hide(); + this.disable(); var form = {}; form['csrfmiddlewaretoken'] = $("#experiment input").filter("[name=csrfmiddlewaretoken]").attr("value") @@ -834,12 +773,17 @@ TaskEntry.prototype.new_row = function(info) { // debug('TaskEntry.prototype.new_row: ' + info.idx); - this.idx = info.idx; + if (typeof(info.idx) == "number") { + this.idx = info.idx; + } else { + this.idx = parseInt(info.idx.match(/(\d+)/)[1]); + } this.tr.removeClass("running active error testing") // make the row hidden (becomes visible when the start or test buttons are pushed) this.tr.hide(); this.tr.click(function() { + if (te) te.destroy(); te = new TaskEntry(null); }) //Clean up event bindings @@ -849,10 +793,10 @@ TaskEntry.prototype.new_row = function(info) { this.tr = $(document.createElement("tr")); // add an id number to the row - this.tr.attr("id", "row"+info.idx); + this.tr.attr("id", "row"+this.idx); // Write the HTML for the table row - this.tr.html("Today" + + this.tr.html("Now" + "--" + ""+info.idx+"" + ""+info.subj+"" + @@ -862,15 +806,17 @@ TaskEntry.prototype.new_row = function(info) { // Insert the new row after the top row of the table $("#newentry").after(this.tr); this.tr.addClass("active rowactive running"); + this.tr.find('td').addClass("firstRowOfday"); + this.tr.next().find('td').removeClass("firstRowOfday"); this.notes = new Notes(this.idx); } TaskEntry.prototype.get_data = function() { var data = {}; - data['subject'] = parseInt($("#subjects").attr("value")); data['task'] = parseInt($("#tasks").attr("value")); data['feats'] = feats.get_checked_features(); data['params'] = this.params.to_json(); + data['metadata'] = this.metadata.get_data(); data['sequence'] = this.sequence.get_data(); data['entry_name'] = $("#entry_name").val(); data['date'] = $("#newentry_today").html(); @@ -879,7 +825,8 @@ TaskEntry.prototype.get_data = function() { } TaskEntry.prototype.enable = function() { debug("TaskEntry.prototype.enable"); - $("#parameters input").removeAttr("disabled"); + this.params.enable(); + this.metadata.enable(); feats.enable_entry(); if (this.sequence) this.sequence.enable(); @@ -888,7 +835,8 @@ TaskEntry.prototype.enable = function() { } TaskEntry.prototype.disable = function() { debug("TaskEntry.prototype.disable"); - $("#parameters input").attr("disabled", "disabled"); + this.params.disable(); + this.metadata.disable() feats.disable_entry(); if (this.sequence) this.sequence.disable(); @@ -904,7 +852,7 @@ TaskEntry.prototype.link_new_files = function() { var file_path = $("#file_path").val(); var new_file_path = $("#new_file_path").val(); var new_file_data = $("#new_file_raw_data").val(); - var new_file_data_format= $("new_file_data_format").val(); + var new_file_data_format= $("#new_file_data_format").val(); var browser_sel_file = document.getElementById("file_path_browser_sel").files[0]; if ($.trim(new_file_data) != "" && $.trim(new_file_path) != "") { @@ -926,28 +874,58 @@ TaskEntry.prototype.link_new_files = function() { $.post("/exp_log/link_data_files/" + this.idx + "/submit", data, function(resp) { $("#file_modal_server_resp").append(resp + "
"); - console.log("posted the file!"); + debug("posted the file!"); } ) } +// +// Metadata class +// +function Metadata() { + $("#metadata_table").html("") + var params = new Parameters(); + this.params = params; + $("#metadata_table").append(this.params.obj); + var add_new_row = $(''); + add_new_row.on("click", function() {params.add_row();}); + this.add_new_row = add_new_row; + $("#metadata_table").append(add_new_row); +} +Metadata.prototype.update = function(info) { + this.params.update(info) +} +Metadata.prototype.enable = function() { + this.params.enable(); + this.add_new_row.show(); +} +Metadata.prototype.disable = function() { + this.params.disable(); + this.add_new_row.hide(); +} +Metadata.prototype.get_data = function () { + var data = this.params.to_json(); + return data; +} + // // Notes class // function Notes(idx) { - this.last_TO = null; this.idx = idx; + $("#notes").val(""); + debug("Cleared notes") this.activate(); } Notes.prototype.update = function(notes) { - //console.log("Updating notes to \""+notes+"\""); + //debug("Updating notes to \""+notes+"\""); $("#notes textarea").attr("value", notes); } Notes.prototype.activate = function() { var notes_keydown_handler = function() { if (this.last_TO != null) clearTimeout(this.last_TO); - this.last_TO = setTimeout(this.save.bind(this), 2000); + this.last_TO = setTimeout(this.save.bind(this), 500); }.bind(this); $("#notes textarea").keydown(notes_keydown_handler); } @@ -955,15 +933,13 @@ Notes.prototype.destroy = function() { // unbind the handler to save notes to the database (see 'activate') $("#notes textarea").unbind("keydown"); - // clear the text - $("#notes").val(""); - - // clear the timeout handler + // clear the timeout handler and save if notes are changing if (this.last_TO != null) clearTimeout(this.last_TO); + this.save(); - // save right at the end - this.save(); + // reset the textarea + $("#notes textarea").val("").removeAttr("disabled"); } Notes.prototype.save = function() { this.last_TO = null; @@ -972,4 +948,107 @@ Notes.prototype.save = function() { 'csrfmiddlewaretoken' : $("#experiment input[name=csrfmiddlewaretoken]").attr("value") }; $.post("ajax/save_notes/"+this.idx+"/", notes_data); + debug("Saved notes"); +} + +// +// Controls class +// + +function create_control_callback(i, control_str, args, static=false) { + return function() {trigger_control(i, control_str, args, static)} +} + +function trigger_control(i, control, params, static) { + debug("Triggering control: " + control) + if (static) { + var data = { + "control": control, + "params": JSON.stringify(params.to_json()), + "base_class": $('#tasks').val(), + "feats": JSON.stringify(feats.get_checked_features()) + } + $.post("trigger_control", data, function(resp) { + debug("Control response", resp); + if (resp["status"] == "success") { + $('#controls_btn_' + i.toString()).css({"background-color": "green"}); + $('#controls_btn_' + i.toString()).animate({"background-color": "black"}, 500 ); + } + }) + } else { + $.post("trigger_control", {"control": control, "params": JSON.stringify(params.to_json())}, function(resp) { + debug("Control response", resp); + params.clear_all(); + if (resp["status"] == "pending") { + $('#controls_btn_' + i.toString()).css({"background-color": "yellow"}); + $('#controls_btn_' + i.toString()).animate({"background-color": "black"}, 500 ); + } + }) + } +} + +function Controls() { + this.control_list = []; + this.static_control_list = []; + this.params_list = []; + this.static_params_list = []; +} +Controls.prototype.update = function(controls) { + debug("Updating controls"); + $("#controls_table").html(''); + this.control_list = []; + this.static_control_list = []; + this.params_list = []; + this.static_params_list = []; + for (var i = 0; i < controls.length; i += 1) { + + var new_params = new Parameters(); + new_params.update(controls[i].params) + + var new_button = $('
+
+
- {{ n_blocks }} saved task entry records shown below
- -
+ {{ n_entries }} saved task entry records shown below{% if show_hidden %} ({{ n_hidden }} hidden) {% endif %} - - - + + + @@ -187,15 +243,27 @@ + + {% for e in templates %} + {% if e.visible %} + + {% else %} + + {% endif %} + {% if e == templates|first %} + + {% endif %} + + + + {% endfor %} + + + {% for e in entries %} + {% if e.visible %} - {% if e.html_date %} - - - - - - - {% else %} - - - - - - {% endif %} - - + {% else %} + + {% endif %} + {% if e.html_date %} + + + + + + + {% else %} + + + + + + {% endif %} + + {% endfor %}
DateTimeIDDateTimeID Who Task Description
Templates{{e.entry_name}}{{e.task.name}}
Today Now - +
{{e.html_date}}{{e.html_time}}{{e.ui_id}}{{e.subject.name}}{{e.task.name}}{{e.desc}}{{e.html_time}}{{e.ui_id}}{{e.subject.name}}{{e.task.name}}{{e.desc}}
{{e.html_date}}{{e.html_time}}{{e.ui_id}}{{e.subject.name}}{{e.task.name}}{{e.desc}}{{e.html_time}}{{e.ui_id}}{{e.subject.name}}{{e.task.name}}{{e.desc}}
@@ -231,18 +305,21 @@ HTML5 Icon
-
+
{% csrf_token %} -
+ + +
Visible? - Flagged for backup? + Flagged for backup? + Use as template
@@ -258,12 +335,17 @@
- Name: - + Name:
+ +
+ Metadata +
+
+
Features
    @@ -302,7 +384,10 @@
    - Parameters +
    + +
    +
    @@ -312,7 +397,7 @@
- Show All Parameters? + Show All Parameters? Parameters @@ -320,6 +405,7 @@
+
Report
@@ -336,11 +422,10 @@
-
- Annotations -
- Misc. annotation:
-
+
+ Controls + +
diff --git a/db/tracker/templates/setup_base.html b/db/tracker/templates/setup_base.html index 6af24807..03986f35 100644 --- a/db/tracker/templates/setup_base.html +++ b/db/tracker/templates/setup_base.html @@ -8,8 +8,9 @@