# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Ray-Surface Interactions
Handles the physics of ray-surface interactions including reflection,
refraction, and intensity updates based on Fresnel equations.
At each interface, rays are split into reflected and refracted components
with intensities determined by the Fresnel equations. This properly models
the physical behavior where light partially reflects and partially transmits
at each interface.
Functions
---------
process_surface_interaction
Process rays intersecting a surface (generates both reflected + refracted rays)
reflect_rays
Apply reflection to rays at surface (modifies rays in place)
refract_rays
Apply refraction to rays at surface (modifies rays in place)
trace_rays_multi_bounce
Trace reflected rays through multiple bounces (legacy, reflected-only tracing)
trace_rays_with_splitting
Trace rays with proper Fresnel splitting (both reflected AND refracted rays)
References
----------
.. [1] Glassner, A. S. (1989). An Introduction to Ray Tracing.
Academic Press.
"""
import numpy as np
from numpy.typing import NDArray
from ..surfaces import Surface
from .fresnel import (
compute_reflection_direction,
compute_refraction_direction,
fresnel_coefficients,
initialize_polarization_vectors,
transform_polarization_reflection,
transform_polarization_refraction,
)
from .ray_data import RayBatch, create_ray_batch
def _orient_media(
surface: Surface,
hit_positions: NDArray[np.float32],
hit_directions: NDArray[np.float32],
hit_wavelengths: NDArray[np.float32],
) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
"""Compute correctly-oriented normals and refractive indices.
Determines which side of the surface each ray is coming from and
assigns n_in/n_out accordingly. The returned normals always face
the incoming ray (consistent with ``surface.normal_at`` when
``incoming_directions`` is provided).
Returns (normals, n_in, n_out).
"""
# Raw normals (un-flipped) define the "front" side of the surface.
raw_normals = surface.normal_at(hit_positions)
# Rays hitting from the front side have dot(dir, raw_normal) < 0.
dot_raw = np.sum(hit_directions * raw_normals, axis=1)
from_front = dot_raw < 0
# Flip normals so they always face the incoming ray.
normals = raw_normals.copy()
normals[~from_front] = -normals[~from_front]
# Refractive index on each side.
n_front = np.array(
[
surface.material_front.get_refractive_index(p[0], p[1], p[2], w)
for p, w in zip(hit_positions, hit_wavelengths, strict=False)
],
dtype=np.float32,
)
n_back = np.array(
[
surface.material_back.get_refractive_index(p[0], p[1], p[2], w)
for p, w in zip(hit_positions, hit_wavelengths, strict=False)
],
dtype=np.float32,
)
# n_in = medium the ray comes FROM; n_out = medium the ray goes INTO.
n_in = np.where(from_front, n_front, n_back)
n_out = np.where(from_front, n_back, n_front)
return normals, n_in, n_out
[docs]
def process_surface_interaction(
rays: RayBatch,
surface: Surface,
wavelength: float | NDArray[np.float32] = 500e-9,
generate_reflected: bool = True,
generate_refracted: bool = True,
polarization: str = "unpolarized",
track_polarization_vector: bool = False,
) -> tuple[RayBatch | None, RayBatch | None]:
"""
Process rays intersecting a surface.
Computes intersections, applies Fresnel equations, and generates
reflected and/or refracted ray bundles.
Parameters
----------
rays : RayBatch
Input rays to test for intersection
surface : Surface
Surface to intersect with
wavelength : float or ndarray, optional
Wavelength for computing refractive indices (default: 500 nm)
If rays have multiple wavelengths, use rays.wavelengths
generate_reflected : bool, optional
Whether to generate reflected rays (default: True)
generate_refracted : bool, optional
Whether to generate refracted rays (default: True)
polarization : str, optional
Polarization state: 's', 'p', or 'unpolarized' (default)
Used for Fresnel coefficient calculation.
track_polarization_vector : bool, optional
Whether to track 3D polarization vectors through the interaction.
If True, polarization vectors are initialized (if not present) and
transformed through reflection/refraction. (default: False)
Returns
-------
reflected_rays : RayBatch or None
Reflected rays (None if generate_reflected=False or no hits)
refracted_rays : RayBatch or None
Refracted rays (None if generate_refracted=False or no hits)
Notes
-----
Only active rays are tested for intersection.
Input rays are not modified.
When track_polarization_vector=True, the function:
1. Initializes polarization vectors if rays.polarization_vector is None
2. Transforms polarization vectors through reflection/refraction
3. Stores the transformed vectors in the output ray batches
Examples
--------
>>> # Create surface and rays
>>> surface = PlanarSurface(
... point=(0, 0, 1),
... normal=(0, 0, -1),
... material_front=BK7_GLASS,
... material_back=AIR_STP
... )
>>> rays = ... # Create rays
>>>
>>> # Process interaction with polarization tracking
>>> reflected, refracted = process_surface_interaction(
... rays, surface, generate_reflected=True, generate_refracted=True,
... track_polarization_vector=True
... )
"""
# Only test active rays
active_mask = rays.active
if not np.any(active_mask):
return None, None
active_origins = rays.positions[active_mask]
active_directions = rays.directions[active_mask]
active_wavelengths = rays.wavelengths[active_mask]
active_intensities = rays.intensities[active_mask]
active_times = rays.accumulated_time[active_mask]
# Handle polarization vectors
active_polarization_vectors = None
if track_polarization_vector:
if rays.polarization_vector is not None:
active_polarization_vectors = rays.polarization_vector[active_mask]
else:
# Initialize polarization vectors based on polarization state
active_polarization_vectors = initialize_polarization_vectors(
active_directions, polarization=polarization
)
# Find intersections
distances, hit_mask = surface.intersect(active_origins, active_directions)
if not np.any(hit_mask):
return None, None
# Intersection points
hit_positions = (
active_origins[hit_mask]
+ distances[hit_mask, np.newaxis] * active_directions[hit_mask]
)
hit_directions = active_directions[hit_mask]
hit_wavelengths = active_wavelengths[hit_mask]
hit_intensities = active_intensities[hit_mask]
hit_times = active_times[hit_mask]
hit_distances = distances[hit_mask]
# Get hit polarization vectors if tracking
hit_polarization_vectors = None
if track_polarization_vector and active_polarization_vectors is not None:
hit_polarization_vectors = active_polarization_vectors[hit_mask]
# Compute correctly-oriented normals and refractive indices.
# _orient_media determines which side each ray hits from and assigns
# n_in (medium the ray travels through) and n_out accordingly.
normals, n1_values, n2_values = _orient_media(
surface, hit_positions, hit_directions, hit_wavelengths
)
# Calculate time to reach surface: distance / phase_velocity
# Phase velocity = c / n; n1_values = n_in = medium the ray was in.
c = 299792458.0 # Speed of light in m/s
travel_time = hit_distances * n1_values / c # Time = distance * n / c
updated_times = hit_times + travel_time
# Compute incident angle (normals face the incoming ray, so this is positive)
cos_theta_i = -np.sum(hit_directions * normals, axis=1)
cos_theta_i = np.abs(cos_theta_i) # Ensure positive
# Compute Fresnel coefficients
R, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization)
# If tracking polarization vectors, also compute R_s and R_p separately
# for proper Fresnel weighting of polarization components
R_s = None
R_p = None
if track_polarization_vector:
R_s, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, "s")
R_p, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, "p")
# Generate reflected rays
reflected_rays = None
if generate_reflected:
reflected_directions = compute_reflection_direction(hit_directions, normals)
reflected_intensities = hit_intensities * R
# Create reflected ray batch
num_reflected = len(hit_positions)
reflected_rays = create_ray_batch(
num_rays=num_reflected,
enable_polarization_vector=track_polarization_vector,
)
# Offset along ray direction to prevent immediate re-intersection
# Must be larger than intersection tolerance (1e-3) to avoid self-intersection
reflected_rays.positions[:] = hit_positions + 0.01 * reflected_directions
reflected_rays.directions[:] = reflected_directions
reflected_rays.wavelengths[:] = hit_wavelengths
reflected_rays.intensities[:] = reflected_intensities
reflected_rays.active[:] = (
reflected_intensities > 1e-10
) # Deactivate very weak rays
reflected_rays.accumulated_time[:] = updated_times
reflected_rays.generations[:] = rays.generations[active_mask][hit_mask] + 1
# Transform polarization vectors for reflected rays with Fresnel weighting
if track_polarization_vector and hit_polarization_vectors is not None:
reflected_pol = transform_polarization_reflection(
hit_polarization_vectors,
hit_directions,
reflected_directions,
normals,
R_s=R_s,
R_p=R_p,
)
reflected_rays.polarization_vector[:] = reflected_pol
# Generate refracted rays
refracted_rays = None
if generate_refracted:
refracted_directions, tir_mask = compute_refraction_direction(
hit_directions, normals, n1_values, n2_values
)
# Transmission intensity (0 for TIR)
refracted_intensities = hit_intensities * T
refracted_intensities[tir_mask] = 0.0
# Create refracted ray batch
num_refracted = len(hit_positions)
refracted_rays = create_ray_batch(
num_rays=num_refracted,
enable_polarization_vector=track_polarization_vector,
)
# Offset along ray direction to prevent immediate re-intersection
# Must be larger than intersection tolerance (1e-3) to avoid self-intersection
refracted_rays.positions[:] = hit_positions + 0.01 * refracted_directions
refracted_rays.directions[:] = refracted_directions
refracted_rays.wavelengths[:] = hit_wavelengths
refracted_rays.intensities[:] = refracted_intensities
refracted_rays.active[:] = (refracted_intensities > 1e-10) & (~tir_mask)
refracted_rays.accumulated_time[:] = updated_times
refracted_rays.generations[:] = rays.generations[active_mask][hit_mask] + 1
# Transform polarization vectors for refracted rays
if track_polarization_vector and hit_polarization_vectors is not None:
refracted_pol = transform_polarization_refraction(
hit_polarization_vectors,
hit_directions,
refracted_directions,
normals,
)
# For TIR rays, set polarization to zero (they won't be used)
refracted_pol[tir_mask] = 0.0
refracted_rays.polarization_vector[:] = refracted_pol
return reflected_rays, refracted_rays
[docs]
def reflect_rays(
rays: RayBatch,
surface: Surface,
wavelength: float | NDArray[np.float32] = 500e-9,
polarization: str = "unpolarized",
in_place: bool = False,
) -> RayBatch:
"""
Apply reflection to rays at surface.
Updates ray directions and intensities based on Fresnel reflection.
Rays that don't hit the surface are deactivated.
Parameters
----------
rays : RayBatch
Input rays
surface : Surface
Surface to reflect from
wavelength : float or ndarray, optional
Wavelength for computing refractive indices
polarization : str, optional
Polarization state
in_place : bool, optional
If True, modify rays in place. If False, return new batch.
Returns
-------
RayBatch
Reflected rays (same object if in_place=True)
"""
if not in_place:
rays = rays.clone()
# Only process active rays
active_mask = rays.active
if not np.any(active_mask):
return rays
# Find intersections
distances, hit_mask = surface.intersect(
rays.positions[active_mask], rays.directions[active_mask]
)
if not np.any(hit_mask):
rays.active[:] = False
return rays
# Build full-size hit mask
full_hit_mask = np.zeros(len(rays.positions), dtype=bool)
full_hit_mask[active_mask] = hit_mask
# Process hits
active_indices = np.where(active_mask)[0][hit_mask]
hit_positions = (
rays.positions[active_indices]
+ distances[hit_mask, np.newaxis] * rays.directions[active_indices]
)
hit_dirs = rays.directions[active_indices]
hit_wls = rays.wavelengths[active_indices]
# Correctly-oriented normals and refractive indices
normals, n1_values, n2_values = _orient_media(
surface, hit_positions, hit_dirs, hit_wls
)
# Compute reflection
reflected_directions = compute_reflection_direction(hit_dirs, normals)
# Compute Fresnel coefficients
cos_theta_i = -np.sum(hit_dirs * normals, axis=1)
cos_theta_i = np.abs(cos_theta_i)
R, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization)
# Update rays - offset must exceed surface.intersect min_distance (0.01)
rays.positions[active_indices] = hit_positions + 0.02 * reflected_directions
rays.directions[active_indices] = reflected_directions
rays.intensities[active_indices] *= R
rays.generations[active_indices] += 1
# Deactivate non-hits and weak rays
rays.active[~full_hit_mask] = False
rays.active[rays.intensities < 1e-10] = False
return rays
[docs]
def refract_rays(
rays: RayBatch,
surface: Surface,
wavelength: float | NDArray[np.float32] = 500e-9,
polarization: str = "unpolarized",
in_place: bool = False,
) -> RayBatch:
"""
Apply refraction to rays at surface.
Updates ray directions and intensities based on Fresnel transmission.
Rays undergoing total internal reflection are deactivated.
Parameters
----------
rays : RayBatch
Input rays
surface : Surface
Surface to refract through
wavelength : float or ndarray, optional
Wavelength for computing refractive indices
polarization : str, optional
Polarization state
in_place : bool, optional
If True, modify rays in place. If False, return new batch.
Returns
-------
RayBatch
Refracted rays (same object if in_place=True)
"""
if not in_place:
rays = rays.clone()
# Only process active rays
active_mask = rays.active
if not np.any(active_mask):
return rays
# Find intersections
distances, hit_mask = surface.intersect(
rays.positions[active_mask], rays.directions[active_mask]
)
if not np.any(hit_mask):
rays.active[:] = False
return rays
# Build full-size hit mask
full_hit_mask = np.zeros(len(rays.positions), dtype=bool)
full_hit_mask[active_mask] = hit_mask
# Process hits
active_indices = np.where(active_mask)[0][hit_mask]
hit_positions = (
rays.positions[active_indices]
+ distances[hit_mask, np.newaxis] * rays.directions[active_indices]
)
hit_dirs = rays.directions[active_indices]
hit_wls = rays.wavelengths[active_indices]
# Correctly-oriented normals and refractive indices
normals, n1_values, n2_values = _orient_media(
surface, hit_positions, hit_dirs, hit_wls
)
# Compute refraction
refracted_directions, tir_mask = compute_refraction_direction(
hit_dirs, normals, n1_values, n2_values
)
# Compute Fresnel coefficients
cos_theta_i = -np.sum(hit_dirs * normals, axis=1)
cos_theta_i = np.abs(cos_theta_i)
_, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization)
# Update rays - offset must exceed surface.intersect min_distance (0.01)
rays.positions[active_indices] = hit_positions + 0.02 * refracted_directions
rays.directions[active_indices] = refracted_directions
rays.intensities[active_indices] *= T
rays.generations[active_indices] += 1
# Deactivate TIR rays, non-hits, and weak rays
tir_indices = active_indices[tir_mask]
rays.active[tir_indices] = False
rays.active[~full_hit_mask] = False
rays.active[rays.intensities < 1e-10] = False
return rays
[docs]
def trace_rays_multi_bounce(
rays: RayBatch,
surface: Surface,
max_bounces: int = 10,
bounding_radius: float = 10000.0,
wavelength: float = 532e-9,
min_intensity: float = 1e-10,
track_refracted: bool = True,
) -> tuple[RayBatch, RayBatch, dict]:
"""
Trace rays through multiple surface interactions until termination.
Rays are traced until they:
- Exit the bounding sphere
- Reach maximum number of bounces
- Have intensity below threshold
Parameters
----------
rays : RayBatch
Initial ray batch to trace
surface : Surface
Surface to interact with
max_bounces : int, optional
Maximum number of surface interactions (default: 10)
bounding_radius : float, optional
Radius of bounding sphere in meters (default: 10000)
wavelength : float, optional
Wavelength for Fresnel calculations (default: 532nm)
min_intensity : float, optional
Minimum intensity threshold (default: 1e-10)
track_refracted : bool, optional
Whether to track refracted rays (default: True)
Returns
-------
final_reflected : RayBatch
Final state of all reflected rays that exited bounding sphere
final_refracted : RayBatch
Final state of all refracted rays (combined from all bounces)
ray_paths : dict
Dictionary containing:
- 'reflected_paths': list of paths, each path is Nx3 array of positions
- 'refracted_paths': list of paths for refracted rays (start from refraction point)
- 'reflected_final_dirs': list of final direction vectors for each reflected path
- 'refracted_final_dirs': list of final direction vectors for each refracted path
Notes
-----
The function tracks reflected rays through multiple bounces on the
wavy surface. Each ray's complete path (all positions) is stored for
visualization. Refracted rays are collected but not further traced
(they go into the water and don't re-interact with the air-water
interface from below in typical scenarios).
Examples
--------
>>> # Trace rays with up to 5 bounces
>>> final_refl, final_refr, paths = trace_rays_multi_bounce(
... rays, surface, max_bounces=5, bounding_radius=5000.0
... )
>>> # Plot a ray's path
>>> path = paths['reflected_paths'][0] # First ray's path
>>> plt.plot(path[:, 0], path[:, 2]) # x-z projection
"""
from .ray_data import create_ray_batch
# Clone input rays to avoid modifying original
current_rays = rays.clone()
# Track original ray indices - maps current index to original ray index
num_original = rays.num_rays
current_to_original = np.arange(num_original)
# Initialize paths for each original ray - start with their original positions
# Each path is a list of positions that will be converted to array at the end
ray_path_lists = [[] for _ in range(num_original)]
ray_final_dirs = [None for _ in range(num_original)] # Final direction for each ray
ray_is_reflected = [
True for _ in range(num_original)
] # True if reflected, False if refracted
# Store initial positions
for i in range(num_original):
ray_path_lists[i].append(rays.positions[i].copy())
# Storage for rays that have exited the bounding sphere
exited_positions = []
exited_directions = []
exited_wavelengths = []
exited_intensities = []
exited_times = []
exited_generations = []
# Storage for all refracted rays
all_refracted_positions = []
all_refracted_directions = []
all_refracted_wavelengths = []
all_refracted_intensities = []
all_refracted_times = []
all_refracted_generations = []
# Storage for refracted ray paths (separate from reflected paths)
refracted_path_lists = []
refracted_final_dirs = []
for _bounce in range(max_bounces):
# Check which rays are still inside bounding sphere
distances_from_origin = np.linalg.norm(current_rays.positions, axis=1)
inside_sphere = distances_from_origin < bounding_radius
# Rays that have exited - store them and record final positions/directions
exited_mask = current_rays.active & ~inside_sphere
if np.any(exited_mask):
exited_indices = np.where(exited_mask)[0]
for local_idx in exited_indices:
orig_idx = current_to_original[local_idx]
# Record final position and direction
ray_path_lists[orig_idx].append(
current_rays.positions[local_idx].copy()
)
ray_final_dirs[orig_idx] = current_rays.directions[local_idx].copy()
exited_positions.append(current_rays.positions[exited_mask].copy())
exited_directions.append(current_rays.directions[exited_mask].copy())
exited_wavelengths.append(current_rays.wavelengths[exited_mask].copy())
exited_intensities.append(current_rays.intensities[exited_mask].copy())
exited_times.append(current_rays.accumulated_time[exited_mask].copy())
exited_generations.append(current_rays.generations[exited_mask].copy())
current_rays.active[exited_mask] = False
# Update active mask for rays still inside sphere
current_rays.active &= inside_sphere
# Check if any rays are still active
if not np.any(current_rays.active):
break
# Process surface interaction
reflected_rays, refracted_rays = process_surface_interaction(
current_rays,
surface,
wavelength=wavelength,
generate_reflected=True,
generate_refracted=track_refracted,
)
# Record intersection positions for active rays (reflected rays have the intersection point)
if reflected_rays is not None and reflected_rays.num_rays > 0:
# Map the reflected rays back to original indices
# Note: process_surface_interaction returns rays in same order as input active rays
active_indices = np.where(current_rays.active)[0]
for i, local_idx in enumerate(active_indices):
if i < reflected_rays.num_rays:
orig_idx = current_to_original[local_idx]
ray_path_lists[orig_idx].append(reflected_rays.positions[i].copy())
# Collect refracted rays and create separate paths for them
if refracted_rays is not None and refracted_rays.num_rays > 0:
active_refr = refracted_rays.active
if np.any(active_refr):
# Record refracted ray paths (they start from the refraction point)
active_indices = np.where(current_rays.active)[0]
refr_active_indices = np.where(active_refr)[0]
for i in refr_active_indices:
if i < len(active_indices):
local_idx = active_indices[i]
orig_idx = current_to_original[local_idx]
# Mark this ray as having been refracted
ray_is_reflected[orig_idx] = False
# Create a refracted path starting from refraction point
refracted_path_lists.append(
np.array([refracted_rays.positions[i].copy()])
)
refracted_final_dirs.append(refracted_rays.directions[i].copy())
all_refracted_positions.append(
refracted_rays.positions[active_refr].copy()
)
all_refracted_directions.append(
refracted_rays.directions[active_refr].copy()
)
all_refracted_wavelengths.append(
refracted_rays.wavelengths[active_refr].copy()
)
all_refracted_intensities.append(
refracted_rays.intensities[active_refr].copy()
)
all_refracted_times.append(
refracted_rays.accumulated_time[active_refr].copy()
)
all_refracted_generations.append(
refracted_rays.generations[active_refr].copy()
)
# Continue with reflected rays
if reflected_rays is None or reflected_rays.num_rays == 0:
break
# Filter out weak rays
reflected_rays.active &= reflected_rays.intensities > min_intensity
if not np.any(reflected_rays.active):
break
# Update index mapping for remaining active rays
active_mask = reflected_rays.active
active_indices = np.where(current_rays.active)[0]
# Build new mapping: new index -> original ray index
new_to_original = []
for i, is_active in enumerate(active_mask):
if is_active and i < len(active_indices):
local_idx = active_indices[i]
new_to_original.append(current_to_original[local_idx])
current_to_original = np.array(new_to_original)
# Use reflected rays for next iteration
current_rays = reflected_rays
# Handle any remaining active rays (didn't exit yet)
if np.any(current_rays.active):
remaining_mask = current_rays.active
remaining_indices = np.where(remaining_mask)[0]
for local_idx in remaining_indices:
if local_idx < len(current_to_original):
orig_idx = current_to_original[local_idx]
ray_final_dirs[orig_idx] = current_rays.directions[local_idx].copy()
remaining_positions = current_rays.positions[remaining_mask]
remaining_directions = current_rays.directions[remaining_mask]
remaining_wavelengths = current_rays.wavelengths[remaining_mask]
remaining_intensities = current_rays.intensities[remaining_mask]
remaining_times = current_rays.accumulated_time[remaining_mask]
remaining_generations = current_rays.generations[remaining_mask]
exited_positions.append(remaining_positions)
exited_directions.append(remaining_directions)
exited_wavelengths.append(remaining_wavelengths)
exited_intensities.append(remaining_intensities)
exited_times.append(remaining_times)
exited_generations.append(remaining_generations)
# Combine all exited rays into final reflected batch
if len(exited_positions) > 0:
all_exited_positions = np.vstack(exited_positions)
all_exited_directions = np.vstack(exited_directions)
all_exited_wavelengths = np.concatenate(exited_wavelengths)
all_exited_intensities = np.concatenate(exited_intensities)
all_exited_times = np.concatenate(exited_times)
all_exited_generations = np.concatenate(exited_generations)
num_exited = len(all_exited_positions)
final_reflected = create_ray_batch(num_rays=num_exited)
final_reflected.positions[:] = all_exited_positions
final_reflected.directions[:] = all_exited_directions
final_reflected.wavelengths[:] = all_exited_wavelengths
final_reflected.intensities[:] = all_exited_intensities
final_reflected.accumulated_time[:] = all_exited_times
final_reflected.generations[:] = all_exited_generations
final_reflected.active[:] = True
else:
# Return empty batch if no rays exited
final_reflected = create_ray_batch(num_rays=0)
# Combine all refracted rays
if len(all_refracted_positions) > 0:
combined_refr_positions = np.vstack(all_refracted_positions)
combined_refr_directions = np.vstack(all_refracted_directions)
combined_refr_wavelengths = np.concatenate(all_refracted_wavelengths)
combined_refr_intensities = np.concatenate(all_refracted_intensities)
combined_refr_times = np.concatenate(all_refracted_times)
combined_refr_generations = np.concatenate(all_refracted_generations)
num_refracted = len(combined_refr_positions)
final_refracted = create_ray_batch(num_rays=num_refracted)
final_refracted.positions[:] = combined_refr_positions
final_refracted.directions[:] = combined_refr_directions
final_refracted.wavelengths[:] = combined_refr_wavelengths
final_refracted.intensities[:] = combined_refr_intensities
final_refracted.accumulated_time[:] = combined_refr_times
final_refracted.generations[:] = combined_refr_generations
final_refracted.active[:] = True
else:
final_refracted = create_ray_batch(num_rays=0)
# Convert path lists to arrays
reflected_paths = []
reflected_final_directions = []
for i in range(num_original):
if len(ray_path_lists[i]) > 0:
reflected_paths.append(np.array(ray_path_lists[i]))
reflected_final_directions.append(ray_final_dirs[i])
ray_paths = {
"reflected_paths": reflected_paths,
"refracted_paths": refracted_path_lists,
"reflected_final_dirs": reflected_final_directions,
"refracted_final_dirs": refracted_final_dirs,
}
return final_reflected, final_refracted, ray_paths
[docs]
def trace_rays_with_splitting(
rays: RayBatch,
surfaces: list,
max_bounces: int = 10,
bounding_radius: float = 10000.0,
bounding_center: tuple = None,
wavelength: float = 532e-9,
min_intensity: float = 1e-10,
polarization: str = "unpolarized",
) -> tuple[RayBatch, dict]:
"""
Trace rays through multiple surfaces with proper Fresnel ray splitting.
At each surface interaction, every ray is split into two child rays:
- A reflected ray with intensity scaled by the Fresnel reflection coefficient R
- A refracted ray with intensity scaled by the Fresnel transmission coefficient T
This creates a ray tree where both reflected and refracted paths are traced
until termination conditions are met.
Parameters
----------
rays : RayBatch
Initial ray batch to trace
surfaces : list of Surface
List of surfaces to interact with (tested in order for closest hit)
max_bounces : int, optional
Maximum number of surface interactions per ray tree branch (default: 10)
bounding_radius : float, optional
Radius of bounding sphere in meters (default: 10000)
bounding_center : tuple of float, optional
Center of bounding sphere (x, y, z) in meters. Default is (0, 0, 0).
For curved Earth simulations, use (0, 0, -EARTH_RADIUS).
wavelength : float, optional
Wavelength for Fresnel calculations (default: 532nm)
min_intensity : float, optional
Minimum intensity threshold for ray termination (default: 1e-10)
polarization : str, optional
Polarization state: 's', 'p', or 'unpolarized' (default)
Returns
-------
final_rays : RayBatch
All terminal rays (rays that exited the scene or hit max bounces)
Each ray has its intensity weighted by the product of all Fresnel
coefficients along its path.
trace_info : dict
Dictionary containing tracing statistics:
- 'total_rays_created': Total number of rays created during tracing
- 'max_depth_reached': Maximum tree depth reached
- 'terminated_by_intensity': Number of rays terminated due to low intensity
- 'terminated_by_bounds': Number of rays that exited bounding sphere
- 'terminated_by_max_bounces': Number of rays that hit max bounce limit
Notes
-----
This function implements proper physical ray splitting where each ray at
an interface creates two child rays:
- Reflected ray: direction from law of reflection, intensity = I_parent * R
- Refracted ray: direction from Snell's law, intensity = I_parent * T
where R and T are the Fresnel reflection and transmission coefficients
computed from the refractive indices and incident angle.
For total internal reflection (TIR), T=0 and R=1, so only a reflected
ray is created.
The number of rays grows exponentially with depth (up to 2^depth), so
the min_intensity threshold is critical for pruning weak ray branches.
Examples
--------
>>> from surface_roughness.surfaces import PlanarSurface
>>> from surface_roughness.materials import Glass, Air
>>>
>>> # Create a glass slab
>>> surface1 = PlanarSurface(
... point=(0, 0, 0.01),
... normal=(0, 0, -1),
... material_front=Air(),
... material_back=Glass()
... )
>>> surface2 = PlanarSurface(
... point=(0, 0, 0.02),
... normal=(0, 0, -1),
... material_front=Glass(),
... material_back=Air()
... )
>>>
>>> # Trace rays with splitting
>>> final_rays, info = trace_rays_with_splitting(
... rays, [surface1, surface2], max_bounces=5
... )
>>> print(f"Created {info['total_rays_created']} rays total")
"""
from .ray_data import create_ray_batch
# Set default bounding center if not provided
if bounding_center is None:
bounding_center = np.array([0.0, 0.0, 0.0])
else:
bounding_center = np.array(bounding_center)
# Statistics tracking
total_rays_created = rays.num_rays
terminated_by_intensity = 0
terminated_by_bounds = 0
terminated_by_max_bounces = 0
max_depth_reached = 0
# Queue of rays to process: (RayBatch, current_depth)
ray_queue = [(rays.clone(), 0)]
# Collect all terminal rays
terminal_positions = []
terminal_directions = []
terminal_wavelengths = []
terminal_intensities = []
terminal_times = []
terminal_generations = []
while ray_queue:
current_rays, depth = ray_queue.pop(0)
max_depth_reached = max(max_depth_reached, depth)
# Filter out already inactive rays
if not np.any(current_rays.active):
continue
# Check bounding sphere - terminate rays outside
# Distance is measured from bounding_center, not origin
distances_from_center = np.linalg.norm(
current_rays.positions - bounding_center, axis=1
)
outside_bounds = distances_from_center >= bounding_radius
exited_mask = current_rays.active & outside_bounds
if np.any(exited_mask):
terminated_by_bounds += np.sum(exited_mask)
terminal_positions.append(current_rays.positions[exited_mask].copy())
terminal_directions.append(current_rays.directions[exited_mask].copy())
terminal_wavelengths.append(current_rays.wavelengths[exited_mask].copy())
terminal_intensities.append(current_rays.intensities[exited_mask].copy())
terminal_times.append(current_rays.accumulated_time[exited_mask].copy())
terminal_generations.append(current_rays.generations[exited_mask].copy())
current_rays.active[exited_mask] = False
# Check depth limit
if depth >= max_bounces:
remaining_mask = current_rays.active
if np.any(remaining_mask):
terminated_by_max_bounces += np.sum(remaining_mask)
terminal_positions.append(current_rays.positions[remaining_mask].copy())
terminal_directions.append(
current_rays.directions[remaining_mask].copy()
)
terminal_wavelengths.append(
current_rays.wavelengths[remaining_mask].copy()
)
terminal_intensities.append(
current_rays.intensities[remaining_mask].copy()
)
terminal_times.append(
current_rays.accumulated_time[remaining_mask].copy()
)
terminal_generations.append(
current_rays.generations[remaining_mask].copy()
)
continue
# Check intensity threshold - terminate weak rays
weak_mask = current_rays.active & (current_rays.intensities < min_intensity)
if np.any(weak_mask):
terminated_by_intensity += np.sum(weak_mask)
current_rays.active[weak_mask] = False
if not np.any(current_rays.active):
continue
# Find closest surface intersection among all surfaces
active_mask = current_rays.active
active_origins = current_rays.positions[active_mask]
active_directions = current_rays.directions[active_mask]
num_active = np.sum(active_mask)
# Initialize with no hit
closest_distances = np.full(num_active, np.inf, dtype=np.float32)
closest_surface_idx = np.full(num_active, -1, dtype=np.int32)
any_hit = np.zeros(num_active, dtype=bool)
for surf_idx, surface in enumerate(surfaces):
distances, hit_mask = surface.intersect(active_origins, active_directions)
# Update closest hit
closer = hit_mask & (distances < closest_distances)
closest_distances[closer] = distances[closer]
closest_surface_idx[closer] = surf_idx
any_hit |= hit_mask
if not np.any(any_hit):
# No surface hit - rays continue to infinity
# Compute exact intersection with bounding sphere and terminate there
# Solve: |pos + t*dir - center|^2 = R^2
# Let p = pos - center
# |p + t*dir|^2 = R^2
# |dir|^2 * t^2 + 2*(p·dir)*t + (|p|^2 - R^2) = 0
positions = current_rays.positions[active_mask]
directions = current_rays.directions[active_mask]
p = positions - bounding_center
a = np.sum(directions**2, axis=1) # |dir|^2, should be 1
b = 2 * np.sum(p * directions, axis=1) # 2*(p·dir)
c = np.sum(p**2, axis=1) - bounding_radius**2 # |p|^2 - R^2
discriminant = b**2 - 4 * a * c
# For rays inside sphere, discriminant > 0 and we want positive t
# t = (-b ± sqrt(disc)) / 2a
# We want the far intersection (exit point), so use + sqrt
valid = discriminant >= 0
t_exit = np.zeros(len(positions), dtype=np.float32)
t_exit[valid] = (-b[valid] + np.sqrt(discriminant[valid])) / (2 * a[valid])
t_exit[~valid] = 0 # Shouldn't happen for rays inside sphere
t_exit = np.maximum(t_exit, 0) # Only forward intersection
# Move rays to exact bounding sphere intersection
exit_positions = positions + t_exit[:, np.newaxis] * directions
# Store these rays as terminated at bounds
active_indices = np.where(active_mask)[0]
terminated_by_bounds += len(active_indices)
terminal_positions.append(exit_positions.copy())
terminal_directions.append(directions.copy())
terminal_wavelengths.append(current_rays.wavelengths[active_mask].copy())
terminal_intensities.append(current_rays.intensities[active_mask].copy())
terminal_times.append(current_rays.accumulated_time[active_mask].copy())
terminal_generations.append(current_rays.generations[active_mask].copy())
# Deactivate these rays
current_rays.active[active_mask] = False
continue
# Process each hit surface separately
for surf_idx, surface in enumerate(surfaces):
surf_hit_mask = (closest_surface_idx == surf_idx) & any_hit
if not np.any(surf_hit_mask):
continue
# Create a temporary batch for rays hitting this surface
full_hit_mask = np.zeros(len(current_rays.positions), dtype=bool)
active_indices = np.where(active_mask)[0]
full_hit_mask[active_indices[surf_hit_mask]] = True
if not np.any(full_hit_mask):
continue
# Extract hitting rays
hit_origins = current_rays.positions[full_hit_mask]
hit_directions = current_rays.directions[full_hit_mask]
hit_wavelengths = current_rays.wavelengths[full_hit_mask]
hit_intensities = current_rays.intensities[full_hit_mask]
hit_times = current_rays.accumulated_time[full_hit_mask]
hit_generations = current_rays.generations[full_hit_mask]
hit_distances = closest_distances[surf_hit_mask]
# Compute intersection points
hit_positions = hit_origins + hit_distances[:, np.newaxis] * hit_directions
# Correctly-oriented normals and refractive indices
normals, n1_values, n2_values = _orient_media(
surface, hit_positions, hit_directions, hit_wavelengths
)
# Compute travel time (n1_values = n_in = medium the ray was in)
c = 299792458.0
travel_time = hit_distances * n1_values / c
updated_times = hit_times + travel_time
# Compute Fresnel coefficients
cos_theta_i = -np.sum(hit_directions * normals, axis=1)
cos_theta_i = np.abs(cos_theta_i)
R, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization)
# Create REFLECTED rays (always created)
reflected_directions = compute_reflection_direction(hit_directions, normals)
reflected_intensities = hit_intensities * R
reflected_active = reflected_intensities > min_intensity
if np.any(reflected_active):
num_reflected = np.sum(reflected_active)
reflected_rays = create_ray_batch(num_rays=num_reflected)
# Offset must exceed surface.intersect min_distance (0.01)
reflected_rays.positions[:] = (
hit_positions[reflected_active]
+ 0.02 * reflected_directions[reflected_active]
)
reflected_rays.directions[:] = reflected_directions[reflected_active]
reflected_rays.wavelengths[:] = hit_wavelengths[reflected_active]
reflected_rays.intensities[:] = reflected_intensities[reflected_active]
reflected_rays.accumulated_time[:] = updated_times[reflected_active]
reflected_rays.generations[:] = hit_generations[reflected_active] + 1
reflected_rays.active[:] = True
total_rays_created += num_reflected
ray_queue.append((reflected_rays, depth + 1))
# Create REFRACTED rays (unless TIR)
refracted_directions, tir_mask = compute_refraction_direction(
hit_directions, normals, n1_values, n2_values
)
refracted_intensities = hit_intensities * T
refracted_intensities[tir_mask] = 0.0
refracted_active = (refracted_intensities > min_intensity) & (~tir_mask)
if np.any(refracted_active):
num_refracted = np.sum(refracted_active)
refracted_rays = create_ray_batch(num_rays=num_refracted)
# Offset must exceed surface.intersect min_distance (0.01)
refracted_rays.positions[:] = (
hit_positions[refracted_active]
+ 0.02 * refracted_directions[refracted_active]
)
refracted_rays.directions[:] = refracted_directions[refracted_active]
refracted_rays.wavelengths[:] = hit_wavelengths[refracted_active]
refracted_rays.intensities[:] = refracted_intensities[refracted_active]
refracted_rays.accumulated_time[:] = updated_times[refracted_active]
refracted_rays.generations[:] = hit_generations[refracted_active] + 1
refracted_rays.active[:] = True
total_rays_created += num_refracted
ray_queue.append((refracted_rays, depth + 1))
# Mark processed rays as inactive in current batch
current_rays.active[full_hit_mask] = False
# Combine all terminal rays
if len(terminal_positions) > 0:
all_positions = np.vstack(terminal_positions)
all_directions = np.vstack(terminal_directions)
all_wavelengths = np.concatenate(terminal_wavelengths)
all_intensities = np.concatenate(terminal_intensities)
all_times = np.concatenate(terminal_times)
all_generations = np.concatenate(terminal_generations)
num_terminal = len(all_positions)
final_rays = create_ray_batch(num_rays=num_terminal)
final_rays.positions[:] = all_positions
final_rays.directions[:] = all_directions
final_rays.wavelengths[:] = all_wavelengths
final_rays.intensities[:] = all_intensities
final_rays.accumulated_time[:] = all_times
final_rays.generations[:] = all_generations
final_rays.active[:] = True
else:
final_rays = create_ray_batch(num_rays=0)
trace_info = {
"total_rays_created": total_rays_created,
"max_depth_reached": max_depth_reached,
"terminated_by_intensity": terminated_by_intensity,
"terminated_by_bounds": terminated_by_bounds,
"terminated_by_max_bounces": terminated_by_max_bounces,
}
return final_rays, trace_info