Source code for pupeyes.apps.aoi_drawer

"""
Interactive AOI Drawing Tool using Dash

This module provides an interactive web-based tool for drawing Areas of Interest (AOIs)
that can be used with the EyeMovementVisualizer.
"""

import json
import numpy as np
import plotly.graph_objects as go
from dash import Dash, html, dcc, callback_context
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from PIL import Image
import dash

[docs] class AOIDrawer: """ An interactive web-based tool for drawing Areas of Interest (AOIs). This class provides a Dash-based web interface for drawing and managing Areas of Interest (AOIs) on stimulus images. It supports multiple drawing tools (freeform, rectangle, circle), editing capabilities, and export functionality. Parameters ---------- screen_dims : tuple, default=(1920, 1080) Screen dimensions in pixels (width, height). Used to set the drawing canvas size and scale background images. stimuli : str or numpy.ndarray, optional Path to the stimulus image or a numpy array containing the image. Supports various image formats and both RGB and grayscale images. stimuli_name : str, optional Name of the stimulus image, used for display and as default save filename. If not provided, defaults to "AOIs". Attributes ---------- aois : dict Dictionary storing AOI data, where keys are AOI names and values are lists of (x, y) coordinate tuples defining the AOI vertices. app : dash.Dash The Dash application instance. screen_dims : tuple The dimensions of the drawing canvas. """ def __init__(self, screen_dims=(1920, 1080), stimuli=None, stimuli_name=None): """ Initialize the AOI drawer. Parameters ---------- screen_dims : tuple, default=(1920, 1080) Screen dimensions in pixels (width, height) stimuli : str or numpy.ndarray, optional Path to the stimulus image or a numpy array containing the image stimuli_name : str, optional Name of the stimulus image, used for display and as default save filename """ self.screen_dims = screen_dims self.stimuli = stimuli self.stimuli_name = stimuli_name or "AOIs" self._stimuli_cache = None self._temp_shape = None # Store temporary shape while waiting for name # Initialize Dash app self.app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) # Initialize AOI storage self.aois = {} # Create app layout self.app.layout = self._create_layout() # Setup callbacks self._setup_callbacks() def _get_stimuli_image(self): """Get the stimuli image.""" if self._stimuli_cache is None and self.stimuli is not None: try: if isinstance(self.stimuli, str): # Handle image file path with Image.open(self.stimuli) as img: try: img.verify() except Exception as e: print(f"Invalid image file: {str(e)}") return None img = Image.open(self.stimuli) if img.size != self.screen_dims: print('Original size:', img.size, 'Resized size:', self.screen_dims) img = img.resize(self.screen_dims) self._stimuli_cache = img elif isinstance(self.stimuli, np.ndarray): # Handle numpy array array_shape = self.stimuli.shape[:2] # Get height, width # Ensure array is uint8 for proper image conversion img_array = self.stimuli if img_array.dtype != np.uint8: img_array = (img_array * 255).astype(np.uint8) if img_array.max() <= 1 else img_array.astype(np.uint8) # Convert to PIL Image based on array shape if len(img_array.shape) == 2: # Grayscale img = Image.fromarray(img_array, mode='L') elif len(img_array.shape) == 3 and img_array.shape[2] == 3: # RGB img = Image.fromarray(img_array, mode='RGB') elif len(img_array.shape) == 3 and img_array.shape[2] == 4: # RGBA img = Image.fromarray(img_array, mode='RGBA') else: print(f"Unsupported array shape: {img_array.shape}") return None # Resize if necessary if array_shape != (self.screen_dims[1], self.screen_dims[0]): print('Original size:', array_shape[::-1], 'Resized size:', self.screen_dims) img = img.resize(self.screen_dims) self._stimuli_cache = img else: print("Invalid background image type. Must be a file path or numpy array.") return None except Exception as e: print(f"Failed to load stimuli: {str(e)}") return None return self._stimuli_cache def _create_layout(self): """Create the Dash app layout.""" return dbc.Container([ dbc.Row([ dbc.Col([ html.H1("AOI Drawer", className="text-center mb-2"), html.H6(f"Stimulus: {self.stimuli_name}", className="text-center text-muted mb-2") ]) ]), # Instructions Card dbc.Row([ dbc.Col([ dbc.Card([ dbc.CardHeader("Instructions", className="py-1"), dbc.CardBody([ html.Ul([ html.Li("Draw: Use the modebar to draw shapes (freeform, rectangle, or circle)"), html.Li("Edit: Click on a shape's border to move it or adjust its vertices"), html.Li("Erase: Click on a shape's border and select 'Erase Active Shape'"), ], className="mb-0 small"), ], className="p-2") ], className="mb-2") ], width={"size": 8, "offset": 2}) ]), # Control Panel dbc.Row([ dbc.Col([ dbc.Card([ dbc.CardHeader("Controls", className="py-1"), dbc.CardBody([ dbc.Row([ # Save and Load Buttons dbc.Col([ html.Div([ dbc.Button( "Save AOIs", id="save-button", color="primary", size="sm", className="me-2" ), dcc.Download(id="download-aois"), ], className="d-flex justify-content-start") ], width=6), # Background Opacity Control dbc.Col([ html.Div([ html.Label("Opacity:", className="me-2 small", style={'whiteSpace': 'nowrap'}), html.Div([ dcc.Slider( id='bg-opacity-slider', min=0, max=1, step=0.1, value=0.5, marks=None, tooltip={"placement": "bottom", "always_visible": True}, className="mt-1" ) ], style={'width': '150px'}) ], className="d-flex align-items-center") ], width=6) ], className="g-1"), # Line Color Control dbc.Row([ dbc.Col([ html.Div([ html.Label("Line Color:", className="me-2 small"), dcc.Dropdown( id='line-color-picker', options=[ {'label': 'Black', 'value': 'black'}, {'label': 'Red', 'value': 'red'}, {'label': 'Blue', 'value': 'blue'}, {'label': 'Green', 'value': 'green'}, {'label': 'Yellow', 'value': 'yellow'}, {'label': 'White', 'value': 'white'} ], value='black', clearable=False, style={'width': '150px'} ) ], className="d-flex align-items-center mb-2") ]) ]), # Current AOIs Table html.Div([ html.H6("Current AOIs:", className="mb-1 mt-1 small"), html.Div(id='aoi-list', className="small") ]) ], className="p-2") ], className="mb-2") ], width={"size": 8, "offset": 2}) ]), # Drawing Area dbc.Row([ dbc.Col([ dbc.Card([ dbc.CardBody([ dcc.Graph( id='drawing-area', config={ 'modeBarButtonsToAdd': [ 'drawclosedpath', 'drawrect', 'drawcircle', 'eraseshape' ], 'modeBarButtonsToRemove': [ 'autoScale2d', 'pan2d', 'zoom2d', 'zoomIn2d', 'zoomOut2d', 'resetScale2d' ], 'displaylogo': False }, style={'height': '600px'} ) ]) ]) ]) ]), # Modal for AOI naming dbc.Modal([ dbc.ModalHeader("Name your AOI"), dbc.ModalBody([ dbc.Input( id="aoi-name-input", type="text", placeholder="Enter AOI name" ), html.Div( id="name-warning", className="text-danger mt-2" ) ]), dbc.ModalFooter([ dbc.Button( "Cancel", id="modal-cancel", className="me-2", color="secondary" ), dbc.Button( "Save", id="modal-save", color="primary" ) ]) ], id="naming-modal", is_open=False) ], fluid=True) def _setup_callbacks(self): """Set up all callbacks.""" @self.app.callback( [Output('drawing-area', 'figure'), Output('aoi-list', 'children'), Output('naming-modal', 'is_open'), Output('aoi-name-input', 'value'), Output('name-warning', 'children')], [Input('drawing-area', 'relayoutData'), Input('modal-save', 'n_clicks'), Input('modal-cancel', 'n_clicks'), Input('bg-opacity-slider', 'value'), Input('line-color-picker', 'value')], [State('drawing-area', 'figure'), State('aoi-name-input', 'value')] ) def update_drawing_area(relayout_data, save_clicks, cancel_clicks, opacity, line_color, figure, aoi_name): ctx = callback_context if not ctx.triggered: return self._create_base_figure(), self._create_aoi_list(), False, '', '' trigger_id = ctx.triggered[0]['prop_id'].split('.')[0] # Initialize figure if None if figure is None: figure = self._create_base_figure() # Handle line color change if trigger_id == 'line-color-picker': if 'layout' in figure: # Update newshape defaults if 'newshape' in figure['layout']: figure['layout']['newshape']['line']['color'] = line_color # Update existing shapes if 'shapes' in figure['layout'] and figure['layout']['shapes']: for shape in figure['layout']['shapes']: if 'line' not in shape: shape['line'] = {} shape['line']['color'] = line_color return figure, self._create_aoi_list(), False, '', '' # Handle opacity change if trigger_id == 'bg-opacity-slider': if 'layout' in figure and 'images' in figure['layout']: figure['layout']['images'][0]['opacity'] = opacity return figure, self._create_aoi_list(), False, '', '' # Handle shape changes (new shapes, deletions, or edits) if trigger_id == 'drawing-area' and relayout_data is not None: # Handle complete shape updates if 'shapes' in relayout_data: current_shapes = relayout_data['shapes'] if relayout_data['shapes'] is not None else [] # If shapes were deleted if len(current_shapes) < len(self.aois): # Create a mapping of shape coordinates to AOI names shape_to_aoi = {} for name, vertices in self.aois.items(): shape_hash = tuple(sorted((x, y) for x, y in vertices)) shape_to_aoi[shape_hash] = name # Check which shapes still exist remaining_aois = {} for shape in current_shapes: vertices = self._shape_to_vertices(shape) if vertices: shape_hash = tuple(sorted((x, y) for x, y in vertices)) if shape_hash in shape_to_aoi: name = shape_to_aoi[shape_hash] remaining_aois[name] = vertices self.aois = remaining_aois # Update annotations after deletion figure = self._update_figure_annotations(figure) return figure, self._create_aoi_list(), False, '', '' # Handle new shape elif len(current_shapes) > len(self.aois): self._temp_shape = current_shapes[-1] return figure, self._create_aoi_list(), True, '', '' # Handle individual coordinate updates shape_updates = {} for key in relayout_data: if key.startswith('shapes['): # Extract shape index and property parts = key.split('.') idx = int(parts[0].split('[')[1].split(']')[0]) # Get existing shape data if idx not in shape_updates: if 'layout' in figure and 'shapes' in figure['layout'] and idx < len(figure['layout']['shapes']): shape_updates[idx] = figure['layout']['shapes'][idx].copy() else: shape_updates[idx] = {} # Update the specific coordinate if len(parts) > 1: prop = parts[1] shape_updates[idx][prop] = relayout_data[key] # If this is a path edit, we need to handle it differently if prop == 'path': shape_updates[idx]['type'] = 'path' elif any(prop.startswith(x) for x in ['x0', 'x1', 'y0', 'y1']): # For rectangles and circles, make sure we preserve the type if 'type' not in shape_updates[idx]: if 'layout' in figure and 'shapes' in figure['layout'] and idx < len(figure['layout']['shapes']): shape_updates[idx]['type'] = figure['layout']['shapes'][idx].get('type', 'rect') else: # Default to rect if we can't determine the type shape_updates[idx]['type'] = 'rect' # Apply updates to AOIs if shape_updates: aoi_names = list(self.aois.keys()) for idx, shape_data in shape_updates.items(): if idx < len(aoi_names): # For path shapes if shape_data.get('type') == 'path' and 'path' in shape_data: vertices = self._shape_to_vertices(shape_data) if vertices: self.aois[aoi_names[idx]] = vertices # For rect and circle shapes elif all(k in shape_data for k in ['x0', 'x1', 'y0', 'y1']): vertices = self._shape_to_vertices(shape_data) if vertices: self.aois[aoi_names[idx]] = vertices # Update annotations after shape updates figure = self._update_figure_annotations(figure) return figure, self._create_aoi_list(), False, '', '' # Handle modal save elif trigger_id == 'modal-save' and self._temp_shape is not None and aoi_name: # Check for duplicate name if aoi_name in self.aois: return figure, self._create_aoi_list(), True, aoi_name, f"An AOI named '{aoi_name}' already exists. Please choose a different name." vertices = self._shape_to_vertices(self._temp_shape) if vertices: self.aois[aoi_name] = vertices # Add annotation for the new AOI if 'layout' in figure and 'shapes' in figure['layout']: annotation = self._shape_to_annotation(figure['layout']['shapes'][-1], aoi_name) if 'annotations' not in figure['layout']: figure['layout']['annotations'] = [] figure['layout']['annotations'].append(annotation) self._temp_shape = None return figure, self._create_aoi_list(), False, '', '' # Handle modal cancel elif trigger_id == 'modal-cancel': # Remove the last shape when canceling if 'layout' in figure and 'shapes' in figure['layout']: figure['layout']['shapes'] = figure['layout']['shapes'][:-1] self._temp_shape = None return figure, self._create_aoi_list(), False, '', '' return figure, self._create_aoi_list(), False, '', '' @self.app.callback( Output('download-aois', 'data'), Input('save-button', 'n_clicks'), prevent_initial_call=True ) def save_aois(n_clicks): if n_clicks is None: return None # Convert AOIs to a JSON-serializable format aois_json = { name: [list(vertex) for vertex in vertices] for name, vertices in self.aois.items() } return dict( content=json.dumps(aois_json, indent=2), filename=f'{self.stimuli_name}_aois.json' ) def _create_base_figure(self): """Create the base figure for drawing.""" fig = go.Figure() # Add background image if available if self.stimuli is not None: background_img = self._get_stimuli_image() if background_img is not None: # check if greyscale if background_img.mode != 'L': fig.add_layout_image( dict( source=background_img, xref="x", yref="y", x=0, y=0, sizex=self.screen_dims[0], sizey=self.screen_dims[1], sizing="stretch", opacity=0.5, # Default opacity layer="below" ) ) else: # For grayscale images, use heatmap to display intensity values img_array = np.array(background_img) fig.add_trace( go.Heatmap( z=img_array, x=np.linspace(0, self.screen_dims[0], img_array.shape[1]), y=np.linspace(0, self.screen_dims[1], img_array.shape[0]), colorscale='gray', showscale=False, hoverongaps=False, xaxis="x", yaxis="y" ) ) # Update layout fig.update_layout( autosize=False, margin=dict(l=20, r=20, t=40, b=20), xaxis=dict( range=[0, self.screen_dims[0]], showgrid=False, zeroline=False, constrain="domain" ), yaxis=dict( range=[self.screen_dims[1], 0], # Invert y-axis showgrid=False, zeroline=False, scaleanchor="x", scaleratio=1, constrain="domain" ), dragmode='drawclosedpath', # Default to closed path drawing # Configure default shape properties newshape=dict( line=dict( width=1, # Thinner line width color='black' # Default color ), fillcolor='rgba(0,0,0,0)', # Transparent fill opacity=1 ), # Apply same style to existing shapes shapedefaults=dict( line=dict( width=1, # Thinner line width color='black' # Default color ), fillcolor='rgba(0,0,0,0)', # Transparent fill opacity=1 ) ) return fig def _create_aoi_list(self): """Create the list of current AOIs.""" if not self.aois: return html.P("No AOIs defined yet.", className="text-muted") return html.Ul([ html.Li( f"{name} ({self._get_shape_type(vertices)} - {len(vertices)} vertices)" ) for name, vertices in self.aois.items() ], className="list-unstyled") def _get_shape_type(self, vertices): """Determine the shape type based on number of vertices.""" num_vertices = len(vertices) if num_vertices == 4: return "Rectangle" elif num_vertices == 32: # Check if it's a circle or oval x_coords = [x for x, _ in vertices] y_coords = [y for _, y in vertices] width = max(x_coords) - min(x_coords) height = max(y_coords) - min(y_coords) # If width and height are within 1% of each other, it's a circle if abs(width - height) / max(width, height) < 0.01: return "Circle" return "Oval" else: return "Free Form" def _shape_to_vertices(self, shape): """ Convert a Plotly shape object to a list of vertices. This method extracts vertex coordinates from different types of Plotly shapes (path, rectangle, circle) and converts them to a consistent format. Parameters ---------- shape : dict Plotly shape object containing shape type and coordinate information. Returns ------- list of tuple or None List of (x, y) coordinate tuples defining the shape vertices. Returns None if shape type is not recognized or conversion fails. Notes ----- - Handles three shape types: - path: Extracts vertices from SVG path string - rect: Converts rectangle coordinates to 4 vertices - circle: Approximates circle/oval with 32 vertices - For circles/ovals, uses evenly spaced points around the perimeter - All shapes are closed by adding the first vertex at the end """ shape_type = shape.get('type', '') if shape_type == 'path': # Extract vertices from SVG path path = shape['path'].split('M')[1].split('Z')[0] vertices = [ tuple(map(float, point.strip().split(','))) for point in path.split('L') ] # Close the path by adding the first vertex at the end if vertices: vertices.append(vertices[0]) return vertices elif shape_type == 'rect': # Convert rectangle to vertices x0, y0 = shape['x0'], shape['y0'] x1, y1 = shape['x1'], shape['y1'] vertices = [(x0, y0), (x1, y0), (x1, y1), (x0, y1)] # Close the rectangle by adding the first vertex at the end vertices.append(vertices[0]) return vertices elif shape_type == 'circle': # Convert circle/oval to polygon approximation x0, y0 = shape['x0'], shape['y0'] x1, y1 = shape['x1'], shape['y1'] center_x = (x0 + x1) / 2 center_y = (y0 + y1) / 2 radius_x = abs(x1 - x0) / 2 radius_y = abs(y1 - y0) / 2 # Create polygon approximation of oval num_points = 32 # Reduced from 99 to match the shape type check angles = np.linspace(0, 2*np.pi, num_points) vertices = [ (center_x + radius_x * np.cos(angle), center_y + radius_y * np.sin(angle)) for angle in angles ] # Close the circle by adding the first vertex at the end if vertices: vertices.append(vertices[0]) return vertices return None def _shape_to_annotation(self, shape, name): """Convert a shape to a Plotly annotation for the AOI name.""" if shape.get('type') == 'path' and 'path' in shape: # For path shapes, use the first point as annotation position path = shape['path'].split('M')[1].split('Z')[0] first_point = path.split('L')[0] x, y = map(float, first_point.strip().split(',')) else: # For rect and circle, use the top-left corner x = shape['x0'] y = shape['y0'] return dict( x=x, y=y, xref="x", yref="y", text=name, showarrow=False, font=dict( size=12, color="white" ), bgcolor="rgba(0,0,0,0.7)", bordercolor="rgba(0,0,0,0)", borderwidth=1, borderpad=4, opacity=0.8 ) def _update_figure_annotations(self, figure): """Update figure annotations to match current AOIs.""" if 'layout' not in figure: return figure # Clear existing annotations figure['layout']['annotations'] = [] # Add annotation for each AOI if 'shapes' in figure['layout']: for idx, (name, _) in enumerate(self.aois.items()): if idx < len(figure['layout']['shapes']): shape = figure['layout']['shapes'][idx] annotation = self._shape_to_annotation(shape, name) if 'annotations' not in figure['layout']: figure['layout']['annotations'] = [] figure['layout']['annotations'].append(annotation) return figure
[docs] def run(self, debug=False, port=8051, **kwargs): """ Start the Dash server and run the AOI drawing application. This method initializes and starts the web server for the AOI drawing interface. The application will be accessible through a web browser at the specified port. Parameters ---------- debug : bool, default=False Whether to run the server in debug mode port : int, default=8051 Port number to run the server on. Make sure the port is available and not blocked by firewall. **kwargs : dict Additional keyword arguments passed to dash.run_server(). See Dash documentation for available options. Notes ----- - The application will run until interrupted (Ctrl+C) - Access the interface at http://localhost:<port> - Debug mode provides additional error information - Default port (8051) can be changed if already in use """ self.app.run(debug=debug, port=port, **kwargs)