Source code for lsurf.visualization.detector_plots

# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
#      * Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#
#      * Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#      * Neither the name of the copyright holder nor the names of its
#      contributors may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""
Detector Visualization - Individual Axis Functions

Functions for plotting detector-related data: beam profiles, wavelength distributions,
detection counts, arrival times, and scan results.
Each function draws on a single axis, enabling flexible composition.
"""

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from ..utilities.ray_data import RayBatch, RayStatistics
    from ..surfaces import Surface

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from .common import (
    add_colorbar,
    save_figure,
    setup_axis_grid,
)

# =============================================================================
# Beam Profile Functions
# =============================================================================


[docs] def plot_beam_slice( ax: Axes, rays: "RayBatch", axis: str = "z", slice_value: float = 0.0, slice_width: float = 0.1, color_by: str = "intensity", point_size: float = 20, alpha: float = 0.6, show_colorbar: bool = True, ) -> Any | None: """ Plot beam profile at a specific slice along propagation axis. Parameters ---------- ax : Axes Matplotlib axes. rays : RayBatch Ray batch. axis : str Propagation axis: 'x', 'y', or 'z'. slice_value : float Position along axis to slice. slice_width : float Width of slice. color_by : str Color by: 'intensity', 'wavelength'. point_size : float Scatter point size. alpha : float Transparency. show_colorbar : bool Whether to add colorbar. Returns ------- scatter or None ScalarMappable for external colorbar. """ active_mask = rays.active positions = rays.positions[active_mask] intensities = rays.intensities[active_mask] wavelengths = rays.wavelengths[active_mask] * 1e9 # Determine axis indices axis_map = {"x": 0, "y": 1, "z": 2} if axis.lower() not in axis_map: raise ValueError(f"Invalid axis: {axis}. Use 'x', 'y', or 'z'") axis_idx = axis_map[axis.lower()] perp_idx1 = (axis_idx + 1) % 3 perp_idx2 = (axis_idx + 2) % 3 axis_labels = ["X", "Y", "Z"] # Select rays in slice axis_pos = positions[:, axis_idx] mask = np.abs(axis_pos - slice_value) <= slice_width / 2 if np.sum(mask) == 0: ax.text( 0.5, 0.5, "No rays in slice", ha="center", va="center", transform=ax.transAxes, ) return None x = positions[mask, perp_idx1] y = positions[mask, perp_idx2] if color_by == "intensity": c = intensities[mask] cmap = "hot" clabel = "Intensity" else: c = wavelengths[mask] cmap = "rainbow" clabel = "Wavelength (nm)" scatter = ax.scatter(x, y, c=c, s=point_size, cmap=cmap, alpha=alpha) ax.set_xlabel(f"{axis_labels[perp_idx1]} (m)") ax.set_ylabel(f"{axis_labels[perp_idx2]} (m)") ax.set_title(f"{axis.upper()}={slice_value:.3f} m") ax.set_aspect("equal", adjustable="box") ax.grid(True, alpha=0.3) if show_colorbar: add_colorbar(ax, scatter, clabel) return scatter
# ============================================================================= # Wavelength Distribution Functions # =============================================================================
[docs] def plot_wavelength_histogram( ax: Axes, rays: "RayBatch", bins: int = 50, alpha: float = 0.7, color: str = "steelblue", edgecolor: str = "black", label: str | None = None, weight_by_intensity: bool = False, ) -> None: """ Plot histogram of ray wavelengths. Parameters ---------- ax : Axes Matplotlib axes. rays : RayBatch Ray batch. bins : int Number of histogram bins. alpha : float Bar transparency. color : str Bar color. edgecolor : str Edge color. label : str, optional Legend label. weight_by_intensity : bool If True, weight histogram by ray intensities. """ active_mask = rays.active wavelengths = rays.wavelengths[active_mask] * 1e9 # nm weights = rays.intensities[active_mask] if weight_by_intensity else None ylabel = "Total Intensity" if weight_by_intensity else "Count" title = ( "Intensity per Wavelength" if weight_by_intensity else "Wavelength Distribution" ) ax.hist( wavelengths, bins=bins, weights=weights, alpha=alpha, color=color, edgecolor=edgecolor, label=label, ) setup_axis_grid(ax, "Wavelength (nm)", ylabel, title)
[docs] def plot_wavelength_intensity_histogram( ax: Axes, rays: "RayBatch", bins: int = 50, alpha: float = 0.7, color: str = "orange", edgecolor: str = "black", label: str | None = None, ) -> None: """ Plot intensity-weighted histogram of ray wavelengths. Deprecated: Use plot_wavelength_histogram(weight_by_intensity=True) instead. Parameters ---------- ax : Axes Matplotlib axes. rays : RayBatch Ray batch. bins : int Number of histogram bins. alpha : float Bar transparency. color : str Bar color. edgecolor : str Edge color. label : str, optional Legend label. """ return plot_wavelength_histogram( ax, rays, bins=bins, alpha=alpha, color=color, edgecolor=edgecolor, label=label, weight_by_intensity=True, )
# ============================================================================= # Detection Count and Efficiency Functions # =============================================================================
[docs] def plot_detection_counts( ax: Axes, detector_angles_deg: np.ndarray, detection_counts: np.ndarray, color: str = "blue", marker: str = "o", linewidth: float = 2, markersize: float = 8, label: str | None = None, ) -> None: """ Plot detection counts vs detector angle. Parameters ---------- ax : Axes Matplotlib axes. detector_angles_deg : ndarray Detector angles in degrees. detection_counts : ndarray Number of rays detected. color : str Line color. marker : str Marker style. linewidth : float Line width. markersize : float Marker size. label : str, optional Legend label. """ ax.plot( detector_angles_deg, detection_counts, f"{color[0]}{marker}-", linewidth=linewidth, markersize=markersize, label=label, ) ax.axhline(0, color="k", linestyle="-", linewidth=0.5) setup_axis_grid(ax, "Detector Angle (degrees)", "Detected Rays", "Detection Count") ax.set_xlim(0, 90)
[docs] def plot_detection_efficiency( ax: Axes, detector_angles_deg: np.ndarray, detected_intensities: np.ndarray, total_source_intensity: float, color: str = "magenta", marker: str = "o", linewidth: float = 2, markersize: float = 8, label: str | None = None, ) -> None: """ Plot detection efficiency vs detector angle. Parameters ---------- ax : Axes Matplotlib axes. detector_angles_deg : ndarray Detector angles in degrees. detected_intensities : ndarray Total intensity detected. total_source_intensity : float Total source intensity. color : str Line color. marker : str Marker style. linewidth : float Line width. markersize : float Marker size. label : str, optional Legend label. """ efficiency = detected_intensities / total_source_intensity * 100 ax.plot( detector_angles_deg, efficiency, f"{color[0]}{marker}-", linewidth=linewidth, markersize=markersize, label=label, ) ax.axhline(0, color="k", linestyle="-", linewidth=0.5) setup_axis_grid( ax, "Detector Angle (degrees)", "Efficiency (%)", "Detection Efficiency" ) ax.set_xlim(0, 90) ax.set_ylim(0, max(efficiency) * 1.1 if max(efficiency) > 0 else 1)
# ============================================================================= # Arrival Time Functions # =============================================================================
[docs] def plot_mean_arrival_time( ax: Axes, detector_angles_deg: np.ndarray, mean_times: np.ndarray, std_times: np.ndarray | None = None, detection_counts: np.ndarray | None = None, color: str = "cyan", marker: str = "o", linewidth: float = 2, markersize: float = 8, label: str | None = None, ) -> None: """ Plot mean arrival time vs detector angle. Parameters ---------- ax : Axes Matplotlib axes. detector_angles_deg : ndarray Detector angles in degrees. mean_times : ndarray Mean arrival times in seconds. std_times : ndarray, optional Standard deviation of arrival times. detection_counts : ndarray, optional For masking invalid data. color : str Line color. marker : str Marker style. linewidth : float Line width. markersize : float Marker size. label : str, optional Legend label. """ if detection_counts is not None: valid_mask = detection_counts > 0 angles = detector_angles_deg[valid_mask] times = mean_times[valid_mask] * 1e6 # to microseconds yerr = std_times[valid_mask] * 1e6 if std_times is not None else None else: angles = detector_angles_deg times = mean_times * 1e6 yerr = std_times * 1e6 if std_times is not None else None ax.errorbar( angles, times, yerr=yerr, fmt=f"{color[0]}{marker}-", linewidth=linewidth, markersize=markersize, capsize=5, alpha=0.7, label=label, ) setup_axis_grid( ax, "Detector Angle (degrees)", "Mean Arrival Time (μs)", "Arrival Time" ) ax.set_xlim(0, 90)
[docs] def plot_timing_distribution( ax: Axes, all_time_distributions: list[tuple], detector_angles_deg: np.ndarray, log_scale: bool = True, show_legend: bool = True, max_curves: int = 10, ) -> None: """ Plot arrival time distributions for multiple detector positions. Parameters ---------- ax : Axes Matplotlib axes. all_time_distributions : list List of (times, intensities, angles) tuples for each detector. detector_angles_deg : ndarray Detector angles in degrees. log_scale : bool Whether to use log scales. show_legend : bool Whether to show legend. max_curves : int Maximum number of curves to plot. """ # Find first arrival time all_times = [] for time_data in all_time_distributions: if len(time_data) == 3: times, _, _ = time_data if len(times) > 0: all_times.extend(times) if len(all_times) == 0: ax.text( 0.5, 0.5, "No timing data", ha="center", va="center", transform=ax.transAxes ) return first_arrival = np.min(all_times) log_bin_edges = np.logspace(-9, -3, 50) # seconds # Collect positions with data positions_with_data = [] for i, angle_deg in enumerate(detector_angles_deg): time_data = all_time_distributions[i] if len(time_data) == 3: times, intensities, _ = time_data if len(times) > 0: times_rel = times - first_arrival counts, _ = np.histogram( times_rel, bins=log_bin_edges, weights=intensities ) if counts.sum() > 0: positions_with_data.append((angle_deg, counts)) if len(positions_with_data) == 0: ax.text( 0.5, 0.5, "No timing data", ha="center", va="center", transform=ax.transAxes ) return # Sample if too many curves if len(positions_with_data) > max_curves: indices = np.linspace(0, len(positions_with_data) - 1, max_curves, dtype=int) positions_with_data = [positions_with_data[i] for i in indices] colors = plt.cm.turbo(np.linspace(0, 1, len(positions_with_data))) bin_centers = (log_bin_edges[:-1] + log_bin_edges[1:]) / 2 for idx, (angle_deg, counts) in enumerate(positions_with_data): label = f"{angle_deg:.0f}°" if show_legend else "" ax.plot( bin_centers * 1e9, counts, color=colors[idx], linewidth=1.5, label=label, alpha=0.7, ) setup_axis_grid( ax, "Relative Arrival Time (ns)", "Intensity", "Timing Distribution" ) if log_scale: ax.set_xscale("log") ax.set_yscale("log") if show_legend: ax.legend(loc="upper right", fontsize=8, ncol=2)
# ============================================================================= # Angular Distribution Functions # =============================================================================
[docs] def plot_arrival_angle_distribution( ax: Axes, detector_angles_deg: np.ndarray, mean_angles: np.ndarray, std_angles: np.ndarray | None = None, detection_counts: np.ndarray | None = None, color: str = "magenta", marker: str = "o", linewidth: float = 2, markersize: float = 8, ) -> None: """ Plot mean arrival angle vs detector position. Parameters ---------- ax : Axes Matplotlib axes. detector_angles_deg : ndarray Detector position angles. mean_angles : ndarray Mean arrival angles to normal. std_angles : ndarray, optional Standard deviation. detection_counts : ndarray, optional For masking invalid data. color : str Line color. marker : str Marker style. linewidth : float Line width. markersize : float Marker size. """ if detection_counts is not None: valid_mask = detection_counts > 0 angles = detector_angles_deg[valid_mask] means = mean_angles[valid_mask] yerr = std_angles[valid_mask] if std_angles is not None else None else: angles = detector_angles_deg means = mean_angles yerr = std_angles ax.errorbar( angles, means, yerr=yerr, fmt=f"{color[0]}{marker}-", linewidth=linewidth, markersize=markersize, capsize=5, alpha=0.7, ) setup_axis_grid( ax, "Detector Angle (degrees)", "Mean Angle to Normal (degrees)", "Arrival Angles", ) ax.set_xlim(0, 90)
[docs] def plot_angular_histogram( ax: Axes, all_time_distributions: list[tuple], detection_counts: np.ndarray, bins: int = 50, color: str = "purple", alpha: float = 0.7, ) -> None: """ Plot histogram of all arrival angles. Parameters ---------- ax : Axes Matplotlib axes. all_time_distributions : list List of (times, intensities, angles) tuples. detection_counts : ndarray Detection counts per position. bins : int Number of histogram bins. color : str Bar color. alpha : float Transparency. """ all_angles = [] all_intensities = [] for i, time_data in enumerate(all_time_distributions): if detection_counts[i] > 0 and len(time_data) == 3: _, intensities, angles = time_data all_angles.extend(angles) all_intensities.extend(intensities) if len(all_angles) == 0: ax.text( 0.5, 0.5, "No angular data", ha="center", va="center", transform=ax.transAxes, ) return all_angles = np.array(all_angles) all_intensities = np.array(all_intensities) ax.hist( all_angles, bins=bins, weights=all_intensities, color=color, alpha=alpha, edgecolor="black", ) setup_axis_grid( ax, "Angle to Normal (degrees)", "Intensity", "Angular Distribution" )
# ============================================================================= # Composite Figure Builders # =============================================================================
[docs] def create_wavelength_figure( rays: "RayBatch", bins: int = 50, figsize: tuple[float, float] = (10, 5), title: str = "Wavelength Distribution", save_path: str | None = None, ) -> Figure: """ Create figure with wavelength histograms. Parameters ---------- rays : RayBatch Ray batch. bins : int Histogram bins. figsize : tuple Figure size. title : str Figure title. save_path : str, optional Save path. Returns ------- Figure Matplotlib figure. """ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize, constrained_layout=True) fig.suptitle(title, fontsize=14, fontweight="bold") plot_wavelength_histogram(ax1, rays, bins=bins) plot_wavelength_intensity_histogram(ax2, rays, bins=bins) if save_path: save_figure(fig, save_path) return fig
[docs] def create_beam_profile_figure( rays: "RayBatch", axis: str = "z", num_slices: int = 5, figsize: tuple[float, float] = (15, 4), title: str | None = None, save_path: str | None = None, ) -> Figure: """ Create figure showing beam profile at multiple slices. Parameters ---------- rays : RayBatch Ray batch. axis : str Propagation axis. num_slices : int Number of slices. figsize : tuple Figure size. title : str, optional Figure title. save_path : str, optional Save path. Returns ------- Figure Matplotlib figure. """ fig, axes = plt.subplots(1, num_slices, figsize=figsize, constrained_layout=True) if title is None: title = f"Beam Profile Along {axis.upper()} Axis" fig.suptitle(title, fontsize=14, fontweight="bold") if num_slices == 1: axes = [axes] active_mask = rays.active positions = rays.positions[active_mask] axis_map = {"x": 0, "y": 1, "z": 2} axis_idx = axis_map.get(axis.lower(), 2) axis_pos = positions[:, axis_idx] axis_min, axis_max = np.min(axis_pos), np.max(axis_pos) slice_centers = np.linspace(axis_min, axis_max, num_slices) slice_width = ( (axis_max - axis_min) / (num_slices - 1) if num_slices > 1 else (axis_max - axis_min) ) for i, (ax, center) in enumerate(zip(axes, slice_centers, strict=False)): plot_beam_slice( ax, rays, axis=axis, slice_value=center, slice_width=slice_width, show_colorbar=(i == num_slices - 1), ) if save_path: save_figure(fig, save_path) return fig
[docs] def create_detector_scan_figure( detector_angles_deg: np.ndarray, detection_counts: np.ndarray, detected_intensities: np.ndarray, total_source_intensity: float, mean_arrival_times: np.ndarray | None = None, std_arrival_times: np.ndarray | None = None, mean_angles_to_normal: np.ndarray | None = None, std_angles_to_normal: np.ndarray | None = None, all_time_distributions: list | None = None, figsize: tuple[float, float] = (16, 12), title: str = "Detector Scan Results", save_path: str | None = None, ) -> Figure: """ Create comprehensive detector scan figure. Parameters ---------- detector_angles_deg : ndarray Detector angles in degrees. detection_counts : ndarray Number of rays detected. detected_intensities : ndarray Total intensity detected. total_source_intensity : float Source intensity. mean_arrival_times : ndarray, optional Mean arrival times. std_arrival_times : ndarray, optional Std of arrival times. mean_angles_to_normal : ndarray, optional Mean arrival angles. std_angles_to_normal : ndarray, optional Std of arrival angles. all_time_distributions : list, optional Timing data. figsize : tuple Figure size. title : str Figure title. save_path : str, optional Save path. Returns ------- Figure Matplotlib figure. """ fig, axes = plt.subplots(3, 2, figsize=figsize, constrained_layout=True) fig.suptitle(title, fontsize=14, fontweight="bold") # Row 1: Detection counts and efficiency plot_detection_counts(axes[0, 0], detector_angles_deg, detection_counts) plot_detection_efficiency( axes[0, 1], detector_angles_deg, detected_intensities, total_source_intensity ) # Row 2: Arrival times and angles if mean_arrival_times is not None: plot_mean_arrival_time( axes[1, 0], detector_angles_deg, mean_arrival_times, std_arrival_times, detection_counts, ) else: axes[1, 0].text( 0.5, 0.5, "No timing data", ha="center", va="center", transform=axes[1, 0].transAxes, ) if mean_angles_to_normal is not None: plot_arrival_angle_distribution( axes[1, 1], detector_angles_deg, mean_angles_to_normal, std_angles_to_normal, detection_counts, ) else: axes[1, 1].text( 0.5, 0.5, "No angular data", ha="center", va="center", transform=axes[1, 1].transAxes, ) # Row 3: Distributions if all_time_distributions is not None: plot_timing_distribution( axes[2, 0], all_time_distributions, detector_angles_deg ) plot_angular_histogram(axes[2, 1], all_time_distributions, detection_counts) else: axes[2, 0].text( 0.5, 0.5, "No distribution data", ha="center", va="center", transform=axes[2, 0].transAxes, ) axes[2, 1].text( 0.5, 0.5, "No distribution data", ha="center", va="center", transform=axes[2, 1].transAxes, ) if save_path: save_figure(fig, save_path) return fig
[docs] def plot_detector_scan_results( detector_angles_deg: np.ndarray, detection_counts: np.ndarray, detected_intensities: np.ndarray, reflected_rays: "RayBatch", surface: "Surface", total_source_intensity: float, mean_arrival_times: np.ndarray | None = None, std_arrival_times: np.ndarray | None = None, mean_angles_to_normal: np.ndarray | None = None, std_angles_to_normal: np.ndarray | None = None, all_time_distributions: list | None = None, detector_distance: float = 1000.0, detector_radius: float = 5.0, water_normal: np.ndarray = None, figsize: tuple[float, float] = (24, 14), save_path: str | None = None, ) -> Figure: """ Create comprehensive detector scan visualization with multiple subplots. This is the full production visualization with all panels including the wave surface and ray visualizations. Parameters ---------- detector_angles_deg : ndarray Detector angles in degrees (0-90). detection_counts : ndarray Number of rays detected at each position. detected_intensities : ndarray Total intensity detected at each position. reflected_rays : RayBatch Batch of reflected rays. surface : Surface The surface object (must have _surface_z method). total_source_intensity : float Total intensity of source rays. mean_arrival_times : ndarray, optional Mean arrival time at each detector position. std_arrival_times : ndarray, optional Standard deviation of arrival times. mean_angles_to_normal : ndarray, optional Mean angle to normal at each detector. std_angles_to_normal : ndarray, optional Standard deviation of angles. all_time_distributions : list, optional List of (times, intensities, angles) tuples for each detector. detector_distance : float Distance to detector in meters. detector_radius : float Detector radius in meters. water_normal : ndarray Normal vector for water surface. figsize : tuple Figure size (width, height). save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure with comprehensive visualization. """ import matplotlib.gridspec as gridspec if water_normal is None: water_normal = np.array([0, 0, 1]) # Create figure with 4x4 grid layout fig = plt.figure(figsize=figsize, constrained_layout=True) gs = gridspec.GridSpec(4, 4, figure=fig, hspace=0.3, wspace=0.4) # Compute detection efficiency detection_efficiency = detected_intensities / total_source_intensity * 100 # 1. Ray count per detector position ax1 = fig.add_subplot(gs[1, :2]) ax1.plot(detector_angles_deg, detection_counts, "bo-", linewidth=2, markersize=8) ax1.axhline(0, color="k", linestyle="-", linewidth=0.5) ax1.set_xlabel( "Detector Position Angle from Surface (degrees)", fontsize=11, fontweight="bold" ) ax1.set_ylabel("Number of Detected Rays", fontsize=11, fontweight="bold") ax1.set_title("Ray Detection Count vs Detector Position", fontweight="bold") ax1.grid(True, alpha=0.3) ax1.set_xlim(0, 90) # 2. Mean arrival angle vs detector position if mean_angles_to_normal is not None: ax2 = fig.add_subplot(gs[1, 2:]) valid_mask = detection_counts > 0 ax2.errorbar( detector_angles_deg[valid_mask], mean_angles_to_normal[valid_mask], yerr=( std_angles_to_normal[valid_mask] if std_angles_to_normal is not None else None ), fmt="mo-", linewidth=2, markersize=8, capsize=5, alpha=0.7, ) ax2.set_xlabel( "Detector Position Angle from Surface (degrees)", fontsize=11, fontweight="bold", ) ax2.set_ylabel("Mean Angle to Normal (degrees)", fontsize=11, fontweight="bold") ax2.set_title("Mean Arrival Angle at Detector", fontweight="bold") ax2.grid(True, alpha=0.3) ax2.set_xlim(0, 90) # 3a. Mean arrival time vs detector position if mean_arrival_times is not None: ax3a = fig.add_subplot(gs[2, :2]) valid_mask = detection_counts > 0 ax3a.errorbar( detector_angles_deg[valid_mask], mean_arrival_times[valid_mask] * 1e6, # Convert to microseconds yerr=( std_arrival_times[valid_mask] * 1e6 if std_arrival_times is not None else None ), fmt="co-", linewidth=2, markersize=8, capsize=5, alpha=0.7, ) ax3a.set_xlabel( "Detector Position Angle from Surface (degrees)", fontsize=11, fontweight="bold", ) ax3a.set_ylabel("Mean Arrival Time (μs)", fontsize=11, fontweight="bold") ax3a.set_title("Mean Arrival Time at Detector", fontweight="bold") ax3a.grid(True, alpha=0.3) ax3a.set_xlim(0, 90) # 3b. Detection efficiency vs detector position ax3b = fig.add_subplot(gs[2, 2:]) ax3b.plot( detector_angles_deg, detection_efficiency, "mo-", linewidth=2, markersize=8 ) ax3b.axhline(0, color="k", linestyle="-", linewidth=0.5) ax3b.set_xlabel( "Detector Position Angle from Surface (degrees)", fontsize=11, fontweight="bold" ) ax3b.set_ylabel("Detection Efficiency (%)", fontsize=11, fontweight="bold") ax3b.set_title("Detection Efficiency (Detected/Source)", fontweight="bold") ax3b.grid(True, alpha=0.3) ax3b.set_xlim(0, 90) ax3b.set_ylim( 0, max(detection_efficiency) * 1.1 if max(detection_efficiency) > 0 else 1 ) # 4. Time distribution if all_time_distributions is not None: ax4 = fig.add_subplot(gs[3, 0:2]) # Find first arrival time across all detectors all_times_global = [] for i, angle_deg in enumerate(detector_angles_deg): time_data = all_time_distributions[i] if len(time_data) == 3: times_raw, _, _ = time_data if len(times_raw) > 0: all_times_global.extend(times_raw) if len(all_times_global) > 0: first_arrival = np.min(all_times_global) log_bin_edges = np.logspace(-9, -3, 50) # in seconds # Identify positions with data positions_with_data = [] for i, angle_deg in enumerate(detector_angles_deg): time_data = all_time_distributions[i] if len(time_data) == 3: times_raw, intensities_raw, _ = time_data if len(times_raw) > 0: times_relative = times_raw - first_arrival counts, _ = np.histogram( times_relative, bins=log_bin_edges, weights=intensities_raw ) if counts.sum() > 0: positions_with_data.append( (i, angle_deg, times_relative, counts, intensities_raw) ) # Plot histograms if len(positions_with_data) > 0: colors = plt.cm.turbo(np.linspace(0, 1, len(positions_with_data))) for color_idx, ( i, angle_deg, times_relative, counts, intensities_raw, ) in enumerate(positions_with_data): bin_centers = (log_bin_edges[:-1] + log_bin_edges[1:]) / 2 ax4.plot( bin_centers * 1e9, # Convert to nanoseconds counts, color=colors[color_idx], linewidth=1.5, label=f"{angle_deg:.0f}°" if color_idx % 10 == 0 else "", alpha=0.7, ) ax4.set_xlabel("Relative Arrival Time (ns)", fontsize=11, fontweight="bold") ax4.set_ylabel( "Intensity (weighted counts)", fontsize=11, fontweight="bold" ) ax4.set_title("Timing Distribution (intensity-weighted)", fontweight="bold") ax4.set_xscale("log") ax4.set_yscale("log") ax4.grid(True, alpha=0.3, which="both") if len(positions_with_data) > 0: ax4.legend(loc="upper right", fontsize=8, ncol=2) # 5. Angular distribution if all_time_distributions is not None: ax5 = fig.add_subplot(gs[3, 2:]) all_angles = [] intensities_for_angles = [] for i, angle_deg in enumerate(detector_angles_deg): if detection_counts[i] > 0: time_data = all_time_distributions[i] if len(time_data) == 3: _, intensities_raw, angles_raw = time_data all_angles.extend(angles_raw) intensities_for_angles.extend(intensities_raw) if len(all_angles) > 0: all_angles = np.array(all_angles) intensities_for_angles = np.array(intensities_for_angles) ax5.hist( all_angles, bins=50, weights=intensities_for_angles, color="purple", alpha=0.7, edgecolor="black", ) ax5.set_xlabel( "Angle to Water Normal (degrees)", fontsize=11, fontweight="bold" ) ax5.set_ylabel( "Intensity (weighted counts)", fontsize=11, fontweight="bold" ) ax5.set_title( "Angular Distribution of Detected Rays (intensity-weighted)", fontweight="bold", ) ax5.grid(True, alpha=0.3) if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight") return fig
# ============================================================================= # Legacy Convenience Functions (Backward Compatibility) # =============================================================================
[docs] def plot_statistics_evolution( stats_history: list["RayStatistics"], figsize: tuple[float, float] = (15, 10), save_path: str | None = None, ) -> Figure: """ Create figure showing evolution of ray statistics over propagation. This is a convenience function for visualizing how beam properties change during propagation. Parameters ---------- stats_history : List[RayStatistics] List of RayStatistics objects at different propagation steps. figsize : tuple Figure size. save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure with statistics evolution. """ if len(stats_history) == 0: fig, ax = plt.subplots(1, 1, figsize=figsize) ax.text( 0.5, 0.5, "No statistics available", ha="center", va="center", transform=ax.transAxes, ) return fig # Extract statistics steps = np.arange(len(stats_history)) active_rays = [s.active_rays for s in stats_history] total_power = [s.total_power for s in stats_history] mean_path = [s.mean_optical_path for s in stats_history] # Optional: mean positions if available mean_x = [ s.mean_position[0] if hasattr(s, "mean_position") else 0 for s in stats_history ] mean_y = [ s.mean_position[1] if hasattr(s, "mean_position") else 0 for s in stats_history ] mean_z = [ s.mean_position[2] if hasattr(s, "mean_position") else 0 for s in stats_history ] fig, axes = plt.subplots(2, 2, figsize=figsize, constrained_layout=True) fig.suptitle("Ray Statistics Evolution", fontsize=14, fontweight="bold") # Active rays ax = axes[0, 0] ax.plot(steps, active_rays, "b-o", markersize=4) ax.set_xlabel("Step") ax.set_ylabel("Active Rays") ax.set_title("Active Ray Count") ax.grid(True, alpha=0.3) # Total power ax = axes[0, 1] ax.plot(steps, total_power, "g-o", markersize=4) ax.set_xlabel("Step") ax.set_ylabel("Power (W)") ax.set_title("Total Power") ax.grid(True, alpha=0.3) # Mean optical path ax = axes[1, 0] ax.plot(steps, mean_path, "r-o", markersize=4) ax.set_xlabel("Step") ax.set_ylabel("Path Length (m)") ax.set_title("Mean Optical Path") ax.grid(True, alpha=0.3) # Mean position ax = axes[1, 1] ax.plot(steps, mean_z, "purple", label="Z", marker="o", markersize=4) ax.plot(steps, mean_x, "orange", label="X", marker="s", markersize=4, alpha=0.7) ax.plot(steps, mean_y, "cyan", label="Y", marker="^", markersize=4, alpha=0.7) ax.set_xlabel("Step") ax.set_ylabel("Position (m)") ax.set_title("Mean Position") ax.legend() ax.grid(True, alpha=0.3) if save_path: from .common import save_figure save_figure(fig, save_path) return fig
[docs] def plot_beam_profile( rays: "RayBatch", axis: str = "z", num_slices: int = 5, figsize: tuple[float, float] = (15, 4), save_path: str | None = None, ) -> Figure: """ Create figure showing beam profile at multiple slices. This is an alias for create_beam_profile_figure for backward compatibility. Parameters ---------- rays : RayBatch Ray batch. axis : str Propagation axis: 'x', 'y', 'z'. num_slices : int Number of slices to show. figsize : tuple Figure size. save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure. """ return create_beam_profile_figure( rays=rays, axis=axis, num_slices=num_slices, figsize=figsize, save_path=save_path, )
[docs] def plot_wavelength_distribution( rays: "RayBatch", bins: int = 50, figsize: tuple[float, float] = (10, 6), save_path: str | None = None, ) -> Figure: """ Create figure showing wavelength distribution of rays. This is a convenience function for quick visualization. For custom layouts, use plot_wavelength_histogram() on a single axis. Parameters ---------- rays : RayBatch Ray batch. bins : int Number of histogram bins. figsize : tuple Figure size. save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure with wavelength histogram. """ fig, ax = plt.subplots(1, 1, figsize=figsize, constrained_layout=True) plot_wavelength_histogram(ax, rays, bins=bins) fig.suptitle("Wavelength Distribution", fontsize=14, fontweight="bold") if save_path: from .common import save_figure save_figure(fig, save_path) return fig