Source code for lsurf.visualization.raytracing_plots

# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
#      * Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#
#      * Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#      * Neither the name of the copyright holder nor the names of its
#      contributors may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""
Ray Tracing Visualization - Individual Axis Functions

Functions for plotting ray paths, intersections, and propagation.
Each function draws on a single axis, enabling flexible composition.
"""

from typing import TYPE_CHECKING, Optional

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure

if TYPE_CHECKING:
    from ..utilities.ray_data import RayBatch
    from ..surfaces import Surface

from .common import (
    DEFAULT_LINEWIDTH,
    DEFAULT_MARKERSIZE,
    INTENSITY_CMAP,
    LINE_ALPHA,
    SCATTER_ALPHA,
    WAVELENGTH_CMAP,
    add_colorbar,
    get_color_mapping,
    get_projection_config,
    save_figure,
    setup_axis_grid,
)

# =============================================================================
# Single-Axis Ray Path Functions
# =============================================================================


[docs] def plot_ray_paths_projection( ax: Axes, ray_history: list["RayBatch"], projection: str = "xz", max_rays: int = 100, color_by: str = "wavelength", alpha: float = LINE_ALPHA, linewidth: float = DEFAULT_LINEWIDTH, show_colorbar: bool = True, ) -> plt.cm.ScalarMappable | None: """ Plot ray paths as a 2D projection on given axis. Parameters ---------- ax : Axes Matplotlib axes to draw on. ray_history : List[RayBatch] List of ray batches at different time steps. projection : str Projection plane: 'xy', 'xz', or 'yz'. max_rays : int Maximum rays to plot for performance. color_by : str Color rays by: 'wavelength', 'intensity', 'generation', 'index'. alpha : float Line transparency. linewidth : float Line width. show_colorbar : bool Whether to add colorbar. Returns ------- sm : ScalarMappable or None ScalarMappable for external colorbar, or None. """ if len(ray_history) == 0: return None initial_batch = ray_history[0] n_rays = initial_batch.num_rays # Sample rays if n_rays > max_rays: ray_indices = np.linspace(0, n_rays - 1, max_rays, dtype=int) else: ray_indices = np.arange(n_rays) # Coordinate mapping idx1, idx2, xlabel, ylabel = get_projection_config(projection) # Color mapping if color_by == "wavelength": values = initial_batch.wavelengths[ray_indices] * 1e9 cmap = WAVELENGTH_CMAP label = "Wavelength (nm)" elif color_by == "intensity": values = initial_batch.intensities[ray_indices] cmap = INTENSITY_CMAP label = "Intensity" elif color_by == "generation": values = initial_batch.generations[ray_indices] cmap = "tab10" label = "Generation" else: values = ray_indices.astype(float) cmap = "tab20" label = "Ray Index" colors, norm, sm = get_color_mapping(values, cmap) # Plot paths for i, ray_idx in enumerate(ray_indices): coords1 = [batch.positions[ray_idx, idx1] for batch in ray_history] coords2 = [batch.positions[ray_idx, idx2] for batch in ray_history] ax.plot(coords1, coords2, color=colors[i], alpha=alpha, linewidth=linewidth) setup_axis_grid(ax, xlabel, ylabel, f"{projection.upper()} Projection") ax.set_aspect("equal", adjustable="box") if show_colorbar: add_colorbar(ax, sm, label) return sm
[docs] def plot_ray_endpoints_scatter( ax: Axes, rays: "RayBatch", projection: str = "xy", color_by: str = "intensity", alpha: float = SCATTER_ALPHA, size: float = DEFAULT_MARKERSIZE, show_colorbar: bool = True, ) -> plt.cm.ScalarMappable | None: """ Plot ray endpoint positions as scatter plot. Parameters ---------- ax : Axes Matplotlib axes. rays : RayBatch Ray batch. projection : str Plane: 'xy', 'xz', 'yz'. color_by : str Color by: 'intensity', 'wavelength', 'generation', 'time'. alpha : float Point transparency. size : float Point size. show_colorbar : bool Whether to add colorbar. Returns ------- sm : ScalarMappable or None For external colorbar. """ active_mask = rays.active positions = rays.positions[active_mask] # Coordinate selection idx1, idx2, xlabel, ylabel = get_projection_config(projection) x, y = positions[:, idx1], positions[:, idx2] # Color values if color_by == "intensity": c = rays.intensities[active_mask] cmap = INTENSITY_CMAP clabel = "Intensity" elif color_by == "wavelength": c = rays.wavelengths[active_mask] * 1e9 cmap = WAVELENGTH_CMAP clabel = "Wavelength (nm)" elif color_by == "generation": c = rays.generations[active_mask] cmap = "tab10" clabel = "Generation" elif color_by == "time": c = rays.accumulated_time[active_mask] * 1e6 cmap = "coolwarm" clabel = "Time (μs)" else: c = np.arange(len(x)) cmap = "viridis" clabel = "Index" scatter = ax.scatter(x, y, c=c, s=size, alpha=alpha, cmap=cmap) setup_axis_grid(ax, xlabel, ylabel, f"Ray Endpoints - {projection.upper()}") ax.set_aspect("equal", adjustable="box") if show_colorbar and c is not None: add_colorbar(ax, scatter, clabel) return scatter
[docs] def plot_ray_endpoints_histogram( ax: Axes, rays: "RayBatch", projection: str = "xy", bins: int = 50, cmap: str = "hot", ) -> None: """ Plot 2D histogram of ray endpoint density. Parameters ---------- ax : Axes Matplotlib axes. rays : RayBatch Ray batch. projection : str Plane: 'xy', 'xz', 'yz'. bins : int Number of histogram bins. cmap : str Colormap name. """ active_mask = rays.active positions = rays.positions[active_mask] idx1, idx2, xlabel, ylabel = get_projection_config(projection) x, y = positions[:, idx1], positions[:, idx2] _, _, _, im = ax.hist2d(x, y, bins=bins, cmap=cmap, cmin=1) setup_axis_grid(ax, xlabel, ylabel, "Ray Density") ax.set_aspect("equal", adjustable="box") add_colorbar(ax, im, "Count")
# ============================================================================= # Surface Intersection Visualization # =============================================================================
[docs] def plot_surface_profile( ax: Axes, surface: "Surface", x_range: tuple[float, float] = (-200, 200), y: float = 0.0, n_points: int = 1000, color: str = "blue", linewidth: float = 2.0, label: str = "Surface", ) -> None: """ Plot surface height profile along x-axis. Parameters ---------- ax : Axes Matplotlib axes. surface : Surface Surface object with _surface_z method. x_range : tuple (x_min, x_max) range. y : float Y-coordinate for profile. n_points : int Number of sample points. color : str Line color. linewidth : float Line width. label : str Legend label. """ x = np.linspace(x_range[0], x_range[1], n_points) y_arr = np.full_like(x, y) if hasattr(surface, "_surface_z"): z = np.array( [surface._surface_z(xi, yi) for xi, yi in zip(x, y_arr, strict=False)] ) else: z = np.zeros_like(x) ax.plot(x, z, color=color, linewidth=linewidth, label=label) ax.fill_between(x, z, z.min() - 0.5, alpha=0.3, color=color) setup_axis_grid(ax, "X (m)", "Z (m)")
[docs] def plot_bounce_points( ax: Axes, bounce_positions: np.ndarray, bounce_number: int = 1, color: str | None = None, size: float = 20, alpha: float = 0.7, projection: str = "xz", label: str | None = None, ) -> None: """ Plot ray bounce points on surface. Parameters ---------- ax : Axes Matplotlib axes. bounce_positions : ndarray (N, 3) array of bounce positions. bounce_number : int Bounce index (for color selection). color : str, optional Override color. size : float Marker size. alpha : float Transparency. projection : str Coordinate projection ('xz', 'xy', 'yz'). label : str, optional Legend label. """ if len(bounce_positions) == 0: return # Default colors by bounce number bounce_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"] if color is None: color = bounce_colors[bounce_number % len(bounce_colors)] idx1, idx2, _, _ = get_projection_config(projection) x = bounce_positions[:, idx1] z = bounce_positions[:, idx2] if label is None: label = f"Bounce {bounce_number}" ax.scatter(x, z, c=color, s=size, alpha=alpha, label=label, edgecolors="none")
[docs] def plot_incoming_rays( ax: Axes, rays: "RayBatch", surface: "Surface", projection: str = "xz", color: str = "gold", alpha: float = 0.3, linewidth: float = 0.5, max_rays: int = 100, ) -> None: """ Plot incoming ray segments from origin to surface intersection. Parameters ---------- ax : Axes Matplotlib axes. rays : RayBatch Ray batch before intersection. surface : Surface Surface for intersection calculation. projection : str Coordinate projection. color : str Ray color. alpha : float Transparency. linewidth : float Line width. max_rays : int Maximum rays to plot. """ idx1, idx2, _, _ = get_projection_config(projection) n_rays = min(rays.num_rays, max_rays) sample_idx = ( np.linspace(0, rays.num_rays - 1, n_rays, dtype=int) if rays.num_rays > max_rays else np.arange(rays.num_rays) ) for i in sample_idx: if not rays.active[i]: continue pos = rays.positions[i] direction = rays.directions[i] # Find intersection t, hit = surface.intersect( pos.reshape(1, 3).astype(np.float32), direction.reshape(1, 3).astype(np.float32), np.array([True]), ) if hit[0] and t[0] > 0: intersection = pos + t[0] * direction ax.plot( [pos[idx1], intersection[idx1]], [pos[idx2], intersection[idx2]], color=color, alpha=alpha, linewidth=linewidth, )
[docs] def plot_reflected_rays( ax: Axes, rays: "RayBatch", length: float = 100.0, projection: str = "xz", color: str = "cyan", alpha: float = 0.3, linewidth: float = 0.5, max_rays: int = 100, ) -> None: """ Plot reflected ray segments from current position. Parameters ---------- ax : Axes Matplotlib axes. rays : RayBatch Reflected ray batch. length : float Length of ray segments to draw. projection : str Coordinate projection. color : str Ray color. alpha : float Transparency. linewidth : float Line width. max_rays : int Maximum rays to plot. """ idx1, idx2, _, _ = get_projection_config(projection) n_rays = min(rays.num_rays, max_rays) sample_idx = ( np.linspace(0, rays.num_rays - 1, n_rays, dtype=int) if rays.num_rays > max_rays else np.arange(rays.num_rays) ) for i in sample_idx: if not rays.active[i]: continue pos = rays.positions[i] direction = rays.directions[i] endpoint = pos + length * direction ax.plot( [pos[idx1], endpoint[idx1]], [pos[idx2], endpoint[idx2]], color=color, alpha=alpha, linewidth=linewidth, )
# ============================================================================= # Multi-Bounce Visualization # =============================================================================
[docs] def plot_multi_bounce_paths( ax: Axes, ray_paths: dict[str, list[np.ndarray]], projection: str = "xz", reflected_color: str = "cyan", refracted_color: str = "orange", alpha: float = 0.3, linewidth: float = 0.5, max_paths: int = 100, ) -> None: """ Plot multi-bounce ray paths from trace_rays_multi_bounce output. Parameters ---------- ax : Axes Matplotlib axes. ray_paths : dict Dictionary with 'reflected_paths' and/or 'refracted_paths' lists. projection : str Coordinate projection. reflected_color : str Color for reflected paths. refracted_color : str Color for refracted paths. alpha : float Transparency. linewidth : float Line width. max_paths : int Maximum paths to plot. """ idx1, idx2, _, _ = get_projection_config(projection) # Plot reflected paths if "reflected_paths" in ray_paths: paths = ray_paths["reflected_paths"] n_paths = min(len(paths), max_paths) sample_idx = ( np.linspace(0, len(paths) - 1, n_paths, dtype=int) if len(paths) > max_paths else range(len(paths)) ) for i in sample_idx: path = paths[i] if len(path) > 1: ax.plot( path[:, idx1], path[:, idx2], color=reflected_color, alpha=alpha, linewidth=linewidth, ) # Plot refracted paths if "refracted_paths" in ray_paths: paths = ray_paths["refracted_paths"] n_paths = min(len(paths), max_paths) sample_idx = ( np.linspace(0, len(paths) - 1, n_paths, dtype=int) if len(paths) > max_paths else range(len(paths)) ) for i in sample_idx: path = paths[i] if len(path) > 1: ax.plot( path[:, idx1], path[:, idx2], color=refracted_color, alpha=alpha, linewidth=linewidth, )
# ============================================================================= # Composite Figure Builders # =============================================================================
[docs] def create_ray_overview_figure( rays: "RayBatch", surface: "Surface", reflected_rays: Optional["RayBatch"] = None, bounce_points: list[np.ndarray] | None = None, figsize: tuple[float, float] = (16, 10), x_range: tuple[float, float] = (-500, 500), title: str = "Ray Tracing Overview", save_path: str | None = None, ) -> Figure: """ Create comprehensive ray tracing overview figure. Parameters ---------- rays : RayBatch Initial rays. surface : Surface Wave surface. reflected_rays : RayBatch, optional Reflected rays. bounce_points : List[ndarray], optional List of bounce position arrays per bounce. figsize : tuple Figure size. x_range : tuple X-axis range. title : str Figure title. save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure. """ fig, axes = plt.subplots(2, 2, figsize=figsize, constrained_layout=True) fig.suptitle(title, fontsize=14, fontweight="bold") # Top-left: XZ view with surface and rays ax_xz = axes[0, 0] plot_surface_profile(ax_xz, surface, x_range=x_range) plot_incoming_rays(ax_xz, rays, surface, projection="xz") if reflected_rays is not None: plot_reflected_rays(ax_xz, reflected_rays, projection="xz", length=50) if bounce_points is not None: for i, bp in enumerate(bounce_points): if len(bp) > 0: plot_bounce_points( ax_xz, np.array(bp), bounce_number=i + 1, projection="xz" ) ax_xz.legend(loc="upper right") ax_xz.set_title("XZ View (Side)") # Top-right: XY view (top-down) ax_xy = axes[0, 1] plot_ray_endpoints_scatter(ax_xy, rays, projection="xy", color_by="intensity") ax_xy.set_title("XY View (Top)") # Bottom-left: Ray endpoint histogram ax_hist = axes[1, 0] if reflected_rays is not None: plot_ray_endpoints_histogram(ax_hist, reflected_rays, projection="xz") else: plot_ray_endpoints_histogram(ax_hist, rays, projection="xz") ax_hist.set_title("Ray Density") # Bottom-right: Surface detail ax_detail = axes[1, 1] detail_range = (x_range[0] / 5, x_range[1] / 5) plot_surface_profile(ax_detail, surface, x_range=detail_range, color="darkblue") if ( bounce_points is not None and len(bounce_points) > 0 and len(bounce_points[0]) > 0 ): # Filter to detail range bp = np.array(bounce_points[0]) mask = (bp[:, 0] >= detail_range[0]) & (bp[:, 0] <= detail_range[1]) if np.any(mask): plot_bounce_points(ax_detail, bp[mask], bounce_number=1, projection="xz") ax_detail.set_title("Surface Detail") if save_path: save_figure(fig, save_path) return fig
# ============================================================================= # Production Simulation Figure # =============================================================================
[docs] def plot_production_ray_overview( original_rays: "RayBatch", surface: "Surface", config: dict, output_path: str, timestamp: str, max_bounces: int = 2, ) -> Figure: """ Create production simulation ray overview with surface bounce points. Shows incoming rays, bounce points on wave surface (colored by bounce number), and reflected rays toward recording sphere. Parameters ---------- original_rays : RayBatch Original rays before tracing. surface : Surface The wave surface (e.g., CurvedWaveSurface). config : dict Simulation configuration with keys: - grazing_angle: Beam grazing angle in degrees - beam_radius: Beam radius in meters - earth_radius: Earth radius in meters - recording_altitude: Recording sphere altitude in meters - source_distance: Source distance in meters output_path : str Directory to save figure. timestamp : str Timestamp for filename. max_bounces : int Maximum number of bounces to visualize (default: 2). Returns ------- Figure Matplotlib figure. """ from pathlib import Path from ..surfaces import EARTH_RADIUS from ..utilities.ray_data import create_ray_batch output_path = Path(output_path) output_path.mkdir(parents=True, exist_ok=True) # Subsample rays for visualization n_vis_rays = min(500, original_rays.num_rays) vis_sample_idx = np.random.choice(original_rays.num_rays, n_vis_rays, replace=False) original_rays_vis = create_ray_batch(num_rays=n_vis_rays) original_rays_vis.positions[:] = original_rays.positions[vis_sample_idx] original_rays_vis.directions[:] = original_rays.directions[vis_sample_idx] original_rays_vis.wavelengths[:] = original_rays.wavelengths[vis_sample_idx] original_rays_vis.intensities[:] = original_rays.intensities[vis_sample_idx] original_rays_vis.active[:] = True # Trace rays through multiple bounces using batch intersection bounce_points = [[] for _ in range(max_bounces)] # Start with original rays current_rays = original_rays_vis.clone() for bounce_num in range(max_bounces): # Batch intersection for all active rays hit_distances, hit_mask = surface.intersect( current_rays.positions, current_rays.directions ) if not np.any(hit_mask): break # Get hit positions for rays that intersected hit_positions = ( current_rays.positions[hit_mask] + hit_distances[hit_mask, np.newaxis] * current_rays.directions[hit_mask] ) # Store bounce points for hit_pos in hit_positions: bounce_points[bounce_num].append(hit_pos.copy()) # Get normals at hit points normals = surface.normal_at(hit_positions, current_rays.directions[hit_mask]) # Compute reflected directions dot_prod = np.sum( current_rays.directions[hit_mask] * normals, axis=1, keepdims=True ) reflected_dirs = current_rays.directions[hit_mask] - 2 * dot_prod * normals # Create new ray batch for next bounce (only rays that hit) n_hits = np.sum(hit_mask) if n_hits == 0: break next_rays = create_ray_batch(num_rays=n_hits) next_rays.positions[:] = hit_positions + 0.01 * reflected_dirs # Small offset next_rays.directions[:] = reflected_dirs next_rays.intensities[:] = current_rays.intensities[hit_mask] next_rays.active[:] = True current_rays = next_rays # Convert bounce points to arrays and filter to beam footprint region grazing_angle_rad = np.radians(config["grazing_angle"]) elongation_factor = 1.0 / np.sin(grazing_angle_rad) footprint_radius = config["beam_radius"] * elongation_factor filter_radius = footprint_radius * 2.0 # Keep points within 2x the footprint for i in range(max_bounces): if len(bounce_points[i]) > 0: bounce_points[i] = np.array(bounce_points[i]) # Filter to remove rays that escaped to infinity distances = np.sqrt( bounce_points[i][:, 0] ** 2 + bounce_points[i][:, 1] ** 2 ) valid_mask = distances < filter_radius bounce_points[i] = bounce_points[i][valid_mask] else: bounce_points[i] = np.empty((0, 3)) # Print bounce position statistics print("Bounce Position Statistics:") for i in range(max_bounces): if len(bounce_points[i]) > 0: x_positions = bounce_points[i][:, 0] z_positions = bounce_points[i][:, 2] print(f" Bounce {i+1}:") print( f" X: mean={np.mean(x_positions):6.1f} m, std={np.std(x_positions):6.1f} m, " f"range=[{np.min(x_positions):6.1f}, {np.max(x_positions):6.1f}]" ) print( f" Z: mean={np.mean(z_positions):6.3f} m, std={np.std(z_positions):6.3f} m, " f"range=[{np.min(z_positions):6.3f}, {np.max(z_positions):6.3f}]" ) print(f" Count: {len(bounce_points[i])} rays") # Create figure fig, axes = plt.subplots(1, 2, figsize=(16, 8)) # Left panel: Full scale view ax1 = axes[0] earth_center = np.array([0, 0, -EARTH_RADIUS]) theta = np.linspace(-0.01, 0.01, 100) earth_x = EARTH_RADIUS * np.sin(theta) earth_z = earth_center[2] + EARTH_RADIUS * np.cos(theta) ax1.fill_between( earth_x / 1000, earth_z / 1000, -10, color="#4a90d9", alpha=0.3, label="Ocean" ) ax1.plot(earth_x / 1000, earth_z / 1000, "b-", linewidth=2, label="Sea surface") # Recording sphere recording_radius = ( config.get("earth_radius", EARTH_RADIUS) + config["recording_altitude"] ) rec_x = recording_radius * np.sin(theta) rec_z = earth_center[2] + recording_radius * np.cos(theta) ax1.plot( rec_x / 1000, rec_z / 1000, "g--", linewidth=1.5, label=f'Recording sphere ({config["recording_altitude"]/1000:.0f} km)', ) # Plot sample rays (show incoming and reflected rays) if len(bounce_points[0]) > 0: n_plot = min(200, len(bounce_points[0])) plot_indices = np.linspace(0, len(bounce_points[0]) - 1, n_plot, dtype=int) # Incoming rays if len(plot_indices) > 0: idx = plot_indices[0] start = original_rays_vis.positions[idx] end = bounce_points[0][idx] ax1.plot( [start[0] / 1000, end[0] / 1000], [start[2] / 1000, end[2] / 1000], "r-", alpha=0.6, linewidth=0.8, label="Incoming rays", ) for idx in plot_indices[1:]: if idx < len(original_rays_vis.positions) and idx < len(bounce_points[0]): start = original_rays_vis.positions[idx] end = bounce_points[0][idx] ax1.plot( [start[0] / 1000, end[0] / 1000], [start[2] / 1000, end[2] / 1000], "r-", alpha=0.5, linewidth=0.8, ) # Reflected rays (from first bounce) ray_length = config["recording_altitude"] * 1.5 if len(bounce_points[0]) > 0: # Get intersection and normals for sample rays sample_positions = original_rays_vis.positions[plot_indices] sample_directions = original_rays_vis.directions[plot_indices] hit_distances, hit_mask = surface.intersect( sample_positions, sample_directions ) if np.any(hit_mask): hit_positions = ( sample_positions[hit_mask] + hit_distances[hit_mask, np.newaxis] * sample_directions[hit_mask] ) normals = surface.normal_at(hit_positions, sample_directions[hit_mask]) dot_prod = np.sum( sample_directions[hit_mask] * normals, axis=1, keepdims=True ) reflected_dirs = sample_directions[hit_mask] - 2 * dot_prod * normals for i, (hit_pos, reflected_dir) in enumerate( zip(hit_positions, reflected_dirs, strict=False) ): end = hit_pos + reflected_dir * ray_length if i == 0: ax1.plot( [hit_pos[0] / 1000, end[0] / 1000], [hit_pos[2] / 1000, end[2] / 1000], "g-", alpha=0.6, linewidth=0.8, label="Reflected rays", ) else: ax1.plot( [hit_pos[0] / 1000, end[0] / 1000], [hit_pos[2] / 1000, end[2] / 1000], "g-", alpha=0.5, linewidth=0.8, ) ax1.set_xlabel("X (km)", fontsize=12) ax1.set_ylabel("Z (km)", fontsize=12) ax1.set_title("Ray Paths Overview (X-Z Plane)", fontsize=14) ax1.legend(loc="upper right") ax1.set_aspect("equal") ax1.grid(True, alpha=0.3) max_range = ( max(config.get("source_distance", 10000), config["recording_altitude"]) * 1.5 ) ax1.set_xlim(-max_range / 1000 * 0.5, max_range / 1000 * 1.5) ax1.set_ylim(-5, config["recording_altitude"] / 1000 * 1.3) # Right panel: Surface detail with multi-bounce points ax2 = axes[1] x_range = np.linspace(-footprint_radius * 1.2, footprint_radius * 1.2, 400) z_wave = [] for x in x_range: if hasattr(surface, "_compute_wave_displacement"): _, _, dz = surface._compute_wave_displacement( np.array([x]), np.array([0.0]) ) z_wave.append(dz[0]) elif hasattr(surface, "_surface_z"): z_wave.append(surface._surface_z(x, 0.0)) else: z_wave.append(0.0) z_wave = np.array(z_wave) ax2.fill_between(x_range, z_wave, z_wave.min() - 2, color="#4a90d9", alpha=0.5) ax2.plot(x_range, z_wave, "b-", linewidth=2, label="Wave surface") # Plot bounce points with different colors for each generation bounce_colors = ["red", "cyan", "magenta", "yellow"] bounce_sizes = [20, 15, 12, 10] bounce_labels = ["1st bounce", "2nd bounce", "3rd bounce", "4th bounce"] for bounce_idx in range(min(max_bounces, len(bounce_points))): if len(bounce_points[bounce_idx]) > 0: bounce_x = bounce_points[bounce_idx][:, 0] bounce_z = bounce_points[bounce_idx][:, 2] mean_x = np.mean(bounce_x) label_text = f"{bounce_labels[bounce_idx]} (mean X={mean_x:.1f}m)" ax2.scatter( bounce_x, bounce_z, c=bounce_colors[bounce_idx], s=bounce_sizes[bounce_idx], alpha=0.7, label=label_text, zorder=3 + bounce_idx, edgecolors="black", linewidths=0.5, ) # Draw vertical line at mean X position ax2.axvline( mean_x, color=bounce_colors[bounce_idx], linestyle=":", alpha=0.4, linewidth=1.5, ) # Mark beam footprint ax2.axvline(-footprint_radius, color="orange", linestyle="--", alpha=0.5) ax2.axvline( footprint_radius, color="orange", linestyle="--", alpha=0.5, label=f"Beam footprint (±{footprint_radius:.0f}m)", ) ax2.set_xlabel("X (m)", fontsize=12) ax2.set_ylabel("Z (m)", fontsize=12) ax2.set_title("Wave Surface Detail with Multi-Bounce Points", fontsize=14) ax2.legend() ax2.grid(True, alpha=0.3) plt.tight_layout() fig_path = output_path / f"simulation_{timestamp}_overview.png" plt.savefig(fig_path, dpi=150, bbox_inches="tight") plt.close() return fig
[docs] def plot_wave_surface_detail( reflected_rays: "RayBatch", surface: "Surface", x_range: tuple[float, float] = (-200, 200), figsize: tuple[float, float] = (12, 6), save_path: str | None = None, ) -> Figure: """ Plot wave surface detail with ray intersection points. Parameters ---------- reflected_rays : RayBatch Batch of reflected rays. surface : Surface The surface object (must have _surface_z method). x_range : tuple X-axis range for plotting. figsize : tuple Figure size (width, height). save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure. """ fig, ax = plt.subplots(figsize=figsize) # Plot wave surface x_detail = np.linspace(x_range[0], x_range[1], 1000) y_detail = np.zeros_like(x_detail) z_detail = surface._surface_z(x_detail, y_detail) ax.plot(x_detail, z_detail, "b-", linewidth=3, label="Wave Surface", zorder=3) ax.fill_between( x_detail, z_detail, z_detail.min() - 0.5, color="lightblue", alpha=0.4, zorder=1, ) # Plot intersection points # Back-calculate actual hit positions (rays are offset by 0.01m along direction) actual_hit_positions = reflected_rays.positions - 0.01 * reflected_rays.directions reflection_x = actual_hit_positions[:, 0] reflection_z = actual_hit_positions[:, 2] ax.scatter( reflection_x, reflection_z, c="red", s=8, alpha=0.6, zorder=5, label="Intersection Points", ) ax.set_xlim(x_range[0], x_range[1]) z_range = z_detail.max() - z_detail.min() ax.set_ylim(z_detail.min() - z_range * 0.3, z_detail.max() + z_range * 0.5) ax.set_xlabel("X Position (m)", fontsize=11, fontweight="bold") ax.set_ylabel("Z Position (m)", fontsize=11, fontweight="bold") ax.set_title("Wave Surface Detail with Ray Intersections", fontweight="bold") ax.grid(True, alpha=0.3) ax.legend(loc="upper right", fontsize=10) if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight") return fig
[docs] def plot_ray_paths_with_surface( rays: "RayBatch", reflected_rays: "RayBatch", surface: "Surface", detector_distance: float = 1000.0, source_distance: float = 1000.0, refracted_rays: Optional["RayBatch"] = None, ray_paths: dict | None = None, figsize: tuple[float, float] = (16, 10), save_path: str | None = None, ) -> Figure: """ Plot full ray paths (incoming, reflected, and refracted) with wave surface. Parameters ---------- rays : RayBatch Initial ray batch (before interaction). reflected_rays : RayBatch Reflected ray batch (after interaction). surface : Surface The surface object. detector_distance : float Detector distance in meters. source_distance : float Source distance in meters (unused, for API compatibility). refracted_rays : RayBatch, optional Refracted ray batch (after interaction). ray_paths : dict, optional Dictionary with ray path data from trace_rays_multi_bounce containing: - 'reflected_paths': list of Nx3 arrays, one per ray - 'refracted_paths': list of Nx3 arrays for refracted rays - 'reflected_final_dirs': final direction for each reflected path - 'refracted_final_dirs': final direction for each refracted path figsize : tuple Figure size. save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure. """ fig, ax = plt.subplots(figsize=figsize) # Determine scale based on distances (use km for large distances) use_km = detector_distance >= 100.0 scale_factor = 1000.0 if use_km else 1.0 distance_label = "km" if use_km else "m" # Plot surface with extended range x_min = min(rays.positions[:, 0].min(), reflected_rays.positions[:, 0].min()) x_max = max(rays.positions[:, 0].max(), reflected_rays.positions[:, 0].max()) x_range = x_max - x_min x_surf = np.linspace(x_min - x_range * 0.2, x_max + detector_distance * 0.3, 1000) y_surf = np.zeros_like(x_surf) # Handle both wave surfaces and planar surfaces if hasattr(surface, "_surface_z"): z_surf = surface._surface_z(x_surf, y_surf) surface_label = "Wave Surface" else: # Planar surface - assume z=0 horizontal plane z_surf = np.zeros_like(x_surf) surface_label = "Planar Surface" ax.plot( x_surf / scale_factor, z_surf, "b-", linewidth=3, label=surface_label, zorder=3, ) z_fill_bottom = z_surf.min() - max(abs(z_surf.max() - z_surf.min()) * 0.5, 0.01) ax.fill_between( x_surf / scale_factor, z_surf, z_fill_bottom, color="lightblue", alpha=0.3, zorder=1, ) # Sample rays to plot (for performance and clarity) num_rays_to_plot = min(100, rays.num_rays) indices_to_plot = np.linspace(0, rays.num_rays - 1, num_rays_to_plot, dtype=int) # Plot ray paths if ray_paths dict is provided (multi-bounce tracking) if ray_paths is not None and "reflected_paths" in ray_paths: reflected_paths = ray_paths["reflected_paths"] reflected_final_dirs = ray_paths.get("reflected_final_dirs", []) refracted_paths = ray_paths.get("refracted_paths", []) refracted_final_dirs = ray_paths.get("refracted_final_dirs", []) # Sample paths to plot num_paths = len(reflected_paths) num_to_plot = min(100, num_paths) path_indices = ( np.linspace(0, num_paths - 1, num_to_plot, dtype=int) if num_paths > 0 else [] ) # Color for reflected paths based on number of bounces for path_idx in path_indices: path = reflected_paths[path_idx] if path is None or len(path) < 2: continue # Plot the complete path ax.plot( path[:, 0] / scale_factor, path[:, 2], "b-", linewidth=0.8, alpha=0.4, zorder=2, ) # Mark bounce points with different colors colors = ["red", "orange", "yellow", "pink", "purple"] for i in range(1, len(path)): # Skip start position color = colors[(i - 1) % len(colors)] ax.scatter( path[i, 0] / scale_factor, path[i, 2], c=color, s=10, alpha=0.6, zorder=4, ) # Plot final direction as an extending ray if ( path_idx < len(reflected_final_dirs) and reflected_final_dirs[path_idx] is not None ): final_pos = path[-1] final_dir = reflected_final_dirs[path_idx] end_pos = final_pos + final_dir * detector_distance * 0.3 ax.plot( [final_pos[0] / scale_factor, end_pos[0] / scale_factor], [final_pos[2], end_pos[2]], "r-", linewidth=0.8, alpha=0.5, zorder=2, ) # Plot refracted paths (dashed green lines going into water) num_refr_paths = len(refracted_paths) num_refr_to_plot = min(50, num_refr_paths) refr_path_indices = ( np.linspace(0, num_refr_paths - 1, num_refr_to_plot, dtype=int) if num_refr_paths > 0 else [] ) for path_idx in refr_path_indices: path = refracted_paths[path_idx] if path is None or len(path) < 1: continue if ( path_idx < len(refracted_final_dirs) and refracted_final_dirs[path_idx] is not None ): start_pos = path[0] refr_dir = refracted_final_dirs[path_idx] end_pos = start_pos + refr_dir * detector_distance * 0.2 ax.plot( [start_pos[0] / scale_factor, end_pos[0] / scale_factor], [start_pos[2], end_pos[2]], "g--", linewidth=0.6, alpha=0.4, zorder=2, ) else: # Original single-bounce plotting for idx in indices_to_plot: if idx >= reflected_rays.num_rays: continue # Back-calculate actual hit position (rays are offset by 0.01m along direction) reflect_dir = reflected_rays.directions[idx, :] actual_hit_pos = reflected_rays.positions[idx, :] - 0.01 * reflect_dir start_pos = rays.positions[idx, :] ax.plot( [start_pos[0] / scale_factor, actual_hit_pos[0] / scale_factor], [start_pos[2], actual_hit_pos[2]], "b-", linewidth=0.8, alpha=0.4, zorder=2, ) # Reflected ray: from actual hit position outward ray_length = detector_distance * 0.5 end_pos = actual_hit_pos + reflect_dir * ray_length ax.plot( [actual_hit_pos[0] / scale_factor, end_pos[0] / scale_factor], [actual_hit_pos[2], end_pos[2]], "r-", linewidth=0.8, alpha=0.5, zorder=2, ) # Plot refracted rays (if provided) if refracted_rays is not None and refracted_rays.num_rays > 0: refr_indices = np.linspace( 0, min(refracted_rays.num_rays - 1, rays.num_rays - 1), num_rays_to_plot, dtype=int, ) for idx in refr_indices: if idx >= refracted_rays.num_rays: continue refract_dir = refracted_rays.directions[idx, :] # Back-calculate actual hit position (rays are offset by 0.01m along direction) actual_hit_pos = refracted_rays.positions[idx, :] - 0.01 * refract_dir ray_length = detector_distance * 0.3 end_pos = actual_hit_pos + refract_dir * ray_length ax.plot( [actual_hit_pos[0] / scale_factor, end_pos[0] / scale_factor], [actual_hit_pos[2], end_pos[2]], "g--", linewidth=0.8, alpha=0.5, zorder=2, ) # Add beam source indicator beam_source = rays.positions[0, :] ax.scatter( [beam_source[0] / scale_factor], [beam_source[2]], c="blue", s=300, marker="*", edgecolors="black", linewidths=2, label="Beam Source", zorder=6, ) ax.set_xlabel(f"X Position ({distance_label})", fontsize=13, fontweight="bold") ax.set_ylabel("Z Position (m)", fontsize=13, fontweight="bold") # Update title based on what's shown if ray_paths is not None and "reflected_paths" in ray_paths: num_paths = len(ray_paths["reflected_paths"]) title_text = f"Ray Paths: Multi-Bounce Reflection ({num_paths} ray paths)" elif refracted_rays is not None and refracted_rays.num_rays > 0: title_text = ( f"Ray Paths: Reflection & Refraction ({num_rays_to_plot} rays shown)" ) else: title_text = ( f"Ray Paths: Reflection from Wave Surface ({num_rays_to_plot} rays shown)" ) ax.set_title(title_text, fontweight="bold", fontsize=15) ax.grid(True, alpha=0.3, linewidth=0.5) # Build legend legend_elements = [ plt.Line2D([0], [0], color="b", linewidth=2, alpha=0.5, label="Incoming"), plt.Line2D([0], [0], color="r", linewidth=2, alpha=0.5, label="Reflected"), ] if refracted_rays is not None and refracted_rays.num_rays > 0: legend_elements.append( plt.Line2D( [0], [0], color="g", linewidth=2, linestyle="--", alpha=0.5, label="Refracted", ) ) legend_elements.append( plt.Line2D( [0], [0], marker="*", color="w", markerfacecolor="blue", markersize=12, label="Beam Source", ) ) ax.legend(handles=legend_elements, loc="upper left", fontsize=11, framealpha=0.9) ax.set_xlim( x_min / scale_factor - x_range * 0.2 / scale_factor, (x_max + detector_distance * 0.3) / scale_factor, ) ax.set_ylim( min( z_surf.min() - abs(z_surf.max() - z_surf.min()) * 0.5, -detector_distance * 0.01, ), max(z_surf.max(), detector_distance * 0.1), ) if not use_km: ax.set_aspect("equal", adjustable="datalim") if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight") plt.close(fig) return fig
# ============================================================================= # Legacy Convenience Functions (Backward Compatibility) # =============================================================================
[docs] def plot_ray_paths_2d( ray_history: list["RayBatch"], max_rays: int = 100, color_by: str = "wavelength", alpha: float = 0.4, linewidth: float = 0.8, figsize: tuple[float, float] = (15, 5), save_path: str | None = None, ) -> Figure: """ Create figure with three 2D projections of ray paths. This is a convenience function for quick visualization. For custom layouts, use plot_ray_paths_projection() on individual axes. Parameters ---------- ray_history : List[RayBatch] List of ray batches at different propagation steps. max_rays : int Maximum rays to plot (sampled uniformly if exceeded). color_by : str Color rays by: 'wavelength', 'intensity', 'generation', 'index'. alpha : float Line transparency. linewidth : float Line width. figsize : tuple Figure size. save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure with three subplots (XY, XZ, YZ). """ fig, axes = plt.subplots(1, 3, figsize=figsize, constrained_layout=True) fig.suptitle("Ray Paths - 2D Projections", fontsize=14, fontweight="bold") for ax, proj in zip(axes, ["xy", "xz", "yz"], strict=False): plot_ray_paths_projection( ax, ray_history, projection=proj, max_rays=max_rays, color_by=color_by, alpha=alpha, linewidth=linewidth, show_colorbar=(proj == "yz"), ) if save_path: save_figure(fig, save_path) return fig
# Re-export Fresnel/Brewster functions from fresnel_plots module for backward compatibility # Re-export polarization functions from polarization_plots module for backward compatibility
[docs] def plot_ray_endpoints( rays: "RayBatch", plane: str = "xy", color_by: str = "wavelength", bins: int = 50, figsize: tuple[float, float] = (12, 5), save_path: str | None = None, ) -> Figure: """ Create figure with scatter and histogram of ray endpoints. This is a convenience function for quick visualization. For custom layouts, use plot_ray_endpoints_scatter() and plot_ray_endpoints_histogram(). Parameters ---------- rays : RayBatch Ray batch with endpoint positions. plane : str Projection plane: 'xy', 'xz', 'yz'. color_by : str Color scatter by: 'wavelength', 'intensity'. bins : int Histogram bins. figsize : tuple Figure size. save_path : str, optional Path to save figure. Returns ------- Figure Matplotlib figure with scatter and histogram. """ fig, axes = plt.subplots(1, 2, figsize=figsize, constrained_layout=True) fig.suptitle( f"Ray Endpoints - {plane.upper()} Plane", fontsize=14, fontweight="bold" ) plot_ray_endpoints_scatter( axes[0], rays, projection=plane, color_by=color_by, show_colorbar=True ) plot_ray_endpoints_histogram(axes[1], rays, projection=plane, bins=bins) if save_path: save_figure(fig, save_path) return fig