Source code for lsurf.visualization.detector_sphere_plots

# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
#      * Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#
#      * Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#      * Neither the name of the copyright holder nor the names of its
#      contributors may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""
Detector Sphere Visualization

Plotting functions for visualizing ray intersections with detector spheres,
energy density maps, Pareto front analysis, and geometry schematics.
"""

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np

if TYPE_CHECKING:
    from numpy.typing import NDArray

    from ..utilities.recording_sphere import RecordedRays


[docs] def plot_geometry_schematic( source_position: tuple[float, float, float], beam_direction: tuple[float, float, float] | "NDArray", grazing_angle_deg: float, detector_altitude: float, intersection_points: "NDArray", reflected_directions: "NDArray", n_rays_to_show: int = 20, save_path: str | None = None, ): """ Plot a schematic overview of the simulation geometry. Shows side view (x-z plane) with source, rays, ocean surface, and detector sphere. Parameters ---------- source_position : tuple Source (x, y, z) position in meters beam_direction : tuple or ndarray Beam direction unit vector grazing_angle_deg : float Grazing angle in degrees detector_altitude : float Detector sphere radius in meters intersection_points : ndarray, shape (N, 3) XYZ coordinates of surface intersections reflected_directions : ndarray, shape (N, 3) Direction vectors of reflected rays n_rays_to_show : int Number of rays to display save_path : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure The created figure """ fig, axes = plt.subplots(1, 2, figsize=(16, 7)) # ========================================================================= # Left panel: Side view schematic (x-z plane) - scaled to show detector # ========================================================================= ax1 = axes[0] # Scale to km for display src_x, src_y, src_z = np.array(source_position) / 1000 det_r = detector_altitude / 1000 # Compute where reflected rays hit the detector sphere hit_positions_km = [] for i in range(len(reflected_directions)): rd = reflected_directions[i] rd_norm = rd / np.linalg.norm(rd) hit_x = det_r * rd_norm[0] hit_z = det_r * rd_norm[2] hit_positions_km.append((hit_x, hit_z)) hit_positions_km = np.array(hit_positions_km) # Get the range of detector hits to focus the view hit_x_min, hit_x_max = hit_positions_km[:, 0].min(), hit_positions_km[:, 0].max() hit_z_min, hit_z_max = hit_positions_km[:, 1].min(), hit_positions_km[:, 1].max() # Draw detector sphere arc (full upper hemisphere, faded) theta_arc = np.linspace(0, np.pi, 200) x_arc = det_r * np.cos(theta_arc) z_arc = det_r * np.sin(theta_arc) ax1.plot(x_arc, z_arc, "g-", linewidth=1.5, alpha=0.3) # Highlight the portion where rays hit theta_hit_min = np.arctan2(hit_z_min, hit_x_max) theta_hit_max = np.arctan2(hit_z_max, hit_x_min) theta_highlight = np.linspace(theta_hit_min - 0.02, theta_hit_max + 0.02, 50) x_highlight = det_r * np.cos(theta_highlight) z_highlight = det_r * np.sin(theta_highlight) ax1.plot( x_highlight, z_highlight, "g-", linewidth=3, label=f"Detector sphere (r={det_r:.0f} km)", ) # Draw ocean surface following Earth curvature # Earth center is at (0, 0, -EARTH_RADIUS), surface at x is at z = sqrt(R^2 - x^2) - R from ..surfaces import EARTH_RADIUS earth_r_km = EARTH_RADIUS / 1000 x_ocean_min = min(src_x - 1, -det_r * 0.15) x_ocean_max = det_r # Extend to detector sphere radius x_ocean = np.linspace(x_ocean_min, x_ocean_max, 300) # Curved Earth surface (for x in km, result in km) z_ocean_curved = np.sqrt(earth_r_km**2 - x_ocean**2) - earth_r_km # Add small decorative waves on top of curvature z_ocean = z_ocean_curved + 0.0003 * det_r * np.sin(80 * x_ocean / det_r) ax1.fill_between( x_ocean, z_ocean, z_ocean.min() - det_r * 0.01, color="lightblue", alpha=0.5, label="Ocean", ) ax1.plot(x_ocean, z_ocean, "b-", linewidth=2) # Draw source ax1.plot(src_x, src_z, "ro", markersize=10, label="Source", zorder=10) # Draw representative rays # Use minimum of intersection_points and reflected_directions sizes n_total = min(len(intersection_points), len(reflected_directions)) indices = np.linspace(0, n_total - 1, min(n_rays_to_show, n_total), dtype=int) for idx in indices: # Incoming ray: source to intersection (near origin) int_x, int_y, int_z = intersection_points[idx] / 1000 ax1.plot([src_x, int_x], [src_z, int_z], "r-", alpha=0.2, linewidth=0.8) # Reflected ray: intersection to detector sphere rd = reflected_directions[idx] rd_norm = rd / np.linalg.norm(rd) det_x = det_r * rd_norm[0] det_z = det_r * rd_norm[2] ax1.plot([int_x, det_x], [int_z, det_z], "orange", alpha=0.2, linewidth=0.8) # Draw chief ray more prominently mid_idx = n_total // 2 int_x, int_y, int_z = intersection_points[mid_idx] / 1000 ax1.plot([src_x, int_x], [src_z, int_z], "r-", linewidth=2, label="Incident rays") rd = reflected_directions[mid_idx] rd_norm = rd / np.linalg.norm(rd) det_x = det_r * rd_norm[0] det_z = det_r * rd_norm[2] ax1.plot( [int_x, det_x], [int_z, det_z], "orange", linewidth=2, label="Reflected rays" ) # Mark key points ax1.plot(0, 0, "k*", markersize=12, label="Reflection point", zorder=10) ax1.plot(det_x, det_z, "g*", markersize=10, label="Detection region", zorder=10) # Draw grazing angle arc (scaled to be visible) arc_r = det_r * 0.06 theta_graze = np.linspace(0, np.radians(grazing_angle_deg), 20) x_graze = -arc_r * np.cos(theta_graze) z_graze = -arc_r * np.sin(theta_graze) ax1.plot(x_graze, z_graze, "m-", linewidth=2) ax1.annotate( f"{grazing_angle_deg}°", (-arc_r * 0.7, -arc_r * 0.5), fontsize=10, color="purple", ) # Set axis limits to show full geometry including source and detector hit region x_left = min(src_x - 1, -det_r * 0.15) ax1.set_xlim(x_left, det_r * 1.05) # Adjust y-limits to show curved ocean surface z_min = min(z_ocean.min(), -det_r * 0.03) ax1.set_ylim(z_min - det_r * 0.01, det_r * 0.35) ax1.set_xlabel("X (km)", fontsize=12) ax1.set_ylabel("Z (km)", fontsize=12) ax1.set_title( "Geometry Schematic - Side View (X-Z plane)", fontsize=14, fontweight="bold" ) ax1.legend(loc="upper left", fontsize=9) ax1.grid(True, alpha=0.3) # ========================================================================= # Right panel: Top view with annotations # ========================================================================= ax2 = axes[1] # Draw intersection points footprint ix_all = intersection_points[:, 0] iy_all = intersection_points[:, 1] ax2.scatter(ix_all, iy_all, c="blue", s=1, alpha=0.3, label="Ray footprint") # Draw source projection ax2.plot( source_position[0], source_position[1], "ro", markersize=12, label="Source (projected)", ) # Draw reflection center ax2.plot(0, 0, "k*", markersize=15, label="Reflection center") # Draw beam spread arrows ax2.annotate( "", xy=(np.mean(ix_all), np.max(iy_all)), xytext=(source_position[0], source_position[1]), arrowprops=dict(arrowstyle="->", color="red", alpha=0.5), ) ax2.annotate( "", xy=(np.mean(ix_all), np.min(iy_all)), xytext=(source_position[0], source_position[1]), arrowprops=dict(arrowstyle="->", color="red", alpha=0.5), ) ax2.set_xlabel("X (m)", fontsize=12) ax2.set_ylabel("Y (m)", fontsize=12) ax2.set_title("Geometry - Top View (X-Y plane)", fontsize=14, fontweight="bold") ax2.legend(loc="upper right", fontsize=9) ax2.set_aspect("equal") ax2.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f" Saved: {save_path}") return fig
[docs] def plot_ocean_intersections_top_view( intersection_points: "NDArray", wave_surface=None, n_bins: int = 100, save_path: str | None = None, ): """ Plot top-down heatmap of where rays hit the ocean surface. Parameters ---------- intersection_points : ndarray, shape (N, 3) XYZ coordinates of surface intersections wave_surface : CurvedWaveSurface, optional The ocean surface for context (not currently used) n_bins : int Number of bins for the 2D histogram save_path : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure The created figure """ from matplotlib.colors import LogNorm fig, ax = plt.subplots(figsize=(12, 10)) x = intersection_points[:, 0] y = intersection_points[:, 1] # Create 2D histogram hist, x_edges, y_edges = np.histogram2d(x, y, bins=n_bins) # Replace zeros with small value for log scale hist_log = np.where(hist > 0, hist, np.nan) # Plot as heatmap with log scale im = ax.pcolormesh( x_edges, y_edges, hist_log.T, cmap="hot", shading="auto", norm=LogNorm(vmin=1, vmax=hist.max()), ) # Add colorbar cbar = plt.colorbar(im, ax=ax) cbar.set_label("Ray count per bin (log scale)", fontsize=11) # Add coordinate grid for reference ax.grid(True, alpha=0.3, color="white", linewidth=0.5) ax.set_xlabel("X (m)", fontsize=12) ax.set_ylabel("Y (m)", fontsize=12) ax.set_title( "Ocean Surface Intersections - Top View", fontsize=14, fontweight="bold" ) ax.set_aspect("equal") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f" Saved: {save_path}") return fig
[docs] def plot_3d_intersection_scatter( intersection_points: "NDArray", save_path: str | None = None, ): """ 3D scatter plot of ocean surface intersection points with equal axes. Parameters ---------- intersection_points : ndarray, shape (N, 3) XYZ coordinates of surface intersections save_path : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure The created figure """ fig = plt.figure(figsize=(12, 10)) ax = fig.add_subplot(111, projection="3d") # Plot points ax.scatter( intersection_points[:, 0], intersection_points[:, 1], intersection_points[:, 2], c=intersection_points[:, 2], # Color by height cmap="viridis", s=1, alpha=0.6, ) # Set equal aspect ratio max_range = np.ptp(intersection_points, axis=0).max() / 2.0 mid_x = (intersection_points[:, 0].max() + intersection_points[:, 0].min()) / 2.0 mid_y = (intersection_points[:, 1].max() + intersection_points[:, 1].min()) / 2.0 mid_z = (intersection_points[:, 2].max() + intersection_points[:, 2].min()) / 2.0 ax.set_xlim(mid_x - max_range, mid_x + max_range) ax.set_ylim(mid_y - max_range, mid_y + max_range) ax.set_zlim(mid_z - max_range, mid_z + max_range) ax.set_xlabel("X (m)", fontsize=11) ax.set_ylabel("Y (m)", fontsize=11) ax.set_zlabel("Z (m)", fontsize=11) ax.set_title("3D Ocean Surface Intersections", fontsize=14, fontweight="bold") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f" Saved: {save_path}") return fig
[docs] def plot_pareto_front( pareto_result: dict, time_threshold_ns: float = 10.0, source_power: float = 1.0, save_path: str | None = None, ): """ Plot the Pareto front of energy density vs time spread. Parameters ---------- pareto_result : dict Result from compute_pareto_front() time_threshold_ns : float Detector time resolution threshold (ns) source_power : float Input source power for normalization (W) save_path : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure or None The created figure, or None if no data """ bin_data = pareto_result["bin_data"] pareto_front = pareto_result["pareto_front"] if len(bin_data) == 0: print(" No bins with sufficient rays for Pareto analysis") return None # Extract data and normalize by source power all_energy = pareto_result["all_energy_densities"] / source_power all_time = pareto_result["all_time_spreads"] pareto_energy = np.array([p["energy_density"] / source_power for p in pareto_front]) pareto_time = np.array([p["time_spread_ns"] for p in pareto_front]) fig, axes = plt.subplots(1, 2, figsize=(16, 6)) # Left: Pareto front scatter plot ax1 = axes[0] # Plot all bins ax1.scatter(all_time, all_energy, c="lightgray", s=20, alpha=0.6, label="All bins") # Plot Pareto front ax1.scatter( pareto_time, pareto_energy, c="red", s=60, alpha=0.9, label="Pareto front", zorder=5, ) # Connect Pareto front points sort_idx = np.argsort(pareto_time) ax1.plot( pareto_time[sort_idx], pareto_energy[sort_idx], "r-", linewidth=1.5, alpha=0.7, zorder=4, ) # Mark threshold ax1.axvline( x=time_threshold_ns, color="blue", linestyle="--", linewidth=2, label=f"Time threshold ({time_threshold_ns} ns)", ) # Find and annotate key points # Best within threshold (highest energy where time < threshold) within_threshold = [ p for p in pareto_front if p["time_spread_ns"] < time_threshold_ns ] if within_threshold: best_within = max(within_threshold, key=lambda x: x["energy_density"]) ax1.scatter( [best_within["time_spread_ns"]], [best_within["energy_density"] / source_power], c="green", s=200, marker="*", edgecolor="black", linewidth=1.5, label="Best within threshold", zorder=10, ) ax1.annotate( f"({best_within['lon_deg']:.1f}, {best_within['lat_deg']:.1f})", ( best_within["time_spread_ns"], best_within["energy_density"] / source_power, ), xytext=(10, 10), textcoords="offset points", fontsize=9, arrowprops=dict(arrowstyle="->", color="green"), ) # Highest energy on Pareto front if len(pareto_front) > 0: best_energy = max(pareto_front, key=lambda x: x["energy_density"]) if not within_threshold or best_energy != best_within: ax1.scatter( [best_energy["time_spread_ns"]], [best_energy["energy_density"] / source_power], c="orange", s=150, marker="D", edgecolor="black", linewidth=1.5, label="Highest energy", zorder=9, ) ax1.set_xlabel("Time Spread (90th-10th percentile) [ns]", fontsize=12) ax1.set_ylabel("Normalized Energy Density (sr^-1)", fontsize=12) ax1.set_title( "Pareto Front: Energy Density vs Time Spread", fontsize=14, fontweight="bold" ) ax1.legend(loc="upper right", fontsize=9) ax1.grid(True, alpha=0.3) ax1.set_xlim(left=0) ax1.set_ylim(bottom=0) # Right: Spatial map of Pareto-optimal bins ax2 = axes[1] lon_all = np.array([b["lon_deg"] for b in bin_data]) lat_all = np.array([b["lat_deg"] for b in bin_data]) energy_all = np.array([b["energy_density"] / source_power for b in bin_data]) time_all = np.array([b["time_spread_ns"] for b in bin_data]) # Color by normalized energy density, size by inverse time spread sizes = 20 + 80 * (1 - time_all / max(time_all)) # Smaller time = larger marker sc2 = ax2.scatter( lon_all, lat_all, c=energy_all, s=sizes, alpha=0.7, cmap="hot", edgecolors="none", ) # Highlight Pareto front lon_pareto = np.array([p["lon_deg"] for p in pareto_front]) lat_pareto = np.array([p["lat_deg"] for p in pareto_front]) ax2.scatter( lon_pareto, lat_pareto, facecolors="none", edgecolors="red", s=100, linewidths=2, label="Pareto front", ) if within_threshold: ax2.scatter( [best_within["lon_deg"]], [best_within["lat_deg"]], c="green", s=200, marker="*", edgecolor="black", linewidth=1.5, label="Best within threshold", zorder=10, ) ax2.set_xlabel("Longitude (deg)", fontsize=12) ax2.set_ylabel("Latitude (deg)", fontsize=12) ax2.set_title( "Spatial Distribution (size = time resolution)", fontsize=14, fontweight="bold" ) ax2.legend(loc="upper right", fontsize=9) ax2.set_aspect("equal") ax2.grid(True, alpha=0.3) cbar2 = plt.colorbar(sc2, ax=ax2) cbar2.set_label("Normalized Energy Density (sr^-1)", fontsize=11) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f" Saved: {save_path}") return fig
[docs] def plot_energy_density_map( peak_result: dict, recorded_rays: "RecordedRays", source_power: float = 1.0, detector_center=None, save_path: str | None = None, ): """ Plot energy density map on detector sphere with peak location marked. Parameters ---------- peak_result : dict Result from find_peak_energy_density() recorded_rays : RecordedRays Recorded rays for overlay source_power : float Input source power for normalization (W) detector_center : array-like, optional Center of detector sphere (not used currently) save_path : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure The created figure """ fig, axes = plt.subplots(1, 2, figsize=(16, 6)) # Left: Energy density heatmap (normalized by source power) ax1 = axes[0] hist = peak_result["histogram"] / source_power # Normalize lon_edges = peak_result["lon_edges"] lat_edges = peak_result["lat_edges"] # Plot heatmap im = ax1.pcolormesh( np.degrees(lon_edges), np.degrees(lat_edges), hist.T, cmap="hot", shading="auto", ) # Mark peak location ax1.plot( peak_result["peak_lon_deg"], peak_result["peak_lat_deg"], "c*", markersize=15, markeredgecolor="white", markeredgewidth=1.5, label=f"Peak: ({peak_result['peak_lon_deg']:.2f}, {peak_result['peak_lat_deg']:.2f})", ) ax1.set_xlabel("Longitude (deg)", fontsize=12) ax1.set_ylabel("Latitude (deg)", fontsize=12) ax1.set_title("Normalized Energy Density", fontsize=14, fontweight="bold") ax1.legend(loc="upper right") cbar1 = plt.colorbar(im, ax=ax1) cbar1.set_label("Energy Density / Input Power (sr^-1)", fontsize=11) # Right: Scatter plot with peak marked ax2 = axes[1] positions = recorded_rays.positions x, y, z = positions[:, 0], positions[:, 1], positions[:, 2] r = np.sqrt(x**2 + y**2 + z**2) lat = np.degrees(np.arcsin(z / r)) lon = np.degrees(np.arctan2(y, x)) sc = ax2.scatter( lon, lat, c=recorded_rays.intensities, s=3, alpha=0.6, cmap="hot", vmin=0, vmax=np.percentile(recorded_rays.intensities, 95), ) ax2.plot( peak_result["peak_lon_deg"], peak_result["peak_lat_deg"], "c*", markersize=15, markeredgecolor="white", markeredgewidth=1.5, ) ax2.set_xlabel("Longitude (deg)", fontsize=12) ax2.set_ylabel("Latitude (deg)", fontsize=12) ax2.set_title("Ray Intersections with Peak", fontsize=14, fontweight="bold") ax2.set_aspect("equal") cbar2 = plt.colorbar(sc, ax=ax2) cbar2.set_label("Intensity", fontsize=11) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f" Saved: {save_path}") return fig
[docs] def plot_mollweide_detector_projection( recorded_rays: "RecordedRays", save_path: str | None = None, ): """ Mollweide projection showing where rays intersect the detector sphere. Parameters ---------- recorded_rays : RecordedRays Recorded rays on detection sphere save_path : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure The created figure """ # Convert Cartesian to spherical coordinates (lon, lat) positions = recorded_rays.positions # Compute longitude and latitude x, y, z = positions[:, 0], positions[:, 1], positions[:, 2] r = np.sqrt(x**2 + y**2 + z**2) # Latitude: angle from equator (-90 to +90 degrees) lat = np.arcsin(z / r) # Longitude: angle in xy-plane (-180 to +180 degrees) lon = np.arctan2(y, x) # Create Mollweide projection fig = plt.figure(figsize=(14, 8)) ax = fig.add_subplot(111, projection="mollweide") # Plot as scatter with intensity coloring sc = ax.scatter( lon, lat, c=recorded_rays.intensities, s=2, alpha=0.5, cmap="hot", vmin=0, vmax=np.percentile( recorded_rays.intensities, 95 ), # Saturate at 95th percentile ) ax.set_xlabel("Longitude (rad)", fontsize=12) ax.set_ylabel("Latitude (rad)", fontsize=12) ax.set_title( "Detector Sphere Intersections - Mollweide Projection", fontsize=14, fontweight="bold", ) ax.grid(True, alpha=0.3) # Add colorbar cbar = plt.colorbar(sc, ax=ax, pad=0.05, fraction=0.046) cbar.set_label("Intensity", fontsize=11) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f" Saved: {save_path}") return fig
[docs] def plot_arrival_time_distributions( recorded_rays: "RecordedRays", pareto_result: dict, n_top: int = 10, bins: int = 50, time_threshold_ns: float = 10.0, save_path: str | None = None, ): """ Plot arrival time histograms for the top N intensity bins. Shows the actual distribution of arrival times (not just percentiles) for each of the highest energy density detector bins. Parameters ---------- recorded_rays : RecordedRays Recorded rays on detection sphere pareto_result : dict Result from compute_pareto_front() n_top : int Number of top intensity bins to plot bins : int Number of histogram bins time_threshold_ns : float Time threshold to mark on plots (ns) save_path : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure or None The created figure, or None if no data """ bin_data = pareto_result["bin_data"] if len(bin_data) == 0: print(" No bin data for arrival time distributions") return None if len(bin_data) < n_top: n_top = len(bin_data) # Sort by energy density and take top N sorted_bins = sorted(bin_data, key=lambda x: x["energy_density"], reverse=True) top_bins = sorted_bins[:n_top] # Get ray data positions = recorded_rays.positions intensities = recorded_rays.intensities times = recorded_rays.times # Convert to spherical coordinates x, y, z = positions[:, 0], positions[:, 1], positions[:, 2] r = np.sqrt(x**2 + y**2 + z**2) lat = np.arcsin(z / r) lon = np.arctan2(y, x) # Infer bin edges from pareto_result # Get the bin spacing from first bin's location all_lons = np.array([b["lon"] for b in bin_data]) all_lats = np.array([b["lat"] for b in bin_data]) unique_lons = np.unique(all_lons) unique_lats = np.unique(all_lats) dlon = np.diff(unique_lons).min() if len(unique_lons) > 1 else 0.1 dlat = np.diff(unique_lats).min() if len(unique_lats) > 1 else 0.1 # Create figure with subplots for each bin n_cols = min(5, n_top) n_rows = (n_top + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3.5 * n_rows)) if n_top == 1: axes = np.array([[axes]]) elif n_rows == 1: axes = axes.reshape(1, -1) # Find global time range for consistent x-axis first_arrival = np.min(times) for i, b in enumerate(top_bins): row, col = divmod(i, n_cols) ax = axes[row, col] # Find rays in this bin bin_lon = b["lon"] bin_lat = b["lat"] mask = ( (lon >= bin_lon - dlon / 2) & (lon < bin_lon + dlon / 2) & (lat >= bin_lat - dlat / 2) & (lat < bin_lat + dlat / 2) ) n_rays = np.sum(mask) if n_rays == 0: ax.text( 0.5, 0.5, "No rays", ha="center", va="center", transform=ax.transAxes ) ax.set_title(f"#{i+1}: ({b['lon_deg']:.1f}, {b['lat_deg']:.1f})°") continue bin_times = times[mask] bin_intensities = intensities[mask] # Convert to relative arrival time in ns times_relative_ns = (bin_times - first_arrival) * 1e9 # Create histogram (intensity-weighted) counts, edges = np.histogram( times_relative_ns, bins=bins, weights=bin_intensities ) centers = (edges[:-1] + edges[1:]) / 2 # Plot as step histogram ax.fill_between(centers, counts, alpha=0.6, color="steelblue", step="mid") ax.step(centers, counts, where="mid", color="steelblue", linewidth=1.5) # Mark the 10th and 90th percentiles from ..utilities.detector_analysis import weighted_percentile t10 = weighted_percentile(times_relative_ns, bin_intensities, 10) t90 = weighted_percentile(times_relative_ns, bin_intensities, 90) ax.axvline( t10, color="red", linestyle="--", linewidth=1.5, alpha=0.8, label=f"10th: {t10:.1f} ns", ) ax.axvline( t90, color="red", linestyle="--", linewidth=1.5, alpha=0.8, label=f"90th: {t90:.1f} ns", ) # Mark time spread spread = t90 - t10 ax.axvspan(t10, t90, alpha=0.15, color="red") ax.set_xlabel("Relative Arrival Time (ns)", fontsize=9) ax.set_ylabel("Intensity", fontsize=9) ax.set_title( f"#{i+1}: ({b['lon_deg']:.1f}, {b['lat_deg']:.1f}\n" f"n={n_rays}, spread={spread:.1f} ns", fontsize=10, ) ax.grid(True, alpha=0.3) ax.set_xlim(left=0) # Add legend only on first subplot if i == 0: ax.legend(fontsize=7, loc="upper right") # Hide unused subplots for i in range(n_top, n_rows * n_cols): row, col = divmod(i, n_cols) axes[row, col].axis("off") fig.suptitle( f"Arrival Time Distributions for Top {n_top} Energy Density Bins", fontsize=14, fontweight="bold", ) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f" Saved: {save_path}") return fig
[docs] def plot_time_spread_comparison( pareto_result: dict, source_position: tuple[float, float, float], beam_direction: tuple[float, float, float], divergence_angle_rad: float, detector_altitude: float, surface=None, n_top: int = 10, source_power: float = 1.0, time_threshold_ns: float = 10.0, save_path: str | None = None, ): """ Compare ray-traced time spread with single-bounce geometric estimate. The geometric estimate assumes a single reflection (source → surface → detector). Ray-traced values may exceed this if rays take multi-bounce paths due to atmospheric refraction causing rays to return to the surface multiple times. Parameters ---------- pareto_result : dict Result from compute_pareto_front() source_position : tuple Source (x, y, z) position in meters beam_direction : tuple Beam direction unit vector divergence_angle_rad : float Beam half-angle divergence in radians detector_altitude : float Detector sphere radius in meters surface : Surface, optional Surface object for geometric estimate. If None, uses flat surface at z=0. n_top : int Number of top intensity bins to analyze source_power : float Input source power for normalization (W) time_threshold_ns : float Time threshold for highlighting (ns) save_path : str, optional Path to save figure Returns ------- fig : matplotlib.figure.Figure The created figure """ from ..utilities import estimate_time_spread bin_data = pareto_result["bin_data"] if len(bin_data) < n_top: n_top = len(bin_data) # Sort by energy density and take top N sorted_bins = sorted(bin_data, key=lambda x: x["energy_density"], reverse=True) top_bins = sorted_bins[:n_top] # Compute geometric time spread estimate for each detector location geometric_bounds = [] raytraced_spreads = [] labels = [] for i, b in enumerate(top_bins): # Convert bin angular position to Cartesian detector position lon_rad = b["lon"] lat_rad = b["lat"] det_x = detector_altitude * np.cos(lat_rad) * np.cos(lon_rad) det_y = detector_altitude * np.cos(lat_rad) * np.sin(lon_rad) det_z = detector_altitude * np.sin(lat_rad) detector_position = (det_x, det_y, det_z) # Compute geometric estimate (single-bounce assumption) result = estimate_time_spread( source_position=source_position, beam_direction=beam_direction, divergence_angle=divergence_angle_rad, detector_position=detector_position, surface=surface, ) geometric_bounds.append(result.time_spread_ns) raytraced_spreads.append(b["time_spread_ns"]) labels.append(f"({b['lon_deg']:.1f}, {b['lat_deg']:.1f})") geometric_bounds = np.array(geometric_bounds) raytraced_spreads = np.array(raytraced_spreads) # Create plot with two subplots fig, axes = plt.subplots(1, 2, figsize=(16, 6)) # ========================================================================== # Left: Bar chart comparing raytraced vs geometric # ========================================================================== ax = axes[0] x = np.arange(n_top) width = 0.35 bars1 = ax.bar( x - width / 2, raytraced_spreads, width, label="Ray-traced (includes multi-bounce)", color="steelblue", alpha=0.8, ) bars2 = ax.bar( x + width / 2, geometric_bounds, width, label="Geometric (single-bounce only)", color="coral", alpha=0.8, ) # Add threshold line ax.axhline( y=time_threshold_ns, color="red", linestyle="--", linewidth=2, label=f"Threshold ({time_threshold_ns} ns)", ) ax.set_xlabel("Detector Location (lon, lat) deg", fontsize=12) ax.set_ylabel("Time Spread (ns)", fontsize=12) ax.set_title( "Time Spread: Ray-Traced vs Single-Bounce Geometric", fontsize=14, fontweight="bold", ) ax.set_xticks(x) ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=9) ax.legend(fontsize=9, loc="upper right") ax.grid(True, alpha=0.3, axis="y") # Add value labels on bars for bar, val in zip(bars1, raytraced_spreads, strict=False): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 2, f"{val:.0f}", ha="center", va="bottom", fontsize=8, ) for bar, val in zip(bars2, geometric_bounds, strict=False): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 2, f"{val:.1f}", ha="center", va="bottom", fontsize=8, ) # ========================================================================== # Right: Ratio of raytraced to geometric (shows multi-bounce contribution) # ========================================================================== ax2 = axes[1] # Compute ratio (how much larger is raytraced vs geometric) ratio = raytraced_spreads / np.maximum(geometric_bounds, 0.1) # Avoid div by zero colors = plt.cm.RdYlGn_r(np.clip(ratio / 50, 0, 1)) # Red = high ratio bars = ax2.bar( x, ratio, width=0.7, color=colors, alpha=0.8, edgecolor="black", linewidth=0.5, ) ax2.axhline( y=1, color="green", linestyle="-", linewidth=2, label="Ratio = 1 (single-bounce)", ) ax2.axhline(y=10, color="orange", linestyle="--", linewidth=1.5, label="Ratio = 10") ax2.set_xlabel("Detector Location (lon, lat) deg", fontsize=12) ax2.set_ylabel("Raytraced / Geometric Ratio", fontsize=12) ax2.set_title( "Multi-Bounce Contribution\n(ratio > 1 indicates multi-bounce paths)", fontsize=14, fontweight="bold", ) ax2.set_xticks(x) ax2.set_xticklabels(labels, rotation=45, ha="right", fontsize=9) ax2.legend(fontsize=9, loc="upper right") ax2.grid(True, alpha=0.3, axis="y") # Add value labels for bar, val in zip(bars, ratio, strict=False): ax2.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, f"{val:.0f}x", ha="center", va="bottom", fontsize=9, fontweight="bold", ) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f" Saved: {save_path}") return fig