# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Detector Visualization - Individual Axis Functions
Functions for plotting detector-related data: beam profiles, wavelength distributions,
detection counts, arrival times, and scan results.
Each function draws on a single axis, enabling flexible composition.
"""
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ..utilities.ray_data import RayBatch, RayStatistics
from ..surfaces import Surface
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from .common import (
add_colorbar,
save_figure,
setup_axis_grid,
)
# =============================================================================
# Beam Profile Functions
# =============================================================================
[docs]
def plot_beam_slice(
ax: Axes,
rays: "RayBatch",
axis: str = "z",
slice_value: float = 0.0,
slice_width: float = 0.1,
color_by: str = "intensity",
point_size: float = 20,
alpha: float = 0.6,
show_colorbar: bool = True,
) -> Any | None:
"""
Plot beam profile at a specific slice along propagation axis.
Parameters
----------
ax : Axes
Matplotlib axes.
rays : RayBatch
Ray batch.
axis : str
Propagation axis: 'x', 'y', or 'z'.
slice_value : float
Position along axis to slice.
slice_width : float
Width of slice.
color_by : str
Color by: 'intensity', 'wavelength'.
point_size : float
Scatter point size.
alpha : float
Transparency.
show_colorbar : bool
Whether to add colorbar.
Returns
-------
scatter or None
ScalarMappable for external colorbar.
"""
active_mask = rays.active
positions = rays.positions[active_mask]
intensities = rays.intensities[active_mask]
wavelengths = rays.wavelengths[active_mask] * 1e9
# Determine axis indices
axis_map = {"x": 0, "y": 1, "z": 2}
if axis.lower() not in axis_map:
raise ValueError(f"Invalid axis: {axis}. Use 'x', 'y', or 'z'")
axis_idx = axis_map[axis.lower()]
perp_idx1 = (axis_idx + 1) % 3
perp_idx2 = (axis_idx + 2) % 3
axis_labels = ["X", "Y", "Z"]
# Select rays in slice
axis_pos = positions[:, axis_idx]
mask = np.abs(axis_pos - slice_value) <= slice_width / 2
if np.sum(mask) == 0:
ax.text(
0.5,
0.5,
"No rays in slice",
ha="center",
va="center",
transform=ax.transAxes,
)
return None
x = positions[mask, perp_idx1]
y = positions[mask, perp_idx2]
if color_by == "intensity":
c = intensities[mask]
cmap = "hot"
clabel = "Intensity"
else:
c = wavelengths[mask]
cmap = "rainbow"
clabel = "Wavelength (nm)"
scatter = ax.scatter(x, y, c=c, s=point_size, cmap=cmap, alpha=alpha)
ax.set_xlabel(f"{axis_labels[perp_idx1]} (m)")
ax.set_ylabel(f"{axis_labels[perp_idx2]} (m)")
ax.set_title(f"{axis.upper()}={slice_value:.3f} m")
ax.set_aspect("equal", adjustable="box")
ax.grid(True, alpha=0.3)
if show_colorbar:
add_colorbar(ax, scatter, clabel)
return scatter
# =============================================================================
# Wavelength Distribution Functions
# =============================================================================
[docs]
def plot_wavelength_histogram(
ax: Axes,
rays: "RayBatch",
bins: int = 50,
alpha: float = 0.7,
color: str = "steelblue",
edgecolor: str = "black",
label: str | None = None,
weight_by_intensity: bool = False,
) -> None:
"""
Plot histogram of ray wavelengths.
Parameters
----------
ax : Axes
Matplotlib axes.
rays : RayBatch
Ray batch.
bins : int
Number of histogram bins.
alpha : float
Bar transparency.
color : str
Bar color.
edgecolor : str
Edge color.
label : str, optional
Legend label.
weight_by_intensity : bool
If True, weight histogram by ray intensities.
"""
active_mask = rays.active
wavelengths = rays.wavelengths[active_mask] * 1e9 # nm
weights = rays.intensities[active_mask] if weight_by_intensity else None
ylabel = "Total Intensity" if weight_by_intensity else "Count"
title = (
"Intensity per Wavelength" if weight_by_intensity else "Wavelength Distribution"
)
ax.hist(
wavelengths,
bins=bins,
weights=weights,
alpha=alpha,
color=color,
edgecolor=edgecolor,
label=label,
)
setup_axis_grid(ax, "Wavelength (nm)", ylabel, title)
[docs]
def plot_wavelength_intensity_histogram(
ax: Axes,
rays: "RayBatch",
bins: int = 50,
alpha: float = 0.7,
color: str = "orange",
edgecolor: str = "black",
label: str | None = None,
) -> None:
"""
Plot intensity-weighted histogram of ray wavelengths.
Deprecated: Use plot_wavelength_histogram(weight_by_intensity=True) instead.
Parameters
----------
ax : Axes
Matplotlib axes.
rays : RayBatch
Ray batch.
bins : int
Number of histogram bins.
alpha : float
Bar transparency.
color : str
Bar color.
edgecolor : str
Edge color.
label : str, optional
Legend label.
"""
return plot_wavelength_histogram(
ax,
rays,
bins=bins,
alpha=alpha,
color=color,
edgecolor=edgecolor,
label=label,
weight_by_intensity=True,
)
# =============================================================================
# Detection Count and Efficiency Functions
# =============================================================================
[docs]
def plot_detection_counts(
ax: Axes,
detector_angles_deg: np.ndarray,
detection_counts: np.ndarray,
color: str = "blue",
marker: str = "o",
linewidth: float = 2,
markersize: float = 8,
label: str | None = None,
) -> None:
"""
Plot detection counts vs detector angle.
Parameters
----------
ax : Axes
Matplotlib axes.
detector_angles_deg : ndarray
Detector angles in degrees.
detection_counts : ndarray
Number of rays detected.
color : str
Line color.
marker : str
Marker style.
linewidth : float
Line width.
markersize : float
Marker size.
label : str, optional
Legend label.
"""
ax.plot(
detector_angles_deg,
detection_counts,
f"{color[0]}{marker}-",
linewidth=linewidth,
markersize=markersize,
label=label,
)
ax.axhline(0, color="k", linestyle="-", linewidth=0.5)
setup_axis_grid(ax, "Detector Angle (degrees)", "Detected Rays", "Detection Count")
ax.set_xlim(0, 90)
[docs]
def plot_detection_efficiency(
ax: Axes,
detector_angles_deg: np.ndarray,
detected_intensities: np.ndarray,
total_source_intensity: float,
color: str = "magenta",
marker: str = "o",
linewidth: float = 2,
markersize: float = 8,
label: str | None = None,
) -> None:
"""
Plot detection efficiency vs detector angle.
Parameters
----------
ax : Axes
Matplotlib axes.
detector_angles_deg : ndarray
Detector angles in degrees.
detected_intensities : ndarray
Total intensity detected.
total_source_intensity : float
Total source intensity.
color : str
Line color.
marker : str
Marker style.
linewidth : float
Line width.
markersize : float
Marker size.
label : str, optional
Legend label.
"""
efficiency = detected_intensities / total_source_intensity * 100
ax.plot(
detector_angles_deg,
efficiency,
f"{color[0]}{marker}-",
linewidth=linewidth,
markersize=markersize,
label=label,
)
ax.axhline(0, color="k", linestyle="-", linewidth=0.5)
setup_axis_grid(
ax, "Detector Angle (degrees)", "Efficiency (%)", "Detection Efficiency"
)
ax.set_xlim(0, 90)
ax.set_ylim(0, max(efficiency) * 1.1 if max(efficiency) > 0 else 1)
# =============================================================================
# Arrival Time Functions
# =============================================================================
[docs]
def plot_mean_arrival_time(
ax: Axes,
detector_angles_deg: np.ndarray,
mean_times: np.ndarray,
std_times: np.ndarray | None = None,
detection_counts: np.ndarray | None = None,
color: str = "cyan",
marker: str = "o",
linewidth: float = 2,
markersize: float = 8,
label: str | None = None,
) -> None:
"""
Plot mean arrival time vs detector angle.
Parameters
----------
ax : Axes
Matplotlib axes.
detector_angles_deg : ndarray
Detector angles in degrees.
mean_times : ndarray
Mean arrival times in seconds.
std_times : ndarray, optional
Standard deviation of arrival times.
detection_counts : ndarray, optional
For masking invalid data.
color : str
Line color.
marker : str
Marker style.
linewidth : float
Line width.
markersize : float
Marker size.
label : str, optional
Legend label.
"""
if detection_counts is not None:
valid_mask = detection_counts > 0
angles = detector_angles_deg[valid_mask]
times = mean_times[valid_mask] * 1e6 # to microseconds
yerr = std_times[valid_mask] * 1e6 if std_times is not None else None
else:
angles = detector_angles_deg
times = mean_times * 1e6
yerr = std_times * 1e6 if std_times is not None else None
ax.errorbar(
angles,
times,
yerr=yerr,
fmt=f"{color[0]}{marker}-",
linewidth=linewidth,
markersize=markersize,
capsize=5,
alpha=0.7,
label=label,
)
setup_axis_grid(
ax, "Detector Angle (degrees)", "Mean Arrival Time (μs)", "Arrival Time"
)
ax.set_xlim(0, 90)
[docs]
def plot_timing_distribution(
ax: Axes,
all_time_distributions: list[tuple],
detector_angles_deg: np.ndarray,
log_scale: bool = True,
show_legend: bool = True,
max_curves: int = 10,
) -> None:
"""
Plot arrival time distributions for multiple detector positions.
Parameters
----------
ax : Axes
Matplotlib axes.
all_time_distributions : list
List of (times, intensities, angles) tuples for each detector.
detector_angles_deg : ndarray
Detector angles in degrees.
log_scale : bool
Whether to use log scales.
show_legend : bool
Whether to show legend.
max_curves : int
Maximum number of curves to plot.
"""
# Find first arrival time
all_times = []
for time_data in all_time_distributions:
if len(time_data) == 3:
times, _, _ = time_data
if len(times) > 0:
all_times.extend(times)
if len(all_times) == 0:
ax.text(
0.5, 0.5, "No timing data", ha="center", va="center", transform=ax.transAxes
)
return
first_arrival = np.min(all_times)
log_bin_edges = np.logspace(-9, -3, 50) # seconds
# Collect positions with data
positions_with_data = []
for i, angle_deg in enumerate(detector_angles_deg):
time_data = all_time_distributions[i]
if len(time_data) == 3:
times, intensities, _ = time_data
if len(times) > 0:
times_rel = times - first_arrival
counts, _ = np.histogram(
times_rel, bins=log_bin_edges, weights=intensities
)
if counts.sum() > 0:
positions_with_data.append((angle_deg, counts))
if len(positions_with_data) == 0:
ax.text(
0.5, 0.5, "No timing data", ha="center", va="center", transform=ax.transAxes
)
return
# Sample if too many curves
if len(positions_with_data) > max_curves:
indices = np.linspace(0, len(positions_with_data) - 1, max_curves, dtype=int)
positions_with_data = [positions_with_data[i] for i in indices]
colors = plt.cm.turbo(np.linspace(0, 1, len(positions_with_data)))
bin_centers = (log_bin_edges[:-1] + log_bin_edges[1:]) / 2
for idx, (angle_deg, counts) in enumerate(positions_with_data):
label = f"{angle_deg:.0f}°" if show_legend else ""
ax.plot(
bin_centers * 1e9,
counts,
color=colors[idx],
linewidth=1.5,
label=label,
alpha=0.7,
)
setup_axis_grid(
ax, "Relative Arrival Time (ns)", "Intensity", "Timing Distribution"
)
if log_scale:
ax.set_xscale("log")
ax.set_yscale("log")
if show_legend:
ax.legend(loc="upper right", fontsize=8, ncol=2)
# =============================================================================
# Angular Distribution Functions
# =============================================================================
[docs]
def plot_arrival_angle_distribution(
ax: Axes,
detector_angles_deg: np.ndarray,
mean_angles: np.ndarray,
std_angles: np.ndarray | None = None,
detection_counts: np.ndarray | None = None,
color: str = "magenta",
marker: str = "o",
linewidth: float = 2,
markersize: float = 8,
) -> None:
"""
Plot mean arrival angle vs detector position.
Parameters
----------
ax : Axes
Matplotlib axes.
detector_angles_deg : ndarray
Detector position angles.
mean_angles : ndarray
Mean arrival angles to normal.
std_angles : ndarray, optional
Standard deviation.
detection_counts : ndarray, optional
For masking invalid data.
color : str
Line color.
marker : str
Marker style.
linewidth : float
Line width.
markersize : float
Marker size.
"""
if detection_counts is not None:
valid_mask = detection_counts > 0
angles = detector_angles_deg[valid_mask]
means = mean_angles[valid_mask]
yerr = std_angles[valid_mask] if std_angles is not None else None
else:
angles = detector_angles_deg
means = mean_angles
yerr = std_angles
ax.errorbar(
angles,
means,
yerr=yerr,
fmt=f"{color[0]}{marker}-",
linewidth=linewidth,
markersize=markersize,
capsize=5,
alpha=0.7,
)
setup_axis_grid(
ax,
"Detector Angle (degrees)",
"Mean Angle to Normal (degrees)",
"Arrival Angles",
)
ax.set_xlim(0, 90)
[docs]
def plot_angular_histogram(
ax: Axes,
all_time_distributions: list[tuple],
detection_counts: np.ndarray,
bins: int = 50,
color: str = "purple",
alpha: float = 0.7,
) -> None:
"""
Plot histogram of all arrival angles.
Parameters
----------
ax : Axes
Matplotlib axes.
all_time_distributions : list
List of (times, intensities, angles) tuples.
detection_counts : ndarray
Detection counts per position.
bins : int
Number of histogram bins.
color : str
Bar color.
alpha : float
Transparency.
"""
all_angles = []
all_intensities = []
for i, time_data in enumerate(all_time_distributions):
if detection_counts[i] > 0 and len(time_data) == 3:
_, intensities, angles = time_data
all_angles.extend(angles)
all_intensities.extend(intensities)
if len(all_angles) == 0:
ax.text(
0.5,
0.5,
"No angular data",
ha="center",
va="center",
transform=ax.transAxes,
)
return
all_angles = np.array(all_angles)
all_intensities = np.array(all_intensities)
ax.hist(
all_angles,
bins=bins,
weights=all_intensities,
color=color,
alpha=alpha,
edgecolor="black",
)
setup_axis_grid(
ax, "Angle to Normal (degrees)", "Intensity", "Angular Distribution"
)
# =============================================================================
# Composite Figure Builders
# =============================================================================
[docs]
def plot_detector_scan_results(
detector_angles_deg: np.ndarray,
detection_counts: np.ndarray,
detected_intensities: np.ndarray,
reflected_rays: "RayBatch",
surface: "Surface",
total_source_intensity: float,
mean_arrival_times: np.ndarray | None = None,
std_arrival_times: np.ndarray | None = None,
mean_angles_to_normal: np.ndarray | None = None,
std_angles_to_normal: np.ndarray | None = None,
all_time_distributions: list | None = None,
detector_distance: float = 1000.0,
detector_radius: float = 5.0,
water_normal: np.ndarray = None,
figsize: tuple[float, float] = (24, 14),
save_path: str | None = None,
) -> Figure:
"""
Create comprehensive detector scan visualization with multiple subplots.
This is the full production visualization with all panels including
the wave surface and ray visualizations.
Parameters
----------
detector_angles_deg : ndarray
Detector angles in degrees (0-90).
detection_counts : ndarray
Number of rays detected at each position.
detected_intensities : ndarray
Total intensity detected at each position.
reflected_rays : RayBatch
Batch of reflected rays.
surface : Surface
The surface object (must have _surface_z method).
total_source_intensity : float
Total intensity of source rays.
mean_arrival_times : ndarray, optional
Mean arrival time at each detector position.
std_arrival_times : ndarray, optional
Standard deviation of arrival times.
mean_angles_to_normal : ndarray, optional
Mean angle to normal at each detector.
std_angles_to_normal : ndarray, optional
Standard deviation of angles.
all_time_distributions : list, optional
List of (times, intensities, angles) tuples for each detector.
detector_distance : float
Distance to detector in meters.
detector_radius : float
Detector radius in meters.
water_normal : ndarray
Normal vector for water surface.
figsize : tuple
Figure size (width, height).
save_path : str, optional
Path to save figure.
Returns
-------
Figure
Matplotlib figure with comprehensive visualization.
"""
import matplotlib.gridspec as gridspec
if water_normal is None:
water_normal = np.array([0, 0, 1])
# Create figure with 4x4 grid layout
fig = plt.figure(figsize=figsize, constrained_layout=True)
gs = gridspec.GridSpec(4, 4, figure=fig, hspace=0.3, wspace=0.4)
# Compute detection efficiency
detection_efficiency = detected_intensities / total_source_intensity * 100
# 1. Ray count per detector position
ax1 = fig.add_subplot(gs[1, :2])
ax1.plot(detector_angles_deg, detection_counts, "bo-", linewidth=2, markersize=8)
ax1.axhline(0, color="k", linestyle="-", linewidth=0.5)
ax1.set_xlabel(
"Detector Position Angle from Surface (degrees)", fontsize=11, fontweight="bold"
)
ax1.set_ylabel("Number of Detected Rays", fontsize=11, fontweight="bold")
ax1.set_title("Ray Detection Count vs Detector Position", fontweight="bold")
ax1.grid(True, alpha=0.3)
ax1.set_xlim(0, 90)
# 2. Mean arrival angle vs detector position
if mean_angles_to_normal is not None:
ax2 = fig.add_subplot(gs[1, 2:])
valid_mask = detection_counts > 0
ax2.errorbar(
detector_angles_deg[valid_mask],
mean_angles_to_normal[valid_mask],
yerr=(
std_angles_to_normal[valid_mask]
if std_angles_to_normal is not None
else None
),
fmt="mo-",
linewidth=2,
markersize=8,
capsize=5,
alpha=0.7,
)
ax2.set_xlabel(
"Detector Position Angle from Surface (degrees)",
fontsize=11,
fontweight="bold",
)
ax2.set_ylabel("Mean Angle to Normal (degrees)", fontsize=11, fontweight="bold")
ax2.set_title("Mean Arrival Angle at Detector", fontweight="bold")
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, 90)
# 3a. Mean arrival time vs detector position
if mean_arrival_times is not None:
ax3a = fig.add_subplot(gs[2, :2])
valid_mask = detection_counts > 0
ax3a.errorbar(
detector_angles_deg[valid_mask],
mean_arrival_times[valid_mask] * 1e6, # Convert to microseconds
yerr=(
std_arrival_times[valid_mask] * 1e6
if std_arrival_times is not None
else None
),
fmt="co-",
linewidth=2,
markersize=8,
capsize=5,
alpha=0.7,
)
ax3a.set_xlabel(
"Detector Position Angle from Surface (degrees)",
fontsize=11,
fontweight="bold",
)
ax3a.set_ylabel("Mean Arrival Time (μs)", fontsize=11, fontweight="bold")
ax3a.set_title("Mean Arrival Time at Detector", fontweight="bold")
ax3a.grid(True, alpha=0.3)
ax3a.set_xlim(0, 90)
# 3b. Detection efficiency vs detector position
ax3b = fig.add_subplot(gs[2, 2:])
ax3b.plot(
detector_angles_deg, detection_efficiency, "mo-", linewidth=2, markersize=8
)
ax3b.axhline(0, color="k", linestyle="-", linewidth=0.5)
ax3b.set_xlabel(
"Detector Position Angle from Surface (degrees)", fontsize=11, fontweight="bold"
)
ax3b.set_ylabel("Detection Efficiency (%)", fontsize=11, fontweight="bold")
ax3b.set_title("Detection Efficiency (Detected/Source)", fontweight="bold")
ax3b.grid(True, alpha=0.3)
ax3b.set_xlim(0, 90)
ax3b.set_ylim(
0, max(detection_efficiency) * 1.1 if max(detection_efficiency) > 0 else 1
)
# 4. Time distribution
if all_time_distributions is not None:
ax4 = fig.add_subplot(gs[3, 0:2])
# Find first arrival time across all detectors
all_times_global = []
for i, angle_deg in enumerate(detector_angles_deg):
time_data = all_time_distributions[i]
if len(time_data) == 3:
times_raw, _, _ = time_data
if len(times_raw) > 0:
all_times_global.extend(times_raw)
if len(all_times_global) > 0:
first_arrival = np.min(all_times_global)
log_bin_edges = np.logspace(-9, -3, 50) # in seconds
# Identify positions with data
positions_with_data = []
for i, angle_deg in enumerate(detector_angles_deg):
time_data = all_time_distributions[i]
if len(time_data) == 3:
times_raw, intensities_raw, _ = time_data
if len(times_raw) > 0:
times_relative = times_raw - first_arrival
counts, _ = np.histogram(
times_relative, bins=log_bin_edges, weights=intensities_raw
)
if counts.sum() > 0:
positions_with_data.append(
(i, angle_deg, times_relative, counts, intensities_raw)
)
# Plot histograms
if len(positions_with_data) > 0:
colors = plt.cm.turbo(np.linspace(0, 1, len(positions_with_data)))
for color_idx, (
i,
angle_deg,
times_relative,
counts,
intensities_raw,
) in enumerate(positions_with_data):
bin_centers = (log_bin_edges[:-1] + log_bin_edges[1:]) / 2
ax4.plot(
bin_centers * 1e9, # Convert to nanoseconds
counts,
color=colors[color_idx],
linewidth=1.5,
label=f"{angle_deg:.0f}°" if color_idx % 10 == 0 else "",
alpha=0.7,
)
ax4.set_xlabel("Relative Arrival Time (ns)", fontsize=11, fontweight="bold")
ax4.set_ylabel(
"Intensity (weighted counts)", fontsize=11, fontweight="bold"
)
ax4.set_title("Timing Distribution (intensity-weighted)", fontweight="bold")
ax4.set_xscale("log")
ax4.set_yscale("log")
ax4.grid(True, alpha=0.3, which="both")
if len(positions_with_data) > 0:
ax4.legend(loc="upper right", fontsize=8, ncol=2)
# 5. Angular distribution
if all_time_distributions is not None:
ax5 = fig.add_subplot(gs[3, 2:])
all_angles = []
intensities_for_angles = []
for i, angle_deg in enumerate(detector_angles_deg):
if detection_counts[i] > 0:
time_data = all_time_distributions[i]
if len(time_data) == 3:
_, intensities_raw, angles_raw = time_data
all_angles.extend(angles_raw)
intensities_for_angles.extend(intensities_raw)
if len(all_angles) > 0:
all_angles = np.array(all_angles)
intensities_for_angles = np.array(intensities_for_angles)
ax5.hist(
all_angles,
bins=50,
weights=intensities_for_angles,
color="purple",
alpha=0.7,
edgecolor="black",
)
ax5.set_xlabel(
"Angle to Water Normal (degrees)", fontsize=11, fontweight="bold"
)
ax5.set_ylabel(
"Intensity (weighted counts)", fontsize=11, fontweight="bold"
)
ax5.set_title(
"Angular Distribution of Detected Rays (intensity-weighted)",
fontweight="bold",
)
ax5.grid(True, alpha=0.3)
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig
# =============================================================================
# Legacy Convenience Functions (Backward Compatibility)
# =============================================================================
[docs]
def plot_statistics_evolution(
stats_history: list["RayStatistics"],
figsize: tuple[float, float] = (15, 10),
save_path: str | None = None,
) -> Figure:
"""
Create figure showing evolution of ray statistics over propagation.
This is a convenience function for visualizing how beam properties
change during propagation.
Parameters
----------
stats_history : List[RayStatistics]
List of RayStatistics objects at different propagation steps.
figsize : tuple
Figure size.
save_path : str, optional
Path to save figure.
Returns
-------
Figure
Matplotlib figure with statistics evolution.
"""
if len(stats_history) == 0:
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.text(
0.5,
0.5,
"No statistics available",
ha="center",
va="center",
transform=ax.transAxes,
)
return fig
# Extract statistics
steps = np.arange(len(stats_history))
active_rays = [s.active_rays for s in stats_history]
total_power = [s.total_power for s in stats_history]
mean_path = [s.mean_optical_path for s in stats_history]
# Optional: mean positions if available
mean_x = [
s.mean_position[0] if hasattr(s, "mean_position") else 0 for s in stats_history
]
mean_y = [
s.mean_position[1] if hasattr(s, "mean_position") else 0 for s in stats_history
]
mean_z = [
s.mean_position[2] if hasattr(s, "mean_position") else 0 for s in stats_history
]
fig, axes = plt.subplots(2, 2, figsize=figsize, constrained_layout=True)
fig.suptitle("Ray Statistics Evolution", fontsize=14, fontweight="bold")
# Active rays
ax = axes[0, 0]
ax.plot(steps, active_rays, "b-o", markersize=4)
ax.set_xlabel("Step")
ax.set_ylabel("Active Rays")
ax.set_title("Active Ray Count")
ax.grid(True, alpha=0.3)
# Total power
ax = axes[0, 1]
ax.plot(steps, total_power, "g-o", markersize=4)
ax.set_xlabel("Step")
ax.set_ylabel("Power (W)")
ax.set_title("Total Power")
ax.grid(True, alpha=0.3)
# Mean optical path
ax = axes[1, 0]
ax.plot(steps, mean_path, "r-o", markersize=4)
ax.set_xlabel("Step")
ax.set_ylabel("Path Length (m)")
ax.set_title("Mean Optical Path")
ax.grid(True, alpha=0.3)
# Mean position
ax = axes[1, 1]
ax.plot(steps, mean_z, "purple", label="Z", marker="o", markersize=4)
ax.plot(steps, mean_x, "orange", label="X", marker="s", markersize=4, alpha=0.7)
ax.plot(steps, mean_y, "cyan", label="Y", marker="^", markersize=4, alpha=0.7)
ax.set_xlabel("Step")
ax.set_ylabel("Position (m)")
ax.set_title("Mean Position")
ax.legend()
ax.grid(True, alpha=0.3)
if save_path:
from .common import save_figure
save_figure(fig, save_path)
return fig
[docs]
def plot_beam_profile(
rays: "RayBatch",
axis: str = "z",
num_slices: int = 5,
figsize: tuple[float, float] = (15, 4),
save_path: str | None = None,
) -> Figure:
"""
Create figure showing beam profile at multiple slices.
This is an alias for create_beam_profile_figure for backward compatibility.
Parameters
----------
rays : RayBatch
Ray batch.
axis : str
Propagation axis: 'x', 'y', 'z'.
num_slices : int
Number of slices to show.
figsize : tuple
Figure size.
save_path : str, optional
Path to save figure.
Returns
-------
Figure
Matplotlib figure.
"""
return create_beam_profile_figure(
rays=rays,
axis=axis,
num_slices=num_slices,
figsize=figsize,
save_path=save_path,
)
[docs]
def plot_wavelength_distribution(
rays: "RayBatch",
bins: int = 50,
figsize: tuple[float, float] = (10, 6),
save_path: str | None = None,
) -> Figure:
"""
Create figure showing wavelength distribution of rays.
This is a convenience function for quick visualization. For custom layouts,
use plot_wavelength_histogram() on a single axis.
Parameters
----------
rays : RayBatch
Ray batch.
bins : int
Number of histogram bins.
figsize : tuple
Figure size.
save_path : str, optional
Path to save figure.
Returns
-------
Figure
Matplotlib figure with wavelength histogram.
"""
fig, ax = plt.subplots(1, 1, figsize=figsize, constrained_layout=True)
plot_wavelength_histogram(ax, rays, bins=bins)
fig.suptitle("Wavelength Distribution", fontsize=14, fontweight="bold")
if save_path:
from .common import save_figure
save_figure(fig, save_path)
return fig