# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Multi-Wave Curved Surface (GPU-Capable)
GPU-accelerated ocean wave surface on a curved (spherical) Earth with multiple
wave components. Supports up to 8 superimposed wave components for realistic
ocean surface modeling.
Uses geometry_id=5 for GPU kernel dispatch. Parameters are packed into a
64-element tuple with the following layout:
- p0-2: earth_center_x/y/z
- p3: earth_radius
- p4: time
- p5: num_waves
- p6-7: reserved
- For each wave i (i=0..7), starting at offset 8 + i*8:
- amplitude, wave_number, dir_x, dir_y, phase, steepness, reserved, reserved
"""
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import numpy.typing as npt
from ..protocol import Surface, SurfaceRole
from ..registry import register_surface_type
from ..cpu.wave_params import GerstnerWaveParams
# Earth parameters
EARTH_RADIUS = 6.371e6 # Earth's mean radius in meters
GRAVITY = 9.81 # Gravity for deep water dispersion
# Maximum number of wave components
MAX_WAVES = 8
[docs]
@dataclass
class GPUMultiCurvedWaveSurface(Surface):
"""
GPU-accelerated curved-earth ocean wave surface with multiple wave components.
This surface supports up to 8 superimposed wave components on a spherical Earth.
Each wave is treated as a perturbation on top of Earth's spherical surface.
Uses GPU-accelerated signed distance computation with geometry_id=5.
Parameters are packed into a 64-element tuple for the GPU kernel.
Parameters
----------
wave_params : list of GerstnerWaveParams
List of wave components (max 8). Each wave has amplitude, wavelength,
direction, phase, and steepness.
role : SurfaceRole
What happens when a ray hits (typically OPTICAL).
earth_center : tuple of float, optional
Center of Earth sphere. Default is (0, 0, -EARTH_RADIUS).
earth_radius : float, optional
Earth radius in meters. Default is EARTH_RADIUS.
time : float, optional
Animation time in seconds. Default is 0.0.
name : str
Human-readable name.
material_front : MaterialField, optional
Material above surface (atmosphere).
material_back : MaterialField, optional
Material below surface (ocean water).
Examples
--------
>>> from lsurf.surfaces import GPUMultiCurvedWaveSurface, SurfaceRole, GerstnerWaveParams
>>> from lsurf.materials import ExponentialAtmosphere, WATER
>>>
>>> waves = [
... GerstnerWaveParams(amplitude=1.5, wavelength=50.0, direction=(1.0, 0.0)),
... GerstnerWaveParams(amplitude=0.8, wavelength=30.0, direction=(0.7, 0.7)),
... GerstnerWaveParams(amplitude=0.3, wavelength=15.0, direction=(-0.5, 0.8)),
... ]
>>>
>>> ocean = GPUMultiCurvedWaveSurface(
... wave_params=waves,
... role=SurfaceRole.OPTICAL,
... name="multi_wave_ocean",
... material_front=ExponentialAtmosphere(),
... material_back=WATER,
... )
"""
wave_params: list[GerstnerWaveParams]
role: SurfaceRole
earth_center: tuple[float, float, float] = (0, 0, -EARTH_RADIUS)
earth_radius: float = EARTH_RADIUS
time: float = 0.0
name: str = "multi_curved_wave"
material_front: Any = None
material_back: Any = None
# GPU-capable with geometry_id=5 (multi-wave curved surface)
_gpu_capable: bool = field(default=True, init=False, repr=False)
_geometry_id: int = field(default=5, init=False, repr=False)
# Precomputed values (set in __post_init__)
_earth_center_arr: npt.NDArray = field(default=None, init=False, repr=False)
_max_amplitude: float = field(default=0.0, init=False, repr=False)
def __post_init__(self) -> None:
if not self.wave_params:
raise ValueError("At least one wave component is required")
if len(self.wave_params) > MAX_WAVES:
raise ValueError(
f"Maximum {MAX_WAVES} wave components supported, got {len(self.wave_params)}"
)
if self.earth_radius <= 0:
raise ValueError("Earth radius must be positive")
for wave in self.wave_params:
if wave.amplitude <= 0:
raise ValueError("All wave amplitudes must be positive")
if wave.wavelength <= 0:
raise ValueError("All wavelengths must be positive")
# Store earth center as array
self._earth_center_arr = np.array(self.earth_center, dtype=np.float64)
# Compute max amplitude for bounding sphere calculations
self._max_amplitude = sum(w.amplitude for w in self.wave_params)
@property
def gpu_capable(self) -> bool:
"""This surface supports GPU acceleration with geometry_id=5."""
return True
@property
def geometry_id(self) -> int:
"""GPU geometry type ID (multi_curved_wave = 5)."""
return 5
@property
def num_waves(self) -> int:
"""Number of wave components."""
return len(self.wave_params)
@property
def max_amplitude(self) -> float:
"""Maximum combined wave amplitude (sum of all amplitudes)."""
return self._max_amplitude
[docs]
def get_gpu_parameters(self) -> tuple:
"""
Return 64-element parameter tuple for GPU kernel.
Parameter layout (geometry_id = 5):
- p0-2: earth_center_x/y/z
- p3: earth_radius
- p4: time
- p5: num_waves
- p6-7: reserved
For each wave i (i=0..7), starting at offset 8 + i*8:
- p[offset+0]: amplitude
- p[offset+1]: wave_number (k)
- p[offset+2]: dir_x (normalized)
- p[offset+3]: dir_y (normalized)
- p[offset+4]: phase
- p[offset+5]: steepness
- p[offset+6-7]: reserved
Returns
-------
tuple of 64 floats
"""
params = [0.0] * 64
# Header parameters
params[0] = self.earth_center[0]
params[1] = self.earth_center[1]
params[2] = self.earth_center[2]
params[3] = self.earth_radius
params[4] = self.time
params[5] = float(len(self.wave_params))
# params[6-7] reserved
# Wave parameters (up to 8 waves)
for i, wave in enumerate(self.wave_params[:MAX_WAVES]):
offset = 8 + i * 8
dir_x, dir_y = wave.direction_normalized
params[offset + 0] = wave.amplitude
params[offset + 1] = wave.wave_number
params[offset + 2] = dir_x
params[offset + 3] = dir_y
params[offset + 4] = wave.phase
params[offset + 5] = wave.steepness
# params[offset+6-7] reserved
return tuple(params)
[docs]
def get_materials(self) -> tuple | None:
"""Return materials for Fresnel calculation."""
if self.role == SurfaceRole.OPTICAL:
return (self.material_front, self.material_back)
return None
def _wave_height(self, x: npt.NDArray, y: npt.NDArray) -> npt.NDArray:
"""Compute combined wave height at positions (x, y) in local tangent space."""
total_height = np.zeros_like(x, dtype=np.float64)
for wave in self.wave_params:
dir_x, dir_y = wave.direction_normalized
omega = wave.angular_frequency
k = wave.wave_number
# Phase at each position
dot = dir_x * x + dir_y * y
theta = k * dot - omega * self.time + wave.phase
total_height += wave.amplitude * np.cos(theta)
return total_height
[docs]
def signed_distance(
self,
positions: npt.NDArray[np.float32],
) -> npt.NDArray[np.float32]:
"""
Compute signed distance from positions to multi-wave curved surface.
Parameters
----------
positions : ndarray, shape (N, 3)
Points to compute distance for
Returns
-------
ndarray, shape (N,)
Signed distance (positive outside, negative inside Earth+wave)
"""
positions = positions.astype(np.float64)
# Vector from Earth center to each point
to_pos = positions - self._earth_center_arr
dist_from_center = np.linalg.norm(to_pos, axis=1)
# Combined wave height using x,y as local tangent coordinates
x = positions[:, 0]
y = positions[:, 1]
wave_height = self._wave_height(x, y)
# Surface is at earth_radius + wave_height from center
surface_radius = self.earth_radius + wave_height
# Signed distance: positive outside, negative inside
return (dist_from_center - surface_radius).astype(np.float32)
[docs]
def intersect(
self,
origins: npt.NDArray[np.float32],
directions: npt.NDArray[np.float32],
min_distance: float = 1e-6,
max_iterations: int = 200,
tolerance: float = 1e-3,
max_distance: float | None = None,
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.bool_]]:
"""
Find ray-surface intersections using ray marching.
Parameters
----------
origins : ndarray, shape (N, 3)
Ray origin positions.
directions : ndarray, shape (N, 3)
Ray direction unit vectors.
min_distance : float
Minimum valid intersection distance.
max_iterations : int
Maximum ray marching iterations.
tolerance : float
Convergence tolerance in meters.
max_distance : float, optional
Maximum search distance.
Returns
-------
distances : ndarray, shape (N,)
Distance to intersection (inf if no hit).
hit_mask : ndarray, shape (N,), dtype=bool
True for rays that hit the surface.
"""
origins = origins.astype(np.float64)
directions = directions.astype(np.float64)
n_rays = len(origins)
distances = np.full(n_rays, np.inf, dtype=np.float64)
hit_mask = np.zeros(n_rays, dtype=bool)
t = np.full(n_rays, min_distance, dtype=np.float64)
active = np.ones(n_rays, dtype=bool)
# Find intersection with outer sphere first
outer_radius = self.earth_radius + self._max_amplitude
oc = origins - self._earth_center_arr
a = np.sum(directions * directions, axis=1)
b = 2.0 * np.sum(directions * oc, axis=1)
c_outer = np.sum(oc * oc, axis=1) - outer_radius**2
discriminant_outer = b**2 - 4 * a * c_outer
has_potential_hit = discriminant_outer >= 0
active[~has_potential_hit] = False
sqrt_disc_outer = np.sqrt(np.maximum(discriminant_outer, 0))
t1_outer = (-b - sqrt_disc_outer) / (2 * a + 1e-20)
t_start = np.where(t1_outer > min_distance, t1_outer, min_distance)
t = t_start.copy()
prev_signed_dist = np.full(n_rays, np.inf)
prev_t = t.copy()
relaxation = 0.5
for _ in range(max_iterations):
if not np.any(active):
break
positions = origins + t[:, np.newaxis] * directions
to_pos = positions - self._earth_center_arr
dist_from_center = np.linalg.norm(to_pos, axis=1)
radial = to_pos / np.maximum(dist_from_center[:, np.newaxis], 1e-10)
cos_angle = np.abs(np.sum(directions * radial, axis=1))
cos_angle = np.maximum(cos_angle, 0.01)
x = positions[:, 0]
y = positions[:, 1]
wave_height = self._wave_height(x, y)
surface_radius = self.earth_radius + wave_height
signed_dist = dist_from_center - surface_radius
converged = np.abs(signed_dist) < tolerance
hit_mask[active & converged] = True
distances[active & converged] = t[active & converged]
active[converged] = False
# Bisection for sign changes
crossed = (
active
& (signed_dist * prev_signed_dist < 0)
& np.isfinite(prev_signed_dist)
)
if np.any(crossed):
t_low = np.where(prev_signed_dist > 0, prev_t, t)
t_high = np.where(prev_signed_dist > 0, t, prev_t)
for _ in range(15):
t_mid = (t_low + t_high) / 2
pos_mid = origins + t_mid[:, np.newaxis] * directions
to_pos_mid = pos_mid - self._earth_center_arr
dist_mid = np.linalg.norm(to_pos_mid, axis=1)
wh_mid = self._wave_height(pos_mid[:, 0], pos_mid[:, 1])
sd_mid = dist_mid - (self.earth_radius + wh_mid)
above = sd_mid > 0
t_low = np.where(crossed & above, t_mid, t_low)
t_high = np.where(crossed & ~above, t_mid, t_high)
t[crossed] = (t_low[crossed] + t_high[crossed]) / 2
positions = origins + t[:, np.newaxis] * directions
to_pos = positions - self._earth_center_arr
dist_from_center = np.linalg.norm(to_pos, axis=1)
wave_height = self._wave_height(positions[:, 0], positions[:, 1])
signed_dist = dist_from_center - (self.earth_radius + wave_height)
converged = np.abs(signed_dist) < tolerance
hit_mask[active & converged] = True
distances[active & converged] = t[active & converged]
active[converged] = False
too_far = signed_dist < -self._max_amplitude - 10
active[too_far] = False
if max_distance is not None:
exceeded_max = t > max_distance
active[exceeded_max] = False
prev_signed_dist = signed_dist.copy()
prev_t = t.copy()
step = signed_dist / cos_angle * relaxation
step = np.clip(step, -self._max_amplitude * 2, self._max_amplitude * 2)
t[active] += step[active]
t = np.maximum(t, 0)
return distances.astype(np.float32), hit_mask
[docs]
def normal_at(
self,
positions: npt.NDArray[np.float32],
incoming_directions: npt.NDArray[np.float32] | None = None,
) -> npt.NDArray[np.float32]:
"""
Compute surface normal at given positions.
Parameters
----------
positions : ndarray, shape (N, 3)
Points on the surface.
incoming_directions : ndarray, shape (N, 3), optional
Incoming ray directions.
Returns
-------
normals : ndarray, shape (N, 3)
Unit normal vectors.
"""
positions = positions.astype(np.float64)
# Get local tangent basis
to_pos = positions - self._earth_center_arr
dist = np.linalg.norm(to_pos, axis=1, keepdims=True)
dist = np.maximum(dist, 1e-10)
radial = to_pos / dist
# Compute tangent vectors (simplified)
global_x = np.array([1.0, 0.0, 0.0], dtype=np.float64)
dot_x = np.sum(radial * global_x, axis=1, keepdims=True)
tangent_x = global_x - dot_x * radial
tangent_x_norm = np.linalg.norm(tangent_x, axis=1, keepdims=True)
tangent_x = tangent_x / np.maximum(tangent_x_norm, 1e-10)
global_y = np.array([0.0, 1.0, 0.0], dtype=np.float64)
dot_y = np.sum(radial * global_y, axis=1, keepdims=True)
tangent_y = global_y - dot_y * radial
tangent_y_norm = np.linalg.norm(tangent_y, axis=1, keepdims=True)
tangent_y = tangent_y / np.maximum(tangent_y_norm, 1e-10)
# Compute combined wave normal in local coordinates
x = positions[:, 0]
y = positions[:, 1]
# Sum wave contributions to gradient
nx_local = np.zeros(len(positions), dtype=np.float64)
ny_local = np.zeros(len(positions), dtype=np.float64)
nz_local = np.ones(len(positions), dtype=np.float64)
for wave in self.wave_params:
dir_x, dir_y = wave.direction_normalized
omega = wave.angular_frequency
k = wave.wave_number
dot_val = dir_x * x + dir_y * y
theta = k * dot_val - omega * self.time + wave.phase
sin_theta = np.sin(theta)
WA = k * wave.amplitude
# Accumulate gradient contributions
nx_local += dir_x * WA * sin_theta
ny_local += dir_y * WA * sin_theta
nz_local -= 0.5 * WA * sin_theta # Simplified steepness effect
# Normalize
norm = np.sqrt(nx_local**2 + ny_local**2 + nz_local**2)
norm = np.maximum(norm, 1e-10)
nx_local /= norm
ny_local /= norm
nz_local /= norm
# Transform to world coordinates
normals = (
nx_local[:, np.newaxis] * tangent_x
+ ny_local[:, np.newaxis] * tangent_y
+ nz_local[:, np.newaxis] * radial
)
if incoming_directions is not None:
dot_products = np.sum(normals * incoming_directions, axis=1)
flip_mask = dot_products > 0
normals[flip_mask] *= -1
return normals.astype(np.float32)
[docs]
def set_time(self, time: float) -> None:
"""Update the wave animation time."""
self.time = time
# Register class with registry
register_surface_type("gpu_multi_curved_wave", "gpu", 5, GPUMultiCurvedWaveSurface)