# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Shared GPU Device Functions
This module consolidates device functions used across GPU kernels:
- Integration functions (Euler, RK4 steps)
- Dispersion models (Sellmeier, Cauchy equations)
- Common utilities (normalization, interpolation)
All functions are designed to be used with @cuda.jit(device=True).
"""
import math
from collections.abc import Callable
import numpy as np
import numpy.typing as npt
# GPU support is optional
try:
from numba import cuda
HAS_CUDA = True
except ImportError:
class _FakeCuda:
"""Fake cuda module for when numba is not installed."""
@staticmethod
def jit(*args, **kwargs):
"""Return a no-op decorator."""
def decorator(func):
return func
if args and callable(args[0]):
return args[0]
return decorator
@staticmethod
def is_available():
return False
@staticmethod
def grid(n):
return 0
@staticmethod
def synchronize():
pass
cuda = _FakeCuda() # type: ignore[assignment]
HAS_CUDA = False
# =============================================================================
# CUDA Device Functions for Integration
# =============================================================================
[docs]
@cuda.jit(device=True)
def device_adaptive_step_size(
gradient_magnitude: float,
refractive_index: float,
wavelength: float,
min_step: float,
max_step: float,
) -> float:
"""
Compute adaptive step size based on local gradient on GPU device.
Parameters
----------
gradient_magnitude : float
|∇n| at current position
refractive_index : float
n at current position
wavelength : float
Wavelength in meters
min_step : float
Minimum allowed step size
max_step : float
Maximum allowed step size
Returns
-------
float
Recommended step size
"""
# Radius of curvature
if gradient_magnitude > 1e-12:
radius_curvature = refractive_index / gradient_magnitude
step_from_curvature = radius_curvature / 10.0
else:
step_from_curvature = max_step
# Wavelength in medium
wavelength_medium = wavelength / refractive_index
# Take minimum
step = min(step_from_curvature, wavelength_medium)
# Clamp to limits
if step < min_step:
step = min_step
if step > max_step:
step = max_step
return step
[docs]
@cuda.jit(device=True)
def device_euler_step(
x: float,
y: float,
z: float,
dx: float,
dy: float,
dz: float,
n: float,
grad_x: float,
grad_y: float,
grad_z: float,
step_size: float,
) -> tuple[float, float, float, float, float, float]:
"""
Single Euler integration step for ray equation on GPU.
The ray equation in gradient media is:
dr/ds = d̂
dd̂/ds = (∇n - (d̂·∇n)d̂) / n
Parameters
----------
x, y, z : float
Current position
dx, dy, dz : float
Current direction (unit vector)
n : float
Refractive index at current position
grad_x, grad_y, grad_z : float
Gradient of n at current position
step_size : float
Step size ds
Returns
-------
tuple
(new_x, new_y, new_z, new_dx, new_dy, new_dz)
"""
# Curvature: κ = (∇n - (d̂·∇n)d̂) / n
dot = dx * grad_x + dy * grad_y + dz * grad_z
kappa_x = (grad_x - dot * dx) / n
kappa_y = (grad_y - dot * dy) / n
kappa_z = (grad_z - dot * dz) / n
# Update position: r_new = r + d̂ * ds
new_x = x + dx * step_size
new_y = y + dy * step_size
new_z = z + dz * step_size
# Update direction: d̂_new = d̂ + κ * ds
new_dx = dx + kappa_x * step_size
new_dy = dy + kappa_y * step_size
new_dz = dz + kappa_z * step_size
# Renormalize
norm = math.sqrt(new_dx * new_dx + new_dy * new_dy + new_dz * new_dz)
if norm > 1e-12:
new_dx /= norm
new_dy /= norm
new_dz /= norm
return new_x, new_y, new_z, new_dx, new_dy, new_dz
[docs]
@cuda.jit(device=True)
def device_rk4_step(
x: float,
y: float,
z: float,
dx: float,
dy: float,
dz: float,
step_size: float,
n_func,
grad_func,
*material_params,
) -> tuple[float, float, float, float, float, float, float]:
"""
RK4 integration step for ray equation on GPU.
This is a generic RK4 step that takes material evaluation functions
as parameters. For material-specific implementations with better
performance, materials should define their own device functions.
Parameters
----------
x, y, z : float
Current position
dx, dy, dz : float
Current direction (unit vector)
step_size : float
Integration step size
n_func : device function
Function to evaluate n(x, y, z, *material_params)
grad_func : device function
Function to evaluate ∇n(x, y, z, *material_params) -> (gx, gy, gz)
material_params : tuple
Material-specific parameters passed to n_func and grad_func
Returns
-------
tuple of float
(new_x, new_y, new_z, new_dx, new_dy, new_dz, n_avg)
Notes
-----
For best performance, materials should implement their own specialized
RK4 device functions that inline the n and gradient calculations.
This generic version has function call overhead.
"""
h = step_size
h2 = h / 2.0
def _normalize(vx, vy, vz):
norm = math.sqrt(vx * vx + vy * vy + vz * vz)
if norm > 1e-12:
return vx / norm, vy / norm, vz / norm
return vx, vy, vz
def _curvature(px, py, pz, dirx, diry, dirz):
n = n_func(px, py, pz, *material_params)
gx, gy, gz = grad_func(px, py, pz, *material_params)
dot = dirx * gx + diry * gy + dirz * gz
kx = (gx - dot * dirx) / n
ky = (gy - dot * diry) / n
kz = (gz - dot * dirz) / n
return n, kx, ky, kz
# k1 at current point
n0, k1_dx, k1_dy, k1_dz = _curvature(x, y, z, dx, dy, dz)
k1_rx, k1_ry, k1_rz = dx, dy, dz
# Intermediate point 1
x1 = x + h2 * k1_rx
y1 = y + h2 * k1_ry
z1 = z + h2 * k1_rz
dx1, dy1, dz1 = _normalize(dx + h2 * k1_dx, dy + h2 * k1_dy, dz + h2 * k1_dz)
# k2
n1, k2_dx, k2_dy, k2_dz = _curvature(x1, y1, z1, dx1, dy1, dz1)
k2_rx, k2_ry, k2_rz = dx1, dy1, dz1
# Intermediate point 2
x2 = x + h2 * k2_rx
y2 = y + h2 * k2_ry
z2 = z + h2 * k2_rz
dx2, dy2, dz2 = _normalize(dx + h2 * k2_dx, dy + h2 * k2_dy, dz + h2 * k2_dz)
# k3
n2, k3_dx, k3_dy, k3_dz = _curvature(x2, y2, z2, dx2, dy2, dz2)
k3_rx, k3_ry, k3_rz = dx2, dy2, dz2
# End point
x3 = x + h * k3_rx
y3 = y + h * k3_ry
z3 = z + h * k3_rz
dx3, dy3, dz3 = _normalize(dx + h * k3_dx, dy + h * k3_dy, dz + h * k3_dz)
# k4
n3, k4_dx, k4_dy, k4_dz = _curvature(x3, y3, z3, dx3, dy3, dz3)
k4_rx, k4_ry, k4_rz = dx3, dy3, dz3
# Final RK4 combination
new_x = x + (h / 6.0) * (k1_rx + 2 * k2_rx + 2 * k3_rx + k4_rx)
new_y = y + (h / 6.0) * (k1_ry + 2 * k2_ry + 2 * k3_ry + k4_ry)
new_z = z + (h / 6.0) * (k1_rz + 2 * k2_rz + 2 * k3_rz + k4_rz)
new_dx = dx + (h / 6.0) * (k1_dx + 2 * k2_dx + 2 * k3_dx + k4_dx)
new_dy = dy + (h / 6.0) * (k1_dy + 2 * k2_dy + 2 * k3_dy + k4_dy)
new_dz = dz + (h / 6.0) * (k1_dz + 2 * k2_dz + 2 * k3_dz + k4_dz)
new_dx, new_dy, new_dz = _normalize(new_dx, new_dy, new_dz)
# Simpson's rule for average n
n_avg = (n0 + 4 * n1 + n2) / 6.0
return new_x, new_y, new_z, new_dx, new_dy, new_dz, n_avg
# =============================================================================
# CUDA Device Functions for Dispersion Models
# =============================================================================
[docs]
@cuda.jit(device=True)
def device_sellmeier_equation(
wl_um: float,
B1: float,
B2: float,
B3: float,
C1: float,
C2: float,
C3: float,
) -> float:
"""
GPU-compatible Sellmeier equation.
Computes refractive index from wavelength and Sellmeier coefficients.
Parameters
----------
wl_um : float
Wavelength in micrometers.
B1, B2, B3 : float
Sellmeier B coefficients.
C1, C2, C3 : float
Sellmeier C coefficients in μm².
Returns
-------
n : float
Refractive index.
"""
wl2 = wl_um * wl_um
n2_minus_1 = B1 * wl2 / (wl2 - C1) + B2 * wl2 / (wl2 - C2) + B3 * wl2 / (wl2 - C3)
return math.sqrt(1.0 + n2_minus_1)
[docs]
@cuda.jit(device=True)
def device_cauchy_equation(
wl_um: float,
A: float,
B: float,
C: float,
) -> float:
"""
GPU-compatible Cauchy equation.
Computes refractive index from wavelength and Cauchy coefficients.
Parameters
----------
wl_um : float
Wavelength in micrometers.
A : float
Constant term.
B : float
First-order dispersion coefficient in μm².
C : float
Second-order dispersion coefficient in μm⁴.
Returns
-------
n : float
Refractive index.
"""
wl2 = wl_um * wl_um
return A + B / wl2 + C / (wl2 * wl2)
# =============================================================================
# Pure Python Step Functions (for CPU propagation)
# =============================================================================
[docs]
def euler_step(
x: float,
y: float,
z: float,
dx: float,
dy: float,
dz: float,
n: float,
grad_x: float,
grad_y: float,
grad_z: float,
step_size: float,
) -> tuple[float, float, float, float, float, float]:
"""
Single Euler integration step for ray equation.
Parameters
----------
x, y, z : float
Current position
dx, dy, dz : float
Current direction (unit vector)
n : float
Refractive index at current position
grad_x, grad_y, grad_z : float
Gradient of n at current position
step_size : float
Step size ds
Returns
-------
tuple
(new_x, new_y, new_z, new_dx, new_dy, new_dz)
"""
# Curvature: κ = (∇n - (d̂·∇n)d̂) / n
dot = dx * grad_x + dy * grad_y + dz * grad_z
kappa_x = (grad_x - dot * dx) / n
kappa_y = (grad_y - dot * dy) / n
kappa_z = (grad_z - dot * dz) / n
# Update position: r_new = r + d̂ * ds
new_x = x + dx * step_size
new_y = y + dy * step_size
new_z = z + dz * step_size
# Update direction: d̂_new = d̂ + κ * ds
new_dx = dx + kappa_x * step_size
new_dy = dy + kappa_y * step_size
new_dz = dz + kappa_z * step_size
# Renormalize
norm = math.sqrt(new_dx**2 + new_dy**2 + new_dz**2)
if norm > 1e-12:
new_dx /= norm
new_dy /= norm
new_dz /= norm
return new_x, new_y, new_z, new_dx, new_dy, new_dz
[docs]
def rk4_step(
x: float,
y: float,
z: float,
dx: float,
dy: float,
dz: float,
n_func: Callable[[float, float, float, float], float],
grad_func: Callable[[float, float, float, float], tuple[float, float, float]],
wavelength: float,
step_size: float,
) -> tuple[float, float, float, float, float, float, float]:
"""
RK4 integration step for ray equation in gradient medium.
Solves the coupled ODEs:
dr/ds = d̂
dd̂/ds = (∇n - (d̂·∇n)d̂) / n
Parameters
----------
x, y, z : float
Current position
dx, dy, dz : float
Current direction (unit vector)
n_func : callable
Function to evaluate n(x, y, z, wavelength)
grad_func : callable
Function to evaluate ∇n(x, y, z, wavelength) -> (gx, gy, gz)
wavelength : float
Wavelength
step_size : float
Integration step size
Returns
-------
tuple of float
(new_x, new_y, new_z, new_dx, new_dy, new_dz, optical_path_increment)
"""
h = step_size
h2 = h / 2.0
def _normalize(vx, vy, vz):
norm = math.sqrt(vx**2 + vy**2 + vz**2)
if norm > 1e-12:
return vx / norm, vy / norm, vz / norm
return vx, vy, vz
def _curvature(px, py, pz, dirx, diry, dirz):
n = n_func(px, py, pz, wavelength)
gx, gy, gz = grad_func(px, py, pz, wavelength)
dot = dirx * gx + diry * gy + dirz * gz
kx = (gx - dot * dirx) / n
ky = (gy - dot * diry) / n
kz = (gz - dot * dirz) / n
return n, kx, ky, kz
# k1 evaluation at current point
n0, k1_dx, k1_dy, k1_dz = _curvature(x, y, z, dx, dy, dz)
k1_rx, k1_ry, k1_rz = dx, dy, dz
# Intermediate point 1
x1 = x + h2 * k1_rx
y1 = y + h2 * k1_ry
z1 = z + h2 * k1_rz
dx1, dy1, dz1 = _normalize(dx + h2 * k1_dx, dy + h2 * k1_dy, dz + h2 * k1_dz)
# k2 evaluation
n1, k2_dx, k2_dy, k2_dz = _curvature(x1, y1, z1, dx1, dy1, dz1)
k2_rx, k2_ry, k2_rz = dx1, dy1, dz1
# Intermediate point 2
x2 = x + h2 * k2_rx
y2 = y + h2 * k2_ry
z2 = z + h2 * k2_rz
dx2, dy2, dz2 = _normalize(dx + h2 * k2_dx, dy + h2 * k2_dy, dz + h2 * k2_dz)
# k3 evaluation
n2, k3_dx, k3_dy, k3_dz = _curvature(x2, y2, z2, dx2, dy2, dz2)
k3_rx, k3_ry, k3_rz = dx2, dy2, dz2
# End point
x3 = x + h * k3_rx
y3 = y + h * k3_ry
z3 = z + h * k3_rz
dx3, dy3, dz3 = _normalize(dx + h * k3_dx, dy + h * k3_dy, dz + h * k3_dz)
# k4 evaluation
n3, k4_dx, k4_dy, k4_dz = _curvature(x3, y3, z3, dx3, dy3, dz3)
k4_rx, k4_ry, k4_rz = dx3, dy3, dz3
# Final RK4 combination
new_x = x + (h / 6.0) * (k1_rx + 2 * k2_rx + 2 * k3_rx + k4_rx)
new_y = y + (h / 6.0) * (k1_ry + 2 * k2_ry + 2 * k3_ry + k4_ry)
new_z = z + (h / 6.0) * (k1_rz + 2 * k2_rz + 2 * k3_rz + k4_rz)
new_dx = dx + (h / 6.0) * (k1_dx + 2 * k2_dx + 2 * k3_dx + k4_dx)
new_dy = dy + (h / 6.0) * (k1_dy + 2 * k2_dy + 2 * k3_dy + k4_dy)
new_dz = dz + (h / 6.0) * (k1_dz + 2 * k2_dz + 2 * k3_dz + k4_dz)
new_dx, new_dy, new_dz = _normalize(new_dx, new_dy, new_dz)
# Optical path: Simpson's rule approximation
n_avg = (n0 + 4 * n1 + n2) / 6.0
optical_increment = n_avg * h
return new_x, new_y, new_z, new_dx, new_dy, new_dz, optical_increment
[docs]
def compute_adaptive_step_size(
gradient_magnitude: float,
refractive_index: float,
wavelength: float,
min_step: float,
max_step: float,
) -> float:
"""
Compute adaptive step size based on local gradient.
Parameters
----------
gradient_magnitude : float
|∇n| at current position
refractive_index : float
n at current position
wavelength : float
Wavelength in meters
min_step : float
Minimum allowed step size
max_step : float
Maximum allowed step size
Returns
-------
float
Recommended step size
Notes
-----
Step size is chosen to resolve curvature: Δs ≤ R_c / 10
where R_c = n / |∇n| is the radius of curvature.
"""
# Radius of curvature
if gradient_magnitude > 1e-12:
radius_curvature = refractive_index / gradient_magnitude
step_from_curvature = radius_curvature / 10.0
else:
step_from_curvature = max_step
# Wavelength in medium
wavelength_medium = wavelength / refractive_index
# Take minimum
step = min(step_from_curvature, wavelength_medium)
# Clamp to limits
return max(min_step, min(step, max_step))
# =============================================================================
# Vectorized NumPy Operations (for batch CPU propagation)
# =============================================================================
[docs]
def normalize_directions(
directions: npt.NDArray[np.float32],
) -> npt.NDArray[np.float32]:
"""Normalize direction vectors to unit length."""
norms = np.linalg.norm(directions, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-12) # Avoid division by zero
return directions / norms
[docs]
def euler_step_batch(
positions: npt.NDArray[np.float32],
directions: npt.NDArray[np.float32],
active_mask: npt.NDArray[np.bool_],
n: npt.NDArray[np.float32],
grad_n: npt.NDArray[np.float32], # (N, 3)
step_size: float,
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
"""
Vectorized Euler step for batch of rays.
Parameters
----------
positions : ndarray of shape (N, 3)
Ray positions
directions : ndarray of shape (N, 3)
Ray directions (unit vectors)
active_mask : ndarray of shape (N,)
Boolean mask for active rays
n : ndarray of shape (N,)
Refractive index at each position
grad_n : ndarray of shape (N, 3)
Gradient of n at each position
step_size : float
Step size
Returns
-------
new_positions : ndarray of shape (N, 3)
new_directions : ndarray of shape (N, 3)
"""
# Only process active rays
new_positions = positions.copy()
new_directions = directions.copy()
# d̂ · ∇n
dot = np.sum(directions[active_mask] * grad_n[active_mask], axis=1, keepdims=True)
# κ = (∇n - (d̂·∇n)d̂) / n
n_active = n[active_mask][:, np.newaxis]
kappa = (grad_n[active_mask] - dot * directions[active_mask]) / n_active
# Update position and direction
new_positions[active_mask] += directions[active_mask] * step_size
new_directions[active_mask] += kappa * step_size
# Renormalize
new_directions[active_mask] = normalize_directions(new_directions[active_mask])
return new_positions, new_directions
[docs]
def rk4_step_batch(
positions: npt.NDArray[np.float32],
directions: npt.NDArray[np.float32],
active_mask: npt.NDArray[np.bool_],
material, # MaterialFieldProtocol
step_size: float,
wavelength: float,
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]:
"""
Vectorized RK4 step for batch of rays.
Parameters
----------
positions : ndarray of shape (N, 3)
Ray positions
directions : ndarray of shape (N, 3)
Ray directions (unit vectors)
active_mask : ndarray of shape (N,)
Boolean mask for active rays
material : MaterialFieldProtocol
Material providing n and ∇n evaluation
step_size : float
Step size
wavelength : float
Wavelength
Returns
-------
new_positions : ndarray of shape (N, 3)
new_directions : ndarray of shape (N, 3)
n_avg : ndarray of shape (N,)
Average refractive index over step (for optical path)
"""
h = step_size
h2 = h / 2.0
num_rays = len(positions)
# Initialize outputs
new_positions = positions.copy()
new_directions = directions.copy()
n_avg = np.ones(num_rays, dtype=np.float32)
if not np.any(active_mask):
return new_positions, new_directions, n_avg
# Get active rays
pos = positions[active_mask]
dirs = directions[active_mask]
def get_n_and_grad(p):
x, y, z = p[:, 0], p[:, 1], p[:, 2]
n = material.get_refractive_index(x, y, z, wavelength).astype(np.float32)
gx, gy, gz = material.get_refractive_index_gradient(x, y, z, wavelength)
grad = np.stack([gx, gy, gz], axis=1).astype(np.float32)
return n, grad
def compute_curvature(n, grad, d):
dot = np.sum(d * grad, axis=1, keepdims=True)
kappa = (grad - dot * d) / n[:, np.newaxis]
return kappa
# k1
n0, grad0 = get_n_and_grad(pos)
k1_r = dirs
k1_d = compute_curvature(n0, grad0, dirs)
# Intermediate 1
pos1 = pos + h2 * k1_r
dirs1 = normalize_directions(dirs + h2 * k1_d)
# k2
n1, grad1 = get_n_and_grad(pos1)
k2_r = dirs1
k2_d = compute_curvature(n1, grad1, dirs1)
# Intermediate 2
pos2 = pos + h2 * k2_r
dirs2 = normalize_directions(dirs + h2 * k2_d)
# k3
n2, grad2 = get_n_and_grad(pos2)
k3_r = dirs2
k3_d = compute_curvature(n2, grad2, dirs2)
# End
pos3 = pos + h * k3_r
dirs3 = normalize_directions(dirs + h * k3_d)
# k4
n3, grad3 = get_n_and_grad(pos3)
k4_r = dirs3
k4_d = compute_curvature(n3, grad3, dirs3)
# RK4 combination
new_pos = pos + (h / 6.0) * (k1_r + 2 * k2_r + 2 * k3_r + k4_r)
new_dir = dirs + (h / 6.0) * (k1_d + 2 * k2_d + 2 * k3_d + k4_d)
new_dir = normalize_directions(new_dir)
# Store results
new_positions[active_mask] = new_pos
new_directions[active_mask] = new_dir
# Simpson's rule for average n
n_avg[active_mask] = (n0 + 4 * n1 + n2) / 6.0
return new_positions, new_directions, n_avg