465 lines
16 KiB
Python
465 lines
16 KiB
Python
"""
|
|
Spine analysis functions for computing curves, Cobb angles, and identifying apex vertebrae.
|
|
"""
|
|
import numpy as np
|
|
from scipy.interpolate import splprep, splev
|
|
from typing import Tuple, List, Optional
|
|
|
|
from data_models import Spine2D, VertebraLandmark
|
|
|
|
|
|
def compute_spine_curve(
|
|
spine: Spine2D,
|
|
smooth: float = 1.0,
|
|
n_samples: int = 200
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
"""
|
|
Compute smooth spine centerline from vertebra centroids.
|
|
|
|
Args:
|
|
spine: Spine2D object with detected vertebrae
|
|
smooth: Smoothing factor for spline (higher = smoother)
|
|
n_samples: Number of points to sample along the curve
|
|
|
|
Returns:
|
|
C: Curve points, shape (n_samples, 2)
|
|
T: Tangent vectors, shape (n_samples, 2)
|
|
N: Normal vectors, shape (n_samples, 2)
|
|
curvature: Curvature at each point, shape (n_samples,)
|
|
"""
|
|
pts = spine.get_centroids()
|
|
|
|
if len(pts) < 4:
|
|
raise ValueError(f"Need at least 4 vertebrae for spline, got {len(pts)}")
|
|
|
|
# Fit parametric spline through centroids
|
|
x = pts[:, 0]
|
|
y = pts[:, 1]
|
|
|
|
try:
|
|
tck, u = splprep([x, y], s=smooth, k=min(3, len(pts)-1))
|
|
except Exception as e:
|
|
# Fallback: simple linear interpolation
|
|
t = np.linspace(0, 1, n_samples)
|
|
xs = np.interp(t, np.linspace(0, 1, len(x)), x)
|
|
ys = np.interp(t, np.linspace(0, 1, len(y)), y)
|
|
C = np.stack([xs, ys], axis=1).astype(np.float32)
|
|
T = np.gradient(C, axis=0)
|
|
T = T / (np.linalg.norm(T, axis=1, keepdims=True) + 1e-8)
|
|
N = np.stack([-T[:, 1], T[:, 0]], axis=1)
|
|
curvature = np.zeros(n_samples, dtype=np.float32)
|
|
return C, T, N, curvature
|
|
|
|
# Sample the spline
|
|
u_new = np.linspace(0, 1, n_samples)
|
|
xs, ys = splev(u_new, tck)
|
|
|
|
# First and second derivatives
|
|
dx, dy = splev(u_new, tck, der=1)
|
|
ddx, ddy = splev(u_new, tck, der=2)
|
|
|
|
# Curve points
|
|
C = np.stack([xs, ys], axis=1).astype(np.float32)
|
|
|
|
# Tangent vectors (normalized)
|
|
T = np.stack([dx, dy], axis=1)
|
|
T_norm = np.linalg.norm(T, axis=1, keepdims=True) + 1e-8
|
|
T = (T / T_norm).astype(np.float32)
|
|
|
|
# Normal vectors (perpendicular to tangent)
|
|
N = np.stack([-T[:, 1], T[:, 0]], axis=1).astype(np.float32)
|
|
|
|
# Curvature: |x'y'' - y'x''| / (x'^2 + y'^2)^(3/2)
|
|
curvature = np.abs(dx * ddy - dy * ddx) / (np.power(dx**2 + dy**2, 1.5) + 1e-8)
|
|
curvature = curvature.astype(np.float32)
|
|
|
|
return C, T, N, curvature
|
|
|
|
|
|
def compute_cobb_angles(spine: Spine2D) -> Tuple[float, float, float]:
|
|
"""
|
|
Compute Cobb angles from vertebra orientations.
|
|
|
|
The Cobb angle is measured as the angle between:
|
|
- The superior endplate of the most tilted vertebra at the top of a curve
|
|
- The inferior endplate of the most tilted vertebra at the bottom
|
|
|
|
This function estimates 3 Cobb angles:
|
|
- PT: Proximal Thoracic (T1-T6 region)
|
|
- MT: Main Thoracic (T6-T12 region)
|
|
- TL: Thoracolumbar/Lumbar (T12-L5 region)
|
|
|
|
Args:
|
|
spine: Spine2D object with detected vertebrae
|
|
|
|
Returns:
|
|
(pt_angle, mt_angle, tl_angle) in degrees
|
|
"""
|
|
orientations = spine.get_orientations()
|
|
n = len(orientations)
|
|
|
|
if n < 7:
|
|
spine.cobb_pt = 0.0
|
|
spine.cobb_mt = 0.0
|
|
spine.cobb_tl = 0.0
|
|
return 0.0, 0.0, 0.0
|
|
|
|
# Divide into regions (approximately)
|
|
# PT: top 1/3, MT: middle 1/3, TL: bottom 1/3
|
|
pt_end = n // 3
|
|
mt_end = 2 * n // 3
|
|
|
|
# Find max tilt difference in each region
|
|
def region_cobb(start_idx: int, end_idx: int) -> float:
|
|
if end_idx <= start_idx:
|
|
return 0.0
|
|
region_angles = orientations[start_idx:end_idx]
|
|
if len(region_angles) < 2:
|
|
return 0.0
|
|
# Cobb angle = max angle - min angle in region
|
|
return abs(float(np.max(region_angles) - np.min(region_angles)))
|
|
|
|
pt_angle = region_cobb(0, pt_end)
|
|
mt_angle = region_cobb(pt_end, mt_end)
|
|
tl_angle = region_cobb(mt_end, n)
|
|
|
|
# Store in spine object
|
|
spine.cobb_pt = pt_angle
|
|
spine.cobb_mt = mt_angle
|
|
spine.cobb_tl = tl_angle
|
|
|
|
# Determine curve type
|
|
if mt_angle > 10 and tl_angle > 10:
|
|
spine.curve_type = "S" # Double curve
|
|
elif mt_angle > 10 or tl_angle > 10:
|
|
spine.curve_type = "C" # Single curve
|
|
else:
|
|
spine.curve_type = "Normal"
|
|
|
|
return pt_angle, mt_angle, tl_angle
|
|
|
|
|
|
def find_apex_vertebrae(spine: Spine2D) -> List[int]:
|
|
"""
|
|
Find indices of apex vertebrae (most deviated from midline).
|
|
|
|
Args:
|
|
spine: Spine2D with computed curve
|
|
|
|
Returns:
|
|
List of vertebra indices that are curve apexes
|
|
"""
|
|
centroids = spine.get_centroids()
|
|
|
|
if len(centroids) < 5:
|
|
return []
|
|
|
|
# Find midline (linear fit through endpoints)
|
|
start = centroids[0]
|
|
end = centroids[-1]
|
|
|
|
# Distance from midline for each vertebra
|
|
midline_vec = end - start
|
|
midline_len = np.linalg.norm(midline_vec)
|
|
|
|
if midline_len < 1e-6:
|
|
return []
|
|
|
|
midline_unit = midline_vec / midline_len
|
|
|
|
# Calculate perpendicular distance to midline
|
|
deviations = []
|
|
for i, pt in enumerate(centroids):
|
|
v = pt - start
|
|
# Project onto midline
|
|
proj_len = np.dot(v, midline_unit)
|
|
proj = proj_len * midline_unit
|
|
# Perpendicular distance
|
|
perp = v - proj
|
|
dist = np.linalg.norm(perp)
|
|
# Sign: positive if to the right of midline
|
|
sign = np.sign(np.cross(midline_unit, v / (np.linalg.norm(v) + 1e-8)))
|
|
deviations.append(dist * sign)
|
|
|
|
deviations = np.array(deviations)
|
|
|
|
# Find local extrema (peaks and valleys)
|
|
apexes = []
|
|
for i in range(1, len(deviations) - 1):
|
|
# Local maximum
|
|
if deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]:
|
|
if abs(deviations[i]) > 5: # Minimum deviation threshold (pixels)
|
|
apexes.append(i)
|
|
# Local minimum
|
|
elif deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]:
|
|
if abs(deviations[i]) > 5:
|
|
apexes.append(i)
|
|
|
|
return apexes
|
|
|
|
|
|
def get_curve_severity(cobb_angle: float) -> str:
|
|
"""
|
|
Get clinical severity classification from Cobb angle.
|
|
|
|
Args:
|
|
cobb_angle: Cobb angle in degrees
|
|
|
|
Returns:
|
|
Severity string: "Normal", "Mild", "Moderate", or "Severe"
|
|
"""
|
|
if cobb_angle < 10:
|
|
return "Normal"
|
|
elif cobb_angle < 25:
|
|
return "Mild"
|
|
elif cobb_angle < 40:
|
|
return "Moderate"
|
|
else:
|
|
return "Severe"
|
|
|
|
|
|
def classify_rigo_type(spine: Spine2D) -> dict:
|
|
"""
|
|
Classify scoliosis according to Rigo-Chêneau brace classification.
|
|
|
|
Rigo Classification Types:
|
|
- A1, A2, A3: 3-curve patterns (thoracic major)
|
|
- B1, B2: 4-curve patterns (double major)
|
|
- C1, C2: Single thoracolumbar/lumbar
|
|
- E1, E2: Single thoracic
|
|
|
|
Args:
|
|
spine: Spine2D object with detected vertebrae and Cobb angles
|
|
|
|
Returns:
|
|
dict with 'rigo_type', 'description', 'apex_region', 'curve_pattern'
|
|
"""
|
|
# Get Cobb angles
|
|
cobb_angles = spine.get_cobb_angles()
|
|
pt = cobb_angles.get('PT', 0)
|
|
mt = cobb_angles.get('MT', 0)
|
|
tl = cobb_angles.get('TL', 0)
|
|
|
|
n_verts = len(spine.vertebrae)
|
|
|
|
# Calculate lateral deviations to determine curve direction
|
|
centroids = spine.get_centroids()
|
|
deviations = _calculate_lateral_deviations(centroids)
|
|
|
|
# Find apex positions and directions
|
|
apex_info = _find_apex_info(centroids, deviations, n_verts)
|
|
|
|
# Determine curve pattern based on number of significant curves
|
|
significant_curves = []
|
|
if pt >= 10:
|
|
significant_curves.append(('PT', pt))
|
|
if mt >= 10:
|
|
significant_curves.append(('MT', mt))
|
|
if tl >= 10:
|
|
significant_curves.append(('TL', tl))
|
|
|
|
n_curves = len(significant_curves)
|
|
|
|
# Classification logic
|
|
rigo_type = "N/A"
|
|
description = ""
|
|
curve_pattern = ""
|
|
|
|
# No significant scoliosis
|
|
if n_curves == 0 or max(pt, mt, tl) < 10:
|
|
rigo_type = "Normal"
|
|
description = "No significant scoliosis (all Cobb angles < 10°)"
|
|
curve_pattern = "None"
|
|
|
|
# Single curve patterns
|
|
elif n_curves == 1:
|
|
max_curve = significant_curves[0][0]
|
|
max_angle = significant_curves[0][1]
|
|
|
|
if max_curve == 'MT' or max_curve == 'PT':
|
|
# Thoracic single curve
|
|
if apex_info['thoracic_apex_idx'] is not None:
|
|
# Check if there's a compensatory lumbar
|
|
if tl > 5:
|
|
rigo_type = "E2"
|
|
description = f"Single thoracic curve ({max_angle:.1f}°) with lumbar compensatory ({tl:.1f}°)"
|
|
curve_pattern = "Thoracic with compensation"
|
|
else:
|
|
rigo_type = "E1"
|
|
description = f"True single thoracic curve ({max_angle:.1f}°)"
|
|
curve_pattern = "Single thoracic"
|
|
else:
|
|
rigo_type = "E1"
|
|
description = f"Single thoracic curve ({max_angle:.1f}°)"
|
|
curve_pattern = "Single thoracic"
|
|
|
|
elif max_curve == 'TL':
|
|
# Thoracolumbar/Lumbar single curve
|
|
if mt > 5 or pt > 5:
|
|
rigo_type = "C2"
|
|
description = f"Thoracolumbar curve ({tl:.1f}°) with upper compensatory"
|
|
curve_pattern = "TL/L with compensation"
|
|
else:
|
|
rigo_type = "C1"
|
|
description = f"Single thoracolumbar/lumbar curve ({tl:.1f}°)"
|
|
curve_pattern = "Single TL/L"
|
|
|
|
# Double curve patterns
|
|
elif n_curves >= 2:
|
|
# Determine which curves are primary
|
|
thoracic_total = pt + mt
|
|
lumbar_total = tl
|
|
|
|
# Check curve directions for S vs C pattern
|
|
is_s_curve = apex_info['is_s_pattern']
|
|
|
|
if is_s_curve:
|
|
# S-curve: typically 3 or 4 curve patterns
|
|
if thoracic_total > lumbar_total * 1.5:
|
|
# Thoracic dominant - Type A (3-curve)
|
|
if apex_info['lumbar_apex_low']:
|
|
rigo_type = "A1"
|
|
description = f"3-curve: Thoracic major ({mt:.1f}°), lumbar apex low"
|
|
elif apex_info['apex_at_tl_junction']:
|
|
rigo_type = "A2"
|
|
description = f"3-curve: Thoracolumbar transition ({mt:.1f}°/{tl:.1f}°)"
|
|
else:
|
|
rigo_type = "A3"
|
|
description = f"3-curve: Thoracic major ({mt:.1f}°) with structural lumbar ({tl:.1f}°)"
|
|
curve_pattern = "3-curve (thoracic major)"
|
|
|
|
elif lumbar_total > thoracic_total * 1.5:
|
|
# Lumbar dominant
|
|
rigo_type = "C2"
|
|
description = f"Lumbar major ({tl:.1f}°) with thoracic compensatory ({mt:.1f}°)"
|
|
curve_pattern = "Lumbar major"
|
|
|
|
else:
|
|
# Double major - Type B (4-curve)
|
|
if tl >= mt:
|
|
rigo_type = "B1"
|
|
description = f"4-curve: Double major, lumbar prominent ({tl:.1f}°/{mt:.1f}°)"
|
|
else:
|
|
rigo_type = "B2"
|
|
description = f"4-curve: Double major, thoracic prominent ({mt:.1f}°/{tl:.1f}°)"
|
|
curve_pattern = "4-curve (double major)"
|
|
else:
|
|
# C-curve pattern (curves in same direction)
|
|
if mt >= tl:
|
|
if tl > 5:
|
|
rigo_type = "A3"
|
|
description = f"Long thoracic curve ({mt:.1f}°) extending to lumbar ({tl:.1f}°)"
|
|
else:
|
|
rigo_type = "E2"
|
|
description = f"Thoracic curve ({mt:.1f}°) with minor lumbar ({tl:.1f}°)"
|
|
curve_pattern = "Extended thoracic"
|
|
else:
|
|
rigo_type = "C2"
|
|
description = f"TL/Lumbar curve ({tl:.1f}°) with thoracic involvement ({mt:.1f}°)"
|
|
curve_pattern = "Extended lumbar"
|
|
|
|
# Store in spine object
|
|
spine.rigo_type = rigo_type
|
|
spine.rigo_description = description
|
|
|
|
return {
|
|
'rigo_type': rigo_type,
|
|
'description': description,
|
|
'curve_pattern': curve_pattern,
|
|
'apex_info': apex_info,
|
|
'cobb_angles': cobb_angles,
|
|
'n_significant_curves': n_curves
|
|
}
|
|
|
|
|
|
def _calculate_lateral_deviations(centroids: np.ndarray) -> np.ndarray:
|
|
"""Calculate lateral deviation from midline for each vertebra."""
|
|
if len(centroids) < 2:
|
|
return np.zeros(len(centroids))
|
|
|
|
# Midline from first to last vertebra
|
|
start = centroids[0]
|
|
end = centroids[-1]
|
|
midline_vec = end - start
|
|
midline_len = np.linalg.norm(midline_vec)
|
|
|
|
if midline_len < 1e-6:
|
|
return np.zeros(len(centroids))
|
|
|
|
midline_unit = midline_vec / midline_len
|
|
|
|
deviations = []
|
|
for pt in centroids:
|
|
v = pt - start
|
|
# Project onto midline
|
|
proj_len = np.dot(v, midline_unit)
|
|
proj = proj_len * midline_unit
|
|
# Perpendicular vector
|
|
perp = v - proj
|
|
dist = np.linalg.norm(perp)
|
|
# Sign: positive = right, negative = left
|
|
sign = np.sign(np.cross(midline_unit, perp / (dist + 1e-8)))
|
|
deviations.append(dist * sign)
|
|
|
|
return np.array(deviations)
|
|
|
|
|
|
def _find_apex_info(centroids: np.ndarray, deviations: np.ndarray, n_verts: int) -> dict:
|
|
"""Find apex positions and determine curve pattern."""
|
|
info = {
|
|
'thoracic_apex_idx': None,
|
|
'lumbar_apex_idx': None,
|
|
'lumbar_apex_low': False,
|
|
'apex_at_tl_junction': False,
|
|
'is_s_pattern': False,
|
|
'apex_directions': []
|
|
}
|
|
|
|
if len(deviations) < 3:
|
|
return info
|
|
|
|
# Find local extrema (apexes)
|
|
apexes = []
|
|
apex_values = []
|
|
for i in range(1, len(deviations) - 1):
|
|
if (deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]) or \
|
|
(deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]):
|
|
if abs(deviations[i]) > 3: # Minimum threshold
|
|
apexes.append(i)
|
|
apex_values.append(deviations[i])
|
|
|
|
# Determine S-pattern (alternating signs at apexes)
|
|
if len(apex_values) >= 2:
|
|
signs = [np.sign(v) for v in apex_values]
|
|
# S-pattern if adjacent apexes have opposite signs
|
|
for i in range(len(signs) - 1):
|
|
if signs[i] != signs[i+1]:
|
|
info['is_s_pattern'] = True
|
|
break
|
|
|
|
# Classify apex regions
|
|
# Assume: top 40% = thoracic, middle 20% = TL junction, bottom 40% = lumbar
|
|
thoracic_end = int(0.4 * n_verts)
|
|
tl_junction_end = int(0.6 * n_verts)
|
|
|
|
for apex_idx in apexes:
|
|
if apex_idx < thoracic_end:
|
|
if info['thoracic_apex_idx'] is None or \
|
|
abs(deviations[apex_idx]) > abs(deviations[info['thoracic_apex_idx']]):
|
|
info['thoracic_apex_idx'] = apex_idx
|
|
elif apex_idx < tl_junction_end:
|
|
info['apex_at_tl_junction'] = True
|
|
else:
|
|
if info['lumbar_apex_idx'] is None or \
|
|
abs(deviations[apex_idx]) > abs(deviations[info['lumbar_apex_idx']]):
|
|
info['lumbar_apex_idx'] = apex_idx
|
|
|
|
# Check if lumbar apex is in lower region (bottom 30%)
|
|
if info['lumbar_apex_idx'] is not None:
|
|
if info['lumbar_apex_idx'] > int(0.7 * n_verts):
|
|
info['lumbar_apex_low'] = True
|
|
|
|
info['apex_directions'] = apex_values
|
|
|
|
return info
|