Source code for lsurf.propagation.kernels.device_functions

# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
#      * Redistributions of source code must retain the above copyright notice,
#      this list of conditions and the following disclaimer.
#
#      * Redistributions in binary form must reproduce the above copyright
#      notice, this list of conditions and the following disclaimer in the
#      documentation and/or other materials provided with the distribution.
#
#      * Neither the name of the copyright holder nor the names of its
#      contributors may be used to endorse or promote products derived from this
#      software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

"""
Shared GPU Device Functions

This module consolidates device functions used across GPU kernels:
- Integration functions (Euler, RK4 steps)
- Dispersion models (Sellmeier, Cauchy equations)
- Common utilities (normalization, interpolation)

All functions are designed to be used with @cuda.jit(device=True).
"""

import math
from collections.abc import Callable

import numpy as np
import numpy.typing as npt

# GPU support is optional
try:
    from numba import cuda

    HAS_CUDA = True
except ImportError:

    class _FakeCuda:
        """Fake cuda module for when numba is not installed."""

        @staticmethod
        def jit(*args, **kwargs):
            """Return a no-op decorator."""

            def decorator(func):
                return func

            if args and callable(args[0]):
                return args[0]
            return decorator

        @staticmethod
        def is_available():
            return False

        @staticmethod
        def grid(n):
            return 0

        @staticmethod
        def synchronize():
            pass

    cuda = _FakeCuda()  # type: ignore[assignment]
    HAS_CUDA = False


# =============================================================================
# CUDA Device Functions for Integration
# =============================================================================


[docs] @cuda.jit(device=True) def device_adaptive_step_size( gradient_magnitude: float, refractive_index: float, wavelength: float, min_step: float, max_step: float, ) -> float: """ Compute adaptive step size based on local gradient on GPU device. Parameters ---------- gradient_magnitude : float |∇n| at current position refractive_index : float n at current position wavelength : float Wavelength in meters min_step : float Minimum allowed step size max_step : float Maximum allowed step size Returns ------- float Recommended step size """ # Radius of curvature if gradient_magnitude > 1e-12: radius_curvature = refractive_index / gradient_magnitude step_from_curvature = radius_curvature / 10.0 else: step_from_curvature = max_step # Wavelength in medium wavelength_medium = wavelength / refractive_index # Take minimum step = min(step_from_curvature, wavelength_medium) # Clamp to limits if step < min_step: step = min_step if step > max_step: step = max_step return step
[docs] @cuda.jit(device=True) def device_euler_step( x: float, y: float, z: float, dx: float, dy: float, dz: float, n: float, grad_x: float, grad_y: float, grad_z: float, step_size: float, ) -> tuple[float, float, float, float, float, float]: """ Single Euler integration step for ray equation on GPU. The ray equation in gradient media is: dr/ds = d̂ dd̂/ds = (∇n - (d̂·∇n)d̂) / n Parameters ---------- x, y, z : float Current position dx, dy, dz : float Current direction (unit vector) n : float Refractive index at current position grad_x, grad_y, grad_z : float Gradient of n at current position step_size : float Step size ds Returns ------- tuple (new_x, new_y, new_z, new_dx, new_dy, new_dz) """ # Curvature: κ = (∇n - (d̂·∇n)d̂) / n dot = dx * grad_x + dy * grad_y + dz * grad_z kappa_x = (grad_x - dot * dx) / n kappa_y = (grad_y - dot * dy) / n kappa_z = (grad_z - dot * dz) / n # Update position: r_new = r + d̂ * ds new_x = x + dx * step_size new_y = y + dy * step_size new_z = z + dz * step_size # Update direction: d̂_new = d̂ + κ * ds new_dx = dx + kappa_x * step_size new_dy = dy + kappa_y * step_size new_dz = dz + kappa_z * step_size # Renormalize norm = math.sqrt(new_dx * new_dx + new_dy * new_dy + new_dz * new_dz) if norm > 1e-12: new_dx /= norm new_dy /= norm new_dz /= norm return new_x, new_y, new_z, new_dx, new_dy, new_dz
[docs] @cuda.jit(device=True) def device_rk4_step( x: float, y: float, z: float, dx: float, dy: float, dz: float, step_size: float, n_func, grad_func, *material_params, ) -> tuple[float, float, float, float, float, float, float]: """ RK4 integration step for ray equation on GPU. This is a generic RK4 step that takes material evaluation functions as parameters. For material-specific implementations with better performance, materials should define their own device functions. Parameters ---------- x, y, z : float Current position dx, dy, dz : float Current direction (unit vector) step_size : float Integration step size n_func : device function Function to evaluate n(x, y, z, *material_params) grad_func : device function Function to evaluate ∇n(x, y, z, *material_params) -> (gx, gy, gz) material_params : tuple Material-specific parameters passed to n_func and grad_func Returns ------- tuple of float (new_x, new_y, new_z, new_dx, new_dy, new_dz, n_avg) Notes ----- For best performance, materials should implement their own specialized RK4 device functions that inline the n and gradient calculations. This generic version has function call overhead. """ h = step_size h2 = h / 2.0 def _normalize(vx, vy, vz): norm = math.sqrt(vx * vx + vy * vy + vz * vz) if norm > 1e-12: return vx / norm, vy / norm, vz / norm return vx, vy, vz def _curvature(px, py, pz, dirx, diry, dirz): n = n_func(px, py, pz, *material_params) gx, gy, gz = grad_func(px, py, pz, *material_params) dot = dirx * gx + diry * gy + dirz * gz kx = (gx - dot * dirx) / n ky = (gy - dot * diry) / n kz = (gz - dot * dirz) / n return n, kx, ky, kz # k1 at current point n0, k1_dx, k1_dy, k1_dz = _curvature(x, y, z, dx, dy, dz) k1_rx, k1_ry, k1_rz = dx, dy, dz # Intermediate point 1 x1 = x + h2 * k1_rx y1 = y + h2 * k1_ry z1 = z + h2 * k1_rz dx1, dy1, dz1 = _normalize(dx + h2 * k1_dx, dy + h2 * k1_dy, dz + h2 * k1_dz) # k2 n1, k2_dx, k2_dy, k2_dz = _curvature(x1, y1, z1, dx1, dy1, dz1) k2_rx, k2_ry, k2_rz = dx1, dy1, dz1 # Intermediate point 2 x2 = x + h2 * k2_rx y2 = y + h2 * k2_ry z2 = z + h2 * k2_rz dx2, dy2, dz2 = _normalize(dx + h2 * k2_dx, dy + h2 * k2_dy, dz + h2 * k2_dz) # k3 n2, k3_dx, k3_dy, k3_dz = _curvature(x2, y2, z2, dx2, dy2, dz2) k3_rx, k3_ry, k3_rz = dx2, dy2, dz2 # End point x3 = x + h * k3_rx y3 = y + h * k3_ry z3 = z + h * k3_rz dx3, dy3, dz3 = _normalize(dx + h * k3_dx, dy + h * k3_dy, dz + h * k3_dz) # k4 n3, k4_dx, k4_dy, k4_dz = _curvature(x3, y3, z3, dx3, dy3, dz3) k4_rx, k4_ry, k4_rz = dx3, dy3, dz3 # Final RK4 combination new_x = x + (h / 6.0) * (k1_rx + 2 * k2_rx + 2 * k3_rx + k4_rx) new_y = y + (h / 6.0) * (k1_ry + 2 * k2_ry + 2 * k3_ry + k4_ry) new_z = z + (h / 6.0) * (k1_rz + 2 * k2_rz + 2 * k3_rz + k4_rz) new_dx = dx + (h / 6.0) * (k1_dx + 2 * k2_dx + 2 * k3_dx + k4_dx) new_dy = dy + (h / 6.0) * (k1_dy + 2 * k2_dy + 2 * k3_dy + k4_dy) new_dz = dz + (h / 6.0) * (k1_dz + 2 * k2_dz + 2 * k3_dz + k4_dz) new_dx, new_dy, new_dz = _normalize(new_dx, new_dy, new_dz) # Simpson's rule for average n n_avg = (n0 + 4 * n1 + n2) / 6.0 return new_x, new_y, new_z, new_dx, new_dy, new_dz, n_avg
# ============================================================================= # CUDA Device Functions for Dispersion Models # =============================================================================
[docs] @cuda.jit(device=True) def device_sellmeier_equation( wl_um: float, B1: float, B2: float, B3: float, C1: float, C2: float, C3: float, ) -> float: """ GPU-compatible Sellmeier equation. Computes refractive index from wavelength and Sellmeier coefficients. Parameters ---------- wl_um : float Wavelength in micrometers. B1, B2, B3 : float Sellmeier B coefficients. C1, C2, C3 : float Sellmeier C coefficients in μm². Returns ------- n : float Refractive index. """ wl2 = wl_um * wl_um n2_minus_1 = B1 * wl2 / (wl2 - C1) + B2 * wl2 / (wl2 - C2) + B3 * wl2 / (wl2 - C3) return math.sqrt(1.0 + n2_minus_1)
[docs] @cuda.jit(device=True) def device_cauchy_equation( wl_um: float, A: float, B: float, C: float, ) -> float: """ GPU-compatible Cauchy equation. Computes refractive index from wavelength and Cauchy coefficients. Parameters ---------- wl_um : float Wavelength in micrometers. A : float Constant term. B : float First-order dispersion coefficient in μm². C : float Second-order dispersion coefficient in μm⁴. Returns ------- n : float Refractive index. """ wl2 = wl_um * wl_um return A + B / wl2 + C / (wl2 * wl2)
# ============================================================================= # Pure Python Step Functions (for CPU propagation) # =============================================================================
[docs] def euler_step( x: float, y: float, z: float, dx: float, dy: float, dz: float, n: float, grad_x: float, grad_y: float, grad_z: float, step_size: float, ) -> tuple[float, float, float, float, float, float]: """ Single Euler integration step for ray equation. Parameters ---------- x, y, z : float Current position dx, dy, dz : float Current direction (unit vector) n : float Refractive index at current position grad_x, grad_y, grad_z : float Gradient of n at current position step_size : float Step size ds Returns ------- tuple (new_x, new_y, new_z, new_dx, new_dy, new_dz) """ # Curvature: κ = (∇n - (d̂·∇n)d̂) / n dot = dx * grad_x + dy * grad_y + dz * grad_z kappa_x = (grad_x - dot * dx) / n kappa_y = (grad_y - dot * dy) / n kappa_z = (grad_z - dot * dz) / n # Update position: r_new = r + d̂ * ds new_x = x + dx * step_size new_y = y + dy * step_size new_z = z + dz * step_size # Update direction: d̂_new = d̂ + κ * ds new_dx = dx + kappa_x * step_size new_dy = dy + kappa_y * step_size new_dz = dz + kappa_z * step_size # Renormalize norm = math.sqrt(new_dx**2 + new_dy**2 + new_dz**2) if norm > 1e-12: new_dx /= norm new_dy /= norm new_dz /= norm return new_x, new_y, new_z, new_dx, new_dy, new_dz
[docs] def rk4_step( x: float, y: float, z: float, dx: float, dy: float, dz: float, n_func: Callable[[float, float, float, float], float], grad_func: Callable[[float, float, float, float], tuple[float, float, float]], wavelength: float, step_size: float, ) -> tuple[float, float, float, float, float, float, float]: """ RK4 integration step for ray equation in gradient medium. Solves the coupled ODEs: dr/ds = d̂ dd̂/ds = (∇n - (d̂·∇n)d̂) / n Parameters ---------- x, y, z : float Current position dx, dy, dz : float Current direction (unit vector) n_func : callable Function to evaluate n(x, y, z, wavelength) grad_func : callable Function to evaluate ∇n(x, y, z, wavelength) -> (gx, gy, gz) wavelength : float Wavelength step_size : float Integration step size Returns ------- tuple of float (new_x, new_y, new_z, new_dx, new_dy, new_dz, optical_path_increment) """ h = step_size h2 = h / 2.0 def _normalize(vx, vy, vz): norm = math.sqrt(vx**2 + vy**2 + vz**2) if norm > 1e-12: return vx / norm, vy / norm, vz / norm return vx, vy, vz def _curvature(px, py, pz, dirx, diry, dirz): n = n_func(px, py, pz, wavelength) gx, gy, gz = grad_func(px, py, pz, wavelength) dot = dirx * gx + diry * gy + dirz * gz kx = (gx - dot * dirx) / n ky = (gy - dot * diry) / n kz = (gz - dot * dirz) / n return n, kx, ky, kz # k1 evaluation at current point n0, k1_dx, k1_dy, k1_dz = _curvature(x, y, z, dx, dy, dz) k1_rx, k1_ry, k1_rz = dx, dy, dz # Intermediate point 1 x1 = x + h2 * k1_rx y1 = y + h2 * k1_ry z1 = z + h2 * k1_rz dx1, dy1, dz1 = _normalize(dx + h2 * k1_dx, dy + h2 * k1_dy, dz + h2 * k1_dz) # k2 evaluation n1, k2_dx, k2_dy, k2_dz = _curvature(x1, y1, z1, dx1, dy1, dz1) k2_rx, k2_ry, k2_rz = dx1, dy1, dz1 # Intermediate point 2 x2 = x + h2 * k2_rx y2 = y + h2 * k2_ry z2 = z + h2 * k2_rz dx2, dy2, dz2 = _normalize(dx + h2 * k2_dx, dy + h2 * k2_dy, dz + h2 * k2_dz) # k3 evaluation n2, k3_dx, k3_dy, k3_dz = _curvature(x2, y2, z2, dx2, dy2, dz2) k3_rx, k3_ry, k3_rz = dx2, dy2, dz2 # End point x3 = x + h * k3_rx y3 = y + h * k3_ry z3 = z + h * k3_rz dx3, dy3, dz3 = _normalize(dx + h * k3_dx, dy + h * k3_dy, dz + h * k3_dz) # k4 evaluation n3, k4_dx, k4_dy, k4_dz = _curvature(x3, y3, z3, dx3, dy3, dz3) k4_rx, k4_ry, k4_rz = dx3, dy3, dz3 # Final RK4 combination new_x = x + (h / 6.0) * (k1_rx + 2 * k2_rx + 2 * k3_rx + k4_rx) new_y = y + (h / 6.0) * (k1_ry + 2 * k2_ry + 2 * k3_ry + k4_ry) new_z = z + (h / 6.0) * (k1_rz + 2 * k2_rz + 2 * k3_rz + k4_rz) new_dx = dx + (h / 6.0) * (k1_dx + 2 * k2_dx + 2 * k3_dx + k4_dx) new_dy = dy + (h / 6.0) * (k1_dy + 2 * k2_dy + 2 * k3_dy + k4_dy) new_dz = dz + (h / 6.0) * (k1_dz + 2 * k2_dz + 2 * k3_dz + k4_dz) new_dx, new_dy, new_dz = _normalize(new_dx, new_dy, new_dz) # Optical path: Simpson's rule approximation n_avg = (n0 + 4 * n1 + n2) / 6.0 optical_increment = n_avg * h return new_x, new_y, new_z, new_dx, new_dy, new_dz, optical_increment
[docs] def compute_adaptive_step_size( gradient_magnitude: float, refractive_index: float, wavelength: float, min_step: float, max_step: float, ) -> float: """ Compute adaptive step size based on local gradient. Parameters ---------- gradient_magnitude : float |∇n| at current position refractive_index : float n at current position wavelength : float Wavelength in meters min_step : float Minimum allowed step size max_step : float Maximum allowed step size Returns ------- float Recommended step size Notes ----- Step size is chosen to resolve curvature: Δs ≤ R_c / 10 where R_c = n / |∇n| is the radius of curvature. """ # Radius of curvature if gradient_magnitude > 1e-12: radius_curvature = refractive_index / gradient_magnitude step_from_curvature = radius_curvature / 10.0 else: step_from_curvature = max_step # Wavelength in medium wavelength_medium = wavelength / refractive_index # Take minimum step = min(step_from_curvature, wavelength_medium) # Clamp to limits return max(min_step, min(step, max_step))
# ============================================================================= # Vectorized NumPy Operations (for batch CPU propagation) # =============================================================================
[docs] def normalize_directions( directions: npt.NDArray[np.float32], ) -> npt.NDArray[np.float32]: """Normalize direction vectors to unit length.""" norms = np.linalg.norm(directions, axis=1, keepdims=True) norms = np.maximum(norms, 1e-12) # Avoid division by zero return directions / norms
[docs] def euler_step_batch( positions: npt.NDArray[np.float32], directions: npt.NDArray[np.float32], active_mask: npt.NDArray[np.bool_], n: npt.NDArray[np.float32], grad_n: npt.NDArray[np.float32], # (N, 3) step_size: float, ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]: """ Vectorized Euler step for batch of rays. Parameters ---------- positions : ndarray of shape (N, 3) Ray positions directions : ndarray of shape (N, 3) Ray directions (unit vectors) active_mask : ndarray of shape (N,) Boolean mask for active rays n : ndarray of shape (N,) Refractive index at each position grad_n : ndarray of shape (N, 3) Gradient of n at each position step_size : float Step size Returns ------- new_positions : ndarray of shape (N, 3) new_directions : ndarray of shape (N, 3) """ # Only process active rays new_positions = positions.copy() new_directions = directions.copy() # d̂ · ∇n dot = np.sum(directions[active_mask] * grad_n[active_mask], axis=1, keepdims=True) # κ = (∇n - (d̂·∇n)d̂) / n n_active = n[active_mask][:, np.newaxis] kappa = (grad_n[active_mask] - dot * directions[active_mask]) / n_active # Update position and direction new_positions[active_mask] += directions[active_mask] * step_size new_directions[active_mask] += kappa * step_size # Renormalize new_directions[active_mask] = normalize_directions(new_directions[active_mask]) return new_positions, new_directions
[docs] def rk4_step_batch( positions: npt.NDArray[np.float32], directions: npt.NDArray[np.float32], active_mask: npt.NDArray[np.bool_], material, # MaterialFieldProtocol step_size: float, wavelength: float, ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]: """ Vectorized RK4 step for batch of rays. Parameters ---------- positions : ndarray of shape (N, 3) Ray positions directions : ndarray of shape (N, 3) Ray directions (unit vectors) active_mask : ndarray of shape (N,) Boolean mask for active rays material : MaterialFieldProtocol Material providing n and ∇n evaluation step_size : float Step size wavelength : float Wavelength Returns ------- new_positions : ndarray of shape (N, 3) new_directions : ndarray of shape (N, 3) n_avg : ndarray of shape (N,) Average refractive index over step (for optical path) """ h = step_size h2 = h / 2.0 num_rays = len(positions) # Initialize outputs new_positions = positions.copy() new_directions = directions.copy() n_avg = np.ones(num_rays, dtype=np.float32) if not np.any(active_mask): return new_positions, new_directions, n_avg # Get active rays pos = positions[active_mask] dirs = directions[active_mask] def get_n_and_grad(p): x, y, z = p[:, 0], p[:, 1], p[:, 2] n = material.get_refractive_index(x, y, z, wavelength).astype(np.float32) gx, gy, gz = material.get_refractive_index_gradient(x, y, z, wavelength) grad = np.stack([gx, gy, gz], axis=1).astype(np.float32) return n, grad def compute_curvature(n, grad, d): dot = np.sum(d * grad, axis=1, keepdims=True) kappa = (grad - dot * d) / n[:, np.newaxis] return kappa # k1 n0, grad0 = get_n_and_grad(pos) k1_r = dirs k1_d = compute_curvature(n0, grad0, dirs) # Intermediate 1 pos1 = pos + h2 * k1_r dirs1 = normalize_directions(dirs + h2 * k1_d) # k2 n1, grad1 = get_n_and_grad(pos1) k2_r = dirs1 k2_d = compute_curvature(n1, grad1, dirs1) # Intermediate 2 pos2 = pos + h2 * k2_r dirs2 = normalize_directions(dirs + h2 * k2_d) # k3 n2, grad2 = get_n_and_grad(pos2) k3_r = dirs2 k3_d = compute_curvature(n2, grad2, dirs2) # End pos3 = pos + h * k3_r dirs3 = normalize_directions(dirs + h * k3_d) # k4 n3, grad3 = get_n_and_grad(pos3) k4_r = dirs3 k4_d = compute_curvature(n3, grad3, dirs3) # RK4 combination new_pos = pos + (h / 6.0) * (k1_r + 2 * k2_r + 2 * k3_r + k4_r) new_dir = dirs + (h / 6.0) * (k1_d + 2 * k2_d + 2 * k3_d + k4_d) new_dir = normalize_directions(new_dir) # Store results new_positions[active_mask] = new_pos new_directions[active_mask] = new_dir # Simpson's rule for average n n_avg[active_mask] = (n0 + 4 * n1 + n2) / 6.0 return new_positions, new_directions, n_avg