# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Detector Sphere Visualization
Plotting functions for visualizing ray intersections with detector spheres,
energy density maps, Pareto front analysis, and geometry schematics.
"""
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
import numpy as np
if TYPE_CHECKING:
from numpy.typing import NDArray
from ..utilities.recording_sphere import RecordedRays
[docs]
def plot_geometry_schematic(
source_position: tuple[float, float, float],
beam_direction: tuple[float, float, float] | "NDArray",
grazing_angle_deg: float,
detector_altitude: float,
intersection_points: "NDArray",
reflected_directions: "NDArray",
n_rays_to_show: int = 20,
save_path: str | None = None,
):
"""
Plot a schematic overview of the simulation geometry.
Shows side view (x-z plane) with source, rays, ocean surface, and detector sphere.
Parameters
----------
source_position : tuple
Source (x, y, z) position in meters
beam_direction : tuple or ndarray
Beam direction unit vector
grazing_angle_deg : float
Grazing angle in degrees
detector_altitude : float
Detector sphere radius in meters
intersection_points : ndarray, shape (N, 3)
XYZ coordinates of surface intersections
reflected_directions : ndarray, shape (N, 3)
Direction vectors of reflected rays
n_rays_to_show : int
Number of rays to display
save_path : str, optional
Path to save figure
Returns
-------
fig : matplotlib.figure.Figure
The created figure
"""
fig, axes = plt.subplots(1, 2, figsize=(16, 7))
# =========================================================================
# Left panel: Side view schematic (x-z plane) - scaled to show detector
# =========================================================================
ax1 = axes[0]
# Scale to km for display
src_x, src_y, src_z = np.array(source_position) / 1000
det_r = detector_altitude / 1000
# Compute where reflected rays hit the detector sphere
hit_positions_km = []
for i in range(len(reflected_directions)):
rd = reflected_directions[i]
rd_norm = rd / np.linalg.norm(rd)
hit_x = det_r * rd_norm[0]
hit_z = det_r * rd_norm[2]
hit_positions_km.append((hit_x, hit_z))
hit_positions_km = np.array(hit_positions_km)
# Get the range of detector hits to focus the view
hit_x_min, hit_x_max = hit_positions_km[:, 0].min(), hit_positions_km[:, 0].max()
hit_z_min, hit_z_max = hit_positions_km[:, 1].min(), hit_positions_km[:, 1].max()
# Draw detector sphere arc (full upper hemisphere, faded)
theta_arc = np.linspace(0, np.pi, 200)
x_arc = det_r * np.cos(theta_arc)
z_arc = det_r * np.sin(theta_arc)
ax1.plot(x_arc, z_arc, "g-", linewidth=1.5, alpha=0.3)
# Highlight the portion where rays hit
theta_hit_min = np.arctan2(hit_z_min, hit_x_max)
theta_hit_max = np.arctan2(hit_z_max, hit_x_min)
theta_highlight = np.linspace(theta_hit_min - 0.02, theta_hit_max + 0.02, 50)
x_highlight = det_r * np.cos(theta_highlight)
z_highlight = det_r * np.sin(theta_highlight)
ax1.plot(
x_highlight,
z_highlight,
"g-",
linewidth=3,
label=f"Detector sphere (r={det_r:.0f} km)",
)
# Draw ocean surface following Earth curvature
# Earth center is at (0, 0, -EARTH_RADIUS), surface at x is at z = sqrt(R^2 - x^2) - R
from ..surfaces import EARTH_RADIUS
earth_r_km = EARTH_RADIUS / 1000
x_ocean_min = min(src_x - 1, -det_r * 0.15)
x_ocean_max = det_r # Extend to detector sphere radius
x_ocean = np.linspace(x_ocean_min, x_ocean_max, 300)
# Curved Earth surface (for x in km, result in km)
z_ocean_curved = np.sqrt(earth_r_km**2 - x_ocean**2) - earth_r_km
# Add small decorative waves on top of curvature
z_ocean = z_ocean_curved + 0.0003 * det_r * np.sin(80 * x_ocean / det_r)
ax1.fill_between(
x_ocean,
z_ocean,
z_ocean.min() - det_r * 0.01,
color="lightblue",
alpha=0.5,
label="Ocean",
)
ax1.plot(x_ocean, z_ocean, "b-", linewidth=2)
# Draw source
ax1.plot(src_x, src_z, "ro", markersize=10, label="Source", zorder=10)
# Draw representative rays
# Use minimum of intersection_points and reflected_directions sizes
n_total = min(len(intersection_points), len(reflected_directions))
indices = np.linspace(0, n_total - 1, min(n_rays_to_show, n_total), dtype=int)
for idx in indices:
# Incoming ray: source to intersection (near origin)
int_x, int_y, int_z = intersection_points[idx] / 1000
ax1.plot([src_x, int_x], [src_z, int_z], "r-", alpha=0.2, linewidth=0.8)
# Reflected ray: intersection to detector sphere
rd = reflected_directions[idx]
rd_norm = rd / np.linalg.norm(rd)
det_x = det_r * rd_norm[0]
det_z = det_r * rd_norm[2]
ax1.plot([int_x, det_x], [int_z, det_z], "orange", alpha=0.2, linewidth=0.8)
# Draw chief ray more prominently
mid_idx = n_total // 2
int_x, int_y, int_z = intersection_points[mid_idx] / 1000
ax1.plot([src_x, int_x], [src_z, int_z], "r-", linewidth=2, label="Incident rays")
rd = reflected_directions[mid_idx]
rd_norm = rd / np.linalg.norm(rd)
det_x = det_r * rd_norm[0]
det_z = det_r * rd_norm[2]
ax1.plot(
[int_x, det_x], [int_z, det_z], "orange", linewidth=2, label="Reflected rays"
)
# Mark key points
ax1.plot(0, 0, "k*", markersize=12, label="Reflection point", zorder=10)
ax1.plot(det_x, det_z, "g*", markersize=10, label="Detection region", zorder=10)
# Draw grazing angle arc (scaled to be visible)
arc_r = det_r * 0.06
theta_graze = np.linspace(0, np.radians(grazing_angle_deg), 20)
x_graze = -arc_r * np.cos(theta_graze)
z_graze = -arc_r * np.sin(theta_graze)
ax1.plot(x_graze, z_graze, "m-", linewidth=2)
ax1.annotate(
f"{grazing_angle_deg}°",
(-arc_r * 0.7, -arc_r * 0.5),
fontsize=10,
color="purple",
)
# Set axis limits to show full geometry including source and detector hit region
x_left = min(src_x - 1, -det_r * 0.15)
ax1.set_xlim(x_left, det_r * 1.05)
# Adjust y-limits to show curved ocean surface
z_min = min(z_ocean.min(), -det_r * 0.03)
ax1.set_ylim(z_min - det_r * 0.01, det_r * 0.35)
ax1.set_xlabel("X (km)", fontsize=12)
ax1.set_ylabel("Z (km)", fontsize=12)
ax1.set_title(
"Geometry Schematic - Side View (X-Z plane)", fontsize=14, fontweight="bold"
)
ax1.legend(loc="upper left", fontsize=9)
ax1.grid(True, alpha=0.3)
# =========================================================================
# Right panel: Top view with annotations
# =========================================================================
ax2 = axes[1]
# Draw intersection points footprint
ix_all = intersection_points[:, 0]
iy_all = intersection_points[:, 1]
ax2.scatter(ix_all, iy_all, c="blue", s=1, alpha=0.3, label="Ray footprint")
# Draw source projection
ax2.plot(
source_position[0],
source_position[1],
"ro",
markersize=12,
label="Source (projected)",
)
# Draw reflection center
ax2.plot(0, 0, "k*", markersize=15, label="Reflection center")
# Draw beam spread arrows
ax2.annotate(
"",
xy=(np.mean(ix_all), np.max(iy_all)),
xytext=(source_position[0], source_position[1]),
arrowprops=dict(arrowstyle="->", color="red", alpha=0.5),
)
ax2.annotate(
"",
xy=(np.mean(ix_all), np.min(iy_all)),
xytext=(source_position[0], source_position[1]),
arrowprops=dict(arrowstyle="->", color="red", alpha=0.5),
)
ax2.set_xlabel("X (m)", fontsize=12)
ax2.set_ylabel("Y (m)", fontsize=12)
ax2.set_title("Geometry - Top View (X-Y plane)", fontsize=14, fontweight="bold")
ax2.legend(loc="upper right", fontsize=9)
ax2.set_aspect("equal")
ax2.grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" Saved: {save_path}")
return fig
[docs]
def plot_ocean_intersections_top_view(
intersection_points: "NDArray",
wave_surface=None,
n_bins: int = 100,
save_path: str | None = None,
):
"""
Plot top-down heatmap of where rays hit the ocean surface.
Parameters
----------
intersection_points : ndarray, shape (N, 3)
XYZ coordinates of surface intersections
wave_surface : CurvedWaveSurface, optional
The ocean surface for context (not currently used)
n_bins : int
Number of bins for the 2D histogram
save_path : str, optional
Path to save figure
Returns
-------
fig : matplotlib.figure.Figure
The created figure
"""
from matplotlib.colors import LogNorm
fig, ax = plt.subplots(figsize=(12, 10))
x = intersection_points[:, 0]
y = intersection_points[:, 1]
# Create 2D histogram
hist, x_edges, y_edges = np.histogram2d(x, y, bins=n_bins)
# Replace zeros with small value for log scale
hist_log = np.where(hist > 0, hist, np.nan)
# Plot as heatmap with log scale
im = ax.pcolormesh(
x_edges,
y_edges,
hist_log.T,
cmap="hot",
shading="auto",
norm=LogNorm(vmin=1, vmax=hist.max()),
)
# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Ray count per bin (log scale)", fontsize=11)
# Add coordinate grid for reference
ax.grid(True, alpha=0.3, color="white", linewidth=0.5)
ax.set_xlabel("X (m)", fontsize=12)
ax.set_ylabel("Y (m)", fontsize=12)
ax.set_title(
"Ocean Surface Intersections - Top View", fontsize=14, fontweight="bold"
)
ax.set_aspect("equal")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" Saved: {save_path}")
return fig
[docs]
def plot_3d_intersection_scatter(
intersection_points: "NDArray",
save_path: str | None = None,
):
"""
3D scatter plot of ocean surface intersection points with equal axes.
Parameters
----------
intersection_points : ndarray, shape (N, 3)
XYZ coordinates of surface intersections
save_path : str, optional
Path to save figure
Returns
-------
fig : matplotlib.figure.Figure
The created figure
"""
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection="3d")
# Plot points
ax.scatter(
intersection_points[:, 0],
intersection_points[:, 1],
intersection_points[:, 2],
c=intersection_points[:, 2], # Color by height
cmap="viridis",
s=1,
alpha=0.6,
)
# Set equal aspect ratio
max_range = np.ptp(intersection_points, axis=0).max() / 2.0
mid_x = (intersection_points[:, 0].max() + intersection_points[:, 0].min()) / 2.0
mid_y = (intersection_points[:, 1].max() + intersection_points[:, 1].min()) / 2.0
mid_z = (intersection_points[:, 2].max() + intersection_points[:, 2].min()) / 2.0
ax.set_xlim(mid_x - max_range, mid_x + max_range)
ax.set_ylim(mid_y - max_range, mid_y + max_range)
ax.set_zlim(mid_z - max_range, mid_z + max_range)
ax.set_xlabel("X (m)", fontsize=11)
ax.set_ylabel("Y (m)", fontsize=11)
ax.set_zlabel("Z (m)", fontsize=11)
ax.set_title("3D Ocean Surface Intersections", fontsize=14, fontweight="bold")
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" Saved: {save_path}")
return fig
[docs]
def plot_pareto_front(
pareto_result: dict,
time_threshold_ns: float = 10.0,
source_power: float = 1.0,
save_path: str | None = None,
):
"""
Plot the Pareto front of energy density vs time spread.
Parameters
----------
pareto_result : dict
Result from compute_pareto_front()
time_threshold_ns : float
Detector time resolution threshold (ns)
source_power : float
Input source power for normalization (W)
save_path : str, optional
Path to save figure
Returns
-------
fig : matplotlib.figure.Figure or None
The created figure, or None if no data
"""
bin_data = pareto_result["bin_data"]
pareto_front = pareto_result["pareto_front"]
if len(bin_data) == 0:
print(" No bins with sufficient rays for Pareto analysis")
return None
# Extract data and normalize by source power
all_energy = pareto_result["all_energy_densities"] / source_power
all_time = pareto_result["all_time_spreads"]
pareto_energy = np.array([p["energy_density"] / source_power for p in pareto_front])
pareto_time = np.array([p["time_spread_ns"] for p in pareto_front])
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# Left: Pareto front scatter plot
ax1 = axes[0]
# Plot all bins
ax1.scatter(all_time, all_energy, c="lightgray", s=20, alpha=0.6, label="All bins")
# Plot Pareto front
ax1.scatter(
pareto_time,
pareto_energy,
c="red",
s=60,
alpha=0.9,
label="Pareto front",
zorder=5,
)
# Connect Pareto front points
sort_idx = np.argsort(pareto_time)
ax1.plot(
pareto_time[sort_idx],
pareto_energy[sort_idx],
"r-",
linewidth=1.5,
alpha=0.7,
zorder=4,
)
# Mark threshold
ax1.axvline(
x=time_threshold_ns,
color="blue",
linestyle="--",
linewidth=2,
label=f"Time threshold ({time_threshold_ns} ns)",
)
# Find and annotate key points
# Best within threshold (highest energy where time < threshold)
within_threshold = [
p for p in pareto_front if p["time_spread_ns"] < time_threshold_ns
]
if within_threshold:
best_within = max(within_threshold, key=lambda x: x["energy_density"])
ax1.scatter(
[best_within["time_spread_ns"]],
[best_within["energy_density"] / source_power],
c="green",
s=200,
marker="*",
edgecolor="black",
linewidth=1.5,
label="Best within threshold",
zorder=10,
)
ax1.annotate(
f"({best_within['lon_deg']:.1f}, {best_within['lat_deg']:.1f})",
(
best_within["time_spread_ns"],
best_within["energy_density"] / source_power,
),
xytext=(10, 10),
textcoords="offset points",
fontsize=9,
arrowprops=dict(arrowstyle="->", color="green"),
)
# Highest energy on Pareto front
if len(pareto_front) > 0:
best_energy = max(pareto_front, key=lambda x: x["energy_density"])
if not within_threshold or best_energy != best_within:
ax1.scatter(
[best_energy["time_spread_ns"]],
[best_energy["energy_density"] / source_power],
c="orange",
s=150,
marker="D",
edgecolor="black",
linewidth=1.5,
label="Highest energy",
zorder=9,
)
ax1.set_xlabel("Time Spread (90th-10th percentile) [ns]", fontsize=12)
ax1.set_ylabel("Normalized Energy Density (sr^-1)", fontsize=12)
ax1.set_title(
"Pareto Front: Energy Density vs Time Spread", fontsize=14, fontweight="bold"
)
ax1.legend(loc="upper right", fontsize=9)
ax1.grid(True, alpha=0.3)
ax1.set_xlim(left=0)
ax1.set_ylim(bottom=0)
# Right: Spatial map of Pareto-optimal bins
ax2 = axes[1]
lon_all = np.array([b["lon_deg"] for b in bin_data])
lat_all = np.array([b["lat_deg"] for b in bin_data])
energy_all = np.array([b["energy_density"] / source_power for b in bin_data])
time_all = np.array([b["time_spread_ns"] for b in bin_data])
# Color by normalized energy density, size by inverse time spread
sizes = 20 + 80 * (1 - time_all / max(time_all)) # Smaller time = larger marker
sc2 = ax2.scatter(
lon_all,
lat_all,
c=energy_all,
s=sizes,
alpha=0.7,
cmap="hot",
edgecolors="none",
)
# Highlight Pareto front
lon_pareto = np.array([p["lon_deg"] for p in pareto_front])
lat_pareto = np.array([p["lat_deg"] for p in pareto_front])
ax2.scatter(
lon_pareto,
lat_pareto,
facecolors="none",
edgecolors="red",
s=100,
linewidths=2,
label="Pareto front",
)
if within_threshold:
ax2.scatter(
[best_within["lon_deg"]],
[best_within["lat_deg"]],
c="green",
s=200,
marker="*",
edgecolor="black",
linewidth=1.5,
label="Best within threshold",
zorder=10,
)
ax2.set_xlabel("Longitude (deg)", fontsize=12)
ax2.set_ylabel("Latitude (deg)", fontsize=12)
ax2.set_title(
"Spatial Distribution (size = time resolution)", fontsize=14, fontweight="bold"
)
ax2.legend(loc="upper right", fontsize=9)
ax2.set_aspect("equal")
ax2.grid(True, alpha=0.3)
cbar2 = plt.colorbar(sc2, ax=ax2)
cbar2.set_label("Normalized Energy Density (sr^-1)", fontsize=11)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" Saved: {save_path}")
return fig
[docs]
def plot_energy_density_map(
peak_result: dict,
recorded_rays: "RecordedRays",
source_power: float = 1.0,
detector_center=None,
save_path: str | None = None,
):
"""
Plot energy density map on detector sphere with peak location marked.
Parameters
----------
peak_result : dict
Result from find_peak_energy_density()
recorded_rays : RecordedRays
Recorded rays for overlay
source_power : float
Input source power for normalization (W)
detector_center : array-like, optional
Center of detector sphere (not used currently)
save_path : str, optional
Path to save figure
Returns
-------
fig : matplotlib.figure.Figure
The created figure
"""
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# Left: Energy density heatmap (normalized by source power)
ax1 = axes[0]
hist = peak_result["histogram"] / source_power # Normalize
lon_edges = peak_result["lon_edges"]
lat_edges = peak_result["lat_edges"]
# Plot heatmap
im = ax1.pcolormesh(
np.degrees(lon_edges),
np.degrees(lat_edges),
hist.T,
cmap="hot",
shading="auto",
)
# Mark peak location
ax1.plot(
peak_result["peak_lon_deg"],
peak_result["peak_lat_deg"],
"c*",
markersize=15,
markeredgecolor="white",
markeredgewidth=1.5,
label=f"Peak: ({peak_result['peak_lon_deg']:.2f}, {peak_result['peak_lat_deg']:.2f})",
)
ax1.set_xlabel("Longitude (deg)", fontsize=12)
ax1.set_ylabel("Latitude (deg)", fontsize=12)
ax1.set_title("Normalized Energy Density", fontsize=14, fontweight="bold")
ax1.legend(loc="upper right")
cbar1 = plt.colorbar(im, ax=ax1)
cbar1.set_label("Energy Density / Input Power (sr^-1)", fontsize=11)
# Right: Scatter plot with peak marked
ax2 = axes[1]
positions = recorded_rays.positions
x, y, z = positions[:, 0], positions[:, 1], positions[:, 2]
r = np.sqrt(x**2 + y**2 + z**2)
lat = np.degrees(np.arcsin(z / r))
lon = np.degrees(np.arctan2(y, x))
sc = ax2.scatter(
lon,
lat,
c=recorded_rays.intensities,
s=3,
alpha=0.6,
cmap="hot",
vmin=0,
vmax=np.percentile(recorded_rays.intensities, 95),
)
ax2.plot(
peak_result["peak_lon_deg"],
peak_result["peak_lat_deg"],
"c*",
markersize=15,
markeredgecolor="white",
markeredgewidth=1.5,
)
ax2.set_xlabel("Longitude (deg)", fontsize=12)
ax2.set_ylabel("Latitude (deg)", fontsize=12)
ax2.set_title("Ray Intersections with Peak", fontsize=14, fontweight="bold")
ax2.set_aspect("equal")
cbar2 = plt.colorbar(sc, ax=ax2)
cbar2.set_label("Intensity", fontsize=11)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" Saved: {save_path}")
return fig
[docs]
def plot_mollweide_detector_projection(
recorded_rays: "RecordedRays",
save_path: str | None = None,
):
"""
Mollweide projection showing where rays intersect the detector sphere.
Parameters
----------
recorded_rays : RecordedRays
Recorded rays on detection sphere
save_path : str, optional
Path to save figure
Returns
-------
fig : matplotlib.figure.Figure
The created figure
"""
# Convert Cartesian to spherical coordinates (lon, lat)
positions = recorded_rays.positions
# Compute longitude and latitude
x, y, z = positions[:, 0], positions[:, 1], positions[:, 2]
r = np.sqrt(x**2 + y**2 + z**2)
# Latitude: angle from equator (-90 to +90 degrees)
lat = np.arcsin(z / r)
# Longitude: angle in xy-plane (-180 to +180 degrees)
lon = np.arctan2(y, x)
# Create Mollweide projection
fig = plt.figure(figsize=(14, 8))
ax = fig.add_subplot(111, projection="mollweide")
# Plot as scatter with intensity coloring
sc = ax.scatter(
lon,
lat,
c=recorded_rays.intensities,
s=2,
alpha=0.5,
cmap="hot",
vmin=0,
vmax=np.percentile(
recorded_rays.intensities, 95
), # Saturate at 95th percentile
)
ax.set_xlabel("Longitude (rad)", fontsize=12)
ax.set_ylabel("Latitude (rad)", fontsize=12)
ax.set_title(
"Detector Sphere Intersections - Mollweide Projection",
fontsize=14,
fontweight="bold",
)
ax.grid(True, alpha=0.3)
# Add colorbar
cbar = plt.colorbar(sc, ax=ax, pad=0.05, fraction=0.046)
cbar.set_label("Intensity", fontsize=11)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" Saved: {save_path}")
return fig
[docs]
def plot_arrival_time_distributions(
recorded_rays: "RecordedRays",
pareto_result: dict,
n_top: int = 10,
bins: int = 50,
time_threshold_ns: float = 10.0,
save_path: str | None = None,
):
"""
Plot arrival time histograms for the top N intensity bins.
Shows the actual distribution of arrival times (not just percentiles)
for each of the highest energy density detector bins.
Parameters
----------
recorded_rays : RecordedRays
Recorded rays on detection sphere
pareto_result : dict
Result from compute_pareto_front()
n_top : int
Number of top intensity bins to plot
bins : int
Number of histogram bins
time_threshold_ns : float
Time threshold to mark on plots (ns)
save_path : str, optional
Path to save figure
Returns
-------
fig : matplotlib.figure.Figure or None
The created figure, or None if no data
"""
bin_data = pareto_result["bin_data"]
if len(bin_data) == 0:
print(" No bin data for arrival time distributions")
return None
if len(bin_data) < n_top:
n_top = len(bin_data)
# Sort by energy density and take top N
sorted_bins = sorted(bin_data, key=lambda x: x["energy_density"], reverse=True)
top_bins = sorted_bins[:n_top]
# Get ray data
positions = recorded_rays.positions
intensities = recorded_rays.intensities
times = recorded_rays.times
# Convert to spherical coordinates
x, y, z = positions[:, 0], positions[:, 1], positions[:, 2]
r = np.sqrt(x**2 + y**2 + z**2)
lat = np.arcsin(z / r)
lon = np.arctan2(y, x)
# Infer bin edges from pareto_result
# Get the bin spacing from first bin's location
all_lons = np.array([b["lon"] for b in bin_data])
all_lats = np.array([b["lat"] for b in bin_data])
unique_lons = np.unique(all_lons)
unique_lats = np.unique(all_lats)
dlon = np.diff(unique_lons).min() if len(unique_lons) > 1 else 0.1
dlat = np.diff(unique_lats).min() if len(unique_lats) > 1 else 0.1
# Create figure with subplots for each bin
n_cols = min(5, n_top)
n_rows = (n_top + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3.5 * n_rows))
if n_top == 1:
axes = np.array([[axes]])
elif n_rows == 1:
axes = axes.reshape(1, -1)
# Find global time range for consistent x-axis
first_arrival = np.min(times)
for i, b in enumerate(top_bins):
row, col = divmod(i, n_cols)
ax = axes[row, col]
# Find rays in this bin
bin_lon = b["lon"]
bin_lat = b["lat"]
mask = (
(lon >= bin_lon - dlon / 2)
& (lon < bin_lon + dlon / 2)
& (lat >= bin_lat - dlat / 2)
& (lat < bin_lat + dlat / 2)
)
n_rays = np.sum(mask)
if n_rays == 0:
ax.text(
0.5, 0.5, "No rays", ha="center", va="center", transform=ax.transAxes
)
ax.set_title(f"#{i+1}: ({b['lon_deg']:.1f}, {b['lat_deg']:.1f})°")
continue
bin_times = times[mask]
bin_intensities = intensities[mask]
# Convert to relative arrival time in ns
times_relative_ns = (bin_times - first_arrival) * 1e9
# Create histogram (intensity-weighted)
counts, edges = np.histogram(
times_relative_ns, bins=bins, weights=bin_intensities
)
centers = (edges[:-1] + edges[1:]) / 2
# Plot as step histogram
ax.fill_between(centers, counts, alpha=0.6, color="steelblue", step="mid")
ax.step(centers, counts, where="mid", color="steelblue", linewidth=1.5)
# Mark the 10th and 90th percentiles
from ..utilities.detector_analysis import weighted_percentile
t10 = weighted_percentile(times_relative_ns, bin_intensities, 10)
t90 = weighted_percentile(times_relative_ns, bin_intensities, 90)
ax.axvline(
t10,
color="red",
linestyle="--",
linewidth=1.5,
alpha=0.8,
label=f"10th: {t10:.1f} ns",
)
ax.axvline(
t90,
color="red",
linestyle="--",
linewidth=1.5,
alpha=0.8,
label=f"90th: {t90:.1f} ns",
)
# Mark time spread
spread = t90 - t10
ax.axvspan(t10, t90, alpha=0.15, color="red")
ax.set_xlabel("Relative Arrival Time (ns)", fontsize=9)
ax.set_ylabel("Intensity", fontsize=9)
ax.set_title(
f"#{i+1}: ({b['lon_deg']:.1f}, {b['lat_deg']:.1f})°\n"
f"n={n_rays}, spread={spread:.1f} ns",
fontsize=10,
)
ax.grid(True, alpha=0.3)
ax.set_xlim(left=0)
# Add legend only on first subplot
if i == 0:
ax.legend(fontsize=7, loc="upper right")
# Hide unused subplots
for i in range(n_top, n_rows * n_cols):
row, col = divmod(i, n_cols)
axes[row, col].axis("off")
fig.suptitle(
f"Arrival Time Distributions for Top {n_top} Energy Density Bins",
fontsize=14,
fontweight="bold",
)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" Saved: {save_path}")
return fig
[docs]
def plot_time_spread_comparison(
pareto_result: dict,
source_position: tuple[float, float, float],
beam_direction: tuple[float, float, float],
divergence_angle_rad: float,
detector_altitude: float,
surface=None,
n_top: int = 10,
source_power: float = 1.0,
time_threshold_ns: float = 10.0,
save_path: str | None = None,
):
"""
Compare ray-traced time spread with single-bounce geometric estimate.
The geometric estimate assumes a single reflection (source → surface → detector).
Ray-traced values may exceed this if rays take multi-bounce paths due to
atmospheric refraction causing rays to return to the surface multiple times.
Parameters
----------
pareto_result : dict
Result from compute_pareto_front()
source_position : tuple
Source (x, y, z) position in meters
beam_direction : tuple
Beam direction unit vector
divergence_angle_rad : float
Beam half-angle divergence in radians
detector_altitude : float
Detector sphere radius in meters
surface : Surface, optional
Surface object for geometric estimate. If None, uses flat surface at z=0.
n_top : int
Number of top intensity bins to analyze
source_power : float
Input source power for normalization (W)
time_threshold_ns : float
Time threshold for highlighting (ns)
save_path : str, optional
Path to save figure
Returns
-------
fig : matplotlib.figure.Figure
The created figure
"""
from ..utilities import estimate_time_spread
bin_data = pareto_result["bin_data"]
if len(bin_data) < n_top:
n_top = len(bin_data)
# Sort by energy density and take top N
sorted_bins = sorted(bin_data, key=lambda x: x["energy_density"], reverse=True)
top_bins = sorted_bins[:n_top]
# Compute geometric time spread estimate for each detector location
geometric_bounds = []
raytraced_spreads = []
labels = []
for i, b in enumerate(top_bins):
# Convert bin angular position to Cartesian detector position
lon_rad = b["lon"]
lat_rad = b["lat"]
det_x = detector_altitude * np.cos(lat_rad) * np.cos(lon_rad)
det_y = detector_altitude * np.cos(lat_rad) * np.sin(lon_rad)
det_z = detector_altitude * np.sin(lat_rad)
detector_position = (det_x, det_y, det_z)
# Compute geometric estimate (single-bounce assumption)
result = estimate_time_spread(
source_position=source_position,
beam_direction=beam_direction,
divergence_angle=divergence_angle_rad,
detector_position=detector_position,
surface=surface,
)
geometric_bounds.append(result.time_spread_ns)
raytraced_spreads.append(b["time_spread_ns"])
labels.append(f"({b['lon_deg']:.1f}, {b['lat_deg']:.1f})")
geometric_bounds = np.array(geometric_bounds)
raytraced_spreads = np.array(raytraced_spreads)
# Create plot with two subplots
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# ==========================================================================
# Left: Bar chart comparing raytraced vs geometric
# ==========================================================================
ax = axes[0]
x = np.arange(n_top)
width = 0.35
bars1 = ax.bar(
x - width / 2,
raytraced_spreads,
width,
label="Ray-traced (includes multi-bounce)",
color="steelblue",
alpha=0.8,
)
bars2 = ax.bar(
x + width / 2,
geometric_bounds,
width,
label="Geometric (single-bounce only)",
color="coral",
alpha=0.8,
)
# Add threshold line
ax.axhline(
y=time_threshold_ns,
color="red",
linestyle="--",
linewidth=2,
label=f"Threshold ({time_threshold_ns} ns)",
)
ax.set_xlabel("Detector Location (lon, lat) deg", fontsize=12)
ax.set_ylabel("Time Spread (ns)", fontsize=12)
ax.set_title(
"Time Spread: Ray-Traced vs Single-Bounce Geometric",
fontsize=14,
fontweight="bold",
)
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=9)
ax.legend(fontsize=9, loc="upper right")
ax.grid(True, alpha=0.3, axis="y")
# Add value labels on bars
for bar, val in zip(bars1, raytraced_spreads, strict=False):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 2,
f"{val:.0f}",
ha="center",
va="bottom",
fontsize=8,
)
for bar, val in zip(bars2, geometric_bounds, strict=False):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 2,
f"{val:.1f}",
ha="center",
va="bottom",
fontsize=8,
)
# ==========================================================================
# Right: Ratio of raytraced to geometric (shows multi-bounce contribution)
# ==========================================================================
ax2 = axes[1]
# Compute ratio (how much larger is raytraced vs geometric)
ratio = raytraced_spreads / np.maximum(geometric_bounds, 0.1) # Avoid div by zero
colors = plt.cm.RdYlGn_r(np.clip(ratio / 50, 0, 1)) # Red = high ratio
bars = ax2.bar(
x,
ratio,
width=0.7,
color=colors,
alpha=0.8,
edgecolor="black",
linewidth=0.5,
)
ax2.axhline(
y=1,
color="green",
linestyle="-",
linewidth=2,
label="Ratio = 1 (single-bounce)",
)
ax2.axhline(y=10, color="orange", linestyle="--", linewidth=1.5, label="Ratio = 10")
ax2.set_xlabel("Detector Location (lon, lat) deg", fontsize=12)
ax2.set_ylabel("Raytraced / Geometric Ratio", fontsize=12)
ax2.set_title(
"Multi-Bounce Contribution\n(ratio > 1 indicates multi-bounce paths)",
fontsize=14,
fontweight="bold",
)
ax2.set_xticks(x)
ax2.set_xticklabels(labels, rotation=45, ha="right", fontsize=9)
ax2.legend(fontsize=9, loc="upper right")
ax2.grid(True, alpha=0.3, axis="y")
# Add value labels
for bar, val in zip(bars, ratio, strict=False):
ax2.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.5,
f"{val:.0f}x",
ha="center",
va="bottom",
fontsize=9,
fontweight="bold",
)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f" Saved: {save_path}")
return fig