# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Ray Tracing Visualization - Individual Axis Functions
Functions for plotting ray paths, intersections, and propagation.
Each function draws on a single axis, enabling flexible composition.
"""
from typing import TYPE_CHECKING, Optional
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
if TYPE_CHECKING:
from ..utilities.ray_data import RayBatch
from ..surfaces import Surface
from .common import (
DEFAULT_LINEWIDTH,
DEFAULT_MARKERSIZE,
INTENSITY_CMAP,
LINE_ALPHA,
SCATTER_ALPHA,
WAVELENGTH_CMAP,
add_colorbar,
get_color_mapping,
get_projection_config,
save_figure,
setup_axis_grid,
)
# =============================================================================
# Single-Axis Ray Path Functions
# =============================================================================
[docs]
def plot_ray_paths_projection(
ax: Axes,
ray_history: list["RayBatch"],
projection: str = "xz",
max_rays: int = 100,
color_by: str = "wavelength",
alpha: float = LINE_ALPHA,
linewidth: float = DEFAULT_LINEWIDTH,
show_colorbar: bool = True,
) -> plt.cm.ScalarMappable | None:
"""
Plot ray paths as a 2D projection on given axis.
Parameters
----------
ax : Axes
Matplotlib axes to draw on.
ray_history : List[RayBatch]
List of ray batches at different time steps.
projection : str
Projection plane: 'xy', 'xz', or 'yz'.
max_rays : int
Maximum rays to plot for performance.
color_by : str
Color rays by: 'wavelength', 'intensity', 'generation', 'index'.
alpha : float
Line transparency.
linewidth : float
Line width.
show_colorbar : bool
Whether to add colorbar.
Returns
-------
sm : ScalarMappable or None
ScalarMappable for external colorbar, or None.
"""
if len(ray_history) == 0:
return None
initial_batch = ray_history[0]
n_rays = initial_batch.num_rays
# Sample rays
if n_rays > max_rays:
ray_indices = np.linspace(0, n_rays - 1, max_rays, dtype=int)
else:
ray_indices = np.arange(n_rays)
# Coordinate mapping
idx1, idx2, xlabel, ylabel = get_projection_config(projection)
# Color mapping
if color_by == "wavelength":
values = initial_batch.wavelengths[ray_indices] * 1e9
cmap = WAVELENGTH_CMAP
label = "Wavelength (nm)"
elif color_by == "intensity":
values = initial_batch.intensities[ray_indices]
cmap = INTENSITY_CMAP
label = "Intensity"
elif color_by == "generation":
values = initial_batch.generations[ray_indices]
cmap = "tab10"
label = "Generation"
else:
values = ray_indices.astype(float)
cmap = "tab20"
label = "Ray Index"
colors, norm, sm = get_color_mapping(values, cmap)
# Plot paths
for i, ray_idx in enumerate(ray_indices):
coords1 = [batch.positions[ray_idx, idx1] for batch in ray_history]
coords2 = [batch.positions[ray_idx, idx2] for batch in ray_history]
ax.plot(coords1, coords2, color=colors[i], alpha=alpha, linewidth=linewidth)
setup_axis_grid(ax, xlabel, ylabel, f"{projection.upper()} Projection")
ax.set_aspect("equal", adjustable="box")
if show_colorbar:
add_colorbar(ax, sm, label)
return sm
[docs]
def plot_ray_endpoints_scatter(
ax: Axes,
rays: "RayBatch",
projection: str = "xy",
color_by: str = "intensity",
alpha: float = SCATTER_ALPHA,
size: float = DEFAULT_MARKERSIZE,
show_colorbar: bool = True,
) -> plt.cm.ScalarMappable | None:
"""
Plot ray endpoint positions as scatter plot.
Parameters
----------
ax : Axes
Matplotlib axes.
rays : RayBatch
Ray batch.
projection : str
Plane: 'xy', 'xz', 'yz'.
color_by : str
Color by: 'intensity', 'wavelength', 'generation', 'time'.
alpha : float
Point transparency.
size : float
Point size.
show_colorbar : bool
Whether to add colorbar.
Returns
-------
sm : ScalarMappable or None
For external colorbar.
"""
active_mask = rays.active
positions = rays.positions[active_mask]
# Coordinate selection
idx1, idx2, xlabel, ylabel = get_projection_config(projection)
x, y = positions[:, idx1], positions[:, idx2]
# Color values
if color_by == "intensity":
c = rays.intensities[active_mask]
cmap = INTENSITY_CMAP
clabel = "Intensity"
elif color_by == "wavelength":
c = rays.wavelengths[active_mask] * 1e9
cmap = WAVELENGTH_CMAP
clabel = "Wavelength (nm)"
elif color_by == "generation":
c = rays.generations[active_mask]
cmap = "tab10"
clabel = "Generation"
elif color_by == "time":
c = rays.accumulated_time[active_mask] * 1e6
cmap = "coolwarm"
clabel = "Time (μs)"
else:
c = np.arange(len(x))
cmap = "viridis"
clabel = "Index"
scatter = ax.scatter(x, y, c=c, s=size, alpha=alpha, cmap=cmap)
setup_axis_grid(ax, xlabel, ylabel, f"Ray Endpoints - {projection.upper()}")
ax.set_aspect("equal", adjustable="box")
if show_colorbar and c is not None:
add_colorbar(ax, scatter, clabel)
return scatter
[docs]
def plot_ray_endpoints_histogram(
ax: Axes,
rays: "RayBatch",
projection: str = "xy",
bins: int = 50,
cmap: str = "hot",
) -> None:
"""
Plot 2D histogram of ray endpoint density.
Parameters
----------
ax : Axes
Matplotlib axes.
rays : RayBatch
Ray batch.
projection : str
Plane: 'xy', 'xz', 'yz'.
bins : int
Number of histogram bins.
cmap : str
Colormap name.
"""
active_mask = rays.active
positions = rays.positions[active_mask]
idx1, idx2, xlabel, ylabel = get_projection_config(projection)
x, y = positions[:, idx1], positions[:, idx2]
_, _, _, im = ax.hist2d(x, y, bins=bins, cmap=cmap, cmin=1)
setup_axis_grid(ax, xlabel, ylabel, "Ray Density")
ax.set_aspect("equal", adjustable="box")
add_colorbar(ax, im, "Count")
# =============================================================================
# Surface Intersection Visualization
# =============================================================================
[docs]
def plot_surface_profile(
ax: Axes,
surface: "Surface",
x_range: tuple[float, float] = (-200, 200),
y: float = 0.0,
n_points: int = 1000,
color: str = "blue",
linewidth: float = 2.0,
label: str = "Surface",
) -> None:
"""
Plot surface height profile along x-axis.
Parameters
----------
ax : Axes
Matplotlib axes.
surface : Surface
Surface object with _surface_z method.
x_range : tuple
(x_min, x_max) range.
y : float
Y-coordinate for profile.
n_points : int
Number of sample points.
color : str
Line color.
linewidth : float
Line width.
label : str
Legend label.
"""
x = np.linspace(x_range[0], x_range[1], n_points)
y_arr = np.full_like(x, y)
if hasattr(surface, "_surface_z"):
z = np.array(
[surface._surface_z(xi, yi) for xi, yi in zip(x, y_arr, strict=False)]
)
else:
z = np.zeros_like(x)
ax.plot(x, z, color=color, linewidth=linewidth, label=label)
ax.fill_between(x, z, z.min() - 0.5, alpha=0.3, color=color)
setup_axis_grid(ax, "X (m)", "Z (m)")
[docs]
def plot_bounce_points(
ax: Axes,
bounce_positions: np.ndarray,
bounce_number: int = 1,
color: str | None = None,
size: float = 20,
alpha: float = 0.7,
projection: str = "xz",
label: str | None = None,
) -> None:
"""
Plot ray bounce points on surface.
Parameters
----------
ax : Axes
Matplotlib axes.
bounce_positions : ndarray
(N, 3) array of bounce positions.
bounce_number : int
Bounce index (for color selection).
color : str, optional
Override color.
size : float
Marker size.
alpha : float
Transparency.
projection : str
Coordinate projection ('xz', 'xy', 'yz').
label : str, optional
Legend label.
"""
if len(bounce_positions) == 0:
return
# Default colors by bounce number
bounce_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"]
if color is None:
color = bounce_colors[bounce_number % len(bounce_colors)]
idx1, idx2, _, _ = get_projection_config(projection)
x = bounce_positions[:, idx1]
z = bounce_positions[:, idx2]
if label is None:
label = f"Bounce {bounce_number}"
ax.scatter(x, z, c=color, s=size, alpha=alpha, label=label, edgecolors="none")
[docs]
def plot_incoming_rays(
ax: Axes,
rays: "RayBatch",
surface: "Surface",
projection: str = "xz",
color: str = "gold",
alpha: float = 0.3,
linewidth: float = 0.5,
max_rays: int = 100,
) -> None:
"""
Plot incoming ray segments from origin to surface intersection.
Parameters
----------
ax : Axes
Matplotlib axes.
rays : RayBatch
Ray batch before intersection.
surface : Surface
Surface for intersection calculation.
projection : str
Coordinate projection.
color : str
Ray color.
alpha : float
Transparency.
linewidth : float
Line width.
max_rays : int
Maximum rays to plot.
"""
idx1, idx2, _, _ = get_projection_config(projection)
n_rays = min(rays.num_rays, max_rays)
sample_idx = (
np.linspace(0, rays.num_rays - 1, n_rays, dtype=int)
if rays.num_rays > max_rays
else np.arange(rays.num_rays)
)
for i in sample_idx:
if not rays.active[i]:
continue
pos = rays.positions[i]
direction = rays.directions[i]
# Find intersection
t, hit = surface.intersect(
pos.reshape(1, 3).astype(np.float32),
direction.reshape(1, 3).astype(np.float32),
np.array([True]),
)
if hit[0] and t[0] > 0:
intersection = pos + t[0] * direction
ax.plot(
[pos[idx1], intersection[idx1]],
[pos[idx2], intersection[idx2]],
color=color,
alpha=alpha,
linewidth=linewidth,
)
[docs]
def plot_reflected_rays(
ax: Axes,
rays: "RayBatch",
length: float = 100.0,
projection: str = "xz",
color: str = "cyan",
alpha: float = 0.3,
linewidth: float = 0.5,
max_rays: int = 100,
) -> None:
"""
Plot reflected ray segments from current position.
Parameters
----------
ax : Axes
Matplotlib axes.
rays : RayBatch
Reflected ray batch.
length : float
Length of ray segments to draw.
projection : str
Coordinate projection.
color : str
Ray color.
alpha : float
Transparency.
linewidth : float
Line width.
max_rays : int
Maximum rays to plot.
"""
idx1, idx2, _, _ = get_projection_config(projection)
n_rays = min(rays.num_rays, max_rays)
sample_idx = (
np.linspace(0, rays.num_rays - 1, n_rays, dtype=int)
if rays.num_rays > max_rays
else np.arange(rays.num_rays)
)
for i in sample_idx:
if not rays.active[i]:
continue
pos = rays.positions[i]
direction = rays.directions[i]
endpoint = pos + length * direction
ax.plot(
[pos[idx1], endpoint[idx1]],
[pos[idx2], endpoint[idx2]],
color=color,
alpha=alpha,
linewidth=linewidth,
)
# =============================================================================
# Multi-Bounce Visualization
# =============================================================================
[docs]
def plot_multi_bounce_paths(
ax: Axes,
ray_paths: dict[str, list[np.ndarray]],
projection: str = "xz",
reflected_color: str = "cyan",
refracted_color: str = "orange",
alpha: float = 0.3,
linewidth: float = 0.5,
max_paths: int = 100,
) -> None:
"""
Plot multi-bounce ray paths from trace_rays_multi_bounce output.
Parameters
----------
ax : Axes
Matplotlib axes.
ray_paths : dict
Dictionary with 'reflected_paths' and/or 'refracted_paths' lists.
projection : str
Coordinate projection.
reflected_color : str
Color for reflected paths.
refracted_color : str
Color for refracted paths.
alpha : float
Transparency.
linewidth : float
Line width.
max_paths : int
Maximum paths to plot.
"""
idx1, idx2, _, _ = get_projection_config(projection)
# Plot reflected paths
if "reflected_paths" in ray_paths:
paths = ray_paths["reflected_paths"]
n_paths = min(len(paths), max_paths)
sample_idx = (
np.linspace(0, len(paths) - 1, n_paths, dtype=int)
if len(paths) > max_paths
else range(len(paths))
)
for i in sample_idx:
path = paths[i]
if len(path) > 1:
ax.plot(
path[:, idx1],
path[:, idx2],
color=reflected_color,
alpha=alpha,
linewidth=linewidth,
)
# Plot refracted paths
if "refracted_paths" in ray_paths:
paths = ray_paths["refracted_paths"]
n_paths = min(len(paths), max_paths)
sample_idx = (
np.linspace(0, len(paths) - 1, n_paths, dtype=int)
if len(paths) > max_paths
else range(len(paths))
)
for i in sample_idx:
path = paths[i]
if len(path) > 1:
ax.plot(
path[:, idx1],
path[:, idx2],
color=refracted_color,
alpha=alpha,
linewidth=linewidth,
)
# =============================================================================
# Composite Figure Builders
# =============================================================================
# =============================================================================
# Production Simulation Figure
# =============================================================================
[docs]
def plot_production_ray_overview(
original_rays: "RayBatch",
surface: "Surface",
config: dict,
output_path: str,
timestamp: str,
max_bounces: int = 2,
) -> Figure:
"""
Create production simulation ray overview with surface bounce points.
Shows incoming rays, bounce points on wave surface (colored by bounce number),
and reflected rays toward recording sphere.
Parameters
----------
original_rays : RayBatch
Original rays before tracing.
surface : Surface
The wave surface (e.g., CurvedWaveSurface).
config : dict
Simulation configuration with keys:
- grazing_angle: Beam grazing angle in degrees
- beam_radius: Beam radius in meters
- earth_radius: Earth radius in meters
- recording_altitude: Recording sphere altitude in meters
- source_distance: Source distance in meters
output_path : str
Directory to save figure.
timestamp : str
Timestamp for filename.
max_bounces : int
Maximum number of bounces to visualize (default: 2).
Returns
-------
Figure
Matplotlib figure.
"""
from pathlib import Path
from ..surfaces import EARTH_RADIUS
from ..utilities.ray_data import create_ray_batch
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
# Subsample rays for visualization
n_vis_rays = min(500, original_rays.num_rays)
vis_sample_idx = np.random.choice(original_rays.num_rays, n_vis_rays, replace=False)
original_rays_vis = create_ray_batch(num_rays=n_vis_rays)
original_rays_vis.positions[:] = original_rays.positions[vis_sample_idx]
original_rays_vis.directions[:] = original_rays.directions[vis_sample_idx]
original_rays_vis.wavelengths[:] = original_rays.wavelengths[vis_sample_idx]
original_rays_vis.intensities[:] = original_rays.intensities[vis_sample_idx]
original_rays_vis.active[:] = True
# Trace rays through multiple bounces using batch intersection
bounce_points = [[] for _ in range(max_bounces)]
# Start with original rays
current_rays = original_rays_vis.clone()
for bounce_num in range(max_bounces):
# Batch intersection for all active rays
hit_distances, hit_mask = surface.intersect(
current_rays.positions, current_rays.directions
)
if not np.any(hit_mask):
break
# Get hit positions for rays that intersected
hit_positions = (
current_rays.positions[hit_mask]
+ hit_distances[hit_mask, np.newaxis] * current_rays.directions[hit_mask]
)
# Store bounce points
for hit_pos in hit_positions:
bounce_points[bounce_num].append(hit_pos.copy())
# Get normals at hit points
normals = surface.normal_at(hit_positions, current_rays.directions[hit_mask])
# Compute reflected directions
dot_prod = np.sum(
current_rays.directions[hit_mask] * normals, axis=1, keepdims=True
)
reflected_dirs = current_rays.directions[hit_mask] - 2 * dot_prod * normals
# Create new ray batch for next bounce (only rays that hit)
n_hits = np.sum(hit_mask)
if n_hits == 0:
break
next_rays = create_ray_batch(num_rays=n_hits)
next_rays.positions[:] = hit_positions + 0.01 * reflected_dirs # Small offset
next_rays.directions[:] = reflected_dirs
next_rays.intensities[:] = current_rays.intensities[hit_mask]
next_rays.active[:] = True
current_rays = next_rays
# Convert bounce points to arrays and filter to beam footprint region
grazing_angle_rad = np.radians(config["grazing_angle"])
elongation_factor = 1.0 / np.sin(grazing_angle_rad)
footprint_radius = config["beam_radius"] * elongation_factor
filter_radius = footprint_radius * 2.0 # Keep points within 2x the footprint
for i in range(max_bounces):
if len(bounce_points[i]) > 0:
bounce_points[i] = np.array(bounce_points[i])
# Filter to remove rays that escaped to infinity
distances = np.sqrt(
bounce_points[i][:, 0] ** 2 + bounce_points[i][:, 1] ** 2
)
valid_mask = distances < filter_radius
bounce_points[i] = bounce_points[i][valid_mask]
else:
bounce_points[i] = np.empty((0, 3))
# Print bounce position statistics
print("Bounce Position Statistics:")
for i in range(max_bounces):
if len(bounce_points[i]) > 0:
x_positions = bounce_points[i][:, 0]
z_positions = bounce_points[i][:, 2]
print(f" Bounce {i+1}:")
print(
f" X: mean={np.mean(x_positions):6.1f} m, std={np.std(x_positions):6.1f} m, "
f"range=[{np.min(x_positions):6.1f}, {np.max(x_positions):6.1f}]"
)
print(
f" Z: mean={np.mean(z_positions):6.3f} m, std={np.std(z_positions):6.3f} m, "
f"range=[{np.min(z_positions):6.3f}, {np.max(z_positions):6.3f}]"
)
print(f" Count: {len(bounce_points[i])} rays")
# Create figure
fig, axes = plt.subplots(1, 2, figsize=(16, 8))
# Left panel: Full scale view
ax1 = axes[0]
earth_center = np.array([0, 0, -EARTH_RADIUS])
theta = np.linspace(-0.01, 0.01, 100)
earth_x = EARTH_RADIUS * np.sin(theta)
earth_z = earth_center[2] + EARTH_RADIUS * np.cos(theta)
ax1.fill_between(
earth_x / 1000, earth_z / 1000, -10, color="#4a90d9", alpha=0.3, label="Ocean"
)
ax1.plot(earth_x / 1000, earth_z / 1000, "b-", linewidth=2, label="Sea surface")
# Recording sphere
recording_radius = (
config.get("earth_radius", EARTH_RADIUS) + config["recording_altitude"]
)
rec_x = recording_radius * np.sin(theta)
rec_z = earth_center[2] + recording_radius * np.cos(theta)
ax1.plot(
rec_x / 1000,
rec_z / 1000,
"g--",
linewidth=1.5,
label=f'Recording sphere ({config["recording_altitude"]/1000:.0f} km)',
)
# Plot sample rays (show incoming and reflected rays)
if len(bounce_points[0]) > 0:
n_plot = min(200, len(bounce_points[0]))
plot_indices = np.linspace(0, len(bounce_points[0]) - 1, n_plot, dtype=int)
# Incoming rays
if len(plot_indices) > 0:
idx = plot_indices[0]
start = original_rays_vis.positions[idx]
end = bounce_points[0][idx]
ax1.plot(
[start[0] / 1000, end[0] / 1000],
[start[2] / 1000, end[2] / 1000],
"r-",
alpha=0.6,
linewidth=0.8,
label="Incoming rays",
)
for idx in plot_indices[1:]:
if idx < len(original_rays_vis.positions) and idx < len(bounce_points[0]):
start = original_rays_vis.positions[idx]
end = bounce_points[0][idx]
ax1.plot(
[start[0] / 1000, end[0] / 1000],
[start[2] / 1000, end[2] / 1000],
"r-",
alpha=0.5,
linewidth=0.8,
)
# Reflected rays (from first bounce)
ray_length = config["recording_altitude"] * 1.5
if len(bounce_points[0]) > 0:
# Get intersection and normals for sample rays
sample_positions = original_rays_vis.positions[plot_indices]
sample_directions = original_rays_vis.directions[plot_indices]
hit_distances, hit_mask = surface.intersect(
sample_positions, sample_directions
)
if np.any(hit_mask):
hit_positions = (
sample_positions[hit_mask]
+ hit_distances[hit_mask, np.newaxis] * sample_directions[hit_mask]
)
normals = surface.normal_at(hit_positions, sample_directions[hit_mask])
dot_prod = np.sum(
sample_directions[hit_mask] * normals, axis=1, keepdims=True
)
reflected_dirs = sample_directions[hit_mask] - 2 * dot_prod * normals
for i, (hit_pos, reflected_dir) in enumerate(
zip(hit_positions, reflected_dirs, strict=False)
):
end = hit_pos + reflected_dir * ray_length
if i == 0:
ax1.plot(
[hit_pos[0] / 1000, end[0] / 1000],
[hit_pos[2] / 1000, end[2] / 1000],
"g-",
alpha=0.6,
linewidth=0.8,
label="Reflected rays",
)
else:
ax1.plot(
[hit_pos[0] / 1000, end[0] / 1000],
[hit_pos[2] / 1000, end[2] / 1000],
"g-",
alpha=0.5,
linewidth=0.8,
)
ax1.set_xlabel("X (km)", fontsize=12)
ax1.set_ylabel("Z (km)", fontsize=12)
ax1.set_title("Ray Paths Overview (X-Z Plane)", fontsize=14)
ax1.legend(loc="upper right")
ax1.set_aspect("equal")
ax1.grid(True, alpha=0.3)
max_range = (
max(config.get("source_distance", 10000), config["recording_altitude"]) * 1.5
)
ax1.set_xlim(-max_range / 1000 * 0.5, max_range / 1000 * 1.5)
ax1.set_ylim(-5, config["recording_altitude"] / 1000 * 1.3)
# Right panel: Surface detail with multi-bounce points
ax2 = axes[1]
x_range = np.linspace(-footprint_radius * 1.2, footprint_radius * 1.2, 400)
z_wave = []
for x in x_range:
if hasattr(surface, "_compute_wave_displacement"):
_, _, dz = surface._compute_wave_displacement(
np.array([x]), np.array([0.0])
)
z_wave.append(dz[0])
elif hasattr(surface, "_surface_z"):
z_wave.append(surface._surface_z(x, 0.0))
else:
z_wave.append(0.0)
z_wave = np.array(z_wave)
ax2.fill_between(x_range, z_wave, z_wave.min() - 2, color="#4a90d9", alpha=0.5)
ax2.plot(x_range, z_wave, "b-", linewidth=2, label="Wave surface")
# Plot bounce points with different colors for each generation
bounce_colors = ["red", "cyan", "magenta", "yellow"]
bounce_sizes = [20, 15, 12, 10]
bounce_labels = ["1st bounce", "2nd bounce", "3rd bounce", "4th bounce"]
for bounce_idx in range(min(max_bounces, len(bounce_points))):
if len(bounce_points[bounce_idx]) > 0:
bounce_x = bounce_points[bounce_idx][:, 0]
bounce_z = bounce_points[bounce_idx][:, 2]
mean_x = np.mean(bounce_x)
label_text = f"{bounce_labels[bounce_idx]} (mean X={mean_x:.1f}m)"
ax2.scatter(
bounce_x,
bounce_z,
c=bounce_colors[bounce_idx],
s=bounce_sizes[bounce_idx],
alpha=0.7,
label=label_text,
zorder=3 + bounce_idx,
edgecolors="black",
linewidths=0.5,
)
# Draw vertical line at mean X position
ax2.axvline(
mean_x,
color=bounce_colors[bounce_idx],
linestyle=":",
alpha=0.4,
linewidth=1.5,
)
# Mark beam footprint
ax2.axvline(-footprint_radius, color="orange", linestyle="--", alpha=0.5)
ax2.axvline(
footprint_radius,
color="orange",
linestyle="--",
alpha=0.5,
label=f"Beam footprint (±{footprint_radius:.0f}m)",
)
ax2.set_xlabel("X (m)", fontsize=12)
ax2.set_ylabel("Z (m)", fontsize=12)
ax2.set_title("Wave Surface Detail with Multi-Bounce Points", fontsize=14)
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
fig_path = output_path / f"simulation_{timestamp}_overview.png"
plt.savefig(fig_path, dpi=150, bbox_inches="tight")
plt.close()
return fig
[docs]
def plot_wave_surface_detail(
reflected_rays: "RayBatch",
surface: "Surface",
x_range: tuple[float, float] = (-200, 200),
figsize: tuple[float, float] = (12, 6),
save_path: str | None = None,
) -> Figure:
"""
Plot wave surface detail with ray intersection points.
Parameters
----------
reflected_rays : RayBatch
Batch of reflected rays.
surface : Surface
The surface object (must have _surface_z method).
x_range : tuple
X-axis range for plotting.
figsize : tuple
Figure size (width, height).
save_path : str, optional
Path to save figure.
Returns
-------
Figure
Matplotlib figure.
"""
fig, ax = plt.subplots(figsize=figsize)
# Plot wave surface
x_detail = np.linspace(x_range[0], x_range[1], 1000)
y_detail = np.zeros_like(x_detail)
z_detail = surface._surface_z(x_detail, y_detail)
ax.plot(x_detail, z_detail, "b-", linewidth=3, label="Wave Surface", zorder=3)
ax.fill_between(
x_detail,
z_detail,
z_detail.min() - 0.5,
color="lightblue",
alpha=0.4,
zorder=1,
)
# Plot intersection points
# Back-calculate actual hit positions (rays are offset by 0.01m along direction)
actual_hit_positions = reflected_rays.positions - 0.01 * reflected_rays.directions
reflection_x = actual_hit_positions[:, 0]
reflection_z = actual_hit_positions[:, 2]
ax.scatter(
reflection_x,
reflection_z,
c="red",
s=8,
alpha=0.6,
zorder=5,
label="Intersection Points",
)
ax.set_xlim(x_range[0], x_range[1])
z_range = z_detail.max() - z_detail.min()
ax.set_ylim(z_detail.min() - z_range * 0.3, z_detail.max() + z_range * 0.5)
ax.set_xlabel("X Position (m)", fontsize=11, fontweight="bold")
ax.set_ylabel("Z Position (m)", fontsize=11, fontweight="bold")
ax.set_title("Wave Surface Detail with Ray Intersections", fontweight="bold")
ax.grid(True, alpha=0.3)
ax.legend(loc="upper right", fontsize=10)
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig
[docs]
def plot_ray_paths_with_surface(
rays: "RayBatch",
reflected_rays: "RayBatch",
surface: "Surface",
detector_distance: float = 1000.0,
source_distance: float = 1000.0,
refracted_rays: Optional["RayBatch"] = None,
ray_paths: dict | None = None,
figsize: tuple[float, float] = (16, 10),
save_path: str | None = None,
) -> Figure:
"""
Plot full ray paths (incoming, reflected, and refracted) with wave surface.
Parameters
----------
rays : RayBatch
Initial ray batch (before interaction).
reflected_rays : RayBatch
Reflected ray batch (after interaction).
surface : Surface
The surface object.
detector_distance : float
Detector distance in meters.
source_distance : float
Source distance in meters (unused, for API compatibility).
refracted_rays : RayBatch, optional
Refracted ray batch (after interaction).
ray_paths : dict, optional
Dictionary with ray path data from trace_rays_multi_bounce containing:
- 'reflected_paths': list of Nx3 arrays, one per ray
- 'refracted_paths': list of Nx3 arrays for refracted rays
- 'reflected_final_dirs': final direction for each reflected path
- 'refracted_final_dirs': final direction for each refracted path
figsize : tuple
Figure size.
save_path : str, optional
Path to save figure.
Returns
-------
Figure
Matplotlib figure.
"""
fig, ax = plt.subplots(figsize=figsize)
# Determine scale based on distances (use km for large distances)
use_km = detector_distance >= 100.0
scale_factor = 1000.0 if use_km else 1.0
distance_label = "km" if use_km else "m"
# Plot surface with extended range
x_min = min(rays.positions[:, 0].min(), reflected_rays.positions[:, 0].min())
x_max = max(rays.positions[:, 0].max(), reflected_rays.positions[:, 0].max())
x_range = x_max - x_min
x_surf = np.linspace(x_min - x_range * 0.2, x_max + detector_distance * 0.3, 1000)
y_surf = np.zeros_like(x_surf)
# Handle both wave surfaces and planar surfaces
if hasattr(surface, "_surface_z"):
z_surf = surface._surface_z(x_surf, y_surf)
surface_label = "Wave Surface"
else:
# Planar surface - assume z=0 horizontal plane
z_surf = np.zeros_like(x_surf)
surface_label = "Planar Surface"
ax.plot(
x_surf / scale_factor,
z_surf,
"b-",
linewidth=3,
label=surface_label,
zorder=3,
)
z_fill_bottom = z_surf.min() - max(abs(z_surf.max() - z_surf.min()) * 0.5, 0.01)
ax.fill_between(
x_surf / scale_factor,
z_surf,
z_fill_bottom,
color="lightblue",
alpha=0.3,
zorder=1,
)
# Sample rays to plot (for performance and clarity)
num_rays_to_plot = min(100, rays.num_rays)
indices_to_plot = np.linspace(0, rays.num_rays - 1, num_rays_to_plot, dtype=int)
# Plot ray paths if ray_paths dict is provided (multi-bounce tracking)
if ray_paths is not None and "reflected_paths" in ray_paths:
reflected_paths = ray_paths["reflected_paths"]
reflected_final_dirs = ray_paths.get("reflected_final_dirs", [])
refracted_paths = ray_paths.get("refracted_paths", [])
refracted_final_dirs = ray_paths.get("refracted_final_dirs", [])
# Sample paths to plot
num_paths = len(reflected_paths)
num_to_plot = min(100, num_paths)
path_indices = (
np.linspace(0, num_paths - 1, num_to_plot, dtype=int)
if num_paths > 0
else []
)
# Color for reflected paths based on number of bounces
for path_idx in path_indices:
path = reflected_paths[path_idx]
if path is None or len(path) < 2:
continue
# Plot the complete path
ax.plot(
path[:, 0] / scale_factor,
path[:, 2],
"b-",
linewidth=0.8,
alpha=0.4,
zorder=2,
)
# Mark bounce points with different colors
colors = ["red", "orange", "yellow", "pink", "purple"]
for i in range(1, len(path)): # Skip start position
color = colors[(i - 1) % len(colors)]
ax.scatter(
path[i, 0] / scale_factor,
path[i, 2],
c=color,
s=10,
alpha=0.6,
zorder=4,
)
# Plot final direction as an extending ray
if (
path_idx < len(reflected_final_dirs)
and reflected_final_dirs[path_idx] is not None
):
final_pos = path[-1]
final_dir = reflected_final_dirs[path_idx]
end_pos = final_pos + final_dir * detector_distance * 0.3
ax.plot(
[final_pos[0] / scale_factor, end_pos[0] / scale_factor],
[final_pos[2], end_pos[2]],
"r-",
linewidth=0.8,
alpha=0.5,
zorder=2,
)
# Plot refracted paths (dashed green lines going into water)
num_refr_paths = len(refracted_paths)
num_refr_to_plot = min(50, num_refr_paths)
refr_path_indices = (
np.linspace(0, num_refr_paths - 1, num_refr_to_plot, dtype=int)
if num_refr_paths > 0
else []
)
for path_idx in refr_path_indices:
path = refracted_paths[path_idx]
if path is None or len(path) < 1:
continue
if (
path_idx < len(refracted_final_dirs)
and refracted_final_dirs[path_idx] is not None
):
start_pos = path[0]
refr_dir = refracted_final_dirs[path_idx]
end_pos = start_pos + refr_dir * detector_distance * 0.2
ax.plot(
[start_pos[0] / scale_factor, end_pos[0] / scale_factor],
[start_pos[2], end_pos[2]],
"g--",
linewidth=0.6,
alpha=0.4,
zorder=2,
)
else:
# Original single-bounce plotting
for idx in indices_to_plot:
if idx >= reflected_rays.num_rays:
continue
# Back-calculate actual hit position (rays are offset by 0.01m along direction)
reflect_dir = reflected_rays.directions[idx, :]
actual_hit_pos = reflected_rays.positions[idx, :] - 0.01 * reflect_dir
start_pos = rays.positions[idx, :]
ax.plot(
[start_pos[0] / scale_factor, actual_hit_pos[0] / scale_factor],
[start_pos[2], actual_hit_pos[2]],
"b-",
linewidth=0.8,
alpha=0.4,
zorder=2,
)
# Reflected ray: from actual hit position outward
ray_length = detector_distance * 0.5
end_pos = actual_hit_pos + reflect_dir * ray_length
ax.plot(
[actual_hit_pos[0] / scale_factor, end_pos[0] / scale_factor],
[actual_hit_pos[2], end_pos[2]],
"r-",
linewidth=0.8,
alpha=0.5,
zorder=2,
)
# Plot refracted rays (if provided)
if refracted_rays is not None and refracted_rays.num_rays > 0:
refr_indices = np.linspace(
0,
min(refracted_rays.num_rays - 1, rays.num_rays - 1),
num_rays_to_plot,
dtype=int,
)
for idx in refr_indices:
if idx >= refracted_rays.num_rays:
continue
refract_dir = refracted_rays.directions[idx, :]
# Back-calculate actual hit position (rays are offset by 0.01m along direction)
actual_hit_pos = refracted_rays.positions[idx, :] - 0.01 * refract_dir
ray_length = detector_distance * 0.3
end_pos = actual_hit_pos + refract_dir * ray_length
ax.plot(
[actual_hit_pos[0] / scale_factor, end_pos[0] / scale_factor],
[actual_hit_pos[2], end_pos[2]],
"g--",
linewidth=0.8,
alpha=0.5,
zorder=2,
)
# Add beam source indicator
beam_source = rays.positions[0, :]
ax.scatter(
[beam_source[0] / scale_factor],
[beam_source[2]],
c="blue",
s=300,
marker="*",
edgecolors="black",
linewidths=2,
label="Beam Source",
zorder=6,
)
ax.set_xlabel(f"X Position ({distance_label})", fontsize=13, fontweight="bold")
ax.set_ylabel("Z Position (m)", fontsize=13, fontweight="bold")
# Update title based on what's shown
if ray_paths is not None and "reflected_paths" in ray_paths:
num_paths = len(ray_paths["reflected_paths"])
title_text = f"Ray Paths: Multi-Bounce Reflection ({num_paths} ray paths)"
elif refracted_rays is not None and refracted_rays.num_rays > 0:
title_text = (
f"Ray Paths: Reflection & Refraction ({num_rays_to_plot} rays shown)"
)
else:
title_text = (
f"Ray Paths: Reflection from Wave Surface ({num_rays_to_plot} rays shown)"
)
ax.set_title(title_text, fontweight="bold", fontsize=15)
ax.grid(True, alpha=0.3, linewidth=0.5)
# Build legend
legend_elements = [
plt.Line2D([0], [0], color="b", linewidth=2, alpha=0.5, label="Incoming"),
plt.Line2D([0], [0], color="r", linewidth=2, alpha=0.5, label="Reflected"),
]
if refracted_rays is not None and refracted_rays.num_rays > 0:
legend_elements.append(
plt.Line2D(
[0],
[0],
color="g",
linewidth=2,
linestyle="--",
alpha=0.5,
label="Refracted",
)
)
legend_elements.append(
plt.Line2D(
[0],
[0],
marker="*",
color="w",
markerfacecolor="blue",
markersize=12,
label="Beam Source",
)
)
ax.legend(handles=legend_elements, loc="upper left", fontsize=11, framealpha=0.9)
ax.set_xlim(
x_min / scale_factor - x_range * 0.2 / scale_factor,
(x_max + detector_distance * 0.3) / scale_factor,
)
ax.set_ylim(
min(
z_surf.min() - abs(z_surf.max() - z_surf.min()) * 0.5,
-detector_distance * 0.01,
),
max(z_surf.max(), detector_distance * 0.1),
)
if not use_km:
ax.set_aspect("equal", adjustable="datalim")
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close(fig)
return fig
# =============================================================================
# Legacy Convenience Functions (Backward Compatibility)
# =============================================================================
[docs]
def plot_ray_paths_2d(
ray_history: list["RayBatch"],
max_rays: int = 100,
color_by: str = "wavelength",
alpha: float = 0.4,
linewidth: float = 0.8,
figsize: tuple[float, float] = (15, 5),
save_path: str | None = None,
) -> Figure:
"""
Create figure with three 2D projections of ray paths.
This is a convenience function for quick visualization. For custom layouts,
use plot_ray_paths_projection() on individual axes.
Parameters
----------
ray_history : List[RayBatch]
List of ray batches at different propagation steps.
max_rays : int
Maximum rays to plot (sampled uniformly if exceeded).
color_by : str
Color rays by: 'wavelength', 'intensity', 'generation', 'index'.
alpha : float
Line transparency.
linewidth : float
Line width.
figsize : tuple
Figure size.
save_path : str, optional
Path to save figure.
Returns
-------
Figure
Matplotlib figure with three subplots (XY, XZ, YZ).
"""
fig, axes = plt.subplots(1, 3, figsize=figsize, constrained_layout=True)
fig.suptitle("Ray Paths - 2D Projections", fontsize=14, fontweight="bold")
for ax, proj in zip(axes, ["xy", "xz", "yz"], strict=False):
plot_ray_paths_projection(
ax,
ray_history,
projection=proj,
max_rays=max_rays,
color_by=color_by,
alpha=alpha,
linewidth=linewidth,
show_colorbar=(proj == "yz"),
)
if save_path:
save_figure(fig, save_path)
return fig
# Re-export Fresnel/Brewster functions from fresnel_plots module for backward compatibility
# Re-export polarization functions from polarization_plots module for backward compatibility
[docs]
def plot_ray_endpoints(
rays: "RayBatch",
plane: str = "xy",
color_by: str = "wavelength",
bins: int = 50,
figsize: tuple[float, float] = (12, 5),
save_path: str | None = None,
) -> Figure:
"""
Create figure with scatter and histogram of ray endpoints.
This is a convenience function for quick visualization. For custom layouts,
use plot_ray_endpoints_scatter() and plot_ray_endpoints_histogram().
Parameters
----------
rays : RayBatch
Ray batch with endpoint positions.
plane : str
Projection plane: 'xy', 'xz', 'yz'.
color_by : str
Color scatter by: 'wavelength', 'intensity'.
bins : int
Histogram bins.
figsize : tuple
Figure size.
save_path : str, optional
Path to save figure.
Returns
-------
Figure
Matplotlib figure with scatter and histogram.
"""
fig, axes = plt.subplots(1, 2, figsize=figsize, constrained_layout=True)
fig.suptitle(
f"Ray Endpoints - {plane.upper()} Plane", fontsize=14, fontweight="bold"
)
plot_ray_endpoints_scatter(
axes[0], rays, projection=plane, color_by=color_by, show_colorbar=True
)
plot_ray_endpoints_histogram(axes[1], rays, projection=plane, bins=bins)
if save_path:
save_figure(fig, save_path)
return fig