# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Fresnel Equations
Computes reflection and transmission coefficients at interfaces between
different media using Fresnel equations.
Functions
---------
fresnel_coefficients
Compute Fresnel reflection and transmission coefficients
compute_reflection_direction
Compute reflected ray direction
compute_refraction_direction
Compute refracted ray direction (Snell's law)
References
----------
.. [1] Born, M., & Wolf, E. (1999). Principles of Optics (7th ed.).
Cambridge University Press. Chapter 1.5.
.. [2] Hecht, E. (2017). Optics (5th ed.). Pearson. Chapter 4.
"""
import numpy as np
from numpy.typing import NDArray
[docs]
def fresnel_coefficients(
n1: NDArray[np.float32] | float,
n2: NDArray[np.float32] | float,
cos_theta_i: NDArray[np.float32],
polarization: str = "unpolarized",
) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
"""
Compute Fresnel reflection and transmission coefficients.
Parameters
----------
n1 : float or ndarray
Refractive index of incident medium
n2 : float or ndarray
Refractive index of transmitted medium
cos_theta_i : ndarray, shape (N,)
Cosine of incident angle (dot product of direction and normal)
polarization : str, optional
Polarization state: 's', 'p', or 'unpolarized' (default)
Returns
-------
R : ndarray, shape (N,)
Reflection coefficient (fraction of intensity reflected)
T : ndarray, shape (N,)
Transmission coefficient (fraction of intensity transmitted)
Notes
-----
For unpolarized light, we average s and p polarizations.
Total internal reflection occurs when n1 > n2 and angle exceeds critical.
The Fresnel equations for intensity (not amplitude) are:
- s-polarization: R_s = |r_s|², T_s = (n2*cos_theta_t)/(n1*cos_theta_i) * |t_s|²
- p-polarization: R_p = |r_p|², T_p = (n2*cos_theta_t)/(n1*cos_theta_i) * |t_p|²
Examples
--------
>>> # Air to glass at 45 degrees
>>> cos_theta_i = np.cos(np.radians(45))
>>> R, T = fresnel_coefficients(1.0, 1.5, cos_theta_i)
>>> print(f"Reflection: {R:.3f}, Transmission: {T:.3f}")
"""
# Ensure arrays
cos_theta_i = np.atleast_1d(cos_theta_i).astype(np.float32)
n1 = np.atleast_1d(n1).astype(np.float32)
n2 = np.atleast_1d(n2).astype(np.float32)
# Broadcast to same shape
n_ratio = n1 / n2
# Compute cos(theta_t) using Snell's law
# n1*sin(theta_i) = n2*sin(theta_t)
# sin²(theta_t) = (n1/n2)² * sin²(theta_i)
# cos²(theta_t) = 1 - sin²(theta_t)
sin_theta_i_sq = 1.0 - cos_theta_i**2
sin_theta_t_sq = (n_ratio**2) * sin_theta_i_sq
# Check for total internal reflection
tir_mask = sin_theta_t_sq > 1.0
# Compute cos(theta_t) for non-TIR cases
cos_theta_t = np.sqrt(np.clip(1.0 - sin_theta_t_sq, 0, 1))
# Fresnel equations for amplitude coefficients
# s-polarization (TE): Electric field perpendicular to plane of incidence
r_s_num = n1 * cos_theta_i - n2 * cos_theta_t
r_s_den = n1 * cos_theta_i + n2 * cos_theta_t
r_s = r_s_num / (r_s_den + 1e-10) # Avoid division by zero
# p-polarization (TM): Electric field parallel to plane of incidence
r_p_num = n2 * cos_theta_i - n1 * cos_theta_t
r_p_den = n2 * cos_theta_i + n1 * cos_theta_t
r_p = r_p_num / (r_p_den + 1e-10)
# Intensity reflection coefficients
R_s = r_s**2
R_p = r_p**2
# Handle total internal reflection
R_s = np.where(tir_mask, 1.0, R_s)
R_p = np.where(tir_mask, 1.0, R_p)
# Combine based on polarization
if polarization == "s":
R = R_s
elif polarization == "p":
R = R_p
else: # unpolarized
R = 0.5 * (R_s + R_p)
# Transmission coefficient (energy conservation)
T = 1.0 - R
# Ensure physical bounds
R = np.clip(R, 0.0, 1.0)
T = np.clip(T, 0.0, 1.0)
return R.astype(np.float32), T.astype(np.float32)
[docs]
def compute_reflection_direction(
incident: NDArray[np.float32],
normal: NDArray[np.float32],
) -> NDArray[np.float32]:
"""
Compute reflected ray direction using law of reflection.
Parameters
----------
incident : ndarray, shape (N, 3)
Incident ray directions (should be normalized)
normal : ndarray, shape (N, 3)
Surface normals at intersection points (should be normalized)
Returns
-------
reflected : ndarray, shape (N, 3)
Reflected ray directions (normalized)
Notes
-----
Reflection formula: r = d - 2(d·n)n
where d is incident direction and n is surface normal.
The normal should point toward the incident side.
Examples
--------
>>> # Reflect ray at 45° off horizontal surface
>>> incident = np.array([[1/np.sqrt(2), 0, -1/np.sqrt(2)]])
>>> normal = np.array([[0, 0, 1]])
>>> reflected = compute_reflection_direction(incident, normal)
>>> print(reflected) # Should be [1/√2, 0, 1/√2]
"""
# Compute dot product: incident · normal
dot_in = np.sum(incident * normal, axis=1, keepdims=True)
# Reflection formula: r = d - 2(d·n)n
reflected = incident - 2.0 * dot_in * normal
# Normalize (should already be normalized, but ensure it)
norms = np.linalg.norm(reflected, axis=1, keepdims=True)
reflected = reflected / (norms + 1e-10)
return reflected.astype(np.float32)
[docs]
def compute_refraction_direction(
incident: NDArray[np.float32],
normal: NDArray[np.float32],
n1: NDArray[np.float32] | float,
n2: NDArray[np.float32] | float,
) -> tuple[NDArray[np.float32], NDArray[np.bool_]]:
"""
Compute refracted ray direction using Snell's law.
Parameters
----------
incident : ndarray, shape (N, 3)
Incident ray directions (should be normalized)
normal : ndarray, shape (N, 3)
Surface normals at intersection points (should be normalized)
n1 : float or ndarray
Refractive index of incident medium
n2 : float or ndarray
Refractive index of transmitted medium
Returns
-------
refracted : ndarray, shape (N, 3)
Refracted ray directions (normalized)
For TIR cases, returns zero vector
tir_mask : ndarray, shape (N,)
Boolean mask indicating total internal reflection
Notes
-----
Snell's law: n1*sin(θ1) = n2*sin(θ2)
Vector form of Snell's law:
t = (n1/n2)[d - (d·n)n] - n*sqrt(1 - (n1/n2)²*[1-(d·n)²])
Total internal reflection occurs when (n1/n2)²*[1-(d·n)²] > 1
Examples
--------
>>> # Air to glass at normal incidence
>>> incident = np.array([[0, 0, -1]])
>>> normal = np.array([[0, 0, 1]])
>>> refracted, tir = compute_refraction_direction(incident, normal, 1.0, 1.5)
>>> print(refracted) # Should be [0, 0, -1] (straight through)
>>> print(tir) # Should be False
"""
# Ensure arrays
n1 = np.atleast_1d(n1).astype(np.float32)
n2 = np.atleast_1d(n2).astype(np.float32)
# Compute cos(theta_i) = -incident · normal
# (negative because incident and normal point in opposite directions)
cos_theta_i = -np.sum(incident * normal, axis=1)
# Compute n1/n2 ratio
n_ratio = n1 / n2
# Check for total internal reflection
# sin²(theta_t) = (n1/n2)² * sin²(theta_i) = (n1/n2)² * (1 - cos²(theta_i))
sin_theta_t_sq = (n_ratio**2) * (1.0 - cos_theta_i**2)
tir_mask = sin_theta_t_sq > 1.0
# Compute cos(theta_t)
cos_theta_t = np.sqrt(np.clip(1.0 - sin_theta_t_sq, 0, 1))
# Vector form of Snell's law
# t = (n1/n2) * incident + [(n1/n2)*cos(theta_i) - cos(theta_t)] * normal
refracted = (
n_ratio[:, np.newaxis] * incident
+ (n_ratio * cos_theta_i - cos_theta_t)[:, np.newaxis] * normal
)
# Set TIR rays to zero vector (they won't be transmitted)
refracted[tir_mask] = 0.0
# Normalize
norms = np.linalg.norm(refracted, axis=1, keepdims=True)
refracted = np.where(norms > 1e-10, refracted / norms, 0.0)
return refracted.astype(np.float32), tir_mask
def compute_polarization_basis(
ray_directions: NDArray[np.float32],
surface_normals: NDArray[np.float32],
) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
"""
Compute local s and p polarization basis vectors at surface intersections.
At a surface, the plane of incidence contains both the ray direction and
the surface normal. The polarization basis is:
- s-polarization (TE): perpendicular to plane of incidence
- p-polarization (TM): in plane of incidence, perpendicular to ray direction
Parameters
----------
ray_directions : ndarray, shape (N, 3)
Ray direction vectors (should be normalized)
surface_normals : ndarray, shape (N, 3)
Surface normal vectors (should be normalized)
Returns
-------
s_vectors : ndarray, shape (N, 3)
S-polarization unit vectors (perpendicular to plane of incidence)
p_vectors : ndarray, shape (N, 3)
P-polarization unit vectors (in plane of incidence, perpendicular to ray)
Notes
-----
The s-vector is computed as: ŝ = (d × n) / |d × n|
The p-vector is computed as: p̂ = (d × ŝ) (perpendicular to both ray and s)
For rays at normal incidence (d parallel to n), we use a default reference
direction to define the basis.
"""
# Compute s-vector: perpendicular to plane of incidence
# s = ray_direction × normal
s_vectors = np.cross(ray_directions, surface_normals)
s_norms = np.linalg.norm(s_vectors, axis=1, keepdims=True)
# Handle degenerate case (normal incidence: ray parallel to normal)
# Use a default reference direction (global Y or X)
degenerate_mask = s_norms.squeeze() < 1e-6
if np.any(degenerate_mask):
# For degenerate cases, pick a perpendicular direction
# Try Y axis first, if parallel to ray, use X
y_axis = np.array([0, 1, 0], dtype=np.float32)
x_axis = np.array([1, 0, 0], dtype=np.float32)
for i in np.where(degenerate_mask)[0]:
# Check if Y axis is parallel to ray direction
if abs(np.dot(ray_directions[i], y_axis)) > 0.99:
s_vectors[i] = x_axis
else:
s_vectors[i] = np.cross(ray_directions[i], y_axis)
# Recompute norms for degenerate cases
s_norms[degenerate_mask] = np.linalg.norm(
s_vectors[degenerate_mask], axis=1, keepdims=True
)
# Normalize s-vectors
s_vectors = s_vectors / np.maximum(s_norms, 1e-10)
# Compute p-vector: perpendicular to ray direction, in plane of incidence
# p = ray_direction × s_vector
p_vectors = np.cross(ray_directions, s_vectors)
# Normalize p-vectors
p_norms = np.linalg.norm(p_vectors, axis=1, keepdims=True)
p_vectors = p_vectors / np.maximum(p_norms, 1e-10)
return s_vectors.astype(np.float32), p_vectors.astype(np.float32)
def transform_polarization_reflection(
polarization_vectors: NDArray[np.float32],
incident_directions: NDArray[np.float32],
reflected_directions: NDArray[np.float32],
surface_normals: NDArray[np.float32],
R_s: NDArray[np.float32] | None = None,
R_p: NDArray[np.float32] | None = None,
) -> NDArray[np.float32]:
"""
Transform polarization vectors through reflection with optional Fresnel weighting.
Upon reflection, the s-polarization component (perpendicular to plane of
incidence) maintains its direction, while the p-polarization component
(in plane of incidence) has its component along the propagation reversed.
When R_s and R_p are provided, the electric field amplitudes are weighted
by sqrt(R_s) and sqrt(R_p) respectively, causing unpolarized light to
become partially polarized after reflection (more s-polarized since R_s > R_p
for most angles).
Parameters
----------
polarization_vectors : ndarray, shape (N, 3)
Input polarization vectors (E-field direction, unit vectors)
incident_directions : ndarray, shape (N, 3)
Incident ray directions
reflected_directions : ndarray, shape (N, 3)
Reflected ray directions
surface_normals : ndarray, shape (N, 3)
Surface normal vectors
R_s : ndarray, shape (N,), optional
Fresnel reflectance for s-polarization. If provided with R_p,
applies amplitude weighting sqrt(R_s) to s-component.
R_p : ndarray, shape (N,), optional
Fresnel reflectance for p-polarization. If provided with R_s,
applies amplitude weighting sqrt(R_p) to p-component.
Returns
-------
reflected_polarization : ndarray, shape (N, 3)
Polarization vectors after reflection (normalized)
Notes
-----
The Fresnel weighting works on electric field amplitude:
- E_s_out = sqrt(R_s) * E_s_in
- E_p_out = sqrt(R_p) * E_p_in
Since intensity I ∝ |E|², this means:
- I_s_out = R_s * I_s_in
- I_p_out = R_p * I_p_in
For unpolarized light (random E direction), after reflection the light
becomes partially s-polarized because R_s > R_p (except at normal incidence).
"""
# Compute incident s and p basis
s_inc, p_inc = compute_polarization_basis(incident_directions, surface_normals)
# Project input polarization onto s and p components
E_s = np.sum(polarization_vectors * s_inc, axis=1, keepdims=True)
E_p = np.sum(polarization_vectors * p_inc, axis=1, keepdims=True)
# Apply Fresnel weighting to electric field amplitudes if provided
if R_s is not None and R_p is not None:
# Weight by sqrt(R) since E amplitude, not intensity
E_s = E_s * np.sqrt(R_s[:, np.newaxis])
E_p = E_p * np.sqrt(R_p[:, np.newaxis])
# For reflection:
# - s-component: same direction (perpendicular to plane of incidence stays same)
# - p-component: direction changes because ray direction changes
# Compute reflected s and p basis
s_refl, p_refl = compute_polarization_basis(reflected_directions, surface_normals)
# Reconstruct polarization in reflected basis
# The s-component direction is preserved
# The p-component gets a sign flip (the component of E parallel to the
# interface stays, the perpendicular component flips)
reflected_polarization = E_s * s_refl + E_p * p_refl
# Normalize
norms = np.linalg.norm(reflected_polarization, axis=1, keepdims=True)
reflected_polarization = reflected_polarization / np.maximum(norms, 1e-10)
return reflected_polarization.astype(np.float32)
def transform_polarization_refraction(
polarization_vectors: NDArray[np.float32],
incident_directions: NDArray[np.float32],
refracted_directions: NDArray[np.float32],
surface_normals: NDArray[np.float32],
) -> NDArray[np.float32]:
"""
Transform polarization vectors through refraction.
Upon refraction, the s-polarization component stays perpendicular to the
plane of incidence, and the p-polarization component stays in the plane.
The basis vectors change because the ray direction changes.
Parameters
----------
polarization_vectors : ndarray, shape (N, 3)
Input polarization vectors (E-field direction, unit vectors)
incident_directions : ndarray, shape (N, 3)
Incident ray directions
refracted_directions : ndarray, shape (N, 3)
Refracted ray directions
surface_normals : ndarray, shape (N, 3)
Surface normal vectors
Returns
-------
refracted_polarization : ndarray, shape (N, 3)
Polarization vectors after refraction (normalized)
"""
# Compute incident s and p basis
s_inc, p_inc = compute_polarization_basis(incident_directions, surface_normals)
# Project input polarization onto s and p components
E_s = np.sum(polarization_vectors * s_inc, axis=1, keepdims=True)
E_p = np.sum(polarization_vectors * p_inc, axis=1, keepdims=True)
# Compute refracted s and p basis
s_refr, p_refr = compute_polarization_basis(refracted_directions, surface_normals)
# Reconstruct polarization in refracted basis
# Both components maintain their character (s stays s, p stays p)
refracted_polarization = E_s * s_refr + E_p * p_refr
# Normalize
norms = np.linalg.norm(refracted_polarization, axis=1, keepdims=True)
refracted_polarization = refracted_polarization / np.maximum(norms, 1e-10)
return refracted_polarization.astype(np.float32)
def initialize_polarization_vectors(
ray_directions: NDArray[np.float32],
polarization: str = "unpolarized",
reference_direction: NDArray[np.float32] = None,
) -> NDArray[np.float32]:
"""
Initialize polarization vectors for rays.
Parameters
----------
ray_directions : ndarray, shape (N, 3)
Ray direction vectors (should be normalized)
polarization : str, optional
Initial polarization state:
- 'unpolarized' or 'random': random polarization perpendicular to ray
- 's' or 'horizontal': horizontal polarization (perpendicular to vertical plane)
- 'p' or 'vertical': vertical polarization (in vertical plane)
- 'custom': use reference_direction projected onto plane perpendicular to ray
reference_direction : ndarray, shape (3,), optional
Reference direction for 'custom' polarization. Will be projected onto
the plane perpendicular to each ray.
Returns
-------
polarization_vectors : ndarray, shape (N, 3)
Initial polarization vectors (unit vectors perpendicular to ray directions)
"""
n_rays = len(ray_directions)
if polarization in ["s", "horizontal"]:
# S-polarization: perpendicular to vertical (Z-containing) plane
# Use global Z as reference to define "horizontal"
z_axis = np.array([0, 0, 1], dtype=np.float32)
# s = ray × Z (horizontal direction perpendicular to ray)
pol_vectors = np.cross(ray_directions, z_axis)
norms = np.linalg.norm(pol_vectors, axis=1, keepdims=True)
# Handle rays parallel to Z
parallel_mask = norms.squeeze() < 1e-6
if np.any(parallel_mask):
# For vertical rays, use X as horizontal direction
pol_vectors[parallel_mask] = np.array([1, 0, 0], dtype=np.float32)
norms[parallel_mask] = 1.0
pol_vectors = pol_vectors / np.maximum(norms, 1e-10)
elif polarization in ["p", "vertical"]:
# P-polarization: in vertical plane containing the ray
# First get horizontal direction, then p = ray × horizontal
z_axis = np.array([0, 0, 1], dtype=np.float32)
horizontal = np.cross(ray_directions, z_axis)
h_norms = np.linalg.norm(horizontal, axis=1, keepdims=True)
# Handle rays parallel to Z
parallel_mask = h_norms.squeeze() < 1e-6
if np.any(parallel_mask):
horizontal[parallel_mask] = np.array([1, 0, 0], dtype=np.float32)
h_norms[parallel_mask] = 1.0
horizontal = horizontal / np.maximum(h_norms, 1e-10)
# p = ray × horizontal (vertical component perpendicular to ray)
pol_vectors = np.cross(ray_directions, horizontal)
norms = np.linalg.norm(pol_vectors, axis=1, keepdims=True)
pol_vectors = pol_vectors / np.maximum(norms, 1e-10)
elif polarization == "custom" and reference_direction is not None:
# Project reference direction onto plane perpendicular to each ray
ref = np.array(reference_direction, dtype=np.float32)
ref = ref / np.linalg.norm(ref)
# For each ray, project ref onto plane perpendicular to ray
# proj = ref - (ref · ray) * ray
dots = np.sum(ray_directions * ref, axis=1, keepdims=True)
pol_vectors = ref - dots * ray_directions
norms = np.linalg.norm(pol_vectors, axis=1, keepdims=True)
# Handle rays parallel to reference direction
parallel_mask = norms.squeeze() < 1e-6
if np.any(parallel_mask):
# Fall back to arbitrary perpendicular
y_axis = np.array([0, 1, 0], dtype=np.float32)
fallback = np.cross(ray_directions[parallel_mask], y_axis)
fallback_norms = np.linalg.norm(fallback, axis=1, keepdims=True)
# If still degenerate, use X
still_degen = fallback_norms.squeeze() < 1e-6
if np.any(still_degen):
x_axis = np.array([1, 0, 0], dtype=np.float32)
fallback[still_degen] = np.cross(
ray_directions[parallel_mask][still_degen], x_axis
)
fallback_norms[still_degen] = np.linalg.norm(
fallback[still_degen], axis=1, keepdims=True
)
pol_vectors[parallel_mask] = fallback / np.maximum(fallback_norms, 1e-10)
norms[parallel_mask] = 1.0
pol_vectors = pol_vectors / np.maximum(norms, 1e-10)
else: # unpolarized or random
# Generate random polarization perpendicular to each ray
# First generate random vectors
rng = np.random.default_rng()
random_vecs = rng.standard_normal((n_rays, 3)).astype(np.float32)
# Project onto plane perpendicular to ray
dots = np.sum(ray_directions * random_vecs, axis=1, keepdims=True)
pol_vectors = random_vecs - dots * ray_directions
norms = np.linalg.norm(pol_vectors, axis=1, keepdims=True)
# Handle any degenerate cases
degen_mask = norms.squeeze() < 1e-6
if np.any(degen_mask):
# Try again with different random vectors
new_random = rng.standard_normal((np.sum(degen_mask), 3)).astype(np.float32)
dots_new = np.sum(
ray_directions[degen_mask] * new_random, axis=1, keepdims=True
)
pol_vectors[degen_mask] = new_random - dots_new * ray_directions[degen_mask]
norms[degen_mask] = np.linalg.norm(
pol_vectors[degen_mask], axis=1, keepdims=True
)
pol_vectors = pol_vectors / np.maximum(norms, 1e-10)
return pol_vectors.astype(np.float32)