Source code for lsurf.visualization.ocean_simulation_plots

# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
#      * Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#
#      * Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#      * Neither the name of the copyright holder nor the names of its
#      contributors may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""
Ocean Wave Simulation Visualization

Complete visualization suite for ocean wave ray tracing simulations.
Generates comprehensive figure sets including ray overview, statistics,
intensity-angle-time plots, 3D views, and energy conservation checks.
"""

from pathlib import Path
from typing import TYPE_CHECKING

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

if TYPE_CHECKING:
    from ..surfaces import CurvedWaveSurface

# Import from specific modules to avoid circular imports
from ..surfaces import EARTH_RADIUS


[docs] def create_ocean_simulation_figures( original_rays, surface: "CurvedWaveSurface", recorded_rays, reflected_rays, refracted_rays, config: dict, output_dir: str, timestamp: str, ) -> None: """ Create complete visualization suite for ocean wave simulation. Generates 8 figures: 1. Ray paths overview (full scale and surface detail) 2. Recorded rays statistics (6-panel figure) 3. Intensity vs angle (log scale, fraction) 4. Intensity vs angle (linear scale, fraction) 5. Intensity density (log scale, ns⁻¹) 6. Intensity density (linear scale, ns⁻¹) 7. 3D visualization 8. Energy conservation check Parameters ---------- original_rays : RayBatch Initial input rays before interaction surface : CurvedWaveSurface The ocean surface model recorded_rays : RecordedRays Rays detected at the recording sphere reflected_rays : RayBatch Reflected rays from surface refracted_rays : RayBatch Refracted rays into water config : dict Simulation configuration parameters output_dir : str Directory for output files timestamp : str Timestamp string for filenames """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # ========================================================================= # Figure 1: Ray overview using actual simulation results # ========================================================================= print(" Creating ray paths overview...") _create_ray_overview( original_rays, surface, recorded_rays, reflected_rays, refracted_rays, config, output_path, timestamp, ) # ========================================================================= # Figure 2: Recorded rays statistics # ========================================================================= print(" Creating recorded rays statistics...") if recorded_rays.num_rays > 0: _create_statistics_figure( original_rays, recorded_rays, reflected_rays, config, output_path, timestamp ) # Create dedicated intensity-angle-time plots _create_intensity_angle_plots( original_rays, recorded_rays, reflected_rays, config, output_path, timestamp ) else: print(" WARNING: No recorded rays - skipping statistics figures") # ========================================================================= # Figure 7: 3D visualization # ========================================================================= print(" Creating 3D visualization...") _create_3d_visualization(recorded_rays, config, output_path, timestamp) # ========================================================================= # Figure 8: Energy conservation check # ========================================================================= print(" Creating energy conservation figure...") _create_energy_conservation( original_rays, recorded_rays, reflected_rays, refracted_rays, config, output_path, timestamp, )
def _create_ray_overview( original_rays, surface, recorded_rays, reflected_rays, refracted_rays=None, config=None, output_path=None, timestamp=None, ): """Create ray paths overview figure.""" fig, axes = plt.subplots(1, 2, figsize=(16, 8)) # Subsample rays for visualization n_vis = min(500, original_rays.num_rays) vis_idx = np.random.choice(original_rays.num_rays, n_vis, replace=False) # Get hit positions (where rays hit the surface) distances, hit_mask = surface.intersect( original_rays.positions[vis_idx], original_rays.directions[vis_idx] ) hit_positions = ( original_rays.positions[vis_idx][hit_mask] + distances[hit_mask, np.newaxis] * original_rays.directions[vis_idx][hit_mask] ) # Left panel: Full scale view (X-Z plane) ax1 = axes[0] # Draw Earth surface earth_center = np.array([0, 0, -EARTH_RADIUS]) theta = np.linspace(-0.01, 0.01, 100) earth_x = EARTH_RADIUS * np.sin(theta) earth_z = earth_center[2] + EARTH_RADIUS * np.cos(theta) ax1.fill_between( earth_x / 1000, earth_z / 1000, -10, color="#4a90d9", alpha=0.3, label="Ocean" ) ax1.plot(earth_x / 1000, earth_z / 1000, "b-", linewidth=2, label="Sea surface") # Recording sphere - LOCAL sphere centered at origin # Draw a circle centered at (0, 0, 0) with radius = recording_altitude recording_radius_local = config["recording_altitude"] / 1000 # in km theta_sphere = np.linspace(0, 2 * np.pi, 100) sphere_x = recording_radius_local * np.cos(theta_sphere) sphere_z = recording_radius_local * np.sin(theta_sphere) ax1.plot( sphere_x, sphere_z, "k--", linewidth=1.5, label=f'Recording sphere ({config["recording_altitude"]/1000:.0f} km)', ) # Plot incoming rays (from source to surface) n_plot = min(100, len(hit_positions)) for i in range(n_plot): idx = vis_idx[np.where(hit_mask)[0][i]] start = original_rays.positions[idx] end = hit_positions[i] color = "r" if i == 0 else "r" alpha = 0.6 if i == 0 else 0.3 label = "Incoming rays" if i == 0 else None ax1.plot( [start[0] / 1000, end[0] / 1000], [start[2] / 1000, end[2] / 1000], color=color, alpha=alpha, linewidth=0.8, label=label, ) # Plot actual reflected rays (from surface toward recording sphere) upward_mask = reflected_rays.directions[:, 2] > 0 upward_reflected = np.where(upward_mask)[0] n_plot_refl = min(100, len(upward_reflected)) ray_length = config["recording_altitude"] * 1.5 for i, idx in enumerate(upward_reflected[:n_plot_refl]): start = reflected_rays.positions[idx] direction = reflected_rays.directions[idx] end = start + direction * ray_length color = "g" alpha = 0.6 if i == 0 else 0.3 label = "Reflected rays (upward)" if i == 0 else None ax1.plot( [start[0] / 1000, end[0] / 1000], [start[2] / 1000, end[2] / 1000], color=color, alpha=alpha, linewidth=0.8, label=label, ) # Plot downward-going reflected rays (a few) downward_reflected = np.where(~upward_mask)[0] n_plot_down = min(20, len(downward_reflected)) for i, idx in enumerate(downward_reflected[:n_plot_down]): start = reflected_rays.positions[idx] direction = reflected_rays.directions[idx] end = start + direction * 1000 # shorter length for downward color = "orange" alpha = 0.6 if i == 0 else 0.3 label = "Reflected rays (downward)" if i == 0 else None ax1.plot( [start[0] / 1000, end[0] / 1000], [start[2] / 1000, end[2] / 1000], color=color, alpha=alpha, linewidth=0.8, label=label, ) # Plot refracted rays (downward into water) - optional if refracted_rays is not None and refracted_rays.num_rays > 0: n_plot_refr = min(50, refracted_rays.num_rays) refracted_indices = np.random.choice( refracted_rays.num_rays, n_plot_refr, replace=False ) ray_length_refr = config["recording_altitude"] * 1.5 # Same as upward reflected for i, idx in enumerate(refracted_indices): start = refracted_rays.positions[idx] direction = refracted_rays.directions[idx] end = start + direction * ray_length_refr color = "cyan" alpha = 0.6 if i == 0 else 0.3 label = "Refracted rays (into water)" if i == 0 else None ax1.plot( [start[0] / 1000, end[0] / 1000], [start[2] / 1000, end[2] / 1000], color=color, alpha=alpha, linewidth=0.8, linestyle="--", label=label, ) ax1.set_xlabel("X (km)", fontsize=12) ax1.set_ylabel("Z (km)", fontsize=12) ax1.set_title("Ray Paths Overview (X-Z Plane)", fontsize=14) ax1.legend(loc="upper right") ax1.set_aspect("equal") ax1.grid(True, alpha=0.3) max_range = ( max(config.get("source_distance", 10000), config["recording_altitude"]) * 1.5 ) ax1.set_xlim(-max_range / 1000 * 0.5, max_range / 1000 * 1.5) ax1.set_ylim(-5, config["recording_altitude"] / 1000 * 1.2) # Right panel: Zoom on surface interaction ax2 = axes[1] # Plot wave surface using get_surface_point x_range = np.linspace(-200, 200, 500) surface_positions = np.column_stack( [x_range, np.zeros_like(x_range), np.zeros_like(x_range)] ) surface_points = surface.get_surface_point(surface_positions.astype(np.float32)) z_surface = surface_points[:, 2] ax2.fill_between( x_range, z_surface, z_surface.min() - 5, color="#4a90d9", alpha=0.3 ) ax2.plot(x_range, z_surface, "b-", linewidth=2, label="Wave surface") # Plot hit points ax2.scatter( hit_positions[:, 0], hit_positions[:, 2], c="red", s=10, alpha=0.5, label="Hit points", ) # Plot reflected ray directions from hit points for i in range(min(50, len(hit_positions))): idx = vis_idx[np.where(hit_mask)[0][i]] # Find corresponding reflected ray (same index in reflected_rays) if idx < reflected_rays.num_rays: start = reflected_rays.positions[idx] direction = reflected_rays.directions[idx] length = 50 # 50m arrow end = start + direction * length color = "green" if direction[2] > 0 else "orange" ax2.arrow( start[0], start[2], direction[0] * length, direction[2] * length, head_width=2, head_length=1, fc=color, ec=color, alpha=0.5, ) ax2.set_xlabel("X (m)", fontsize=12) ax2.set_ylabel("Z (m)", fontsize=12) ax2.set_title("Surface Interaction Detail", fontsize=14) ax2.legend(loc="upper right") ax2.set_aspect("equal") ax2.grid(True, alpha=0.3) ax2.set_xlim(-200, 200) ax2.set_ylim(-5, 10) plt.tight_layout() overview_path = output_path / f"local_simulation_{timestamp}_overview.png" plt.savefig(overview_path, dpi=150, bbox_inches="tight") plt.close() print(f" Saved: {overview_path}") def _create_statistics_figure( original_rays, recorded_rays, reflected_rays, config, output_path, timestamp ): """Create 6-panel statistics figure.""" fig = plt.figure(figsize=(16, 12)) gs = gridspec.GridSpec(2, 3, figure=fig) # Angular coordinates (for azimuth only) angular = recorded_rays.compute_angular_coordinates() # Ray direction angle relative to horizontal at origin directions = recorded_rays.directions ray_elevation_deg = np.degrees(np.arcsin(directions[:, 2])) # Panel 1: Ray angle from horizontal ax1 = fig.add_subplot(gs[0, 0]) ax1.hist( ray_elevation_deg, bins=50, weights=recorded_rays.intensities, color="steelblue", edgecolor="black", alpha=0.7, ) ax1.set_xlabel("Ray Angle from Horizontal (degrees)", fontsize=11) ax1.set_ylabel("Intensity-weighted Count", fontsize=11) ax1.set_title("Ray Direction Angle (relative to z=0 at origin)", fontsize=12) ax1.grid(True, alpha=0.3) # Panel 2: Azimuth angle distribution ax2 = fig.add_subplot(gs[0, 1]) azimuth_deg = np.degrees(angular["azimuth"]) ax2.hist( azimuth_deg, bins=50, weights=recorded_rays.intensities, color="coral", edgecolor="black", alpha=0.7, ) ax2.set_xlabel("Azimuth Angle (degrees)", fontsize=11) ax2.set_ylabel("Intensity-weighted Count", fontsize=11) ax2.set_title("Azimuth Angle Distribution", fontsize=12) ax2.grid(True, alpha=0.3) # Panel 3: Time distribution ax3 = fig.add_subplot(gs[0, 2]) times_ns = recorded_rays.times * 1e9 # Convert to nanoseconds relative_times_ns = times_ns - times_ns.min() relative_times_ns_safe = np.maximum(relative_times_ns, 1.0) if relative_times_ns.max() > 1.0: log_bins = np.logspace(0, np.log10(relative_times_ns_safe.max()), 51) ax3.hist( relative_times_ns_safe, bins=log_bins, weights=recorded_rays.intensities, color="green", edgecolor="black", alpha=0.7, ) ax3.set_xscale("log") else: ax3.hist( relative_times_ns, bins=50, weights=recorded_rays.intensities, color="green", edgecolor="black", alpha=0.7, ) ax3.set_xlabel("Relative Arrival Time (ns)", fontsize=11) ax3.set_ylabel("Intensity-weighted Count", fontsize=11) ax3.set_title("Time of Arrival Distribution", fontsize=12) ax3.grid(True, alpha=0.3) # Panel 4: Intensity distribution ax4 = fig.add_subplot(gs[1, 0]) ax4.hist( np.log10(recorded_rays.intensities + 1e-20), bins=50, color="purple", edgecolor="black", alpha=0.7, ) ax4.set_xlabel("log₁₀(Intensity)", fontsize=11) ax4.set_ylabel("Count", fontsize=11) ax4.set_title("Intensity Distribution", fontsize=12) ax4.grid(True, alpha=0.3) # Panel 5: Relative arrival times by angle bin ax5 = fig.add_subplot(gs[1, 1]) # Compute normalization total_incident_power = np.sum(original_rays.intensities) # Bin rays by angle num_angle_bins = 15 angle_bins = np.linspace( ray_elevation_deg.min(), ray_elevation_deg.max(), num_angle_bins + 1 ) bin_indices = np.digitize(ray_elevation_deg, angle_bins) colors = plt.cm.turbo(np.linspace(0, 1, num_angle_bins)) # Shared log time bins time_bins = np.logspace(-2, 4, 101) times_ns_plot = recorded_rays.times * 1e9 for bin_idx in range(1, len(angle_bins)): mask = bin_indices == bin_idx if np.sum(mask) > 5: bin_times = times_ns_plot[mask] bin_intensities = recorded_rays.intensities[mask] earliest = bin_times.min() relative_times = bin_times - earliest relative_times_safe = np.maximum(relative_times, 0.01) hist_intensity, _ = np.histogram( relative_times_safe, bins=time_bins, weights=bin_intensities ) hist_intensity_normalized = hist_intensity / total_incident_power bin_centers = np.sqrt(time_bins[:-1] * time_bins[1:]) mean_angle = ray_elevation_deg[mask].mean() ax5.plot( bin_centers, hist_intensity_normalized, alpha=0.6, linewidth=1.0, color=colors[bin_idx - 1], label=f"{mean_angle:.1f}°", ) ax5.set_xlabel("Relative Arrival Time (ns)", fontsize=11) ax5.set_ylabel("Normalized Intensity Fraction", fontsize=11) ax5.set_title("Intensity vs Time by Angle Bin", fontsize=12) ax5.set_xscale("log") ax5.grid(True, alpha=0.3) ax5.legend(fontsize=7, ncol=2, title="Angle") # Panel 6: 2D angular distribution ax6 = fig.add_subplot(gs[1, 2]) h = ax6.hist2d( azimuth_deg, ray_elevation_deg, bins=30, weights=recorded_rays.intensities, cmap="hot", ) plt.colorbar(h[3], ax=ax6, label="Intensity") ax6.set_xlabel("Azimuth (degrees)", fontsize=11) ax6.set_ylabel("Ray Angle from Horizontal (degrees)", fontsize=11) ax6.set_title("2D Angular Distribution", fontsize=12) plt.tight_layout() fig_path = output_path / f"local_simulation_{timestamp}_statistics.png" plt.savefig(fig_path, dpi=150, bbox_inches="tight") plt.close() print(f" Saved: {fig_path}") def _create_intensity_angle_plots( original_rays, recorded_rays, reflected_rays, config, output_path, timestamp ): """Create dedicated intensity vs angle bin plots (4 variants).""" # Common setup directions = recorded_rays.directions ray_elevation_deg = np.degrees(np.arcsin(directions[:, 2])) total_incident_power = np.sum(original_rays.intensities) times_ns_plot = recorded_rays.times * 1e9 num_angle_bins = 20 angle_bins = np.linspace( ray_elevation_deg.min(), ray_elevation_deg.max(), num_angle_bins + 1 ) bin_indices = np.digitize(ray_elevation_deg, angle_bins) colors = plt.cm.turbo(np.linspace(0, 1, num_angle_bins)) # Logarithmic time bins time_bins_log = np.logspace(-2, 4, 101) bin_widths_log = time_bins_log[1:] - time_bins_log[:-1] bin_centers_log = np.sqrt(time_bins_log[:-1] * time_bins_log[1:]) # Linear time bins max_time = times_ns_plot.max() - times_ns_plot.min() + 100 time_bins_lin = np.arange(0, max_time, 1.0) bin_widths_lin = 1.0 bin_centers_lin = time_bins_lin[:-1] + 0.5 # ========================================================================= # Plot 1: Log scale, fraction # ========================================================================= print(" Creating intensity vs angle bin plot (log)...") fig_log = plt.figure(figsize=(12, 8)) ax_log = fig_log.add_subplot(111) legend_handles_log = [] legend_labels_log = [] for bin_idx in range(1, len(angle_bins)): mask = bin_indices == bin_idx if np.sum(mask) > 5: bin_times = times_ns_plot[mask] bin_intensities = recorded_rays.intensities[mask] earliest = bin_times.min() relative_times = bin_times - earliest relative_times_safe = np.maximum(relative_times, 0.01) hist_intensity, _ = np.histogram( relative_times_safe, bins=time_bins_log, weights=bin_intensities ) hist_intensity_normalized = hist_intensity / total_incident_power mean_angle = ray_elevation_deg[mask].mean() bin_total = np.sum(bin_intensities) (line,) = ax_log.plot( bin_centers_log, hist_intensity_normalized, alpha=0.7, linewidth=1.0, color=colors[bin_idx - 1], ) legend_handles_log.append(line) legend_labels_log.append(f"{mean_angle:.1f}° (Σ={bin_total:.2e})") ax_log.set_xlabel("Relative Arrival Time (ns)", fontsize=12) ax_log.set_ylabel("Normalized Intensity Fraction", fontsize=12) ax_log.set_title( f"Intensity vs Arrival Time by Ray Angle Bin (Log Scale)\n" f"Grazing angle: {config['grazing_angle']:.1f}°, " f"Wave amplitude: {config['wave_amplitude']:.2f} m, " f"Wavelength: {config['wave_wavelength']:.1f} m", fontsize=14, ) ax_log.set_xscale("log") ax_log.grid(True, alpha=0.3) ax_log.legend( legend_handles_log, legend_labels_log, fontsize=9, ncol=3, title="Ray Angle (Total Intensity)", loc="upper right", ) sm = plt.cm.ScalarMappable( cmap="turbo", norm=plt.Normalize(vmin=ray_elevation_deg.min(), vmax=ray_elevation_deg.max()), ) sm.set_array([]) plt.colorbar(sm, ax=ax_log, label="Ray Angle from Horizontal (°)") plt.tight_layout() plt.savefig( output_path / f"local_simulation_{timestamp}_intensity_angle_log.png", dpi=150, bbox_inches="tight", ) plt.close() print( f" Saved: {output_path / f'local_simulation_{timestamp}_intensity_angle_log.png'}" ) # ========================================================================= # Plot 2: Linear scale, fraction # ========================================================================= print(" Creating intensity vs angle bin plot (linear)...") fig_lin = plt.figure(figsize=(12, 8)) ax_lin = fig_lin.add_subplot(111) legend_handles_lin = [] legend_labels_lin = [] for bin_idx in range(1, len(angle_bins)): mask = bin_indices == bin_idx if np.sum(mask) > 5: bin_times = times_ns_plot[mask] bin_intensities = recorded_rays.intensities[mask] earliest = bin_times.min() relative_times = bin_times - earliest hist_intensity, _ = np.histogram( relative_times, bins=time_bins_lin, weights=bin_intensities ) hist_intensity_normalized = hist_intensity / total_incident_power mean_angle = ray_elevation_deg[mask].mean() bin_total = np.sum(bin_intensities) (line,) = ax_lin.plot( bin_centers_lin, hist_intensity_normalized, alpha=0.7, linewidth=1.0, color=colors[bin_idx - 1], ) legend_handles_lin.append(line) legend_labels_lin.append(f"{mean_angle:.1f}° (Σ={bin_total:.2e})") ax_lin.set_xlabel("Relative Arrival Time (ns)", fontsize=12) ax_lin.set_ylabel("Normalized Intensity Fraction", fontsize=12) ax_lin.set_title( f"Intensity vs Arrival Time by Ray Angle Bin (Linear Scale)\n" f"Grazing angle: {config['grazing_angle']:.1f}°, " f"Wave amplitude: {config['wave_amplitude']:.2f} m, " f"Wavelength: {config['wave_wavelength']:.1f} m", fontsize=14, ) ax_lin.grid(True, alpha=0.3) ax_lin.legend( legend_handles_lin, legend_labels_lin, fontsize=9, ncol=3, title="Ray Angle (Total Intensity)", loc="upper right", ) sm = plt.cm.ScalarMappable( cmap="turbo", norm=plt.Normalize(vmin=ray_elevation_deg.min(), vmax=ray_elevation_deg.max()), ) sm.set_array([]) plt.colorbar(sm, ax=ax_lin, label="Ray Angle from Horizontal (°)") plt.tight_layout() plt.savefig( output_path / f"local_simulation_{timestamp}_intensity_angle_linear.png", dpi=150, bbox_inches="tight", ) plt.close() print( f" Saved: {output_path / f'local_simulation_{timestamp}_intensity_angle_linear.png'}" ) # ========================================================================= # Plot 3: Log scale, density # ========================================================================= print(" Creating intensity density plot (log)...") fig_dens_log = plt.figure(figsize=(12, 8)) ax_dens_log = fig_dens_log.add_subplot(111) legend_handles_dens_log = [] legend_labels_dens_log = [] for bin_idx in range(1, len(angle_bins)): mask = bin_indices == bin_idx if np.sum(mask) > 5: bin_times = times_ns_plot[mask] bin_intensities = recorded_rays.intensities[mask] earliest = bin_times.min() relative_times = bin_times - earliest relative_times_safe = np.maximum(relative_times, 0.01) hist_intensity, _ = np.histogram( relative_times_safe, bins=time_bins_log, weights=bin_intensities ) hist_intensity_density = hist_intensity / ( total_incident_power * bin_widths_log ) mean_angle = ray_elevation_deg[mask].mean() bin_total = np.sum(bin_intensities) (line,) = ax_dens_log.plot( bin_centers_log, hist_intensity_density, alpha=0.7, linewidth=1.0, color=colors[bin_idx - 1], ) legend_handles_dens_log.append(line) legend_labels_dens_log.append(f"{mean_angle:.1f}° (Σ={bin_total:.2e})") ax_dens_log.set_xlabel("Relative Arrival Time (ns)", fontsize=12) ax_dens_log.set_ylabel("Normalized Intensity Density (ns⁻¹)", fontsize=12) ax_dens_log.set_title( f"Intensity Density vs Arrival Time by Ray Angle Bin (Log Scale)\n" f"Grazing angle: {config['grazing_angle']:.1f}°, " f"Wave amplitude: {config['wave_amplitude']:.2f} m, " f"Wavelength: {config['wave_wavelength']:.1f} m", fontsize=14, ) ax_dens_log.set_xscale("log") ax_dens_log.grid(True, alpha=0.3) ax_dens_log.legend( legend_handles_dens_log, legend_labels_dens_log, fontsize=9, ncol=3, title="Ray Angle (Total Intensity)", loc="upper right", ) sm = plt.cm.ScalarMappable( cmap="turbo", norm=plt.Normalize(vmin=ray_elevation_deg.min(), vmax=ray_elevation_deg.max()), ) sm.set_array([]) plt.colorbar(sm, ax=ax_dens_log, label="Ray Angle from Horizontal (°)") plt.tight_layout() plt.savefig( output_path / f"local_simulation_{timestamp}_intensity_angle_log_density.png", dpi=150, bbox_inches="tight", ) plt.close() print( f" Saved: {output_path / f'local_simulation_{timestamp}_intensity_angle_log_density.png'}" ) # ========================================================================= # Plot 4: Linear scale, density # ========================================================================= print(" Creating intensity density plot (linear)...") fig_dens_lin = plt.figure(figsize=(12, 8)) ax_dens_lin = fig_dens_lin.add_subplot(111) legend_handles_dens_lin = [] legend_labels_dens_lin = [] for bin_idx in range(1, len(angle_bins)): mask = bin_indices == bin_idx if np.sum(mask) > 5: bin_times = times_ns_plot[mask] bin_intensities = recorded_rays.intensities[mask] earliest = bin_times.min() relative_times = bin_times - earliest hist_intensity, _ = np.histogram( relative_times, bins=time_bins_lin, weights=bin_intensities ) hist_intensity_density = hist_intensity / ( total_incident_power * bin_widths_lin ) mean_angle = ray_elevation_deg[mask].mean() bin_total = np.sum(bin_intensities) (line,) = ax_dens_lin.plot( bin_centers_lin, hist_intensity_density, alpha=0.7, linewidth=1.0, color=colors[bin_idx - 1], ) legend_handles_dens_lin.append(line) legend_labels_dens_lin.append(f"{mean_angle:.1f}° (Σ={bin_total:.2e})") ax_dens_lin.set_xlabel("Relative Arrival Time (ns)", fontsize=12) ax_dens_lin.set_ylabel("Normalized Intensity Density (ns⁻¹)", fontsize=12) ax_dens_lin.set_title( f"Intensity Density vs Arrival Time by Ray Angle Bin (Linear Scale)\n" f"Grazing angle: {config['grazing_angle']:.1f}°, " f"Wave amplitude: {config['wave_amplitude']:.2f} m, " f"Wavelength: {config['wave_wavelength']:.1f} m", fontsize=14, ) ax_dens_lin.grid(True, alpha=0.3) ax_dens_lin.legend( legend_handles_dens_lin, legend_labels_dens_lin, fontsize=9, ncol=3, title="Ray Angle (Total Intensity)", loc="upper right", ) sm = plt.cm.ScalarMappable( cmap="turbo", norm=plt.Normalize(vmin=ray_elevation_deg.min(), vmax=ray_elevation_deg.max()), ) sm.set_array([]) plt.colorbar(sm, ax=ax_dens_lin, label="Ray Angle from Horizontal (°)") plt.tight_layout() plt.savefig( output_path / f"local_simulation_{timestamp}_intensity_angle_linear_density.png", dpi=150, bbox_inches="tight", ) plt.close() print( f" Saved: {output_path / f'local_simulation_{timestamp}_intensity_angle_linear_density.png'}" ) def _create_3d_visualization(recorded_rays, config, output_path, timestamp): """Create 3D scatter plot of recorded rays.""" fig = plt.figure(figsize=(14, 10)) ax = fig.add_subplot(111, projection="3d") if recorded_rays.num_rays > 0: n_plot = min(500, recorded_rays.num_rays) indices = np.random.choice(recorded_rays.num_rays, n_plot, replace=False) positions = recorded_rays.positions[indices] / 1000 # km intensities = recorded_rays.intensities[indices] scatter = ax.scatter( positions[:, 0], positions[:, 1], positions[:, 2], c=intensities, cmap="hot", s=10, alpha=0.6, ) plt.colorbar(scatter, ax=ax, label="Intensity", shrink=0.6) # Draw coordinate axes at origin axis_length = config["recording_altitude"] / 1000 * 0.3 ax.quiver(0, 0, 0, axis_length, 0, 0, color="r", arrow_length_ratio=0.1, label="X") ax.quiver(0, 0, 0, 0, axis_length, 0, color="g", arrow_length_ratio=0.1, label="Y") ax.quiver(0, 0, 0, 0, 0, axis_length, color="b", arrow_length_ratio=0.1, label="Z") ax.set_xlabel("X (km)", fontsize=11) ax.set_ylabel("Y (km)", fontsize=11) ax.set_zlabel("Z (km)", fontsize=11) ax.set_title("Recorded Rays at Detection Sphere (3D)", fontsize=14) plt.tight_layout() fig_path = output_path / f"local_simulation_{timestamp}_3d.png" plt.savefig(fig_path, dpi=150, bbox_inches="tight") plt.close() print(f" Saved: {fig_path}") def _create_energy_conservation( original_rays, recorded_rays, reflected_rays, refracted_rays, config, output_path, timestamp, ): """Create energy conservation check figure.""" fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Calculate intensities input_intensity = np.sum(original_rays.intensities) output_intensity = ( np.sum(recorded_rays.intensities) if recorded_rays.num_rays > 0 else 0 ) if reflected_rays.num_rays > 0: up_mask = reflected_rays.directions[:, 2] > 0 reflected_intensity = np.sum(reflected_rays.intensities[up_mask]) else: reflected_intensity = 0 if refracted_rays.num_rays > 0: refracted_intensity = np.sum(refracted_rays.intensities) else: refracted_intensity = 0 # Bar chart ax1 = axes[0] categories = ["Input", "Reflected\n(upward)", "Refracted\n(downward)", "Recorded"] values = [ input_intensity, reflected_intensity, refracted_intensity, output_intensity, ] colors = ["steelblue", "coral", "lightblue", "green"] ax1.bar(categories, values, color=colors, edgecolor="black") ax1.set_ylabel("Total Intensity", fontsize=11) ax1.set_title("Energy Balance", fontsize=12) ax1.grid(True, alpha=0.3, axis="y") for i, (_cat, val) in enumerate(zip(categories, values, strict=False)): ax1.text(i, val + 0.02 * max(values), f"{val:.3f}", ha="center", fontsize=10) # Efficiency text ax2 = axes[1] if input_intensity > 0: efficiency = output_intensity / input_intensity * 100 reflected_frac = reflected_intensity / input_intensity * 100 refracted_frac = refracted_intensity / input_intensity * 100 else: efficiency = 0 reflected_frac = 0 refracted_frac = 0 info_text = f"""Simulation Summary ───────────────────────────── Input rays: {original_rays.num_rays:,} Recorded rays: {recorded_rays.num_rays:,} Input intensity: {input_intensity:.4f} Reflected intensity: {reflected_intensity:.4f} ({reflected_frac:.1f}%) Refracted intensity: {refracted_intensity:.4f} ({refracted_frac:.1f}%) Recorded intensity: {output_intensity:.4f} Recording efficiency: {efficiency:.2f}% ───────────────────────────── Recording altitude: {config['recording_altitude']/1000:.0f} km Grazing angle: {config['grazing_angle']:.1f}° Wave amplitude: {config['wave_amplitude']:.2f} m Wave wavelength: {config['wave_wavelength']:.1f} m """ ax2.text( 0.1, 0.5, info_text, transform=ax2.transAxes, fontsize=11, verticalalignment="center", fontfamily="monospace", bbox={"boxstyle": "round", "facecolor": "wheat", "alpha": 0.5}, ) ax2.axis("off") ax2.set_title("Simulation Statistics", fontsize=12) plt.tight_layout() fig_path = output_path / f"local_simulation_{timestamp}_energy.png" plt.savefig(fig_path, dpi=150, bbox_inches="tight") plt.close() print(f" Saved: {fig_path}")