Source code for lsurf.utilities.interactions

# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
#      * Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#
#      * Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#      * Neither the name of the copyright holder nor the names of its
#      contributors may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""
Ray-Surface Interactions

Handles the physics of ray-surface interactions including reflection,
refraction, and intensity updates based on Fresnel equations.

At each interface, rays are split into reflected and refracted components
with intensities determined by the Fresnel equations. This properly models
the physical behavior where light partially reflects and partially transmits
at each interface.

Functions
---------
process_surface_interaction
    Process rays intersecting a surface (generates both reflected + refracted rays)
reflect_rays
    Apply reflection to rays at surface (modifies rays in place)
refract_rays
    Apply refraction to rays at surface (modifies rays in place)
trace_rays_multi_bounce
    Trace reflected rays through multiple bounces (legacy, reflected-only tracing)
trace_rays_with_splitting
    Trace rays with proper Fresnel splitting (both reflected AND refracted rays)

References
----------
.. [1] Glassner, A. S. (1989). An Introduction to Ray Tracing.
       Academic Press.
"""

import numpy as np
from numpy.typing import NDArray

from ..surfaces import Surface
from .fresnel import (
    compute_reflection_direction,
    compute_refraction_direction,
    fresnel_coefficients,
    initialize_polarization_vectors,
    transform_polarization_reflection,
    transform_polarization_refraction,
)
from .ray_data import RayBatch, create_ray_batch


def _orient_media(
    surface: Surface,
    hit_positions: NDArray[np.float32],
    hit_directions: NDArray[np.float32],
    hit_wavelengths: NDArray[np.float32],
) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
    """Compute correctly-oriented normals and refractive indices.

    Determines which side of the surface each ray is coming from and
    assigns n_in/n_out accordingly.  The returned normals always face
    the incoming ray (consistent with ``surface.normal_at`` when
    ``incoming_directions`` is provided).

    Returns (normals, n_in, n_out).
    """
    # Raw normals (un-flipped) define the "front" side of the surface.
    raw_normals = surface.normal_at(hit_positions)

    # Rays hitting from the front side have dot(dir, raw_normal) < 0.
    dot_raw = np.sum(hit_directions * raw_normals, axis=1)
    from_front = dot_raw < 0

    # Flip normals so they always face the incoming ray.
    normals = raw_normals.copy()
    normals[~from_front] = -normals[~from_front]

    # Refractive index on each side.
    n_front = np.array(
        [
            surface.material_front.get_refractive_index(p[0], p[1], p[2], w)
            for p, w in zip(hit_positions, hit_wavelengths, strict=False)
        ],
        dtype=np.float32,
    )
    n_back = np.array(
        [
            surface.material_back.get_refractive_index(p[0], p[1], p[2], w)
            for p, w in zip(hit_positions, hit_wavelengths, strict=False)
        ],
        dtype=np.float32,
    )

    # n_in = medium the ray comes FROM; n_out = medium the ray goes INTO.
    n_in = np.where(from_front, n_front, n_back)
    n_out = np.where(from_front, n_back, n_front)

    return normals, n_in, n_out


[docs] def process_surface_interaction( rays: RayBatch, surface: Surface, wavelength: float | NDArray[np.float32] = 500e-9, generate_reflected: bool = True, generate_refracted: bool = True, polarization: str = "unpolarized", track_polarization_vector: bool = False, ) -> tuple[RayBatch | None, RayBatch | None]: """ Process rays intersecting a surface. Computes intersections, applies Fresnel equations, and generates reflected and/or refracted ray bundles. Parameters ---------- rays : RayBatch Input rays to test for intersection surface : Surface Surface to intersect with wavelength : float or ndarray, optional Wavelength for computing refractive indices (default: 500 nm) If rays have multiple wavelengths, use rays.wavelengths generate_reflected : bool, optional Whether to generate reflected rays (default: True) generate_refracted : bool, optional Whether to generate refracted rays (default: True) polarization : str, optional Polarization state: 's', 'p', or 'unpolarized' (default) Used for Fresnel coefficient calculation. track_polarization_vector : bool, optional Whether to track 3D polarization vectors through the interaction. If True, polarization vectors are initialized (if not present) and transformed through reflection/refraction. (default: False) Returns ------- reflected_rays : RayBatch or None Reflected rays (None if generate_reflected=False or no hits) refracted_rays : RayBatch or None Refracted rays (None if generate_refracted=False or no hits) Notes ----- Only active rays are tested for intersection. Input rays are not modified. When track_polarization_vector=True, the function: 1. Initializes polarization vectors if rays.polarization_vector is None 2. Transforms polarization vectors through reflection/refraction 3. Stores the transformed vectors in the output ray batches Examples -------- >>> # Create surface and rays >>> surface = PlanarSurface( ... point=(0, 0, 1), ... normal=(0, 0, -1), ... material_front=BK7_GLASS, ... material_back=AIR_STP ... ) >>> rays = ... # Create rays >>> >>> # Process interaction with polarization tracking >>> reflected, refracted = process_surface_interaction( ... rays, surface, generate_reflected=True, generate_refracted=True, ... track_polarization_vector=True ... ) """ # Only test active rays active_mask = rays.active if not np.any(active_mask): return None, None active_origins = rays.positions[active_mask] active_directions = rays.directions[active_mask] active_wavelengths = rays.wavelengths[active_mask] active_intensities = rays.intensities[active_mask] active_times = rays.accumulated_time[active_mask] # Handle polarization vectors active_polarization_vectors = None if track_polarization_vector: if rays.polarization_vector is not None: active_polarization_vectors = rays.polarization_vector[active_mask] else: # Initialize polarization vectors based on polarization state active_polarization_vectors = initialize_polarization_vectors( active_directions, polarization=polarization ) # Find intersections distances, hit_mask = surface.intersect(active_origins, active_directions) if not np.any(hit_mask): return None, None # Intersection points hit_positions = ( active_origins[hit_mask] + distances[hit_mask, np.newaxis] * active_directions[hit_mask] ) hit_directions = active_directions[hit_mask] hit_wavelengths = active_wavelengths[hit_mask] hit_intensities = active_intensities[hit_mask] hit_times = active_times[hit_mask] hit_distances = distances[hit_mask] # Get hit polarization vectors if tracking hit_polarization_vectors = None if track_polarization_vector and active_polarization_vectors is not None: hit_polarization_vectors = active_polarization_vectors[hit_mask] # Compute correctly-oriented normals and refractive indices. # _orient_media determines which side each ray hits from and assigns # n_in (medium the ray travels through) and n_out accordingly. normals, n1_values, n2_values = _orient_media( surface, hit_positions, hit_directions, hit_wavelengths ) # Calculate time to reach surface: distance / phase_velocity # Phase velocity = c / n; n1_values = n_in = medium the ray was in. c = 299792458.0 # Speed of light in m/s travel_time = hit_distances * n1_values / c # Time = distance * n / c updated_times = hit_times + travel_time # Compute incident angle (normals face the incoming ray, so this is positive) cos_theta_i = -np.sum(hit_directions * normals, axis=1) cos_theta_i = np.abs(cos_theta_i) # Ensure positive # Compute Fresnel coefficients R, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization) # If tracking polarization vectors, also compute R_s and R_p separately # for proper Fresnel weighting of polarization components R_s = None R_p = None if track_polarization_vector: R_s, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, "s") R_p, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, "p") # Generate reflected rays reflected_rays = None if generate_reflected: reflected_directions = compute_reflection_direction(hit_directions, normals) reflected_intensities = hit_intensities * R # Create reflected ray batch num_reflected = len(hit_positions) reflected_rays = create_ray_batch( num_rays=num_reflected, enable_polarization_vector=track_polarization_vector, ) # Offset along ray direction to prevent immediate re-intersection # Must be larger than intersection tolerance (1e-3) to avoid self-intersection reflected_rays.positions[:] = hit_positions + 0.01 * reflected_directions reflected_rays.directions[:] = reflected_directions reflected_rays.wavelengths[:] = hit_wavelengths reflected_rays.intensities[:] = reflected_intensities reflected_rays.active[:] = ( reflected_intensities > 1e-10 ) # Deactivate very weak rays reflected_rays.accumulated_time[:] = updated_times reflected_rays.generations[:] = rays.generations[active_mask][hit_mask] + 1 # Transform polarization vectors for reflected rays with Fresnel weighting if track_polarization_vector and hit_polarization_vectors is not None: reflected_pol = transform_polarization_reflection( hit_polarization_vectors, hit_directions, reflected_directions, normals, R_s=R_s, R_p=R_p, ) reflected_rays.polarization_vector[:] = reflected_pol # Generate refracted rays refracted_rays = None if generate_refracted: refracted_directions, tir_mask = compute_refraction_direction( hit_directions, normals, n1_values, n2_values ) # Transmission intensity (0 for TIR) refracted_intensities = hit_intensities * T refracted_intensities[tir_mask] = 0.0 # Create refracted ray batch num_refracted = len(hit_positions) refracted_rays = create_ray_batch( num_rays=num_refracted, enable_polarization_vector=track_polarization_vector, ) # Offset along ray direction to prevent immediate re-intersection # Must be larger than intersection tolerance (1e-3) to avoid self-intersection refracted_rays.positions[:] = hit_positions + 0.01 * refracted_directions refracted_rays.directions[:] = refracted_directions refracted_rays.wavelengths[:] = hit_wavelengths refracted_rays.intensities[:] = refracted_intensities refracted_rays.active[:] = (refracted_intensities > 1e-10) & (~tir_mask) refracted_rays.accumulated_time[:] = updated_times refracted_rays.generations[:] = rays.generations[active_mask][hit_mask] + 1 # Transform polarization vectors for refracted rays if track_polarization_vector and hit_polarization_vectors is not None: refracted_pol = transform_polarization_refraction( hit_polarization_vectors, hit_directions, refracted_directions, normals, ) # For TIR rays, set polarization to zero (they won't be used) refracted_pol[tir_mask] = 0.0 refracted_rays.polarization_vector[:] = refracted_pol return reflected_rays, refracted_rays
[docs] def reflect_rays( rays: RayBatch, surface: Surface, wavelength: float | NDArray[np.float32] = 500e-9, polarization: str = "unpolarized", in_place: bool = False, ) -> RayBatch: """ Apply reflection to rays at surface. Updates ray directions and intensities based on Fresnel reflection. Rays that don't hit the surface are deactivated. Parameters ---------- rays : RayBatch Input rays surface : Surface Surface to reflect from wavelength : float or ndarray, optional Wavelength for computing refractive indices polarization : str, optional Polarization state in_place : bool, optional If True, modify rays in place. If False, return new batch. Returns ------- RayBatch Reflected rays (same object if in_place=True) """ if not in_place: rays = rays.clone() # Only process active rays active_mask = rays.active if not np.any(active_mask): return rays # Find intersections distances, hit_mask = surface.intersect( rays.positions[active_mask], rays.directions[active_mask] ) if not np.any(hit_mask): rays.active[:] = False return rays # Build full-size hit mask full_hit_mask = np.zeros(len(rays.positions), dtype=bool) full_hit_mask[active_mask] = hit_mask # Process hits active_indices = np.where(active_mask)[0][hit_mask] hit_positions = ( rays.positions[active_indices] + distances[hit_mask, np.newaxis] * rays.directions[active_indices] ) hit_dirs = rays.directions[active_indices] hit_wls = rays.wavelengths[active_indices] # Correctly-oriented normals and refractive indices normals, n1_values, n2_values = _orient_media( surface, hit_positions, hit_dirs, hit_wls ) # Compute reflection reflected_directions = compute_reflection_direction(hit_dirs, normals) # Compute Fresnel coefficients cos_theta_i = -np.sum(hit_dirs * normals, axis=1) cos_theta_i = np.abs(cos_theta_i) R, _ = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization) # Update rays - offset must exceed surface.intersect min_distance (0.01) rays.positions[active_indices] = hit_positions + 0.02 * reflected_directions rays.directions[active_indices] = reflected_directions rays.intensities[active_indices] *= R rays.generations[active_indices] += 1 # Deactivate non-hits and weak rays rays.active[~full_hit_mask] = False rays.active[rays.intensities < 1e-10] = False return rays
[docs] def refract_rays( rays: RayBatch, surface: Surface, wavelength: float | NDArray[np.float32] = 500e-9, polarization: str = "unpolarized", in_place: bool = False, ) -> RayBatch: """ Apply refraction to rays at surface. Updates ray directions and intensities based on Fresnel transmission. Rays undergoing total internal reflection are deactivated. Parameters ---------- rays : RayBatch Input rays surface : Surface Surface to refract through wavelength : float or ndarray, optional Wavelength for computing refractive indices polarization : str, optional Polarization state in_place : bool, optional If True, modify rays in place. If False, return new batch. Returns ------- RayBatch Refracted rays (same object if in_place=True) """ if not in_place: rays = rays.clone() # Only process active rays active_mask = rays.active if not np.any(active_mask): return rays # Find intersections distances, hit_mask = surface.intersect( rays.positions[active_mask], rays.directions[active_mask] ) if not np.any(hit_mask): rays.active[:] = False return rays # Build full-size hit mask full_hit_mask = np.zeros(len(rays.positions), dtype=bool) full_hit_mask[active_mask] = hit_mask # Process hits active_indices = np.where(active_mask)[0][hit_mask] hit_positions = ( rays.positions[active_indices] + distances[hit_mask, np.newaxis] * rays.directions[active_indices] ) hit_dirs = rays.directions[active_indices] hit_wls = rays.wavelengths[active_indices] # Correctly-oriented normals and refractive indices normals, n1_values, n2_values = _orient_media( surface, hit_positions, hit_dirs, hit_wls ) # Compute refraction refracted_directions, tir_mask = compute_refraction_direction( hit_dirs, normals, n1_values, n2_values ) # Compute Fresnel coefficients cos_theta_i = -np.sum(hit_dirs * normals, axis=1) cos_theta_i = np.abs(cos_theta_i) _, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization) # Update rays - offset must exceed surface.intersect min_distance (0.01) rays.positions[active_indices] = hit_positions + 0.02 * refracted_directions rays.directions[active_indices] = refracted_directions rays.intensities[active_indices] *= T rays.generations[active_indices] += 1 # Deactivate TIR rays, non-hits, and weak rays tir_indices = active_indices[tir_mask] rays.active[tir_indices] = False rays.active[~full_hit_mask] = False rays.active[rays.intensities < 1e-10] = False return rays
[docs] def trace_rays_multi_bounce( rays: RayBatch, surface: Surface, max_bounces: int = 10, bounding_radius: float = 10000.0, wavelength: float = 532e-9, min_intensity: float = 1e-10, track_refracted: bool = True, ) -> tuple[RayBatch, RayBatch, dict]: """ Trace rays through multiple surface interactions until termination. Rays are traced until they: - Exit the bounding sphere - Reach maximum number of bounces - Have intensity below threshold Parameters ---------- rays : RayBatch Initial ray batch to trace surface : Surface Surface to interact with max_bounces : int, optional Maximum number of surface interactions (default: 10) bounding_radius : float, optional Radius of bounding sphere in meters (default: 10000) wavelength : float, optional Wavelength for Fresnel calculations (default: 532nm) min_intensity : float, optional Minimum intensity threshold (default: 1e-10) track_refracted : bool, optional Whether to track refracted rays (default: True) Returns ------- final_reflected : RayBatch Final state of all reflected rays that exited bounding sphere final_refracted : RayBatch Final state of all refracted rays (combined from all bounces) ray_paths : dict Dictionary containing: - 'reflected_paths': list of paths, each path is Nx3 array of positions - 'refracted_paths': list of paths for refracted rays (start from refraction point) - 'reflected_final_dirs': list of final direction vectors for each reflected path - 'refracted_final_dirs': list of final direction vectors for each refracted path Notes ----- The function tracks reflected rays through multiple bounces on the wavy surface. Each ray's complete path (all positions) is stored for visualization. Refracted rays are collected but not further traced (they go into the water and don't re-interact with the air-water interface from below in typical scenarios). Examples -------- >>> # Trace rays with up to 5 bounces >>> final_refl, final_refr, paths = trace_rays_multi_bounce( ... rays, surface, max_bounces=5, bounding_radius=5000.0 ... ) >>> # Plot a ray's path >>> path = paths['reflected_paths'][0] # First ray's path >>> plt.plot(path[:, 0], path[:, 2]) # x-z projection """ from .ray_data import create_ray_batch # Clone input rays to avoid modifying original current_rays = rays.clone() # Track original ray indices - maps current index to original ray index num_original = rays.num_rays current_to_original = np.arange(num_original) # Initialize paths for each original ray - start with their original positions # Each path is a list of positions that will be converted to array at the end ray_path_lists = [[] for _ in range(num_original)] ray_final_dirs = [None for _ in range(num_original)] # Final direction for each ray ray_is_reflected = [ True for _ in range(num_original) ] # True if reflected, False if refracted # Store initial positions for i in range(num_original): ray_path_lists[i].append(rays.positions[i].copy()) # Storage for rays that have exited the bounding sphere exited_positions = [] exited_directions = [] exited_wavelengths = [] exited_intensities = [] exited_times = [] exited_generations = [] # Storage for all refracted rays all_refracted_positions = [] all_refracted_directions = [] all_refracted_wavelengths = [] all_refracted_intensities = [] all_refracted_times = [] all_refracted_generations = [] # Storage for refracted ray paths (separate from reflected paths) refracted_path_lists = [] refracted_final_dirs = [] for _bounce in range(max_bounces): # Check which rays are still inside bounding sphere distances_from_origin = np.linalg.norm(current_rays.positions, axis=1) inside_sphere = distances_from_origin < bounding_radius # Rays that have exited - store them and record final positions/directions exited_mask = current_rays.active & ~inside_sphere if np.any(exited_mask): exited_indices = np.where(exited_mask)[0] for local_idx in exited_indices: orig_idx = current_to_original[local_idx] # Record final position and direction ray_path_lists[orig_idx].append( current_rays.positions[local_idx].copy() ) ray_final_dirs[orig_idx] = current_rays.directions[local_idx].copy() exited_positions.append(current_rays.positions[exited_mask].copy()) exited_directions.append(current_rays.directions[exited_mask].copy()) exited_wavelengths.append(current_rays.wavelengths[exited_mask].copy()) exited_intensities.append(current_rays.intensities[exited_mask].copy()) exited_times.append(current_rays.accumulated_time[exited_mask].copy()) exited_generations.append(current_rays.generations[exited_mask].copy()) current_rays.active[exited_mask] = False # Update active mask for rays still inside sphere current_rays.active &= inside_sphere # Check if any rays are still active if not np.any(current_rays.active): break # Process surface interaction reflected_rays, refracted_rays = process_surface_interaction( current_rays, surface, wavelength=wavelength, generate_reflected=True, generate_refracted=track_refracted, ) # Record intersection positions for active rays (reflected rays have the intersection point) if reflected_rays is not None and reflected_rays.num_rays > 0: # Map the reflected rays back to original indices # Note: process_surface_interaction returns rays in same order as input active rays active_indices = np.where(current_rays.active)[0] for i, local_idx in enumerate(active_indices): if i < reflected_rays.num_rays: orig_idx = current_to_original[local_idx] ray_path_lists[orig_idx].append(reflected_rays.positions[i].copy()) # Collect refracted rays and create separate paths for them if refracted_rays is not None and refracted_rays.num_rays > 0: active_refr = refracted_rays.active if np.any(active_refr): # Record refracted ray paths (they start from the refraction point) active_indices = np.where(current_rays.active)[0] refr_active_indices = np.where(active_refr)[0] for i in refr_active_indices: if i < len(active_indices): local_idx = active_indices[i] orig_idx = current_to_original[local_idx] # Mark this ray as having been refracted ray_is_reflected[orig_idx] = False # Create a refracted path starting from refraction point refracted_path_lists.append( np.array([refracted_rays.positions[i].copy()]) ) refracted_final_dirs.append(refracted_rays.directions[i].copy()) all_refracted_positions.append( refracted_rays.positions[active_refr].copy() ) all_refracted_directions.append( refracted_rays.directions[active_refr].copy() ) all_refracted_wavelengths.append( refracted_rays.wavelengths[active_refr].copy() ) all_refracted_intensities.append( refracted_rays.intensities[active_refr].copy() ) all_refracted_times.append( refracted_rays.accumulated_time[active_refr].copy() ) all_refracted_generations.append( refracted_rays.generations[active_refr].copy() ) # Continue with reflected rays if reflected_rays is None or reflected_rays.num_rays == 0: break # Filter out weak rays reflected_rays.active &= reflected_rays.intensities > min_intensity if not np.any(reflected_rays.active): break # Update index mapping for remaining active rays active_mask = reflected_rays.active active_indices = np.where(current_rays.active)[0] # Build new mapping: new index -> original ray index new_to_original = [] for i, is_active in enumerate(active_mask): if is_active and i < len(active_indices): local_idx = active_indices[i] new_to_original.append(current_to_original[local_idx]) current_to_original = np.array(new_to_original) # Use reflected rays for next iteration current_rays = reflected_rays # Handle any remaining active rays (didn't exit yet) if np.any(current_rays.active): remaining_mask = current_rays.active remaining_indices = np.where(remaining_mask)[0] for local_idx in remaining_indices: if local_idx < len(current_to_original): orig_idx = current_to_original[local_idx] ray_final_dirs[orig_idx] = current_rays.directions[local_idx].copy() remaining_positions = current_rays.positions[remaining_mask] remaining_directions = current_rays.directions[remaining_mask] remaining_wavelengths = current_rays.wavelengths[remaining_mask] remaining_intensities = current_rays.intensities[remaining_mask] remaining_times = current_rays.accumulated_time[remaining_mask] remaining_generations = current_rays.generations[remaining_mask] exited_positions.append(remaining_positions) exited_directions.append(remaining_directions) exited_wavelengths.append(remaining_wavelengths) exited_intensities.append(remaining_intensities) exited_times.append(remaining_times) exited_generations.append(remaining_generations) # Combine all exited rays into final reflected batch if len(exited_positions) > 0: all_exited_positions = np.vstack(exited_positions) all_exited_directions = np.vstack(exited_directions) all_exited_wavelengths = np.concatenate(exited_wavelengths) all_exited_intensities = np.concatenate(exited_intensities) all_exited_times = np.concatenate(exited_times) all_exited_generations = np.concatenate(exited_generations) num_exited = len(all_exited_positions) final_reflected = create_ray_batch(num_rays=num_exited) final_reflected.positions[:] = all_exited_positions final_reflected.directions[:] = all_exited_directions final_reflected.wavelengths[:] = all_exited_wavelengths final_reflected.intensities[:] = all_exited_intensities final_reflected.accumulated_time[:] = all_exited_times final_reflected.generations[:] = all_exited_generations final_reflected.active[:] = True else: # Return empty batch if no rays exited final_reflected = create_ray_batch(num_rays=0) # Combine all refracted rays if len(all_refracted_positions) > 0: combined_refr_positions = np.vstack(all_refracted_positions) combined_refr_directions = np.vstack(all_refracted_directions) combined_refr_wavelengths = np.concatenate(all_refracted_wavelengths) combined_refr_intensities = np.concatenate(all_refracted_intensities) combined_refr_times = np.concatenate(all_refracted_times) combined_refr_generations = np.concatenate(all_refracted_generations) num_refracted = len(combined_refr_positions) final_refracted = create_ray_batch(num_rays=num_refracted) final_refracted.positions[:] = combined_refr_positions final_refracted.directions[:] = combined_refr_directions final_refracted.wavelengths[:] = combined_refr_wavelengths final_refracted.intensities[:] = combined_refr_intensities final_refracted.accumulated_time[:] = combined_refr_times final_refracted.generations[:] = combined_refr_generations final_refracted.active[:] = True else: final_refracted = create_ray_batch(num_rays=0) # Convert path lists to arrays reflected_paths = [] reflected_final_directions = [] for i in range(num_original): if len(ray_path_lists[i]) > 0: reflected_paths.append(np.array(ray_path_lists[i])) reflected_final_directions.append(ray_final_dirs[i]) ray_paths = { "reflected_paths": reflected_paths, "refracted_paths": refracted_path_lists, "reflected_final_dirs": reflected_final_directions, "refracted_final_dirs": refracted_final_dirs, } return final_reflected, final_refracted, ray_paths
[docs] def trace_rays_with_splitting( rays: RayBatch, surfaces: list, max_bounces: int = 10, bounding_radius: float = 10000.0, bounding_center: tuple = None, wavelength: float = 532e-9, min_intensity: float = 1e-10, polarization: str = "unpolarized", ) -> tuple[RayBatch, dict]: """ Trace rays through multiple surfaces with proper Fresnel ray splitting. At each surface interaction, every ray is split into two child rays: - A reflected ray with intensity scaled by the Fresnel reflection coefficient R - A refracted ray with intensity scaled by the Fresnel transmission coefficient T This creates a ray tree where both reflected and refracted paths are traced until termination conditions are met. Parameters ---------- rays : RayBatch Initial ray batch to trace surfaces : list of Surface List of surfaces to interact with (tested in order for closest hit) max_bounces : int, optional Maximum number of surface interactions per ray tree branch (default: 10) bounding_radius : float, optional Radius of bounding sphere in meters (default: 10000) bounding_center : tuple of float, optional Center of bounding sphere (x, y, z) in meters. Default is (0, 0, 0). For curved Earth simulations, use (0, 0, -EARTH_RADIUS). wavelength : float, optional Wavelength for Fresnel calculations (default: 532nm) min_intensity : float, optional Minimum intensity threshold for ray termination (default: 1e-10) polarization : str, optional Polarization state: 's', 'p', or 'unpolarized' (default) Returns ------- final_rays : RayBatch All terminal rays (rays that exited the scene or hit max bounces) Each ray has its intensity weighted by the product of all Fresnel coefficients along its path. trace_info : dict Dictionary containing tracing statistics: - 'total_rays_created': Total number of rays created during tracing - 'max_depth_reached': Maximum tree depth reached - 'terminated_by_intensity': Number of rays terminated due to low intensity - 'terminated_by_bounds': Number of rays that exited bounding sphere - 'terminated_by_max_bounces': Number of rays that hit max bounce limit Notes ----- This function implements proper physical ray splitting where each ray at an interface creates two child rays: - Reflected ray: direction from law of reflection, intensity = I_parent * R - Refracted ray: direction from Snell's law, intensity = I_parent * T where R and T are the Fresnel reflection and transmission coefficients computed from the refractive indices and incident angle. For total internal reflection (TIR), T=0 and R=1, so only a reflected ray is created. The number of rays grows exponentially with depth (up to 2^depth), so the min_intensity threshold is critical for pruning weak ray branches. Examples -------- >>> from surface_roughness.surfaces import PlanarSurface >>> from surface_roughness.materials import Glass, Air >>> >>> # Create a glass slab >>> surface1 = PlanarSurface( ... point=(0, 0, 0.01), ... normal=(0, 0, -1), ... material_front=Air(), ... material_back=Glass() ... ) >>> surface2 = PlanarSurface( ... point=(0, 0, 0.02), ... normal=(0, 0, -1), ... material_front=Glass(), ... material_back=Air() ... ) >>> >>> # Trace rays with splitting >>> final_rays, info = trace_rays_with_splitting( ... rays, [surface1, surface2], max_bounces=5 ... ) >>> print(f"Created {info['total_rays_created']} rays total") """ from .ray_data import create_ray_batch # Set default bounding center if not provided if bounding_center is None: bounding_center = np.array([0.0, 0.0, 0.0]) else: bounding_center = np.array(bounding_center) # Statistics tracking total_rays_created = rays.num_rays terminated_by_intensity = 0 terminated_by_bounds = 0 terminated_by_max_bounces = 0 max_depth_reached = 0 # Queue of rays to process: (RayBatch, current_depth) ray_queue = [(rays.clone(), 0)] # Collect all terminal rays terminal_positions = [] terminal_directions = [] terminal_wavelengths = [] terminal_intensities = [] terminal_times = [] terminal_generations = [] while ray_queue: current_rays, depth = ray_queue.pop(0) max_depth_reached = max(max_depth_reached, depth) # Filter out already inactive rays if not np.any(current_rays.active): continue # Check bounding sphere - terminate rays outside # Distance is measured from bounding_center, not origin distances_from_center = np.linalg.norm( current_rays.positions - bounding_center, axis=1 ) outside_bounds = distances_from_center >= bounding_radius exited_mask = current_rays.active & outside_bounds if np.any(exited_mask): terminated_by_bounds += np.sum(exited_mask) terminal_positions.append(current_rays.positions[exited_mask].copy()) terminal_directions.append(current_rays.directions[exited_mask].copy()) terminal_wavelengths.append(current_rays.wavelengths[exited_mask].copy()) terminal_intensities.append(current_rays.intensities[exited_mask].copy()) terminal_times.append(current_rays.accumulated_time[exited_mask].copy()) terminal_generations.append(current_rays.generations[exited_mask].copy()) current_rays.active[exited_mask] = False # Check depth limit if depth >= max_bounces: remaining_mask = current_rays.active if np.any(remaining_mask): terminated_by_max_bounces += np.sum(remaining_mask) terminal_positions.append(current_rays.positions[remaining_mask].copy()) terminal_directions.append( current_rays.directions[remaining_mask].copy() ) terminal_wavelengths.append( current_rays.wavelengths[remaining_mask].copy() ) terminal_intensities.append( current_rays.intensities[remaining_mask].copy() ) terminal_times.append( current_rays.accumulated_time[remaining_mask].copy() ) terminal_generations.append( current_rays.generations[remaining_mask].copy() ) continue # Check intensity threshold - terminate weak rays weak_mask = current_rays.active & (current_rays.intensities < min_intensity) if np.any(weak_mask): terminated_by_intensity += np.sum(weak_mask) current_rays.active[weak_mask] = False if not np.any(current_rays.active): continue # Find closest surface intersection among all surfaces active_mask = current_rays.active active_origins = current_rays.positions[active_mask] active_directions = current_rays.directions[active_mask] num_active = np.sum(active_mask) # Initialize with no hit closest_distances = np.full(num_active, np.inf, dtype=np.float32) closest_surface_idx = np.full(num_active, -1, dtype=np.int32) any_hit = np.zeros(num_active, dtype=bool) for surf_idx, surface in enumerate(surfaces): distances, hit_mask = surface.intersect(active_origins, active_directions) # Update closest hit closer = hit_mask & (distances < closest_distances) closest_distances[closer] = distances[closer] closest_surface_idx[closer] = surf_idx any_hit |= hit_mask if not np.any(any_hit): # No surface hit - rays continue to infinity # Compute exact intersection with bounding sphere and terminate there # Solve: |pos + t*dir - center|^2 = R^2 # Let p = pos - center # |p + t*dir|^2 = R^2 # |dir|^2 * t^2 + 2*(p·dir)*t + (|p|^2 - R^2) = 0 positions = current_rays.positions[active_mask] directions = current_rays.directions[active_mask] p = positions - bounding_center a = np.sum(directions**2, axis=1) # |dir|^2, should be 1 b = 2 * np.sum(p * directions, axis=1) # 2*(p·dir) c = np.sum(p**2, axis=1) - bounding_radius**2 # |p|^2 - R^2 discriminant = b**2 - 4 * a * c # For rays inside sphere, discriminant > 0 and we want positive t # t = (-b ± sqrt(disc)) / 2a # We want the far intersection (exit point), so use + sqrt valid = discriminant >= 0 t_exit = np.zeros(len(positions), dtype=np.float32) t_exit[valid] = (-b[valid] + np.sqrt(discriminant[valid])) / (2 * a[valid]) t_exit[~valid] = 0 # Shouldn't happen for rays inside sphere t_exit = np.maximum(t_exit, 0) # Only forward intersection # Move rays to exact bounding sphere intersection exit_positions = positions + t_exit[:, np.newaxis] * directions # Store these rays as terminated at bounds active_indices = np.where(active_mask)[0] terminated_by_bounds += len(active_indices) terminal_positions.append(exit_positions.copy()) terminal_directions.append(directions.copy()) terminal_wavelengths.append(current_rays.wavelengths[active_mask].copy()) terminal_intensities.append(current_rays.intensities[active_mask].copy()) terminal_times.append(current_rays.accumulated_time[active_mask].copy()) terminal_generations.append(current_rays.generations[active_mask].copy()) # Deactivate these rays current_rays.active[active_mask] = False continue # Process each hit surface separately for surf_idx, surface in enumerate(surfaces): surf_hit_mask = (closest_surface_idx == surf_idx) & any_hit if not np.any(surf_hit_mask): continue # Create a temporary batch for rays hitting this surface full_hit_mask = np.zeros(len(current_rays.positions), dtype=bool) active_indices = np.where(active_mask)[0] full_hit_mask[active_indices[surf_hit_mask]] = True if not np.any(full_hit_mask): continue # Extract hitting rays hit_origins = current_rays.positions[full_hit_mask] hit_directions = current_rays.directions[full_hit_mask] hit_wavelengths = current_rays.wavelengths[full_hit_mask] hit_intensities = current_rays.intensities[full_hit_mask] hit_times = current_rays.accumulated_time[full_hit_mask] hit_generations = current_rays.generations[full_hit_mask] hit_distances = closest_distances[surf_hit_mask] # Compute intersection points hit_positions = hit_origins + hit_distances[:, np.newaxis] * hit_directions # Correctly-oriented normals and refractive indices normals, n1_values, n2_values = _orient_media( surface, hit_positions, hit_directions, hit_wavelengths ) # Compute travel time (n1_values = n_in = medium the ray was in) c = 299792458.0 travel_time = hit_distances * n1_values / c updated_times = hit_times + travel_time # Compute Fresnel coefficients cos_theta_i = -np.sum(hit_directions * normals, axis=1) cos_theta_i = np.abs(cos_theta_i) R, T = fresnel_coefficients(n1_values, n2_values, cos_theta_i, polarization) # Create REFLECTED rays (always created) reflected_directions = compute_reflection_direction(hit_directions, normals) reflected_intensities = hit_intensities * R reflected_active = reflected_intensities > min_intensity if np.any(reflected_active): num_reflected = np.sum(reflected_active) reflected_rays = create_ray_batch(num_rays=num_reflected) # Offset must exceed surface.intersect min_distance (0.01) reflected_rays.positions[:] = ( hit_positions[reflected_active] + 0.02 * reflected_directions[reflected_active] ) reflected_rays.directions[:] = reflected_directions[reflected_active] reflected_rays.wavelengths[:] = hit_wavelengths[reflected_active] reflected_rays.intensities[:] = reflected_intensities[reflected_active] reflected_rays.accumulated_time[:] = updated_times[reflected_active] reflected_rays.generations[:] = hit_generations[reflected_active] + 1 reflected_rays.active[:] = True total_rays_created += num_reflected ray_queue.append((reflected_rays, depth + 1)) # Create REFRACTED rays (unless TIR) refracted_directions, tir_mask = compute_refraction_direction( hit_directions, normals, n1_values, n2_values ) refracted_intensities = hit_intensities * T refracted_intensities[tir_mask] = 0.0 refracted_active = (refracted_intensities > min_intensity) & (~tir_mask) if np.any(refracted_active): num_refracted = np.sum(refracted_active) refracted_rays = create_ray_batch(num_rays=num_refracted) # Offset must exceed surface.intersect min_distance (0.01) refracted_rays.positions[:] = ( hit_positions[refracted_active] + 0.02 * refracted_directions[refracted_active] ) refracted_rays.directions[:] = refracted_directions[refracted_active] refracted_rays.wavelengths[:] = hit_wavelengths[refracted_active] refracted_rays.intensities[:] = refracted_intensities[refracted_active] refracted_rays.accumulated_time[:] = updated_times[refracted_active] refracted_rays.generations[:] = hit_generations[refracted_active] + 1 refracted_rays.active[:] = True total_rays_created += num_refracted ray_queue.append((refracted_rays, depth + 1)) # Mark processed rays as inactive in current batch current_rays.active[full_hit_mask] = False # Combine all terminal rays if len(terminal_positions) > 0: all_positions = np.vstack(terminal_positions) all_directions = np.vstack(terminal_directions) all_wavelengths = np.concatenate(terminal_wavelengths) all_intensities = np.concatenate(terminal_intensities) all_times = np.concatenate(terminal_times) all_generations = np.concatenate(terminal_generations) num_terminal = len(all_positions) final_rays = create_ray_batch(num_rays=num_terminal) final_rays.positions[:] = all_positions final_rays.directions[:] = all_directions final_rays.wavelengths[:] = all_wavelengths final_rays.intensities[:] = all_intensities final_rays.accumulated_time[:] = all_times final_rays.generations[:] = all_generations final_rays.active[:] = True else: final_rays = create_ray_batch(num_rays=0) trace_info = { "total_rays_created": total_rays_created, "max_depth_reached": max_depth_reached, "terminated_by_intensity": terminated_by_intensity, "terminated_by_bounds": terminated_by_bounds, "terminated_by_max_bounces": terminated_by_max_bounces, } return final_rays, trace_info