Source code for pupeyes.pupil

# -*- coding:utf-8 -*-

"""
Pupil Data Processing Module

This module provides tools for processing pupillometry data from eye trackers.
It includes functionality for deblinking, smoothing, baseline correction, and plotting
pupil size data.
"""

import warnings
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import os
import dill
from .utils import make_mask, lowpass_filter
from .aoi import is_inside
from .external.based_noise_blinks_detection import based_noise_blinks_detection

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from .defaults import default_mpl, default_plotly

[docs] class PupilProcessor: """ A class for processing and analyzing pupillometry data. This class provides methods for preprocessing pupil size data, including blink removal, artifact rejection, smoothing, and baseline correction. It also includes tools for data visualization and analysis. Parameters ---------- data : pd.DataFrame DataFrame containing pupil size data and associated measurements trial_identifier : str or list Column name(s) identifying unique trials. If list, trials are uniquely identified by the combination of these columns. pupil_col : str Column name containing pupil size measurements time_col : str Column name containing timestamps. Must be in milliseconds and integer. x_col : str Column name containing x-coordinates of gaze position y_col : str Column name containing y-coordinates of gaze position samp_freq : float Sampling frequency of the eye tracker in Hz convert_pupil_size : bool, default=False Whether to convert pupil size from area to diameter or vice versa artificial_d : float, default=5 Artificial pupil diameter in mm, used for pupil size conversion artificial_size : float, default=5663 Artificial pupil size in arbitrary units, used for pupil size conversion recording_unit : {'diameter', 'area'}, default='diameter' Unit of the recorded pupil size device : {'eyelink', 'tobii_titta', 'tobii_prolab','smi'}, default='eyelink' Device type. At the moment, this only controls whether sampling frequency is checked. eyetracker_missing_value : int, default=0 Value for missing pupil size for the eye tracker. Different eye trackers use different values to indicate missing values. PupEyes will replace these values with 0. Other possible values are pd.NA, np.nan, -1, -999, etc. progress_bar : bool, default=True Whether to show a progress bar for preprocessing steps Attributes ---------- data : pd.DataFrame The processed pupil data summary_data : pd.DataFrame Summary statistics for each trial trials : pd.DataFrame A dataframe of unique trial identifiers params : dict Dictionary storing parameters used in processing steps all_pupil_cols : list List of column names containing pupil data at different processing stages all_steps : list List of processing steps applied to the data Notes ----- - All processing methods return self for method chaining - Most methods create new columns with processed data rather than modifying existing ones - Processing parameters are stored in the params dictionary for reproducibility - Summary statistics are automatically updated after each processing step - artificial_d is the diameter of an artificial pupil provided by Eyelink. - artificial_size was measured for the setup of our research group and may not generalize to other setups. """ def __init__(self, data, trial_identifier, pupil_col, time_col, x_col, y_col, samp_freq, convert_pupil_size=False, artificial_d=5, artificial_size=5663, recording_unit='diameter', device='eyelink', eyetracker_missing_value=0, progress_bar=True): """ Initialize PupilProcessor object. """ #### device #### self.device = device print(f'Device: {self.device}') #### data #### # x gaze position self.x_col = x_col # y gaze position self.y_col = y_col # time column self.time_col = time_col # pupil column self.pupil_col = pupil_col # trial identifier # group by column for preprocessing if isinstance(trial_identifier, str): self.trial_identifier = [trial_identifier] else: self.trial_identifier = trial_identifier # make a copy of the data self.data = data.copy().convert_dtypes() #### handle missing values #### self.eyetracker_missing_value = eyetracker_missing_value # replace eye-tracker specified missing values with 0 # check if the missing value exists in the data if pd.isna(self.eyetracker_missing_value): # handle pd.NA or np.nan if self.data[self.pupil_col].isna().any(): print(f'Eye-tracker missing value is {self.eyetracker_missing_value}. Replacing with 0.') # convert to numeric to handle mixed types, then replace NA values self.data[self.pupil_col] = pd.to_numeric(self.data[self.pupil_col], errors='coerce').fillna(0) else: # handle specific numeric values if (self.data[self.pupil_col] == self.eyetracker_missing_value).any(): if self.eyetracker_missing_value == 0: print(f'Eye-tracker missing value for pupil size is 0. No replacement needed.') else: print(f'Eye-tracker missing value is {self.eyetracker_missing_value}. Replacing with 0.') self.data[self.pupil_col] = self.data[self.pupil_col].replace({self.eyetracker_missing_value: 0}) # check for non-integer timestamps and warn time_values = self.data[time_col].dropna() if not all(isinstance(val, (int, np.integer)) for val in time_values): import warnings warnings.warn( f"Non-integer timestamps detected in column '{time_col}'. " "The preprocessing pipeline expects integer timestamps in milliseconds. " "Decimal timestamps may cause issues in some preprocessing steps.", UserWarning ) #### check sampling frequency #### # sampling frequency self.samp_freq = samp_freq # check sampling frequency self.check_sampling_frequency() #### convert pupil size #### if convert_pupil_size: self.recording_unit = recording_unit self.artificial_d = artificial_d self.artificial_size = artificial_size self.data[self.pupil_col] = convert_pupil(self.data[self.pupil_col], artificial_d=artificial_d, artificial_size=artificial_size, recording_unit=recording_unit) print(f'Pupil data converted to {recording_unit} with artificial d={artificial_d} and artificial size={artificial_size}') else: self.recording_unit = None self.artificial_d = None self.artificial_size = None #### other stuff #### # store all preprocessing steps self.all_steps = [] # store generated pupil columns self.all_pupil_cols = [pupil_col] # store parameters for each step self.params = dict() # trials self.trials = self.data[self.trial_identifier].drop_duplicates().reset_index(drop=True) # empty dataframe to store summary of preprocessing steps self.summary_data = self.data.groupby(self.trial_identifier, sort=False).size().reset_index(name='n_samples') # outlier detection by info. leave as None if not performed self.baseline_outlier_by = None self.trace_outlier_by = None # progress bar self.progress_bar = progress_bar print(f'PupilProcessor initialized with {len(self.data)} samples') print(f'Pupil column: {self.pupil_col}, Time column: {self.time_col}, X column: {self.x_col}, Y column: {self.y_col}') print(f'Trial identifier: {self.trial_identifier}, Number of trials: {len(self.trials)}')
[docs] def check_sampling_frequency(self, sampling_rate=None, data=None): """ Check if the sampling frequency is consistent. Only performed for Eyelink data. This method checks if the sampling frequency is consistent across trials. If not, it raises an error. It is automatically called when initializing the PupilProcessor. If resampling is performed, the sampling frequency is checked again. Parameters ---------- sampling_rate : int, default=None Sampling rate to check. If None, the sampling rate is checked against the current sampling rate. data : pd.DataFrame, default=None Data to check. If None, the data is checked against the current data. The time column must be in milliseconds and integer. Returns ------- check_pass : bool True if the sampling frequency is consistent, False otherwise """ if self.device not in ['eyelink']: print(f'Sampling frequency check skipped for {self.device} data.') return True data = self.data if data is None else data sampling_rate = self.samp_freq if sampling_rate is None else sampling_rate check_pass = False # check if the difference between consecutive samples is equal to a fixed value diff = data.groupby(self.trial_identifier, sort=False)[self.time_col].diff().dropna().unique() if len(diff) == 1: if 1000/diff[0] != sampling_rate: raise ValueError(f'Actual sampling frequency {1000/diff[0]}Hz does not match the provided sampling frequency {sampling_rate}Hz!') else: print(f'Sampling frequency check passed. Sampling rate: {sampling_rate}Hz') check_pass = True else: raise ValueError('Sampling frequency is not consistent!') return check_pass
[docs] def artifact_rejection(self, suffix='_ar', method='both', speed_n=16, zscore_threshold=2.5, zscore_allowp=0.1): """ Reject artifacts from pupil data using speed and/or z-score based methods. This method identifies and removes artifacts using two possible approaches: 1. Speed-based: Removes samples where pupil size changes too rapidly 2. Z-score based: Removes extreme values based on z-score thresholds The method can use either approach individually or combine both. Parameters ---------- suffix : str, default='_ar' Suffix to append to the pupil column name for the artifact-rejected data. For example, if pupil column is 'pupil', the new column will be 'pupil_ar'. method : {'speed', 'zscore', 'both'}, default='both' Method to use for artifact rejection: - 'speed': Use only speed-based rejection - 'zscore': Use only z-score based rejection - 'both': Use both methods speed_n : int, default=16 Number of MADs above median speed to use as threshold for speed-based rejection zscore_threshold : float, default=2.5 Z-score threshold for artifact rejection for z-score based rejection zscore_allowp : float, default=0.1 Proportion of mean to use as minimum standard deviation for z-score based rejection. If sd/mean < zscore_allowp, the z-score threshold is not applied. This is to avoid rejecting stable data. Returns ------- self : PupilProcessor Returns self for method chaining Notes ----- - Updates summary_data with: - run_artifact: Boolean indicating if artifact rejection was performed - pct_artifact: Percentage of samples identified as artifacts - Creates a new column with suffix appended to the current pupil column name - Updates all_pupil_cols and all_steps to track processing history - Artifact periods are replaced with NaN values - Trials with all missing pupil data are skipped and reported - Processing parameters are stored in self.params['artifact_rejection'] """ # store parameters self.params['artifact_rejection'] = {k:v for k,v in locals().items() if k != 'self'} # create new column for artifact rejected data time_col = self.time_col pupil_col = self.all_pupil_cols[-1] new_col = pupil_col + suffix self.data[new_col] = self.data[pupil_col] # default to last pupil column # initialize artifacts removed column in summary data if method in ['speed', 'both']: self.summary_data['run_speed'] = False self.summary_data['pct_speed'] = pd.NA if method in ['zscore', 'both']: self.summary_data['run_size'] = False self.summary_data['pct_size'] = pd.NA # iterate over trials if trial_identifier is provided empty_trials = [] grouped = self.data.groupby(self.trial_identifier, sort=False) for group, groupdata in tqdm(grouped, desc=f'Artifact rejection', disable=not self.progress_bar): # check if groupdata has any pupil data if np.all(groupdata[pupil_col].isna()): empty_trials.append(group) else: if method in ['speed', 'both']: speed_mask = np.zeros(len(groupdata), dtype=bool) # initialize mask pupil_speed = compute_speed(groupdata[pupil_col].values, groupdata[time_col].values) median_speed = np.nanmedian(pupil_speed) mad = np.nanmedian(np.abs(pupil_speed - median_speed)) speed_mask = pupil_speed > (median_speed + (speed_n * mad)) # select row numbers in the group data, which are then used in .loc to select values needed to be set to nan self.data.loc[groupdata.index[speed_mask], new_col] = pd.NA # calculate percentage of speed artifacts pct_speed_artifacts = speed_mask.sum()/len(speed_mask) # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_speed'] = True self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'pct_speed'] = pct_speed_artifacts if method in ['zscore', 'both']: zscore_mask = np.zeros(len(groupdata), dtype=bool) # initialize mask mean = np.nanmean(groupdata[new_col]) std = np.nanstd(groupdata[new_col]) # check if std is larger than zscore_allowp * mean if std > zscore_allowp * mean: zscore_mask = np.abs(groupdata[new_col] - mean) > zscore_threshold * std self.data.loc[groupdata.index[zscore_mask], new_col] = pd.NA # calculate percentage of size artifacts pct_size_artifacts = zscore_mask.sum()/len(zscore_mask) else: pct_size_artifacts = 0.0 # no size artifacts # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_size'] = True self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'pct_size'] = pct_size_artifacts # update latest pupil column self.all_pupil_cols.append(new_col) self.all_steps.append('Artifact Rejected') # print summary print(f"✓ Artifact rejection completed!") print(f" → New column: '{new_col}' (artifacts removed)") print(f" → Previous column '{pupil_col}' preserved.") print(f" → {len(empty_trials)} trial(s) failed.") # print empty trials if len(empty_trials) > 0: # print a list of trials with high missing values print(f"\n {pd.DataFrame(empty_trials, columns=self.trial_identifier)}") return self
[docs] def filter_position(self, vertices, suffix = '_xy'): """ Filter pupil data based on gaze position within a polygon. This method removes pupil data points where the gaze position falls outside a specified polygon. This is useful for excluding data where participants were not looking at the intended region of interest. Parameters ---------- vertices : list of tuples List of (x,y) coordinates defining the polygon vertices. Must be in screen coordinates and form a closed polygon. Example: [(0,0), (0,1080), (1920,1080), (1920,0), (0,0)] suffix : str, default='_xy' Suffix to append to the pupil column name for the filtered data. For example, if pupil column is 'pupil', the new column will be 'pupil_xy'. Returns ------- self : PupilProcessor Returns self for method chaining. Notes ----- - Updates summary_data with: - run_gaze_filter: Boolean indicating if gaze filtering was performed - pct_gaze_filter: Percentage of samples outside the polygon - avg_gaze_x: Average gaze x-coordinate for the remaining samples - avg_gaze_y: Average gaze y-coordinate for the remaining samples - Creates a new column with suffix appended to the current pupil column name - Updates all_pupil_cols and all_steps to track processing history - Samples outside the polygon are replaced with NaN values - Trials with all missing pupil data are skipped and reported - Processing parameters are stored in self.params['filter_position'] Raises ------ ValueError If vertices cannot be converted to float numpy array """ # check if vertices can be converted to float numpy array try: vertices = np.array(vertices, dtype=float) except: raise ValueError("Vertices must be convertible to float numpy array") # store parameters self.params['filter_position'] = {k:v for k,v in locals().items() if k != 'self'} # create new column for filtered gaze position data x_col = self.x_col y_col = self.y_col pupil_col = self.all_pupil_cols[-1] new_col = pupil_col + suffix self.data[new_col] = self.data[pupil_col] # initialize position removed column in summary data self.summary_data['run_gaze_filter'] = False self.summary_data['pct_gaze_filter'] = pd.NA # iterate over trials if trial_identifier is provided empty_trials = [] grouped = self.data.groupby(self.trial_identifier, sort=False) for group, groupdata in tqdm(grouped, desc=f'Filtering based on gaze position', disable=not self.progress_bar): # check if groupdata has any pupil data if np.all(groupdata[pupil_col].isna()): empty_trials.append(group) else: # get gaze position gaze_pos = np.array(groupdata[[x_col, y_col]], dtype=float) # check if gaze position is inside the specified region inside_mask = is_inside(gaze_pos, vertices) # set pupil size to NaN if gaze position is outside the specified region self.data.loc[groupdata.index[~inside_mask], new_col] = pd.NA groupdata.loc[groupdata.index[~inside_mask], new_col] = pd.NA # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_gaze_filter'] = True self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'pct_gaze_filter'] = 1 - (inside_mask.sum()/len(inside_mask)) self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'avg_gaze_x'] = np.nanmean(groupdata.loc[groupdata[new_col].notna(), x_col]) self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'avg_gaze_y'] = np.nanmean(groupdata.loc[groupdata[new_col].notna(), y_col]) # update latest pupil column self.all_pupil_cols.append(new_col) self.all_steps.append('Gaze Filtered') # print summary print(f"✓ Gaze spatial filtering completed!") print(f" → New column: '{new_col}' (gaze filtered)") print(f" → Previous column '{pupil_col}' preserved.") print(f" → {len(empty_trials)} trial(s) failed.") # print empty trials if len(empty_trials) > 0: # print a list of trials with high missing values print(f"\n {len(empty_trials)} trials not filtered due to missing pupil data:") print(f"\n {pd.DataFrame(empty_trials, columns=self.trial_identifier)}") return self
[docs] def smooth(self, suffix='_sm', method='hann', window=100, **kwargs): """ Smooth pupil data using various smoothing methods. This method applies signal smoothing to reduce noise in the pupil data. Three smoothing methods are available: 1. Rolling mean: Simple moving average 2. Hann window: Weighted moving average using Hann window 3. Butterworth filter: Low-pass filter with specified cutoff Parameters ---------- suffix : str, default='_sm' Suffix to append to the pupil column name for the smoothed data. For example, if pupil column is 'pupil', the new column will be 'pupil_sm'. method : {'rollingmean', 'hann', 'butter'}, default='hann' Method to use for smoothing: - 'rollingmean': Simple moving average - 'hann': Hann window smoothing - 'butter': Butterworth low-pass filter window : int, default=100 Window size (in number of samples) for rolling mean or Hann window smoothing. Not used for Butterworth filter. **kwargs : dict Additional arguments for specific smoothing methods. - For rolling mean and hann window: Check `pandas.DataFrame.rolling <https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.rolling.html>`_ documentation for additional arguments. - For Butterworth filter: cutoff_freq : float Cutoff frequency in Hz. Default is 4 Hz. order : int Filter order. Default is 3. Returns ------- self : PupilProcessor Returns self for method chaining Notes ----- - Updates summary_data with smoothing method and parameters - Creates a new column with suffix appended to the current pupil column name - Updates all_pupil_cols and all_steps to track processing history - Missing values (NaN) are preserved """ # store parameters self.params['smooth'] = {k:v for k,v in locals().items() if k != 'self'} pupil_col = self.all_pupil_cols[-1] new_col = pupil_col + suffix if not isinstance(window, int) or window < 3: raise ValueError("Window size must be integer >= 3") if method not in ['rollingmean', 'hann', 'butter']: raise ValueError("Method must be 'rollingmean', 'hann', or 'butter'") if (method in ['rollingmean', 'hann']) and (len(self.data[pupil_col]) < window): raise ValueError('Data length smaller than window size') if method == 'butter' and ('sampling_freq' not in kwargs or 'cutoff_freq' not in kwargs): raise ValueError("For Butterworth filter, 'sampling_freq' and 'cutoff_freq' must be specified") if (method == 'butter') and (self.data[pupil_col].isnull().sum() > 0): raise ValueError("Butterworth filter does not support NaN values") # create new column for smoothed data self.data[new_col] = pd.NA # initialize summary data self.summary_data['run_smooth'] = False # iterate over trials if trial_identifier is provided empty_trials = [] grouped = self.data.groupby(self.trial_identifier, sort=False) for group, groupdata in tqdm(grouped, desc=f'Smoothing', disable=not self.progress_bar): # check if groupdata has any pupil data if np.all(groupdata[pupil_col].isna()): empty_trials.append(group) else: if method == 'rollingmean': smoothed = groupdata[pupil_col].rolling(window=window, center=True, **kwargs).mean() elif method == 'hann': smoothed = groupdata[pupil_col].rolling(window=window, win_type='hann', center=True, **kwargs).mean() elif method == 'butter': smoothed = lowpass_filter(groupdata[pupil_col], sampling_freq=self.sampling_freq, **kwargs) self.data.loc[groupdata.index, new_col] = smoothed # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_smooth'] = True # convert to Float64 since some values changed to float64 self.data[new_col] = self.data[new_col].convert_dtypes() # update latest pupil column self.all_pupil_cols.append(new_col) self.all_steps.append('Smoothed') # print summary print(f"✓ Smoothing completed!") print(f" → New column: '{new_col}' (smoothed)") print(f" → Previous column '{pupil_col}' preserved.") print(f" → {len(empty_trials)} trial(s) failed.") # print empty trials if len(empty_trials) > 0: # print a list of trials with high missing values print(f"\n {pd.DataFrame(empty_trials, columns=self.trial_identifier)}") return self
[docs] def check_missing(self, pupil_col=None, missing_value=pd.NA): """ Check for missing values in pupil data. This method calculates the percentage of missing values for each trial and updates the summary statistics. Missing values can be either NaN or a specific value. Parameters ---------- pupil_col : str, optional Column name to check for missing values. If None, uses the latest pupil column. missing_value : float or pd.NA, default=pd.NA Value to consider as missing. Can be pd.NA for NaN values or any specific value. Returns ------- self : PupilProcessor Returns self for method chaining. Notes ----- - Updates summary_data with: - run_check_missing: Boolean indicating if missing check was performed - missing: Percentage of missing values in each trial - Updates all_steps to track processing history - Trials that cannot be checked are reported - Processing parameters are stored in self.params['check_missing'] """ # store parameters self.params['check_missing'] = {k:v for k,v in locals().items() if k != 'self'} # use latest pupil column if not specified if pupil_col is None: pupil_col = self.all_pupil_cols[-1] # initialize summary data self.summary_data['run_check_missing'] = False self.summary_data['missing'] = 0.0 # iterate over trials if trial_identifier is provided skip_trials = [] missing_pct = 0.0 grouped = self.data.groupby(self.trial_identifier, sort=False) for group, groupdata in tqdm(grouped, desc=f'Checking missing values', disable=not self.progress_bar): try: if pd.isna(missing_value): missing_pct = groupdata[pupil_col].isna().sum()/len(groupdata) else: missing_pct = (groupdata[pupil_col] == missing_value).sum()/len(groupdata) # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_check_missing'] = True self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'missing'] = missing_pct except: skip_trials.append(group) # update latest step self.all_steps.append('Missing Values Checked') # print summary print(f"✓ Missing values checked!") print(f" → {len(skip_trials)} trial(s) failed.") # print failed trials if len(skip_trials) > 0: print(f"\n {pd.DataFrame(skip_trials, columns=self.trial_identifier)}") return self
[docs] def interpolate(self, suffix='_it', method='linear', missing_threshold=0.6): """ Interpolate missing values in pupil data. This method fills missing values in the pupil data using either linear or spline interpolation. Trials with too many missing values (above missing_threshold) are skipped to avoid unreliable interpolation. Parameters ---------- suffix : str, default='_it' Suffix to append to the pupil column name for the interpolated data. For example, if pupil column is 'pupil', the new column will be 'pupil_it'. method : {'linear', 'spline'}, default='linear' Method to use for interpolation: - 'linear': Linear interpolation between points - 'spline': Cubic spline interpolation missing_threshold : float, default=0.6 Maximum proportion of missing values allowed for interpolation. Trials with more missing values than this threshold are skipped. Returns ------- self : PupilProcessor Returns self for method chaining. Notes ----- - Updates summary_data with: - run_interpolate: Boolean indicating if interpolation was performed - pct_interpolate: Percentage of interpolated values in each trial - Creates a new column with suffix appended to the current pupil column name - Updates all_pupil_cols and all_steps to track processing history - Trials with too many missing values are skipped and reported - Processing parameters are stored in self.params['interpolate'] Raises ------ ValueError If method is not 'linear' or 'spline' """ if method not in ['spline', 'linear']: raise ValueError("Invalid method. Use 'linear' or 'spline'") # store parameters self.params['interpolate'] = {k:v for k,v in locals().items() if k != 'self'} # create new column for interpolated data pupil_col = self.all_pupil_cols[-1] new_col = pupil_col + suffix # initialize summary data self.summary_data['run_interpolate'] = False self.summary_data['pct_interpolate'] = 0.0 # iterate over trials if trial_identifier is provided skip_trials = [] grouped = self.data.groupby(self.trial_identifier, sort=False) for group, groupdata in tqdm(grouped, desc=f'Interpolating', disable=not self.progress_bar): # update summary data pct_missing = groupdata[pupil_col].isna().mean() self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'pct_interpolate'] = pct_missing # check for missing values if (pct_missing >= missing_threshold): skip_trials.append(group) else: if method == 'linear': interpolated = groupdata[pupil_col].interpolate(method='linear').ffill().bfill() else: interpolated = groupdata[pupil_col].interpolate(method='spline', order=3).ffill().bfill() # overwrite the new column with interpolated values self.data.loc[groupdata.index, new_col] = interpolated # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_interpolate'] = True # update latest pupil column self.all_pupil_cols.append(new_col) self.all_steps.append('Interpolated') # print summary print(f"✓ Interpolation completed!") print(f" → New column: '{new_col}' (interpolated)") print(f" → Previous column '{pupil_col}' preserved.") print(f" → {len(skip_trials)} trial(s) failed.") if len(skip_trials) > 0: # print a list of trials with high missing values print(f"\n {pd.DataFrame(skip_trials, columns=self.trial_identifier)}") return self
[docs] def downsample(self, target_samp_freq, agg_methods=None): """ Downsample pupil data to a new sampling rate. This method downsamples the data by binning into fixed time windows and aggregating values within each bin. This is useful for reducing data size or matching sampling rates between different recordings. Parameters ---------- target_samp_freq : int Target sampling frequency in Hz. agg_methods : dict, optional Dictionary mapping column names to aggregation methods. Example: {'pupil': 'mean', 'time': 'first', 'x': 'mean', 'y': 'mean'} If None, uses 'first' for all columns. Returns ------- self : PupilProcessor Returns self for method chaining. Notes ----- - Unlike other preprocessing functions, this function will replace the original .data with the downsampled data rather than creating a new column to the original .data. - Trials that cannot be downsampled are reported. - The sampling frequency is checked and updated again after downsampling. - Updates summary_data with: - run_downsample: Boolean indicating if downsampling was performed - downsampled_bin_size: Size of the downsampled time bin in milliseconds - downsampled_samp_freq: Downsampled sampling frequency in Hz - Updates all_steps to track processing history - Processing parameters are stored in self.params['downsample'] """ # store parameters self.params['downsample'] = {k: v for k, v in locals().items() if k != 'self'} # calculate new time step in milliseconds bin_size_ms = 1000/target_samp_freq # get data data = self.data time_col = self.time_col # aggregate methods for downsampling aggregation_methods = {col: 'first' for col in data.columns} if agg_methods is not None: aggregation_methods.update(agg_methods) # initialize summary data self.summary_data['run_downsample'] = False # group data grouped = self.data.groupby(self.trial_identifier, sort=False) # precompute offsets for each group offsets = grouped[time_col].transform('min') # normalize time and compute bins # this ensures that the first sample in each trial is at time 0 and that the same bin size is used for all trials normalized_time = self.data[time_col] - offsets bins = normalized_time // bin_size_ms # iterate over trials and aggregate data skip_trials = [] all_downsampled = [] for group, groupdata in tqdm(grouped, desc='Downsampling', disable=not self.progress_bar): try: # group by bins and aggregate data groupdata = groupdata.groupby(bins, as_index=False).agg(aggregation_methods) # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_downsample'] = True self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'downsampled_bin_size'] = bin_size_ms self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'downsampled_samp_freq'] = 1000/bin_size_ms except: skip_trials.append(group) # append downsampled data all_downsampled.append(groupdata) if all_downsampled: # concatenate downsampled data data = pd.concat(all_downsampled, ignore_index=True) # check sampling frequency, update samp_freq if pass check_pass = self.check_sampling_frequency(sampling_rate=target_samp_freq, data=data) if check_pass: self.samp_freq = target_samp_freq # update data self.data = data # print summary print(f"✓ Downsampling completed!") print(f" → New sampling frequency: {target_samp_freq} Hz") print(f" → {len(skip_trials)} trial(s) failed.") # print failed trials if len(skip_trials) > 0: print(f"\n {pd.DataFrame(skip_trials, columns=self.trial_identifier)}") # update latest step self.all_steps.append('Downsampled') else: raise ValueError('No trials were successfully downsampled!') return self
[docs] def upsample(self, target_samp_freq, fill_pupil=False): """ Upsample pupil data to a higher sampling rate. This method upsamples the data by inserting empty rows to meet the required sampling rate. Missing values are foward-filled for non-pupil columns. Pupil columns remain as NaN where no data exists unless fill_pupil=True. A new column 'upsampled' is added to track the inserted rows. Trials that cannot be upsampled are reported. The sampling frequency is checked and updated again after upsampling. Parameters ---------- target_samp_freq : int Target sampling frequency in Hz. Must be higher than current sampling frequency. fill_pupil : bool, default=False Whether to also fill missing values in pupil columns. If False, pupil columns remain as NaN where no data exists. This is simply a forward-fill. If you want to interpolate missing values, you can do so after upsampling. Returns ------- self : PupilProcessor Returns self for method chaining. Notes ----- - There might be slight discrepancies in the actual sampling rate from the target sampling rate because the time step between samples is rounded to the nearest integer. For example, if you supply a target sampling rate of 1001 Hz, the actual sampling rate will be 1000 Hz (round(1000/1001)= 1 ms time step). In the current implementation, this will result in an error because the actual sampling rate 1000 Hz does not match the target sampling rate 1001 Hz. - Upsampled data can be interpolated to fill missing values. You may need to set a lower missing_threshold for interpolation as the upsampling will introduce more missing values. - Updates summary_data with: - run_upsample: Boolean indicating if upsampling was performed - upsampled_bin_size: Size of the upsampled time bin in milliseconds - upsampled_samp_freq: Upsampled sampling frequency in Hz - Updates all_steps to track processing history - Processing parameters are stored in self.params['upsample'] """ # Validate target sampling frequency if target_samp_freq <= self.samp_freq: raise ValueError(f"Target sampling frequency ({target_samp_freq} Hz) must be higher than current frequency ({self.samp_freq} Hz)") # store parameters self.params['upsample'] = {k: v for k, v in locals().items() if k != 'self'} # calculate new time step in milliseconds new_time_step_ms = round(1000 / target_samp_freq) # get data data = self.data time_col = self.time_col # initialize summary data self.summary_data['run_upsample'] = False # group data grouped = self.data.groupby(self.trial_identifier, sort=False) # iterate over trials and upsample data skip_trials = [] all_upsampled = [] for group, groupdata in tqdm(grouped, desc='Upsampling', disable=not self.progress_bar): try: # get original time range min_time = int(groupdata[time_col].min()) max_time = int(groupdata[time_col].max()) # create complete time series from min_time to max_time complete_time_ms = np.arange(min_time, max_time + new_time_step_ms, new_time_step_ms) # create new dataframe with complete time series new_data = pd.DataFrame({time_col: complete_time_ms}) # add trial identifier columns for i, col in enumerate(self.trial_identifier): if isinstance(group, tuple): new_data[col] = group[i] # Use integer index for tuple else: new_data[col] = group # Single column case # merge with original data to get existing values # use outer merge to keep all time points merged = pd.merge(new_data, groupdata, on=[time_col] + self.trial_identifier, how='left') # determine which row is upsampled merged['upsampled'] = True merged.loc[merged[time_col].isin(groupdata[time_col]), 'upsampled'] = False # determine which columns to preserve vs fill if fill_pupil: preserve_cols = self.all_pupil_cols + [self.x_col, self.y_col] # pupil + gaze columns else: preserve_cols = [self.x_col, self.y_col] # only gaze columns fill_cols = [col for col in merged.columns if col not in preserve_cols and col not in [time_col] + self.trial_identifier] # fill missing values for non-preserved columns merged[fill_cols] = merged[fill_cols].ffill() # optionally fill pupil columns if requested if fill_pupil: pupil_fill_cols = [col for col in self.all_pupil_cols if col in merged.columns] merged[pupil_fill_cols] = merged[pupil_fill_cols].ffill() # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_upsample'] = True self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'upsampled_bin_size'] = new_time_step_ms self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'upsampled_samp_freq'] = target_samp_freq # append upsampled data all_upsampled.append(merged) except Exception as e: skip_trials.append(group) print(f"Failed to upsample trial {group}: {str(e)}") # concatenate upsampled data if all_upsampled: data = pd.concat(all_upsampled, ignore_index=True) # check sampling frequency, update samp_freq if pass check_pass = self.check_sampling_frequency(sampling_rate=target_samp_freq, data=data) if check_pass: self.samp_freq = target_samp_freq # update data self.data = data # print summary print(f"✓ Upsampling completed!") print(f" → New sampling frequency: {target_samp_freq} Hz") print(f" → {len(skip_trials)} trial(s) failed.") # print failed trials if len(skip_trials) > 0: print(f"\n {pd.DataFrame(skip_trials, columns=self.trial_identifier)}") # update latest step self.all_steps.append('Upsampled') else: raise ValueError("No trials were successfully upsampled!") return self
[docs] def baseline_correction(self, baseline_query, baseline_range=[None, None], suffix='_bc', method='subtractive'): """ Apply baseline correction to pupil data. Corrects pupil data by subtracting or dividing by baseline values. Creates a new column with the baseline-corrected data. Parameters ---------- baseline_query : str Query string to select baseline period data baseline_range : list, default=[None, None] Start and end indices for baseline period suffix : str, default='_bc' Suffix to append to the pupil column name for the corrected data. For example, if pupil column is 'pupil', the new column will be 'pupil_bc'. method : {'subtractive', 'divisive'}, default='subtractive' Method to use for baseline correction: - 'subtractive': Subtract baseline mean from pupil data - 'divisive': Divide pupil data by baseline mean Returns ------- self : PupilProcessor Returns self for method chaining. Notes ----- - Updates summary_data with: - run_baseline_correction: Boolean indicating if baseline correction was performed - baseline: Mean baseline value used for correction - Adds a new column with suffix appended to the current pupil column name - Updates all_pupil_cols and all_steps to track processing history """ # check for valid method if method not in ['subtractive', 'divisive']: raise ValueError("Invalid method. Use 'subtractive' or 'divisive'") # store parameters self.params['baseline_correction'] = {k:v for k,v in locals().items() if k != 'self'} # initialize summary data self.summary_data['run_baseline_correction'] = False self.summary_data['baseline'] = pd.NA # which columns to use for baseline correction pupil_col = self.all_pupil_cols[-1] new_col = pupil_col + suffix # get baseline data baseline_data = self.data.query(baseline_query) # get baseline range s, e = baseline_range # Precompute baseline means for each group baseline_means = baseline_data.groupby(self.trial_identifier)[pupil_col].apply(lambda x: x.iloc[s:e].mean()) # iterate over trials in data skip_trials = [] grouped = self.data.groupby(self.trial_identifier, sort=False) for group, groupdata in tqdm(grouped, desc=f'Baseline correction', disable=not self.progress_bar): # select baseline data for the current group baseline = baseline_means.loc[group] # check for nan if pd.isna(baseline) or np.all(groupdata[pupil_col].isna()): skip_trials.append(group) else: # do baseline correction if method == 'subtractive': self.data.loc[groupdata.index, new_col] = groupdata[pupil_col] - baseline else: self.data.loc[groupdata.index, new_col] = groupdata[pupil_col] / baseline # update summary data self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'run_baseline_correction'] = True self.summary_data.loc[np.all(self.summary_data[self.trial_identifier] == group, axis=1), 'baseline'] = baseline # print summary print(f"✓ Baseline correction completed!") print(f" → New column: '{new_col}' (baseline corrected)") print(f" → Previous column '{pupil_col}' preserved.") print(f" → {len(skip_trials)} trial(s) failed.") # print failed trials if len(skip_trials) > 0: print(f"\n {pd.DataFrame(skip_trials, columns=self.trial_identifier)}") # update latest step and latest pupil column self.all_steps.append('Baseline Corrected') self.all_pupil_cols.append(new_col) return self
[docs] def check_baseline_outliers(self, outlier_by=None, n_mad_baseline=4, plot=True, **kwargs): """ Check for outliers in baseline pupil values. Identifies outliers in baseline values using median absolute deviation (MAD). Can group data and check outliers within groups. Parameters ---------- outlier_by : str or list, optional Column(s) to group data by for outlier detection n_mad_baseline : float, default=4 Number of MADs from median to use as outlier threshold plot : bool, default=True Whether to plot the baseline distributions **kwargs : dict Additional arguments passed to plot_baseline Returns ------- self : PupilProcessor Returns self for method chaining Notes ----- - Updates summary_data with baseline outlier statistics - Updates all_steps """ # get summary data df_summary = self.summary_data.copy() # check if baseline data is available if 'baseline' not in df_summary.columns: raise ValueError("Baseline data is not available. Please run baseline correction first.") # convert outlier_by to list if it is a string if isinstance(outlier_by, str): outlier_by = [outlier_by] # initialize outlier masks df_summary['baseline_outlier'] = False df_summary['baseline_upper'] = pd.NA df_summary['baseline_lower'] = pd.NA if outlier_by is None: # calculate thresholds using pandas median_baseline = df_summary['baseline'].median() mad = (df_summary['baseline'] - median_baseline).abs().median() # calculate thresholds upper = median_baseline + n_mad_baseline*mad lower = median_baseline - n_mad_baseline*mad # mark outliers is_outlier = (df_summary['baseline'] > upper) | (df_summary['baseline'] < lower) # fill na with False is_outlier = is_outlier.fillna(False) # update summary data df_summary['baseline_outlier'] = is_outlier df_summary['baseline_upper'] = upper df_summary['baseline_lower'] = lower else: # calculate thresholds for each group for group, groupdata in tqdm(df_summary.groupby(outlier_by, sort=False), desc='Checking baseline pupil sizes for outliers', disable=not self.progress_bar): # calculate group thresholds using pandas median_baseline = groupdata['baseline'].median() mad = (groupdata['baseline'] - median_baseline).abs().median() upper = median_baseline + n_mad_baseline*mad lower = median_baseline - n_mad_baseline*mad # mark outliers for this group group_indices = groupdata.index is_outlier = (groupdata['baseline'] > upper) | (groupdata['baseline'] < lower) # fill na with False is_outlier = is_outlier.fillna(False) # update summary data df_summary.loc[group_indices, 'baseline_outlier'] = is_outlier df_summary.loc[group_indices, 'baseline_upper'] = upper df_summary.loc[group_indices, 'baseline_lower'] = lower # update summary data self.summary_data = df_summary # update outlier by self.baseline_outlier_by = outlier_by # update steps self.all_steps.append('Baseline Outlier Detection') # print summary print(f"✓ Baseline outlier detection completed!") print(f" → {df_summary['baseline_outlier'].sum()} trial(s) detected as baseline outliers.") # print outlier trials if df_summary['baseline_outlier'].any(): print(f"\n {df_summary.query('baseline_outlier==True')[self.trial_identifier]}") # plot if requested if plot: self.plot_baseline(plot_by=outlier_by, return_fig=False, **kwargs) return self
[docs] def check_trace_outliers(self, time_col=None, pupil_col=None, outlier_by=None, n_mad_trace=4, plot=True, **kwargs): """ Check for outlier trials based on pupil trace values. Detects outlier trials by comparing each trial's pupil trace against thresholds calculated from the median absolute deviation (MAD) of all trials. Outliers can be calculated globally or within specified groups. Parameters ---------- time_col : str, optional Column name for x-axis values (time). Defaults to time column. pupil_col : str, optional Column name for pupil values. Defaults to last pupil column. outlier_by : str or list, optional Column(s) to group trials by when calculating outlier thresholds. n_mad_trace : float, default=4 Number of MADs to use for outlier threshold. plot : bool, default=True Whether to plot the results. **kwargs Additional arguments passed to plotting function. Returns ------- self : object Returns self for method chaining. Notes ----- - Updates summary_data with: - run_trace_outlier: Boolean indicating if trace outlier detection was performed - trace_outlier: Boolean indicating if trial is an outlier - trace_upper: Upper threshold for outlier detection - trace_lower: Lower threshold for outlier detection - Outlier detection uses median absolute deviation (MAD) method - Can detect outliers globally or within groups specified by outlier_by """ # get data df = self.data.copy() df_summary = self.summary_data.copy() # get x and y columns if time_col is None: time_col = self.time_col if pupil_col is None: pupil_col = self.all_pupil_cols[-1] print(f'Checking trace outliers for {pupil_col}') # initialize outlier columns df_summary['trace_outlier'] = False df_summary['trace_upper'] = pd.NA df_summary['trace_lower'] = pd.NA # calculate outlier thresholds if outlier_by is None: # calculate thresholds for all trials grand_mean = df[pupil_col].mean() pupil_dist = df.groupby(self.trial_identifier)[pupil_col].apply(lambda x: (x - grand_mean).abs().max()) median_dist = pupil_dist.median() mad = (pupil_dist - median_dist).abs().median() upper = grand_mean + median_dist + (n_mad_trace * mad) lower = grand_mean - median_dist - (n_mad_trace * mad) # update summary data df_summary['trace_upper'] = upper df_summary['trace_lower'] = lower else: if not isinstance(outlier_by, list): outlier_by = [outlier_by] # calculate thresholds for each group for group, groupdata in df.groupby(outlier_by, sort=False): grand_mean = groupdata[pupil_col].mean() pupil_dist = groupdata.groupby(self.trial_identifier)[pupil_col].apply(lambda x: (x - grand_mean).abs().max()) median_dist = pupil_dist.median() mad = (pupil_dist - median_dist).abs().median() upper = grand_mean + median_dist + (n_mad_trace * mad) lower = grand_mean - median_dist - (n_mad_trace * mad) # update summary data for this group df_summary.loc[np.all(df_summary[outlier_by] == group, axis=1), 'trace_upper'] = upper df_summary.loc[np.all(df_summary[outlier_by] == group, axis=1), 'trace_lower'] = lower # mark outliers for trial, trialdata in tqdm(df.groupby(self.trial_identifier, sort=False), desc='Checking pupil traces for outliers', disable=not self.progress_bar): # get max and min values max_val = trialdata[pupil_col].max() min_val = trialdata[pupil_col].min() # get thresholds for this trial trial_mask = np.all(df_summary[self.trial_identifier] == trial, axis=1) upper_threshold = df_summary.loc[trial_mask, 'trace_upper'].iloc[0] lower_threshold = df_summary.loc[trial_mask, 'trace_lower'].iloc[0] # check for outliers and update summary data if pd.notna(max_val) and pd.notna(min_val) and pd.notna(upper_threshold) and pd.notna(lower_threshold): is_outlier = (max_val > upper_threshold) or (min_val < lower_threshold) df_summary.loc[trial_mask, 'trace_outlier'] = is_outlier # update summary data and steps self.summary_data = df_summary self.all_steps.append('Trace Outlier Detection') # update outlier by self.trace_outlier_by = outlier_by # print summary print(f"✓ Trace outlier detection completed!") print(f" → {df_summary['trace_outlier'].sum()} trial(s) detected as trace outliers.") # print outlier trials if df_summary['trace_outlier'].any(): print(f"\n {df_summary.query('trace_outlier==True')[self.trial_identifier]}") # plot if requested if plot: self.plot_spaghetti(time_col=time_col, pupil_col=pupil_col, plot_by=outlier_by, return_fig=False, **kwargs) return self
[docs] def summary(self, columns=None, level=None, agg_methods=None): """ Get summary statistics of the data. Returns summary data for specified columns, optionally grouped by level and aggregated using specified methods. Parameters ---------- columns : list, optional Columns to include in summary. Defaults to all columns. level : str or list, optional Column(s) to group by. agg_methods : dict, optional Dictionary mapping column names to aggregation methods. If None, uses mean for numeric columns. Returns ------- pandas.DataFrame Summary statistics dataframe. """ # convert dtypes self.summary_data = self.summary_data.convert_dtypes() if columns is None: columns = self.summary_data.columns if level is None: return self.summary_data[columns] else: if agg_methods is None: # get all numeric columns numeric_cols = self.summary_data[columns].select_dtypes(include=['number','boolean']).columns agg_methods = {col: 'mean' for col in numeric_cols} print(f"Using default aggregation methods: {agg_methods}") return self.summary_data.groupby(level)[columns].agg(agg_methods)
[docs] def validate_trials(self, trials_to_exclude, invert_mask=False): """ Mark trials as valid/invalid based on exclusion criteria. This method adds a 'valid' column to both the data and summary_data, marking trials as valid or invalid based on the provided exclusion criteria. Parameters ---------- trials_to_exclude : pandas.DataFrame DataFrame containing trial identifiers to exclude. Must have columns matching the trial_identifier of the PupilProcessor. invert_mask : bool, default=False If True, excludes all trials except those specified in trials_to_exclude. If False, excludes only the trials specified in trials_to_exclude. Returns ------- self : PupilProcessor Returns self for method chaining. Notes ----- - Adds 'valid' column to both summary_data and data - Valid column is boolean: True for valid trials, False for invalid trials - Trials are matched based on trial_identifier columns - Duplicate entries in trials_to_exclude are automatically removed """ # drop duplicates trials_to_exclude = trials_to_exclude.drop_duplicates() # get mask summary_mask = make_mask(self.summary_data, trials_to_exclude, invert=invert_mask) data_mask = make_mask(self.data, trials_to_exclude, invert=invert_mask) # update summary data self.summary_data['valid'] = summary_mask # update data self.data['valid'] = data_mask return self
[docs] def plot_pupil_surface(self, data=None, pupil_col=None, x_col=None, y_col=None, plot_type='count', vertices=None, nbins=64, log_counts=False, plot_by=None, show_centroid=True, save=None, plot_params=None): """ Create an interactive surface plot of pupil dilation by gaze coordinates using numpy.histogram2d. Parameters ---------- data : pandas.DataFrame, optional DataFrame containing pupil size, x-coordinates, and y-coordinates. Being able to specify data is useful for plotting a subset of the data. See examples below. If None, uses self.data. pupil_col : str, optional Column name for pupil size. Defaults to self.all_pupil_cols[-1]. x_col : str, optional Column name for x-coordinates of gaze. Defaults to self.x_col. y_col : str, optional Column name for y-coordinates of gaze. Defaults to self.y_col. plot_type : str, optional 'count' for number of measurements or 'size' for mean pupil size. Defaults to 'count'. nbins : int, optional Number of bins for the 2D histogram. Defaults to 64. log_counts : bool, default=False Whether to apply log transformation to counts (only applies when plot_type='count'). Defaults to False. plot_by : str, optional Column name to group data by for separate subplots. Defaults to None. show_centroid : bool, default=True Whether to show the centroid of the data. Defaults to True. save : str, optional Path to save plot. plot_params : dict, optional Dictionary of plotting parameters to override defaults - x_title : str, default='Gaze X' - y_title : str, default='Gaze Y' - title : str, default='Pupil Foreshortening Error Surface' - palette : str, default='Viridis' - width : int, default=400 - height : int, default=300 Examples -------- >>> # Plot a 2d histogram of the number of pupil measurements by condition >>> p.plot_pupil_surface(plot_by='condition') >>> # Plot a 2d histogram of the mean pupil size based on custom data >>> p.plot_pupil_surface(data=p.data[p.data['event'] == 'event_name']) >>> # Plot the mean pupil size rather than the count of measurements as a function of gaze coordinates >>> p.plot_pupil_surface(plot_type='size') """ # plot params if plot_params is None: plot_params = {} # get data if data is None: data = self.data.copy() # get column names if x_col is None: x_col = self.x_col if y_col is None: y_col = self.y_col if pupil_col is None: pupil_col = self.all_pupil_cols[-1] # drop nans data = data[data[x_col].notna() & data[y_col].notna() & data[pupil_col].notna()].reset_index(drop=True) # get x and y x = data[x_col] y = data[y_col] # Create groupby object if plot_by is not None: # Convert plot_by to list if it's not already if isinstance(plot_by, str): plot_by = [plot_by] grouped = data.groupby(plot_by, sort=False) else: grouped = [(None, data)] # Create figure fig = go.Figure() # Initialize lists to store all data for common color scale all_counts = [] all_pupil_means = [] # Process each group for _, group_data in grouped: group_x = group_data[x_col] group_y = group_data[y_col] group_pupil = group_data[pupil_col] # Compute histograms for this group counts, xedges, yedges = np.histogram2d(group_x, group_y, bins=nbins, range=[[x.min(), x.max()], [y.min(), y.max()]]) pupil_sum, _, _ = np.histogram2d(group_x, group_y, bins=[xedges, yedges], weights=group_pupil) pupil_mean = pupil_sum / np.maximum(counts, 1) all_counts.append(counts) all_pupil_means.append(pupil_mean) # Calculate global min/max for common color scale if plot_type == 'count': global_min = np.min([np.min(c) for c in all_counts]) global_max = np.max([np.max(c) for c in all_counts]) if log_counts: global_min = np.log1p(global_min) global_max = np.log1p(global_max) else: global_min = np.min([np.min(pm) for pm in all_pupil_means]) global_max = np.max([np.max(pm) for pm in all_pupil_means]) # Create a list to store visibility settings for each trace all_traces = [] visible_settings = [] dropdown_options = [] # Plot each group for i, (group_name, group_data) in enumerate(grouped): traces_in_group = [] # Format dropdown label if group_name is not None: # Convert group_name to list if it's not already if not isinstance(group_name, tuple): group_name = (group_name,) # Format each value with leading zeros if numeric label_parts = [f"{val:03d}" if isinstance(val, (int, float)) else str(val) for val in group_name] label = " | ".join(label_parts) else: label = "All" # Add dropdown option dropdown_options.append({ 'label': label, 'method': "update", 'args': [{"visible": []}, {}] # Will be filled later }) # get colorscale data if plot_type == 'count': colorscale_data = all_counts[i].T if log_counts: colorscale_data = np.log1p(colorscale_data) colorbar_title = 'Log Count' if log_counts else 'Count' else: colorscale_data = all_pupil_means[i].T colorbar_title = "Mean Size" # Add heatmap trace heatmap = go.Heatmap( x=(xedges[:-1] + xedges[1:]) / 2, y=(yedges[:-1] + yedges[1:]) / 2, z=colorscale_data, colorbar=dict(title=colorbar_title), colorscale=plot_params.get('palette', 'Viridis'), hoverongaps=False, hoverinfo='x+y+z', visible=(i == 0), # Only first group visible initially zmin=global_min, zmax=global_max ) fig.add_trace(heatmap) traces_in_group.append(heatmap) # Add vertices if provided if vertices is not None: vertices = np.array(vertices) aoi = go.Scatter( x=vertices[:,0], y=vertices[:,1], line=dict(color='black', width=2), name='Custom region', mode='lines', showlegend=False, visible=(i == 0) # Only first group visible initially ) fig.add_trace(aoi) traces_in_group.append(aoi) # Add centroid if show_centroid: centroid = np.array([group_data[x_col].mean(), group_data[y_col].mean()]) centroid_trace = go.Scatter( x=[centroid[0]], y=[centroid[1]], mode='markers', marker=dict(color='red', size=10, symbol='x'), showlegend=False, hoverinfo='x+y+name', name='Average Gaze', visible=(i == 0) # Only first group visible initially ) fig.add_trace(centroid_trace) traces_in_group.append(centroid_trace) all_traces.append(traces_in_group) # Create visibility settings for each dropdown option for i in range(len(all_traces)): vis = [] for j, traces in enumerate(all_traces): vis.extend([True if j == i else False] * len(traces)) visible_settings.append(vis) # Update the args for each dropdown option dropdown_options[i]['args'][0]["visible"] = vis # Update layout fig.update_layout( title=dict( text=plot_params.get('title', 'Pupil Foreshortening Error Surface'), x=0.5, y=0.95, xanchor='center', yanchor='top', font=dict(size=20, family='Arial', weight='bold') ), xaxis_title='Gaze X', yaxis_title='Gaze Y', yaxis_range=[y.max(), y.min()-1], # invert y axis xaxis_range=[x.min()-1, x.max()], width=plot_params.get('width', 800), height=plot_params.get('height', 600), updatemenus=[dict( type="dropdown", direction="down", x=1.0, # Position the dropdown at the right y=1.1, # Position slightly above the plot showactive=True, active=0, # Show first group by default buttons=dropdown_options )], margin=dict(l=80, r=80, t=100, b=80) ) # save figure if path is provided if save: if save.endswith('.html'): fig.write_html(save) else: raise ValueError(f"Interactive plots must be saved as html file. Got {save}.") return fig
def _get_plot_settings(self, x, y, plot_params=None, is_interactive=True): """ Helper method to get plot settings for both static and interactive trial plots. Parameters ---------- x : str Column name for x-axis values y : list Column name(s) for y-axis values plot_params : dict, optional Dictionary of plotting parameters to override defaults is_interactive : bool, default=True Whether settings are for interactive (Plotly) or static (Matplotlib) plot Returns ------- tuple (plot_specific_settings, kwargs) where kwargs are either matplotlib or plotly settings """ if plot_params is None: plot_params = {} # common plot specific settings plot_specific_settings = { 'layout': (len(y), 1), # number of rows, number of columns 'subplot_titles': y, # subplot titles 'x_title': x, 'y_title': '', 'showlegend': True, 'grid': False, # show grid } # update plot-specific settings if provided plot_specific_settings.update({k:v for k,v in plot_params.items() if k in plot_specific_settings}) if is_interactive: # plotly specific settings kwargs = default_plotly.copy() kwargs['width'] = plot_params.get('width', 800*plot_specific_settings['layout'][1]) kwargs['height'] = plot_params.get('height', 300*plot_specific_settings['layout'][0]) kwargs['title_text'] = plot_params.get('title_text', '') kwargs['xaxis_title'] = plot_params.get('xaxis_title', plot_specific_settings['x_title']) kwargs['yaxis_title'] = plot_params.get('yaxis_title', plot_specific_settings['y_title']) kwargs['showlegend'] = plot_params.get('showlegend', plot_specific_settings['showlegend']) kwargs['xaxis_showgrid'] = plot_params.get('xaxis_showgrid', plot_specific_settings['grid']) kwargs['yaxis_showgrid'] = plot_params.get('yaxis_showgrid', plot_specific_settings['grid']) else: # matplotlib specific settings kwargs = default_mpl.copy() kwargs['figure.figsize'] = plot_params.get('figure.figsize', (10*plot_specific_settings['layout'][1], 3*plot_specific_settings['layout'][0])) # update with any remaining valid kwargs kwargs.update({k:v for k,v in plot_params.items() if k not in plot_specific_settings}) return plot_specific_settings, kwargs
[docs] def plot_trial(self, trial, time_col=None, pupil_col=None, hue=None, save=None, interactive=True, plot_params=None): """ Plot data for a single trial. A wrapper function that calls either _plot_trial_interactive() or _plot_trial_static() depending on the interactive parameter. Parameters ---------- trial : pandas.DataFrame DataFrame containing trial identifier. time_col : str, optional Column name for x-axis values. Defaults to time column specified during initialization. pupil_col : str or list, optional Column name(s) for y-axis values. Defaults to all pupil columns. hue : str or list, optional Column(s) to group data by for separate lines. save : str, optional Path to save plot. interactive : bool, default=True Whether to create interactive plot. plot_params : dict, optional Additional plotting parameters. Returns ------- figure : matplotlib.figure.Figure or plotly.graph_objects.Figure Plot figure object. axes : matplotlib.axes.Axes, optional Plot axes object (only for static plots). """ if plot_params is None: plot_params = {} # plot using appropriate function if interactive: return self._plot_trial_interactive(trial, time_col, pupil_col, hue, save, plot_params) else: return self._plot_trial_static(trial, time_col, pupil_col, hue, save, plot_params)
def _plot_trial_static(self, trial, time_col=None, pupil_col=None, hue=None, save=None, plot_params=None): """ Create static plot of trial data using matplotlib. Parameters ---------- trial : pandas.DataFrame DataFrame containing trial identifier. time_col : str, optional Column name for x-axis values. Defaults to time column specified during initialization. pupil_col : str or list, optional Column name(s) for y-axis values. Defaults to all pupil columns. hue : str or list, optional Column(s) to group data by for separate lines save : str, optional Path to save figure plot_params : dict, optional Dictionary of plotting parameters to override defaults. Can include: - layout : tuple of (rows, cols) for subplot layout - subplot_titles : list of titles for subplots - x_title : x-axis label - y_title : y-axis label - showlegend : bool, whether to show legend - grid : bool, whether to show grid - Any matplotlib rcParams key Returns ------- tuple A tuple containing the figure and axes objects (fig, ax). Notes ----- - Uses matplotlib for static plotting - Creates subplots if multiple y columns provided - Groups data by hue variable(s) if provided - Applies default matplotlib styling that can be overridden """ if plot_params is None: plot_params = {} # get data data = self.data.copy() # get mask mask = make_mask(data, trial, invert=True) # mask data data = data[mask] # check if data is empty if data.empty: trial_info = ', '.join(f"{k}: {v}" for k, v in trial.items()) raise ValueError(f"No data found for trial with {trial_info}") # get x and y if time_col is None: time_col = self.time_col # default to time column if pupil_col is None: pupil_col = self.all_pupil_cols # default to all pupil columns if isinstance(pupil_col, str): pupil_col = [pupil_col] # make sure pupil_col is a list # get plot settings plot_specific_settings, mpl_kwargs = self._get_plot_settings(time_col, pupil_col, plot_params, is_interactive=False) # create subplots with context manager with mpl.rc_context(mpl_kwargs): fig = plt.figure() for i, col in enumerate(pupil_col): ax = fig.add_subplot(plot_specific_settings['layout'][0], plot_specific_settings['layout'][1], i+1) if hue: for trial_group, groupdata in data.groupby(hue, sort=False): # create label for legend label = ', '.join([str(k) for k in trial_group]) if isinstance(trial_group, tuple) else str(trial_group) ax.plot(groupdata[time_col], groupdata[col], label=label) else: # if no hue, plot all data together ax.plot(data[time_col], data[col]) # set labels and legend ax.set_xlabel(plot_specific_settings['x_title']) ax.set_ylabel(plot_specific_settings['y_title']) if plot_specific_settings['showlegend']: ax.legend() ax.set_title(plot_specific_settings['subplot_titles'][i]) # configure grid ax.grid(plot_specific_settings['grid']) # save figure if path is provided if save: plt.savefig(save) return fig, ax def _plot_trial_interactive(self, trial, time_col=None, pupil_col=None, hue=None, save=None, plot_params=None): """ Create an interactive plot of trial data using Plotly. Parameters ---------- trial : pandas.DataFrame DataFrame containing trial identifier. time_col : str, optional Column name for x-axis values. Defaults to time column specified during initialization. pupil_col : str or list, optional Column name(s) for y-axis values. Defaults to all pupil columns. hue : str or list, optional Column(s) to group data by for different traces save : str, optional Path to save the plot plot_params : dict, optional Dictionary of plot parameters to override defaults - layout : tuple of (rows, cols) for subplot layout - subplot_titles : list of titles for subplots - x_title : str, title of the x-axis - y_title : str, title of the y-axis - showlegend : bool, whether to show the legend - grid : bool, whether to show the grid - width : int, width of the plot - height : int, height of the plot - title_text : str, title of the plot - xaxis_showgrid : bool, whether to show the x-axis grid - yaxis_showgrid : bool, whether to show the y-axis grid Returns ------- plotly.graph_objects.Figure Interactive Plotly figure Notes ----- - Creates subplots if multiple y variables provided - Uses Plotly's default color scheme for traces - Allows customization through plot_params dictionary """ if plot_params is None: plot_params = {} # get data data = self.data.copy() # get mask mask = make_mask(data, trial, invert=True) # mask data data = data[mask] # check if data is empty if data.empty: trial_info = ', '.join(f"{k}: {v}" for k, v in trial.items()) raise ValueError(f"No data found for trial with {trial_info}") # get x and y if time_col is None: time_col = self.time_col # default to time column if pupil_col is None: pupil_col = self.all_pupil_cols # default to all pupil columns if isinstance(pupil_col, str): pupil_col = [pupil_col] # make sure pupil_col is a list # get plot settings plot_specific_settings, ply_kwargs = self._get_plot_settings(time_col, pupil_col, plot_params, is_interactive=True) # plot using plotly fig = make_subplots(rows=plot_specific_settings['layout'][0], cols=plot_specific_settings['layout'][1], start_cell="top-left", subplot_titles=plot_specific_settings['subplot_titles'], specs = np.full((plot_specific_settings['layout'][0],plot_specific_settings['layout'][1]), {}).tolist(), # remove margins horizontal_spacing = 0.1, # reduce spacing #vertical_spacing = 0.12 ) # default plotly colors cols = plotly.colors.DEFAULT_PLOTLY_COLORS # iterate over y variables for i, col in enumerate(pupil_col): # figure out row and column curr_row = int(i // plot_specific_settings['layout'][1] + 1) curr_col = int(i % plot_specific_settings['layout'][1] + 1) if hue: for g, (trial_group, groupdata) in enumerate(data.groupby(hue, sort=False)): # create label for legend label = ', '.join([str(k) for k in trial_group]) if isinstance(trial_group, tuple) else str(trial_group) # assign color but cycle through colors if more trials than colors curr_color = cols[g % len(cols)] fig.add_trace(go.Scatter(x=groupdata[time_col], y=groupdata[col], mode='lines', name=label, line=dict(color=curr_color), showlegend=ply_kwargs['showlegend'] if i==0 else False # only show legend for first plot ), row=curr_row, col=curr_col) else: curr_color = cols[0] fig.add_trace(go.Scatter(x=data[time_col], y=data[col], mode='lines', name=col, line=dict(color=curr_color), showlegend=False ), row=curr_row, col=curr_col) # update layout fig.update_xaxes(**{k[6:]:v for k, v in ply_kwargs.items() if 'xaxis' in k}) # update x-axis settings fig.update_yaxes(**{k[6:]:v for k, v in ply_kwargs.items() if 'yaxis' in k}) # update y-axis settings fig.update_layout(**{k:v for k, v in ply_kwargs.items() if 'xaxis' not in k and 'yaxis' not in k}) # update other layout settings # hack to update font and color for subplot titles for i in fig['layout']['annotations']: i['text'] = '<b>' + i['text'] + '</b>' # make subplot titles bold i['font'] = dict(size=16,color='black') # set font size and color # save figure if path is provided if save: fig.write_image(save) return fig
[docs] def plot_baseline(self, plot_by=None, show_outliers=True, save=None, interactive=True, plot_params=None, return_fig=False): """ Plot histogram of baseline pupil sizes. This is a wrapper function that calls either plot_baseline_interactive() or plot_baseline_static() depending on the interactive parameter. Parameters ---------- plot_by : str or list, optional Column(s) to group data by for separate plots. show_outliers : bool, default=True Whether to show outlier thresholds. save : str, optional Path to save plot. interactive : bool, default=True Whether to create interactive plot. plot_params : dict, optional Additional plotting parameters. return_fig : bool, default=False Whether to return the figure object. Returns ------- figure : matplotlib.figure.Figure or plotly.graph_objects.Figure Plot figure object if return_fig is True. axes : matplotlib.axes.Axes, optional Plot axes object (only for static plots). See Also -------- plot_baseline_interactive : Create interactive baseline histogram plot plot_baseline_static : Create static baseline histogram plot """ # plot if interactive: return self._plot_baseline_interactive(plot_by=plot_by, show_outliers=show_outliers, save=save, plot_params=plot_params, return_fig=return_fig) else: return self._plot_baseline_static(plot_by=plot_by, show_outliers=show_outliers, save=save, plot_params=plot_params, return_fig=return_fig)
def _plot_baseline_static(self, plot_by=None, show_outliers=True, save=None, plot_params=None, return_fig=False): """ Plot histogram of baseline pupil sizes. Parameters ---------- plot_by : str or list, optional Column(s) to group data by for separate plots. show_outliers : bool, default=True Whether to show outlier thresholds. save : str, optional Path to save plot. plot_params : dict, default={} Additional plotting parameters. return_fig : bool, default=False Whether to return the figure object. Returns ------- figure : matplotlib.figure.Figure Plot figure object. axes : matplotlib.axes.Axes Plot axes object. Notes ----- Requires baseline data and optionally baseline outlier information. """ plot_params = plot_params or {} # get summary data df_summary = self.summary_data.copy() # check if baseline data is available if ('baseline' not in df_summary.columns): raise ValueError("Baseline data is not available. Please run baseline correction first.") elif plot_by is not None: # convert plot_by to list if not already if isinstance(plot_by, str): plot_by = [plot_by] # check if plot_by columns exist if not all(col in df_summary.columns for col in plot_by): raise ValueError(f"Plot by column(s) {plot_by} not found in summary data.") elif show_outliers and not all(col in df_summary.columns for col in ['baseline_outlier', 'baseline_upper', 'baseline_lower']): raise ValueError("Outlier data is not available. Please run check_baseline_outliers first.") # check if outlier by is the same as plot_by if show_outliers and self.baseline_outlier_by is not None and self.baseline_outlier_by != plot_by: # both outlier by and plot by should be a list at this point warnings.warn(f"Outlier detection was performed by {self.baseline_outlier_by}. Plotting by {plot_by}. The plotted thresholds may be incorrect.") # number of plots if plot_by is not None: grouped = df_summary.groupby(plot_by, sort=False) n_plots = grouped.ngroups else: grouped = [(None, df_summary)] n_plots = 1 # some additional plot settings specific to histogram plots plot_specific_settings = { 'layout': [(n_plots - 1) // min(2, n_plots) + 1, min(2, n_plots)], # nrows, ncols 'title': 'Baseline Pupil Sizes', 'x_title': 'Baseline Pupil Sizes', 'y_title': 'Count', 'vline_color': 'red', 'vline_linestyle': '--', 'bins': 30, 'grid': False } # update plot-specific settings if provided plot_specific_settings.update({k:v for k,v in plot_params.items() if k in plot_specific_settings}) # update defaults settings if provided mpl_kwargs = default_mpl.copy() mpl_kwargs['figure.figsize'] = (plot_specific_settings['layout'][1]*8,plot_specific_settings['layout'][0]*3) # ncols, nrows mpl_kwargs.update({k:v for k,v in plot_params.items() if k not in plot_specific_settings}) with mpl.rc_context(mpl_kwargs): # Get unique combinations of grouping variables fig, axes = plt.subplots(plot_specific_settings['layout'][0], plot_specific_settings['layout'][1]) if n_plots == 1: axes = [axes] else: axes = axes.flatten() # Plot each group for idx, (group_name, group_data) in enumerate(grouped): # get axis ax = axes[idx] if show_outliers: # Plot histogram sns.histplot(data=group_data, x='baseline', hue='baseline_outlier', bins=plot_specific_settings['bins'], ax=ax, legend=True) # get thresholds upper_thresh = group_data['baseline_upper'].values[0] # asume all values are the same lower_thresh = group_data['baseline_lower'].values[0] # asume all values are the same # Add threshold lines ax.axvline(upper_thresh, color=plot_specific_settings['vline_color'], linestyle=plot_specific_settings['vline_linestyle']) ax.axvline(lower_thresh, color=plot_specific_settings['vline_color'], linestyle=plot_specific_settings['vline_linestyle']) # Add threshold labels ax.text(upper_thresh, ax.get_ylim()[1]*0.1, f'{upper_thresh:.2f}', rotation=90, va='bottom', ha='right') ax.text(lower_thresh, ax.get_ylim()[1]*0.1, f'{lower_thresh:.2f}', rotation=90, va='bottom', ha='right') else: # Plot histogram sns.histplot(data=group_data, x='baseline', ax=ax, bins=plot_specific_settings['bins']) # Set labels ax.set_xlabel(plot_specific_settings['x_title']) ax.set_ylabel(plot_specific_settings['y_title']) # Set title if n_plots > 1: if isinstance(group_name, tuple): title = ' | '.join([x for x in group_name]) else: title = f'{group_name}' ax.set_title(title) fig.suptitle(plot_specific_settings['title']) else: ax.set_title(plot_specific_settings['title']) # Configure grid ax.grid(plot_specific_settings['grid']) # Remove empty subplots if any for idx in range(n_plots, len(axes)): fig.delaxes(axes[idx]) fig.tight_layout() # Save figure if path is provided if save: plt.savefig(save, bbox_inches='tight', dpi=mpl_kwargs['figure.dpi']) # return figure if return_fig: return fig, axes def _plot_baseline_interactive(self, plot_by=None, show_outliers=True, save=None, plot_params=None, return_fig=True): """ Create interactive histogram plot of baseline pupil sizes using Plotly Express. Parameters ---------- plot_by : str or list, optional Column(s) to group data by for separate plots. show_outliers : bool, default=True Whether to show outlier thresholds. save : str, optional Path to save plot. plot_params : dict, default={} Additional plotting parameters. return_fig : bool, default=True Whether to return the figure object. Returns ------- figure : plotly.graph_objects.Figure Interactive Plotly figure object. """ plot_params = plot_params or {} # get summary data df_summary = self.summary_data.copy() # check if baseline data is available if ('baseline' not in df_summary.columns): raise ValueError("Baseline data is not available. Please run baseline correction first.") elif plot_by is not None: # convert plot_by to list if not already if isinstance(plot_by, str): plot_by = [plot_by] # check if plot_by columns exist if not all(col in df_summary.columns for col in plot_by): raise ValueError(f"Plot by column(s) {plot_by} not found in summary data.") elif show_outliers and not all(col in df_summary.columns for col in ['baseline_outlier', 'baseline_upper', 'baseline_lower']): raise ValueError("Outlier data is not available. Please run check_baseline_outliers first.") # check if outlier by is the same as plot_by if show_outliers and self.baseline_outlier_by is not None and self.baseline_outlier_by != plot_by: warnings.warn(f"Outlier detection was performed by {self.baseline_outlier_by}. Plotting by {plot_by}. The plotted thresholds may be incorrect.") # Plot settings plot_specific_settings = { 'title': 'Baseline Pupil Sizes', 'x_title': 'Baseline Pupil Sizes', 'y_title': 'Count', 'vline_color': 'red', 'vline_style': 'dash', 'bins': 30 } # Update plot settings if provided plot_specific_settings.update({k:v for k,v in plot_params.items() if k in plot_specific_settings}) # Get groups if plot_by is not None: grouped = df_summary.groupby(plot_by, sort=False) else: grouped = [(None, df_summary)] # Create figure fig = go.Figure() # Create dropdown menu options dropdown_options = [] # Keep track of trace indices for each group group_traces = [] # Create a temporary matplotlib figure for seaborn to plot into temp_fig, temp_ax = plt.subplots() # Add traces for each group for groupid, (group_name, group_data) in enumerate(grouped): # Format group name for display group_title = ' | '.join([str(x) for x in group_name]) if isinstance(group_name, tuple) else str(group_name) if group_name is not None else "All" # Keep track of traces for this group current_group_traces = [] # Clear the temporary axis temp_ax.clear() if show_outliers: # Use seaborn to compute histogram sns_hist = sns.histplot( data=group_data, x='baseline', hue='baseline_outlier', bins=plot_specific_settings['bins'], ax=temp_ax, legend=True ) # Get all patches and their colors all_patches = sns_hist.patches n_patches = len(all_patches) patches_per_category = n_patches // 2 # Since we have two categories # Check if there are any outliers in the data has_outliers = group_data['baseline_outlier'].any() if has_outliers: # Process non-outliers (first half of patches) and outliers (second half) categories = [ (False, '#DD8452', all_patches[:patches_per_category], 'Outliers'), (True, '#4C72B0', all_patches[patches_per_category:], 'Non-outliers') ] else: # If no outliers, use all patches with a single color categories = [ (False, '#4C72B0', all_patches, 'Non-outliers') ] for outlier_status, color, patches, label in categories: if patches: # Only add trace if there are bars # Extract x and y values from patches x = [p.get_x() + p.get_width()/2 for p in patches] y = [p.get_height() for p in patches] widths = [p.get_width() for p in patches] # Add trace to plotly figure fig.add_trace( go.Bar( x=x, y=y, width=widths[0], # All widths should be the same name=label, marker=dict(line=dict(color='black', width=1.5)), marker_color=color, opacity=0.75, visible=(groupid == 0), hovertemplate="Baseline: %{x}<br>Count: %{y}<extra></extra>" ) ) current_group_traces.append(len(fig.data) - 1) # Get thresholds upper_thresh = group_data['baseline_upper'].values[0] lower_thresh = group_data['baseline_lower'].values[0] # Get max y value from the histogram max_y = max(p.get_height() for p in all_patches) * 1.1 # Add threshold lines fig.add_trace( go.Scatter( x=[upper_thresh, upper_thresh], y=[0, max_y], mode='lines', name=f'Upper threshold: {upper_thresh:.2f}', line=dict( color=plot_specific_settings['vline_color'], dash=plot_specific_settings['vline_style'], width=2 ), visible=(groupid == 0), showlegend=True ) ) current_group_traces.append(len(fig.data) - 1) fig.add_trace( go.Scatter( x=[lower_thresh, lower_thresh], y=[0, max_y], mode='lines', name=f'Lower threshold: {lower_thresh:.2f}', line=dict( color=plot_specific_settings['vline_color'], dash=plot_specific_settings['vline_style'], width=2 ), visible=(groupid == 0), showlegend=True ) ) current_group_traces.append(len(fig.data) - 1) else: # Use seaborn to compute histogram without outlier distinction sns_hist = sns.histplot( data=group_data, x='baseline', bins=plot_specific_settings['bins'], ax=temp_ax ) # Extract histogram data patches = sns_hist.patches x = [p.get_x() + p.get_width()/2 for p in patches] y = [p.get_height() for p in patches] widths = [p.get_width() for p in patches] # Add trace to plotly figure fig.add_trace( go.Bar( x=x, y=y, width=widths[0], # All widths should be the same name='All trials', marker_color='#4C72B0', marker=dict(line=dict(color='black', width=1.5)), opacity=0.75, visible=(groupid == 0), showlegend=False, hovertemplate="Baseline: %{x}<br>Count: %{y}<extra></extra>" ) ) current_group_traces.append(len(fig.data) - 1) # Store traces for this group group_traces.append({ 'title': group_title, 'traces': current_group_traces }) # Clean up temporary matplotlib figure plt.close(temp_fig) # Create dropdown menu options for group_info in group_traces: # Create visibility settings vis = [False] * len(fig.data) for trace_idx in group_info['traces']: vis[trace_idx] = True # Add dropdown option with proper title update dropdown_options.append( dict( args=[ {"visible": vis}, { "title": { "text": f"{plot_specific_settings['title']} - {group_info['title']}", "x": 0.5, "xanchor": "center", "y": 0.95, "yanchor": "top", "font": { "size": default_plotly['title_font_size'], "family": default_plotly['title_font_family'], "weight": default_plotly['title_font_weight'] } } } ], label=group_info['title'], method="update" ) ) # Update layout with Plotly defaults and dropdown menu ply_kwargs = default_plotly.copy() ply_kwargs['width'] = plot_params.get('width', 800) ply_kwargs['height'] = plot_params.get('height', 500) # Set initial title to include the first group's name for consistency initial_group_title = group_traces[0]['title'] if group_traces else "All" ply_kwargs['title'] = { "text": f"{plot_specific_settings['title']} - {initial_group_title}", "x": 0.5, "xanchor": "center", "y": 0.95, "yanchor": "top", "font": { "size": default_plotly['title_font_size'], "family": default_plotly['title_font_family'], "weight": default_plotly['title_font_weight'] } } ply_kwargs['xaxis_title'] = plot_specific_settings['x_title'] ply_kwargs['yaxis_title'] = plot_specific_settings['y_title'] ply_kwargs['updatemenus'] = [{ 'buttons': dropdown_options, 'direction': 'down', 'showactive': True, 'x': 1.2, 'y': 1.2, 'xanchor': 'right', 'yanchor': 'top' }] ply_kwargs['yaxis_range'] = [0, max_y] ply_kwargs['barmode'] = 'overlay' ply_kwargs.update({k:v for k,v in plot_params.items() if k not in plot_specific_settings}) # update layout fig.update_layout(**ply_kwargs) # Save if requested if save: if save.endswith('.html'): fig.write_html(save) else: raise ValueError(f"Interactive plots must be saved as html file. Got {save}.") # Return or display figure if return_fig: return fig else: display(fig)
[docs] def plot_spaghetti(self, time_col=None, pupil_col=None, show_outliers=True, plot_by=None, save=False, plot_params=None, return_fig=True): """ Plot pupil traces for all trials as a spaghetti plot. Parameters ---------- time_col : str, optional Column name for x-axis. Defaults to time column specified during initialization. pupil_col : str, optional Column name for y-axis. Defaults to latest pupil column. show_outliers : bool, default=True Whether to highlight outlier traces. plot_by : str or list, optional Column(s) to group data by for separate plots. save : str, optional Path to save plot. Only supports html files. If None, plot is not saved. plot_params : dict, default={} Additional plotting parameters. return_fig : bool, default=True Whether to return the figure object. Returns ------- plotly.graph_objects.Figure Plot figure object if return_fig is True. Notes ----- Creates an interactive spaghetti plot showing pupil traces for all trials. If plot_by is specified, creates separate subplots for each group using dropdown menus. Outlier traces can be highlighted if outlier detection was performed. """ plot_params = plot_params or {} # get summary data df_summary = self.summary_data.copy() # check if trace_outlier is in summary_data if show_outliers and ('trace_outlier' not in df_summary.columns): raise ValueError("trace_outlier column not found in summary_data. Please run check_trace_outliers first.") # get x and y if time_col is None: time_col = self.time_col # default to time column if pupil_col is None: pupil_col = self.all_pupil_cols[-1] # default to last pupil column # get data df_plot = self.data.copy() if plot_by is not None: # convert plot_by to list if not already if isinstance(plot_by, str): plot_by = [plot_by] # get unique columns cols = [time_col, pupil_col] + plot_by + self.trial_identifier cols = list(set(cols)) grouped = df_plot[cols].groupby(plot_by, sort=False) else: cols = [time_col, pupil_col] + self.trial_identifier cols = list(set(cols)) grouped = [(None, df_plot[cols])] # Get overall x range for threshold lines x_min = df_plot[time_col].min() x_max = df_plot[time_col].max() # check if outlier by is the same as plot_by if show_outliers and self.trace_outlier_by is not None and self.trace_outlier_by != plot_by: # both outlier by and plot by should be a list at this point warnings.warn(f"Outlier detection was performed by {self.trace_outlier_by}. Plotting by {plot_by}. The plotted thresholds may be incorrect.") # plot # some additional plot settings specific to this plot plot_specific_settings = { 'title': 'Spaghetti plot', 'subplot_titles': [' | '.join([str(x) for x in group]) for group, _ in grouped] if plot_by is not None else None, 'x_title': time_col, 'y_title': pupil_col, # line settings 'line_width': 2, 'line_style': 'solid', # hline settings 'hline_color': 'black', 'hline_style': 'dash', 'hline_width': 2, # grid 'grid': False } # update plot-specific settings if provided plot_specific_settings.update({k:v for k,v in plot_params.items() if k in plot_specific_settings.keys()}) # keep only additional plot-specific keys # update defaults settings if provided ply_kwargs = default_plotly.copy() ply_kwargs['width'] = plot_params.get('width', 1200) ply_kwargs['height'] = plot_params.get('height', 400) ply_kwargs['title_text'] = plot_params.get('title_text', plot_specific_settings['title']) # override default title ply_kwargs['xaxis_title_text'] = plot_params.get('xaxis_title_text', plot_specific_settings['x_title']) # override default x-axis title ply_kwargs['yaxis_title_text'] = plot_params.get('yaxis_title_text', plot_specific_settings['y_title']) # override default y-axis title ply_kwargs['xaxis_showgrid'] = plot_params.get('xaxis_showgrid', plot_specific_settings['grid']) # override default x-axis grid ply_kwargs['yaxis_showgrid'] = plot_params.get('yaxis_showgrid', plot_specific_settings['grid']) # override default y-axis grid ply_kwargs.update({k:v for k,v in plot_params.items() if k not in plot_specific_settings}) # plot using plotly fig = go.Figure() # Create a list to store visibility settings for each trace all_traces = [] visible_settings = [] dropdown_options = [] # Add traces for each group for groupid, (group, groupdata) in enumerate(grouped): traces_in_group = [] dropdown_options.append({ 'label': plot_specific_settings['subplot_titles'][groupid] if group is not None else "All", 'method': "update", 'args': [{"visible": []}, {"title": ""}] # Will be filled later }) # Add traces for each trial in the group for trial_id, (trial, trialdata) in enumerate(groupdata.groupby(self.trial_identifier, sort=False)): is_outlier = False if show_outliers: is_outlier = df_summary.loc[np.all(df_summary[self.trial_identifier] == trial, axis=1), 'trace_outlier'].values[0] alpha = 1 if is_outlier or not show_outliers else 0.2 showlegend = True if is_outlier else False label = ', '.join([f"{k}: {v}" for k,v in zip(self.trial_identifier, trial)]) if is_outlier else None # downsample trialdata by selecting every 10th sample for faster plotting if self.samp_freq > 100: downsample_mask = np.arange(len(trialdata)) % 10 == 0 downsampled = trialdata[downsample_mask] else: downsampled = trialdata trace = go.Scatter( x=downsampled[time_col], y=downsampled[pupil_col], name=label, mode='lines', line=dict(width=plot_specific_settings['line_width']), line_dash=plot_specific_settings['line_style'], opacity=alpha, showlegend=showlegend, visible=(groupid == 0), # Only first group visible initially hovertemplate="x=%{x:.2f}, y=%{y:.2f}<br>" + "<br>".join([f"{k}: {v}" for k,v in dict(zip(self.trial_identifier, trial)).items()]) + "<extra></extra>" ) fig.add_trace(trace) traces_in_group.append(trace) # Add threshold lines if showing outliers if show_outliers: # Get thresholds from the first trial in the group first_trial = next(iter(groupdata.groupby(self.trial_identifier, sort=False)))[0] upper_threshold = df_summary.loc[np.all(df_summary[self.trial_identifier] == first_trial, axis=1), 'trace_upper'].values[0] lower_threshold = df_summary.loc[np.all(df_summary[self.trial_identifier] == first_trial, axis=1), 'trace_lower'].values[0] # Add upper threshold line trace_upper = go.Scatter( x=[x_min, x_max], # Use overall x range y=[upper_threshold, upper_threshold], mode='lines', line=dict(dash=plot_specific_settings['hline_style'], color=plot_specific_settings['hline_color'], width=plot_specific_settings['hline_width']), name=f'Upper threshold: {upper_threshold:.2f}', showlegend=False, visible=(groupid == 0) ) fig.add_trace(trace_upper) traces_in_group.append(trace_upper) # Add lower threshold line trace_lower = go.Scatter( x=[x_min, x_max], # Use overall x range y=[lower_threshold, lower_threshold], mode='lines', line=dict(dash=plot_specific_settings['hline_style'], color=plot_specific_settings['hline_color'], width=plot_specific_settings['hline_width']), name=f'Lower threshold: {lower_threshold:.2f}', showlegend=False, visible=(groupid == 0) ) fig.add_trace(trace_lower) traces_in_group.append(trace_lower) all_traces.append(traces_in_group) # Create visibility settings for each dropdown option for i in range(len(all_traces)): vis = [] for j, traces in enumerate(all_traces): vis.extend([True if j == i else False] * len(traces)) visible_settings.append(vis) # Update the args for each dropdown option dropdown_options[i]['args'][0]["visible"] = vis dropdown_options[i]['args'][1] = {} # Empty dict to avoid title updates # Update layout to include dropdown menu and set fixed title fig.update_layout( title=plot_specific_settings['title'], # Set fixed title updatemenus=[dict( type="dropdown", direction="down", x=1.0, # Position the dropdown at the right y=1.2, # Position slightly above the plot showactive=True, active=0, # Show first group by default buttons=dropdown_options )] ) # Update layout fig.update_xaxes(**{k[6:]:v for k, v in ply_kwargs.items() if 'xaxis' in k}) fig.update_yaxes(**{k[6:]:v for k, v in ply_kwargs.items() if 'yaxis' in k}) fig.update_layout(**{k:v for k, v in ply_kwargs.items() if 'xaxis' not in k and 'yaxis' not in k}) # Save figure if path is provided if save: if save.endswith('.html'): fig.write_html(save) else: raise ValueError(f"Interactive plots must be saved as html file. Got {save}.") # return figure if requested if return_fig: return fig else: display(fig)
[docs] def plot_evoked(self, data=None, pupil_col=None, condition=None, agg_by=None, error='ci', save=None, plot_params=None, **kwargs): """ Plot evoked pupil response. Creates plot of average pupil response across trials, optionally split by condition and aggregated by specified groups. Parameters ---------- data : str or pandas.DataFrame, optional Data to plot. If string, uses corresponding attribute. pupil_col : str, optional Column name for pupil values. condition : str or list, optional Column(s) to split data by. agg_by : str or list, optional Column(s) to aggregate data by before computing mean trace and confidence bands. For example, to compute subject-level means, use 'subject_id'. error : {'ci', 'sem', 'std', None}, default='ci' Type of error to plot: - 'ci': bootstrap confidence interval - 'sem': standard error of the mean - 'std': standard deviation - None: no error bars save : str, optional Path to save plot. plot_params : dict, default={} Additional plotting parameters. This includes all rcParams accepted by matplotlib, as well as the following: - 'title': title of plot - 'x_title': x-axis label - 'y_title': y-axis label - 'vline_color': color of vertical line - 'vline_linestyle': linestyle of vertical line - 'grid': whether to show grid - 'legend_labels': labels for legend **kwargs Additional arguments passed to confidence interval calculation. Returns ------- arrays_by_condition : dict Dictionary of arrays containing trial data for each condition. (figure, axes) : tuple Plot figure and axes objects. """ plot_params = plot_params or {} # get data if data is None: data = self.data.copy() else: data = getattr(self, data) # get samp_freq samp_freq = self.samp_freq # get column if pupil_col is None: pupil_col = self.all_pupil_cols[-1] # handle condition if condition is not None: if isinstance(condition, str): condition = [condition] # get unique values for each condition condition_values = {cond: data[cond].unique() for cond in condition} # handle agg_by if agg_by is not None: if isinstance(agg_by, str): agg_by = [agg_by] # get minimum length across all trials min_len = data.groupby(self.trial_identifier, sort=False)[pupil_col].count().min() print(f'Data will be padded to minimum length: {min_len} samples') # if no condition, process all data together if condition is None: if agg_by is not None: # First compute mean trace for each aggregation group agg_traces = [] for group, group_data in data.groupby(agg_by, sort=False): # Get all trials for this group and compute mean trials = group_data.groupby(self.trial_identifier, sort=False) group_array = np.empty((trials.ngroups, min_len)) for i, (_, trial_data) in enumerate(trials): vals = np.asarray(trial_data[pupil_col].to_list()) vals = vals[:min_len] group_array[i,:] = vals # Store mean trace for this group agg_traces.append(np.nanmean(group_array, axis=0)) # Convert to array for plotting test_array = np.array(agg_traces) n_groups = len(agg_traces) print(f'Computing average from {n_groups} {agg_by} means') else: # Process all trials without aggregation grouped = data.groupby(self.trial_identifier, sort=False) test_array = np.empty((grouped.ngroups, min_len)) for i, (_, trial_data) in enumerate(grouped): vals = np.asarray(trial_data[pupil_col].to_list()) vals = vals[:min_len] test_array[i,:] = vals print(f'Computing average from {grouped.ngroups} trials') arrays_by_condition = {'all': test_array} else: # Get actual combinations from data arrays_by_condition = {} # Get unique combinations of conditions that exist in the data condition_combinations = data[condition].drop_duplicates() for _, comb in condition_combinations.iterrows(): # create mask for this combination mask = pd.Series(True, index=data.index) for cond in condition: mask &= (data[cond] == comb[cond]) # get data for this combination subset = data[mask] if agg_by is not None: # First compute mean trace for each aggregation group agg_traces = [] for group, group_data in subset.groupby(agg_by, sort=False): # Get all trials for this group and compute mean trials = group_data.groupby(self.trial_identifier, sort=False) group_array = np.empty((trials.ngroups, min_len)) for i, (_, trial_data) in enumerate(trials): vals = np.asarray(trial_data[pupil_col].to_list()) vals = vals[:min_len] group_array[i,:] = vals # Store mean trace for this group agg_traces.append(np.nanmean(group_array, axis=0)) # Convert to array for plotting test_array = np.array(agg_traces) n_groups = len(agg_traces) print(f'Condition {comb.to_dict()}: Computing average from {n_groups} {agg_by} means') else: # Process all trials without aggregation grouped = subset.groupby(self.trial_identifier, sort=False) test_array = np.empty((grouped.ngroups, min_len)) for i, (_, trial_data) in enumerate(grouped): vals = np.asarray(trial_data[pupil_col].to_list()) vals = vals[:min_len] test_array[i,:] = vals print(f'Condition {comb.to_dict()}: Computing average from {grouped.ngroups} trials') # store array with condition name cond_name = '_'.join([f'{v}' for v in comb.values]) arrays_by_condition[cond_name] = test_array # plot settings plot_specific_settings = { 'title': 'Task Evoked Pupillary Response', 'x_title': 'Time (s)', 'y_title': 'Pupil Size Change', 'vline_color': 'red', 'vline_linestyle': '--', 'grid': False, 'legend_labels': list(arrays_by_condition.keys()) } plot_specific_settings.update({k:v for k,v in plot_params.items() if k in plot_specific_settings}) mpl_kwargs = default_mpl.copy() mpl_kwargs.update({k:v for k,v in plot_params.items() if k not in plot_specific_settings}) # create plot with context manager with mpl.rc_context(mpl_kwargs): fig, ax = plt.subplots() for i, (cond_name, test_array) in enumerate(arrays_by_condition.items()): # get time array t = np.arange(test_array.shape[1]) / samp_freq if error == 'ci': try: import mne.stats as ms ci_low, ci_high = ms.bootstrap_confidence_interval(test_array, **kwargs) except ImportError: warnings.warn("mne is not installed. Not computing confidence interval.") ci_low, ci_high = None, None elif error == 'sem': ci_low = test_array.mean(axis=0) - test_array.std(axis=0) / np.sqrt(test_array.shape[0]) ci_high = test_array.mean(axis=0) + test_array.std(axis=0) / np.sqrt(test_array.shape[0]) elif error == 'std': ci_low = test_array.mean(axis=0) - test_array.std(axis=0) ci_high = test_array.mean(axis=0) + test_array.std(axis=0) else: ci_low, ci_high = None, None ax.plot(t, test_array.mean(axis=0), label=plot_specific_settings['legend_labels'][i]) if error and ci_low is not None and ci_high is not None: ax.fill_between(t, ci_low, ci_high, alpha=0.2) if len(arrays_by_condition) > 1: ax.legend() ax.set_xlabel(plot_specific_settings['x_title']) ax.set_ylabel(plot_specific_settings['y_title']) ax.set_title(plot_specific_settings['title']) ax.grid(plot_specific_settings['grid']) if save: plt.savefig(save, bbox_inches='tight', dpi=mpl_kwargs['figure.dpi']) return arrays_by_condition, (fig, ax)
[docs] def save(self, path): """ Save PupilProcessor object to file using dill serialization. This method saves the entire PupilProcessor object, including all data and processing history, to a file for later use. Parameters ---------- path : str Path where the object should be saved. Should include the file extension (e.g., '.pkl'). Raises ------ FileExistsError If a file already exists at the specified path. """ # check if file exists if os.path.exists(path): raise FileExistsError(f"File {path} already exists.") # save data with open(path, 'wb') as f: dill.dump(self, f)
[docs] @staticmethod def load(path): """ Load PupilProcessor object from file using dill deserialization. This method loads a previously saved PupilProcessor object, restoring all data and processing history. Parameters ---------- path : str Path to the file containing the saved PupilProcessor object. Returns ------- PupilProcessor The loaded PupilProcessor object. Notes ----- - The loaded object will be an exact copy of the saved object - All data, parameters, and processing history are preserved - Make sure the file was created using the save() method """ # load data with open(path, 'rb') as f: return dill.load(f)
[docs] def copy(self): """ Create a deep copy of the PupilProcessor object. This method creates an independent copy of the PupilProcessor object, including all data and processing history. Modifications to the copy will not affect the original object. Returns ------- PupilProcessor A deep copy of the current object. Notes ----- - Creates a completely independent copy using copy.deepcopy - All data, parameters, and processing history are copied - Useful for creating alternative processing pipelines """ import copy # deepcopy return copy.deepcopy(self)
[docs] @staticmethod def combine(processors): """ Combine multiple PupilProcessor instances into a single instance. This method allows combining data from multiple processors that have gone through identical preprocessing pipelines. This is useful for: 1. Processing large datasets in chunks to manage memory 2. Adding new data to an existing processed dataset 3. Processing data from multiple participants separately Parameters ---------- processors : list of PupilProcessor List of PupilProcessor instances to combine. All processors must have identical preprocessing settings. Returns ------- PupilProcessor A new PupilProcessor instance containing combined data. Notes ----- - All processors must have identical: - Initialization parameters (pupil_col, time_col, etc.) - Data structure (column names and order) - Preprocessing steps and parameters - Outlier detection settings (if used) - Data and summary statistics are concatenated Raises ------ ValueError If processors have different preprocessing settings If processors have incompatible data structures If no processors are provided """ if not processors: raise ValueError("No processors provided") if len(processors) == 1: return processors[0].copy() # Use the first processor as reference ref = processors[0] # Check compatibility of all processors for i, proc in enumerate(processors[1:], 1): # Check initialization parameters and data structure init_attrs = [ 'pupil_col', # Original pupil column 'time_col', # Time column 'x_col', # X position column 'y_col', # Y position column 'samp_freq', # Sampling frequency 'trial_identifier', # Trial identifier columns 'recording_unit', # Recording unit 'artificial_d', # Artificial pupil diameter 'artificial_size' # Artificial pupil size ] # Check if all initialization attributes match for attr in init_attrs: if not hasattr(proc, attr) or getattr(proc, attr) != getattr(ref, attr): raise ValueError(f"Processor {i} has different {attr} than the reference processor") # Check if data columns match (both names and order) if not list(proc.data.columns) == list(ref.data.columns): raise ValueError(f"Processor {i} has different columns or column order than the reference processor") # Check preprocessing steps by comparing pupil column names and steps if not proc.all_pupil_cols == ref.all_pupil_cols: raise ValueError(f"Processor {i} has different pupil columns than the reference processor") if not proc.all_steps == ref.all_steps: raise ValueError(f"Processor {i} has different preprocessing steps than the reference processor.\nReference steps: {ref.all_steps}\nProcessor {i} steps: {proc.all_steps}") # Check if preprocessing parameters match for each step # Get all unique parameter keys all_keys = set(ref.params.keys()) | set(proc.params.keys()) diff_params = {} missing_params = set(ref.params.keys()) - set(proc.params.keys()) extra_params = set(proc.params.keys()) - set(ref.params.keys()) # Compare parameters that exist in both for k in ref.params.keys() & proc.params.keys(): ref_val = ref.params[k] proc_val = proc.params[k] # Compare dictionaries within parameters if isinstance(ref_val, dict) and isinstance(proc_val, dict): ref_dict = {key: str(val) if hasattr(val, 'shape') else val for key, val in ref_val.items() if key != 'self'} proc_dict = {key: str(val) if hasattr(val, 'shape') else val for key, val in proc_val.items() if key != 'self'} if ref_dict != proc_dict: diff_params[k] = (ref_dict, proc_dict) else: # For non-dictionary values, convert to string if they're array-like ref_str = str(ref_val) if hasattr(ref_val, 'shape') else ref_val proc_str = str(proc_val) if hasattr(proc_val, 'shape') else proc_val if ref_str != proc_str: diff_params[k] = (ref_str, proc_str) if diff_params or missing_params or extra_params: error_msg = f"Processor {i} has different preprocessing parameters than the reference processor.\n" if diff_params: error_msg += "Different parameters:\n" for param, (ref_val, proc_val) in diff_params.items(): error_msg += f" {param}: reference={ref_val}, processor{i}={proc_val}\n" if missing_params: error_msg += f"Missing parameters: {missing_params}\n" if extra_params: error_msg += f"Extra parameters: {extra_params}\n" raise ValueError(error_msg) # Check if outlier detection settings match outlier_attrs = [ 'baseline_outlier_by', # Grouping for baseline outliers 'trace_outlier_by', # Grouping for trace outliers 'baseline_query', # Baseline selection query 'baseline_range' # Baseline time range ] for attr in outlier_attrs: if hasattr(ref, attr): # Only check if reference has this attribute if not hasattr(proc, attr) or getattr(proc, attr) != getattr(ref, attr): raise ValueError(f"Processor {i} has different {attr} than the reference processor") # Check summary data structure if it exists if proc.summary_data is not None and ref.summary_data is not None: if not list(proc.summary_data.columns) == list(ref.summary_data.columns): raise ValueError(f"Processor {i} summary data has different columns or column order than the reference processor") # Create new processor with combined data combined = ref.copy() # Combine data from all processors data_frames = [p.data for p in processors] combined.data = pd.concat(data_frames, axis=0, ignore_index=True) # Update trials attribute to reflect combined data combined.trials = combined.data[combined.trial_identifier].drop_duplicates().reset_index(drop=True) # Combine summary data if it exists in all processors if all(p.summary_data is not None for p in processors): summary_frames = [p.summary_data for p in processors] combined.summary_data = pd.concat(summary_frames, axis=0, ignore_index=True) return combined
[docs] def compute_speed(x, y): """ Compute the speed of change between two arrays. This function calculates the rate of change (speed) between corresponding points in two arrays. The speed is computed as the absolute maximum of the forward and backward differences at each point, normalized by the time difference. Parameters ---------- x : array-like First array of values, typically pupil measurements. Must be numeric and same length as y. y : array-like Second array of values, typically time points. Must be numeric and same length as x. Returns ------- numpy.ndarray Array of speed values with same length as input arrays. Contains NaN values at endpoints and where division by zero or invalid values occur. Notes ----- - Uses np.diff() to compute differences between consecutive points - Takes absolute maximum of forward/backward differences at each point - Suppresses RuntimeWarnings for NaN/inf values - Sets NaN/inf values to NaN in output """ x = np.asarray(x, dtype=np.float64) y = np.asarray(y, dtype=np.float64) diff = np.diff(x) / np.diff(y) speed_diff = np.abs(np.column_stack((np.insert(diff, 0, np.nan), np.append(diff, np.nan)))) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) speed_diff = np.nanmax(speed_diff, axis=1) speed_diff[np.isnan(speed_diff) | np.isinf(speed_diff)] = np.nan return speed_diff
[docs] def convert_pupil(pupil_size, artificial_d, artificial_size, recording_unit='diameter'): """ Convert pupil measurements between different recording units. This function converts pupil measurements from raw units (arbitrary units from the eye tracker) to millimeters using calibration values from an artificial pupil. It handles both diameter and area measurements. Parameters ---------- pupil_size : float or array-like Pupil size in recording units (diameter or area). Can be a single value or an array of measurements. artificial_d : float Diameter of artificial pupil used for calibration (in mm). This is the known physical size of the calibration pupil. artificial_size : float Size of artificial pupil in recording units (diameter or area). This is the size measured by the eye tracker for the calibration pupil. recording_unit : {'diameter', 'area'}, default='diameter' Unit of the recorded measurements: - 'diameter': Linear scaling is applied - 'area': Square root is taken before scaling Returns ------- numpy.ndarray Converted pupil measurements in millimeters. Will have same shape as input pupil_size. Notes ----- - The unit of artificial_size must match the recording_unit - The unit of artificial_d is always in millimeters - For diameter recordings: output = artificial_d * pupil_size / artificial_size - For area recordings: output = artificial_d * sqrt(pupil_size / artificial_size) - Useful for standardizing pupil measurements across different setups Raises ------ ValueError If recording_unit is not 'diameter' or 'area' """ if recording_unit == 'diameter': return artificial_d * pupil_size / artificial_size elif recording_unit == 'area': return artificial_d * np.sqrt(pupil_size) / np.sqrt(artificial_size) else: raise ValueError(f"Invalid recording unit: {recording_unit}")
[docs] def prf(t, t_max=500, n=10.1): """ PRF function according to Hoeks and Levelt (1993) Parameters ---------- t : array-like Time points in milliseconds. t_max : float, optional Location of the peak (default is 500 ms). n : float, optional Scale parameter (default is 10.1). Returns ------- numpy.ndarray Normalized PRF values at each time point. """ h = (t**n)*np.exp((-n*t)/t_max) # normalize h = h / np.max(h) return h
def _generate_pupil_data(n_participants=6, n_trials=20, stim_duration_ms=2000, baseline_duration_ms=500, sampling_rate=1000, design_type='within-subject', condition_names=['A','B'], condition_effect=0.5, seed=1): """Generate fake pupillometry data for experimental designs. This function generates fake pupil size data that mimics typical task-evoked pupillary responses. It supports both between-subject and within-subject designs with two conditions. The pupil response is generated by convolving an impulse at stimulus onset with a pupil response function (PRF). Parameters ---------- n_participants : int, default=6 Number of participants to simulate. For between-subject designs, this should be even to ensure balanced groups. n_trials : int, default=20 Number of trials per participant. For within-subject designs, this will be adjusted to the nearest even number to ensure balanced conditions. stim_duration_ms : int, default=2000 Duration of the stimulus period in milliseconds. baseline_duration_ms : int, default=500 Duration of the pre-stimulus baseline period in milliseconds. sampling_rate : int, default=1000 Sampling rate in Hz. Determines the temporal resolution of the data. design_type : {'between-subject', 'within-subject'}, default='within-subject' Type of experimental design: - 'between-subject': Each participant is assigned to one condition - 'within-subject': Each participant completes trials in both conditions condition_names : list of str, default=['A', 'B'] Names of the two experimental conditions. First name is control condition, second name is experimental condition. condition_effect : float, default=0.5 Size of the experimental effect for the second condition relative to the first. For example, 0.5 means condition B has 50% larger responses than condition A. seed : int, optional, default=1 Random seed for reproducibility. Set to None for random behavior. Returns ------- pandas.DataFrame A DataFrame containing the simulated pupil data with columns: - participant: Participant identifier (e.g., 'P1', 'P2', ...) - condition: Experimental condition - trial: Trial number - event: Trial phase ('fixation' or 'stimulus') - trialtime: Time points in milliseconds - pp: Pupil size values - x: Horizontal gaze position - y: Vertical gaze position Notes ----- These features are hard-coded: - Individual differences in baseline pupil size - Trial-to-trial variability in response amplitude - Random blinks (30% probability per trial) - Measurement noise - Gaze position drift - 50% probability of no response on each trial The pupil response is generated using the following steps: - Create baseline period with participant-specific mean - Generate stimulus response by convolving an impulse with PRF - Add various sources of noise and artifacts - Combine baseline and stimulus periods Examples -------- >>> # Generate data for a within-subject design >>> data_within = generate_pupil_data( ... n_participants=4, ... n_trials=10, ... design_type='within-subject', ... condition_names=['low_load', 'high_load'], ... condition_effect=0.3 ... ) >>> # Generate data for a between-subject design >>> data_between = generate_pupil_data( ... n_participants=6, ... n_trials=10, ... design_type='between-subject', ... condition_names=['control', 'treatment'], ... condition_effect=0.5 ... ) """ # Set random seed for reproducibility if seed is not None: np.random.seed(seed) if design_type not in ['between-subject', 'within-subject']: raise ValueError("design_type must be either 'between-subject' or 'within-subject'") # For within-subject design, ensure n_trials is even if design_type == 'within-subject' and n_trials % 2 != 0: n_trials += 1 print(f"Adjusted n_trials to {n_trials} for balanced design") # Assign participants to conditions for between-subject design if design_type == 'between-subject': participant_conditions = {} for p in range(1, n_participants + 1): # Ensure balanced assignment to conditions condition = condition_names[0] if p <= n_participants // 2 else condition_names[1] participant_conditions[f'P{p}'] = condition all_data = [] # Pre-compute PRF kernel for convolution kernel_time = np.arange(stim_duration_ms) t_max = 1200 prf_kernel = prf(kernel_time, t_max=t_max) # peak at 1500ms for p in range(1, n_participants + 1): participant_id = f'P{p}' participant_baseline = np.random.normal(3.0, 0.05) # Create balanced sequence of conditions for within-subject design if design_type == 'within-subject': conditions = [condition_names[0]] * (n_trials // 2) + [condition_names[1]] * (n_trials // 2) np.random.shuffle(conditions) for t in range(1, n_trials + 1): # Determine condition for this trial if design_type == 'between-subject': condition = participant_conditions[participant_id] else: # within-subject condition = conditions[t-1] # Calculate number of samples n_samples_stim = int(stim_duration_ms * (sampling_rate/1000)) n_samples_baseline = int(baseline_duration_ms * (sampling_rate/1000)) # Generate baseline period trial_baseline = participant_baseline + np.random.normal(0, 0.05) baseline_data = trial_baseline * np.ones(n_samples_baseline) baseline_data += np.random.normal(0, 0.05, n_samples_baseline) # Add noise to baseline # Generate stimulus response # Create impulse at stimulus onset (t=0 in stimulus period) impulse = np.zeros(n_samples_stim) impulse[0] = 1.0 # Add condition effect and trial-to-trial variability amplitude = (condition_effect if condition == condition_names[1] else 1.0) amplitude *= np.random.normal(1, 0.5) # Add trial-to-trial variability # sometimes there is no response if np.random.random() < 0.5: amplitude = 0 impulse[0] *= amplitude # Convolve with PRF to get stimulus response response = np.convolve(impulse, prf_kernel, mode='full')[:n_samples_stim] stim_data = trial_baseline + response stim_data += np.random.normal(0, 0.05, n_samples_stim) # Add noise to stimulus period # Combine baseline and stimulus data pupil = np.concatenate([baseline_data, stim_data]) time = np.arange(len(pupil)) event = ['fixation']*n_samples_baseline + ['stimulus']*n_samples_stim # Add random blinks if np.random.random() < 0.3: blink_start = np.random.randint(0, len(pupil) - 200) blink_duration = np.random.randint(100, 200) blink_idx = np.arange(blink_start, blink_start + blink_duration) blink_idx = blink_idx[blink_idx < len(pupil)] pupil[blink_idx] = 0 # Add gaze position with drift drift_x = np.cumsum(np.random.normal(0, 0.01, len(pupil))) drift_y = np.cumsum(np.random.normal(0, 0.01, len(pupil))) x = np.random.normal(1920/2, 20, len(pupil)) + drift_x y = np.random.normal(1080/2, 20, len(pupil)) + drift_y trial_data = pd.DataFrame({ 'participant': participant_id, 'condition': condition, 'trial': t, 'event': event, 'trialtime': time, 'pp': pupil, 'x': x, 'y': y }) all_data.append(trial_data) samples = pd.concat(all_data, ignore_index=True).convert_dtypes() return samples