# The Clear BSD License
#
# Copyright (c) 2026 Tobias Heibges
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Annular Plane Surface (GPU-Capable)
Implements an annular (ring-shaped) planar detector surface with GPU acceleration.
Useful for creating concentric ring detector arrays that share a common center
and normal but have different inner/outer radii.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import numpy.typing as npt
from ..protocol import Surface, SurfaceRole
from ..registry import register_surface_type
if TYPE_CHECKING:
from ...propagation.kernels.registry import IntersectionKernelID
[docs]
@dataclass
class AnnularPlaneSurface(Surface):
"""
Annular (ring-shaped) planar detector surface with GPU acceleration.
The annular plane is defined by a center point, normal vector, and
inner/outer radii. Useful for creating concentric ring detector arrays
that all face the same target point.
Parameters
----------
center : tuple of float
Center point (cx, cy, cz) of the annulus.
normal : tuple of float
Unit normal vector (nx, ny, nz) pointing "outward" (front side).
inner_radius : float
Inner radius of the annulus (meters). Use 0 for a disk.
outer_radius : float
Outer radius of the annulus (meters).
role : SurfaceRole
What happens when a ray hits (DETECTOR, OPTICAL, or ABSORBER).
name : str
Human-readable name.
material_front : MaterialField, optional
Material on front side (where normal points). Required for OPTICAL.
material_back : MaterialField, optional
Material on back side. Required for OPTICAL.
Examples
--------
>>> # Annular detector ring at 33 km altitude
>>> ring = AnnularPlaneSurface(
... center=(0, 0, 33000),
... normal=(0, 0, -1), # Facing down toward origin
... inner_radius=5000.0,
... outer_radius=10000.0,
... role=SurfaceRole.DETECTOR,
... name="ring_5_10km",
... )
>>>
>>> # Central disk (inner_radius=0)
>>> disk = AnnularPlaneSurface(
... center=(0, 0, 33000),
... normal=(0, 0, -1),
... inner_radius=0.0,
... outer_radius=2500.0,
... role=SurfaceRole.DETECTOR,
... name="disk_2.5km",
... )
"""
center: tuple[float, float, float]
normal: tuple[float, float, float]
inner_radius: float
outer_radius: float
role: SurfaceRole
name: str = "annular_plane"
material_front: Any = None
material_back: Any = None
# GPU capability
_gpu_capable: bool = field(default=True, init=False, repr=False)
_geometry_id: int = field(default=7, init=False, repr=False) # annular plane = 7
# Internal normalized normal (set in __post_init__)
_normal: tuple[float, float, float] = field(
default=(0, 0, 1), init=False, repr=False
)
# Local U/V axes for coordinate computation (set in __post_init__)
_u_axis: tuple[float, float, float] = field(
default=(1, 0, 0), init=False, repr=False
)
_v_axis: tuple[float, float, float] = field(
default=(0, 1, 0), init=False, repr=False
)
# Squared radii for bounds checking (set in __post_init__)
_inner_radius_sq: float = field(default=0.0, init=False, repr=False)
_outer_radius_sq: float = field(default=0.0, init=False, repr=False)
# Kernel ID for this instance (set in __post_init__)
_kernel_id: IntersectionKernelID | None = field(
default=None, init=False, repr=False
)
@classmethod
def _get_supported_kernels(cls) -> list[IntersectionKernelID]:
"""Get supported intersection kernels (lazy initialization)."""
from ...propagation.kernels.registry import IntersectionKernelID
return [IntersectionKernelID.ANNULAR_PLANE_ANALYTICAL]
@classmethod
def _get_default_kernel(cls) -> IntersectionKernelID:
"""Get default intersection kernel."""
from ...propagation.kernels.registry import IntersectionKernelID
return IntersectionKernelID.ANNULAR_PLANE_ANALYTICAL
[docs]
@classmethod
def supported_kernels(cls) -> list[IntersectionKernelID]:
"""Return list of intersection kernels supported by this surface type."""
return cls._get_supported_kernels()
[docs]
@classmethod
def default_kernel(cls) -> IntersectionKernelID:
"""Return the default intersection kernel for this surface type."""
return cls._get_default_kernel()
def __post_init__(self) -> None:
# Validate radii
if self.inner_radius < 0:
raise ValueError("Inner radius must be non-negative")
if self.outer_radius <= 0:
raise ValueError("Outer radius must be positive")
if self.inner_radius >= self.outer_radius:
raise ValueError("Inner radius must be less than outer radius")
# Pre-compute squared radii
self._inner_radius_sq = self.inner_radius * self.inner_radius
self._outer_radius_sq = self.outer_radius * self.outer_radius
# Normalize the normal vector
n = np.array(self.normal, dtype=np.float64)
norm = np.linalg.norm(n)
if norm < 1e-12:
raise ValueError("Normal vector cannot be zero")
n = n / norm
self._normal = tuple(n.tolist())
# Compute local U/V axes
self._compute_local_axes()
# Set default kernel
self._kernel_id = self._get_default_kernel()
def _compute_local_axes(self) -> None:
"""Compute orthonormal U/V axes on the plane."""
n = np.array(self._normal)
# Choose reference vector not parallel to normal
if abs(n[2]) < 0.9:
ref = np.array([0.0, 0.0, 1.0])
else:
ref = np.array([1.0, 0.0, 0.0])
# U = normalize(ref - (ref.n)n)
u = ref - np.dot(ref, n) * n
u = u / np.linalg.norm(u)
# V = n x U (right-handed coordinate system)
v = np.cross(n, u)
self._u_axis = tuple(u.tolist())
self._v_axis = tuple(v.tolist())
@property
def gpu_capable(self) -> bool:
"""This surface supports GPU acceleration."""
return True
@property
def geometry_id(self) -> int:
"""GPU geometry type ID (annular plane = 7)."""
return 7
@property
def area(self) -> float:
"""Area of the annular ring in square meters."""
return np.pi * (self._outer_radius_sq - self._inner_radius_sq)
[docs]
def get_gpu_parameters(self) -> tuple:
"""
Return parameters for GPU kernel.
Returns
-------
tuple
(normal_x, normal_y, normal_z, center_x, center_y, center_z,
u_axis_x, u_axis_y, u_axis_z, v_axis_x, v_axis_y, v_axis_z,
inner_radius_sq, outer_radius_sq)
"""
return (
self._normal[0],
self._normal[1],
self._normal[2],
self.center[0],
self.center[1],
self.center[2],
self._u_axis[0],
self._u_axis[1],
self._u_axis[2],
self._v_axis[0],
self._v_axis[1],
self._v_axis[2],
self._inner_radius_sq,
self._outer_radius_sq,
)
[docs]
def get_materials(self) -> tuple | None:
"""
Return (material_front, material_back) for Fresnel calculation.
Returns
-------
tuple or None
(material_front, material_back) or None if not OPTICAL
"""
if self.role == SurfaceRole.OPTICAL:
return (self.material_front, self.material_back)
return None
[docs]
def signed_distance(
self,
positions: npt.NDArray[np.float32],
) -> npt.NDArray[np.float32]:
"""
Compute signed distance from positions to plane.
Note: This returns signed distance to the infinite plane.
Bounds checking is performed in the intersection kernel.
Parameters
----------
positions : ndarray, shape (N, 3)
Points to compute distance for
Returns
-------
ndarray, shape (N,)
Signed distance (positive on normal side, negative on back side)
"""
normal = np.array(self._normal, dtype=np.float32)
center = np.array(self.center, dtype=np.float32)
# d = dot(p - center, normal)
diff = positions - center
return np.dot(diff, normal)
[docs]
def intersect(
self,
origins: npt.NDArray[np.float32],
directions: npt.NDArray[np.float32],
min_distance: float = 1e-6,
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.bool_]]:
"""
Compute ray-plane intersection with annular bounds checking.
Parameters
----------
origins : ndarray, shape (N, 3)
Ray origins
directions : ndarray, shape (N, 3)
Ray directions (normalized)
min_distance : float
Minimum valid intersection distance
Returns
-------
distances : ndarray, shape (N,)
Distance to intersection (inf if no hit)
hit_mask : ndarray, shape (N,)
Boolean mask of valid intersections
"""
normal = np.array(self._normal, dtype=np.float32)
center = np.array(self.center, dtype=np.float32)
u_axis = np.array(self._u_axis, dtype=np.float32)
v_axis = np.array(self._v_axis, dtype=np.float32)
# t = dot(center - origin, normal) / dot(direction, normal)
denom = np.dot(directions, normal)
# Parallel rays don't intersect
parallel_mask = np.abs(denom) < 1e-10
diff = center - origins
t = np.dot(diff, normal) / np.where(parallel_mask, 1.0, denom)
# Valid intersection: not parallel, t >= min_distance
valid_t = (~parallel_mask) & (t >= min_distance)
# Compute intersection points
intersection_points = origins + t[:, np.newaxis] * directions
# Compute local coordinates relative to center
rel = intersection_points - center
# Project onto local axes
u_coord = np.dot(rel, u_axis)
v_coord = np.dot(rel, v_axis)
# Check annular bounds (circular)
r_sq = u_coord * u_coord + v_coord * v_coord
within_bounds = (r_sq >= self._inner_radius_sq) & (
r_sq <= self._outer_radius_sq
)
# Final hit mask
hit_mask = valid_t & within_bounds
distances = np.where(hit_mask, t, np.inf)
return distances.astype(np.float32), hit_mask
[docs]
def normal_at(
self,
positions: npt.NDArray[np.float32],
incoming_directions: npt.NDArray[np.float32] | None = None,
) -> npt.NDArray[np.float32]:
"""
Compute surface normal at positions.
For a plane, normal is constant everywhere.
Parameters
----------
positions : ndarray, shape (N, 3)
Points on the surface
incoming_directions : ndarray, shape (N, 3), optional
Ray directions (used to flip normal if needed)
Returns
-------
ndarray, shape (N, 3)
Normal vectors at each position
"""
n = len(positions)
normals = np.tile(np.array(self._normal, dtype=np.float32), (n, 1))
# Optionally flip normals to face incoming rays
if incoming_directions is not None:
dot = np.sum(normals * incoming_directions, axis=1)
flip_mask = dot > 0 # Normal facing same direction as ray
normals[flip_mask] = -normals[flip_mask]
return normals
[docs]
def get_local_coordinates(
self,
positions: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""
Get local (u, v) coordinates for positions on the annulus.
Useful for post-processing azimuthal binning.
Parameters
----------
positions : ndarray, shape (N, 3)
Points on or near the annular plane
Returns
-------
u_coords : ndarray, shape (N,)
U coordinate in local frame
v_coords : ndarray, shape (N,)
V coordinate in local frame
"""
center = np.array(self.center, dtype=np.float64)
u_axis = np.array(self._u_axis, dtype=np.float64)
v_axis = np.array(self._v_axis, dtype=np.float64)
rel = positions - center
u_coords = np.dot(rel, u_axis)
v_coords = np.dot(rel, v_axis)
return u_coords, v_coords
[docs]
def get_polar_coordinates(
self,
positions: npt.NDArray[np.float64],
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""
Get polar (r, theta) coordinates for positions on the annulus.
Useful for post-processing azimuthal binning.
Parameters
----------
positions : ndarray, shape (N, 3)
Points on or near the annular plane
Returns
-------
radii : ndarray, shape (N,)
Radial distance from center
azimuths : ndarray, shape (N,)
Azimuthal angle in radians (-pi to pi)
"""
u_coords, v_coords = self.get_local_coordinates(positions)
radii = np.sqrt(u_coords * u_coords + v_coords * v_coords)
azimuths = np.arctan2(v_coords, u_coords)
return radii, azimuths
# Register class with registry
register_surface_type("annular_plane", "gpu", 7, AnnularPlaneSurface)