Source code for lsurf.surfaces.gpu.multi_curved_wave

# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
#      * Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#
#      * Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#      * Neither the name of the copyright holder nor the names of its
#      contributors may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""
Multi-Wave Curved Surface (GPU-Capable)

GPU-accelerated ocean wave surface on a curved (spherical) Earth with multiple
wave components. Supports up to 8 superimposed wave components for realistic
ocean surface modeling.

Uses geometry_id=5 for GPU kernel dispatch. Parameters are packed into a
64-element tuple with the following layout:
- p0-2: earth_center_x/y/z
- p3: earth_radius
- p4: time
- p5: num_waves
- p6-7: reserved
- For each wave i (i=0..7), starting at offset 8 + i*8:
  - amplitude, wave_number, dir_x, dir_y, phase, steepness, reserved, reserved
"""

from dataclasses import dataclass, field
from typing import Any

import numpy as np
import numpy.typing as npt

from ..protocol import Surface, SurfaceRole
from ..registry import register_surface_type
from ..cpu.wave_params import GerstnerWaveParams

# Earth parameters
EARTH_RADIUS = 6.371e6  # Earth's mean radius in meters
GRAVITY = 9.81  # Gravity for deep water dispersion

# Maximum number of wave components
MAX_WAVES = 8


[docs] @dataclass class GPUMultiCurvedWaveSurface(Surface): """ GPU-accelerated curved-earth ocean wave surface with multiple wave components. This surface supports up to 8 superimposed wave components on a spherical Earth. Each wave is treated as a perturbation on top of Earth's spherical surface. Uses GPU-accelerated signed distance computation with geometry_id=5. Parameters are packed into a 64-element tuple for the GPU kernel. Parameters ---------- wave_params : list of GerstnerWaveParams List of wave components (max 8). Each wave has amplitude, wavelength, direction, phase, and steepness. role : SurfaceRole What happens when a ray hits (typically OPTICAL). earth_center : tuple of float, optional Center of Earth sphere. Default is (0, 0, -EARTH_RADIUS). earth_radius : float, optional Earth radius in meters. Default is EARTH_RADIUS. time : float, optional Animation time in seconds. Default is 0.0. name : str Human-readable name. material_front : MaterialField, optional Material above surface (atmosphere). material_back : MaterialField, optional Material below surface (ocean water). Examples -------- >>> from lsurf.surfaces import GPUMultiCurvedWaveSurface, SurfaceRole, GerstnerWaveParams >>> from lsurf.materials import ExponentialAtmosphere, WATER >>> >>> waves = [ ... GerstnerWaveParams(amplitude=1.5, wavelength=50.0, direction=(1.0, 0.0)), ... GerstnerWaveParams(amplitude=0.8, wavelength=30.0, direction=(0.7, 0.7)), ... GerstnerWaveParams(amplitude=0.3, wavelength=15.0, direction=(-0.5, 0.8)), ... ] >>> >>> ocean = GPUMultiCurvedWaveSurface( ... wave_params=waves, ... role=SurfaceRole.OPTICAL, ... name="multi_wave_ocean", ... material_front=ExponentialAtmosphere(), ... material_back=WATER, ... ) """ wave_params: list[GerstnerWaveParams] role: SurfaceRole earth_center: tuple[float, float, float] = (0, 0, -EARTH_RADIUS) earth_radius: float = EARTH_RADIUS time: float = 0.0 name: str = "multi_curved_wave" material_front: Any = None material_back: Any = None # GPU-capable with geometry_id=5 (multi-wave curved surface) _gpu_capable: bool = field(default=True, init=False, repr=False) _geometry_id: int = field(default=5, init=False, repr=False) # Precomputed values (set in __post_init__) _earth_center_arr: npt.NDArray = field(default=None, init=False, repr=False) _max_amplitude: float = field(default=0.0, init=False, repr=False) def __post_init__(self) -> None: if not self.wave_params: raise ValueError("At least one wave component is required") if len(self.wave_params) > MAX_WAVES: raise ValueError( f"Maximum {MAX_WAVES} wave components supported, got {len(self.wave_params)}" ) if self.earth_radius <= 0: raise ValueError("Earth radius must be positive") for wave in self.wave_params: if wave.amplitude <= 0: raise ValueError("All wave amplitudes must be positive") if wave.wavelength <= 0: raise ValueError("All wavelengths must be positive") # Store earth center as array self._earth_center_arr = np.array(self.earth_center, dtype=np.float64) # Compute max amplitude for bounding sphere calculations self._max_amplitude = sum(w.amplitude for w in self.wave_params) @property def gpu_capable(self) -> bool: """This surface supports GPU acceleration with geometry_id=5.""" return True @property def geometry_id(self) -> int: """GPU geometry type ID (multi_curved_wave = 5).""" return 5 @property def num_waves(self) -> int: """Number of wave components.""" return len(self.wave_params) @property def max_amplitude(self) -> float: """Maximum combined wave amplitude (sum of all amplitudes).""" return self._max_amplitude
[docs] def get_gpu_parameters(self) -> tuple: """ Return 64-element parameter tuple for GPU kernel. Parameter layout (geometry_id = 5): - p0-2: earth_center_x/y/z - p3: earth_radius - p4: time - p5: num_waves - p6-7: reserved For each wave i (i=0..7), starting at offset 8 + i*8: - p[offset+0]: amplitude - p[offset+1]: wave_number (k) - p[offset+2]: dir_x (normalized) - p[offset+3]: dir_y (normalized) - p[offset+4]: phase - p[offset+5]: steepness - p[offset+6-7]: reserved Returns ------- tuple of 64 floats """ params = [0.0] * 64 # Header parameters params[0] = self.earth_center[0] params[1] = self.earth_center[1] params[2] = self.earth_center[2] params[3] = self.earth_radius params[4] = self.time params[5] = float(len(self.wave_params)) # params[6-7] reserved # Wave parameters (up to 8 waves) for i, wave in enumerate(self.wave_params[:MAX_WAVES]): offset = 8 + i * 8 dir_x, dir_y = wave.direction_normalized params[offset + 0] = wave.amplitude params[offset + 1] = wave.wave_number params[offset + 2] = dir_x params[offset + 3] = dir_y params[offset + 4] = wave.phase params[offset + 5] = wave.steepness # params[offset+6-7] reserved return tuple(params)
[docs] def get_materials(self) -> tuple | None: """Return materials for Fresnel calculation.""" if self.role == SurfaceRole.OPTICAL: return (self.material_front, self.material_back) return None
def _wave_height(self, x: npt.NDArray, y: npt.NDArray) -> npt.NDArray: """Compute combined wave height at positions (x, y) in local tangent space.""" total_height = np.zeros_like(x, dtype=np.float64) for wave in self.wave_params: dir_x, dir_y = wave.direction_normalized omega = wave.angular_frequency k = wave.wave_number # Phase at each position dot = dir_x * x + dir_y * y theta = k * dot - omega * self.time + wave.phase total_height += wave.amplitude * np.cos(theta) return total_height
[docs] def signed_distance( self, positions: npt.NDArray[np.float32], ) -> npt.NDArray[np.float32]: """ Compute signed distance from positions to multi-wave curved surface. Parameters ---------- positions : ndarray, shape (N, 3) Points to compute distance for Returns ------- ndarray, shape (N,) Signed distance (positive outside, negative inside Earth+wave) """ positions = positions.astype(np.float64) # Vector from Earth center to each point to_pos = positions - self._earth_center_arr dist_from_center = np.linalg.norm(to_pos, axis=1) # Combined wave height using x,y as local tangent coordinates x = positions[:, 0] y = positions[:, 1] wave_height = self._wave_height(x, y) # Surface is at earth_radius + wave_height from center surface_radius = self.earth_radius + wave_height # Signed distance: positive outside, negative inside return (dist_from_center - surface_radius).astype(np.float32)
[docs] def intersect( self, origins: npt.NDArray[np.float32], directions: npt.NDArray[np.float32], min_distance: float = 1e-6, max_iterations: int = 200, tolerance: float = 1e-3, max_distance: float | None = None, ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.bool_]]: """ Find ray-surface intersections using ray marching. Parameters ---------- origins : ndarray, shape (N, 3) Ray origin positions. directions : ndarray, shape (N, 3) Ray direction unit vectors. min_distance : float Minimum valid intersection distance. max_iterations : int Maximum ray marching iterations. tolerance : float Convergence tolerance in meters. max_distance : float, optional Maximum search distance. Returns ------- distances : ndarray, shape (N,) Distance to intersection (inf if no hit). hit_mask : ndarray, shape (N,), dtype=bool True for rays that hit the surface. """ origins = origins.astype(np.float64) directions = directions.astype(np.float64) n_rays = len(origins) distances = np.full(n_rays, np.inf, dtype=np.float64) hit_mask = np.zeros(n_rays, dtype=bool) t = np.full(n_rays, min_distance, dtype=np.float64) active = np.ones(n_rays, dtype=bool) # Find intersection with outer sphere first outer_radius = self.earth_radius + self._max_amplitude oc = origins - self._earth_center_arr a = np.sum(directions * directions, axis=1) b = 2.0 * np.sum(directions * oc, axis=1) c_outer = np.sum(oc * oc, axis=1) - outer_radius**2 discriminant_outer = b**2 - 4 * a * c_outer has_potential_hit = discriminant_outer >= 0 active[~has_potential_hit] = False sqrt_disc_outer = np.sqrt(np.maximum(discriminant_outer, 0)) t1_outer = (-b - sqrt_disc_outer) / (2 * a + 1e-20) t_start = np.where(t1_outer > min_distance, t1_outer, min_distance) t = t_start.copy() prev_signed_dist = np.full(n_rays, np.inf) prev_t = t.copy() relaxation = 0.5 for _ in range(max_iterations): if not np.any(active): break positions = origins + t[:, np.newaxis] * directions to_pos = positions - self._earth_center_arr dist_from_center = np.linalg.norm(to_pos, axis=1) radial = to_pos / np.maximum(dist_from_center[:, np.newaxis], 1e-10) cos_angle = np.abs(np.sum(directions * radial, axis=1)) cos_angle = np.maximum(cos_angle, 0.01) x = positions[:, 0] y = positions[:, 1] wave_height = self._wave_height(x, y) surface_radius = self.earth_radius + wave_height signed_dist = dist_from_center - surface_radius converged = np.abs(signed_dist) < tolerance hit_mask[active & converged] = True distances[active & converged] = t[active & converged] active[converged] = False # Bisection for sign changes crossed = ( active & (signed_dist * prev_signed_dist < 0) & np.isfinite(prev_signed_dist) ) if np.any(crossed): t_low = np.where(prev_signed_dist > 0, prev_t, t) t_high = np.where(prev_signed_dist > 0, t, prev_t) for _ in range(15): t_mid = (t_low + t_high) / 2 pos_mid = origins + t_mid[:, np.newaxis] * directions to_pos_mid = pos_mid - self._earth_center_arr dist_mid = np.linalg.norm(to_pos_mid, axis=1) wh_mid = self._wave_height(pos_mid[:, 0], pos_mid[:, 1]) sd_mid = dist_mid - (self.earth_radius + wh_mid) above = sd_mid > 0 t_low = np.where(crossed & above, t_mid, t_low) t_high = np.where(crossed & ~above, t_mid, t_high) t[crossed] = (t_low[crossed] + t_high[crossed]) / 2 positions = origins + t[:, np.newaxis] * directions to_pos = positions - self._earth_center_arr dist_from_center = np.linalg.norm(to_pos, axis=1) wave_height = self._wave_height(positions[:, 0], positions[:, 1]) signed_dist = dist_from_center - (self.earth_radius + wave_height) converged = np.abs(signed_dist) < tolerance hit_mask[active & converged] = True distances[active & converged] = t[active & converged] active[converged] = False too_far = signed_dist < -self._max_amplitude - 10 active[too_far] = False if max_distance is not None: exceeded_max = t > max_distance active[exceeded_max] = False prev_signed_dist = signed_dist.copy() prev_t = t.copy() step = signed_dist / cos_angle * relaxation step = np.clip(step, -self._max_amplitude * 2, self._max_amplitude * 2) t[active] += step[active] t = np.maximum(t, 0) return distances.astype(np.float32), hit_mask
[docs] def normal_at( self, positions: npt.NDArray[np.float32], incoming_directions: npt.NDArray[np.float32] | None = None, ) -> npt.NDArray[np.float32]: """ Compute surface normal at given positions. Parameters ---------- positions : ndarray, shape (N, 3) Points on the surface. incoming_directions : ndarray, shape (N, 3), optional Incoming ray directions. Returns ------- normals : ndarray, shape (N, 3) Unit normal vectors. """ positions = positions.astype(np.float64) # Get local tangent basis to_pos = positions - self._earth_center_arr dist = np.linalg.norm(to_pos, axis=1, keepdims=True) dist = np.maximum(dist, 1e-10) radial = to_pos / dist # Compute tangent vectors (simplified) global_x = np.array([1.0, 0.0, 0.0], dtype=np.float64) dot_x = np.sum(radial * global_x, axis=1, keepdims=True) tangent_x = global_x - dot_x * radial tangent_x_norm = np.linalg.norm(tangent_x, axis=1, keepdims=True) tangent_x = tangent_x / np.maximum(tangent_x_norm, 1e-10) global_y = np.array([0.0, 1.0, 0.0], dtype=np.float64) dot_y = np.sum(radial * global_y, axis=1, keepdims=True) tangent_y = global_y - dot_y * radial tangent_y_norm = np.linalg.norm(tangent_y, axis=1, keepdims=True) tangent_y = tangent_y / np.maximum(tangent_y_norm, 1e-10) # Compute combined wave normal in local coordinates x = positions[:, 0] y = positions[:, 1] # Sum wave contributions to gradient nx_local = np.zeros(len(positions), dtype=np.float64) ny_local = np.zeros(len(positions), dtype=np.float64) nz_local = np.ones(len(positions), dtype=np.float64) for wave in self.wave_params: dir_x, dir_y = wave.direction_normalized omega = wave.angular_frequency k = wave.wave_number dot_val = dir_x * x + dir_y * y theta = k * dot_val - omega * self.time + wave.phase sin_theta = np.sin(theta) WA = k * wave.amplitude # Accumulate gradient contributions nx_local += dir_x * WA * sin_theta ny_local += dir_y * WA * sin_theta nz_local -= 0.5 * WA * sin_theta # Simplified steepness effect # Normalize norm = np.sqrt(nx_local**2 + ny_local**2 + nz_local**2) norm = np.maximum(norm, 1e-10) nx_local /= norm ny_local /= norm nz_local /= norm # Transform to world coordinates normals = ( nx_local[:, np.newaxis] * tangent_x + ny_local[:, np.newaxis] * tangent_y + nz_local[:, np.newaxis] * radial ) if incoming_directions is not None: dot_products = np.sum(normals * incoming_directions, axis=1) flip_mask = dot_products > 0 normals[flip_mask] *= -1 return normals.astype(np.float32)
[docs] def set_time(self, time: float) -> None: """Update the wave animation time.""" self.time = time
# Register class with registry register_surface_type("gpu_multi_curved_wave", "gpu", 5, GPUMultiCurvedWaveSurface)