Files
braceiqmed/brace-generator/spine_analysis.py

465 lines
16 KiB
Python

"""
Spine analysis functions for computing curves, Cobb angles, and identifying apex vertebrae.
"""
import numpy as np
from scipy.interpolate import splprep, splev
from typing import Tuple, List, Optional
from data_models import Spine2D, VertebraLandmark
def compute_spine_curve(
spine: Spine2D,
smooth: float = 1.0,
n_samples: int = 200
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Compute smooth spine centerline from vertebra centroids.
Args:
spine: Spine2D object with detected vertebrae
smooth: Smoothing factor for spline (higher = smoother)
n_samples: Number of points to sample along the curve
Returns:
C: Curve points, shape (n_samples, 2)
T: Tangent vectors, shape (n_samples, 2)
N: Normal vectors, shape (n_samples, 2)
curvature: Curvature at each point, shape (n_samples,)
"""
pts = spine.get_centroids()
if len(pts) < 4:
raise ValueError(f"Need at least 4 vertebrae for spline, got {len(pts)}")
# Fit parametric spline through centroids
x = pts[:, 0]
y = pts[:, 1]
try:
tck, u = splprep([x, y], s=smooth, k=min(3, len(pts)-1))
except Exception as e:
# Fallback: simple linear interpolation
t = np.linspace(0, 1, n_samples)
xs = np.interp(t, np.linspace(0, 1, len(x)), x)
ys = np.interp(t, np.linspace(0, 1, len(y)), y)
C = np.stack([xs, ys], axis=1).astype(np.float32)
T = np.gradient(C, axis=0)
T = T / (np.linalg.norm(T, axis=1, keepdims=True) + 1e-8)
N = np.stack([-T[:, 1], T[:, 0]], axis=1)
curvature = np.zeros(n_samples, dtype=np.float32)
return C, T, N, curvature
# Sample the spline
u_new = np.linspace(0, 1, n_samples)
xs, ys = splev(u_new, tck)
# First and second derivatives
dx, dy = splev(u_new, tck, der=1)
ddx, ddy = splev(u_new, tck, der=2)
# Curve points
C = np.stack([xs, ys], axis=1).astype(np.float32)
# Tangent vectors (normalized)
T = np.stack([dx, dy], axis=1)
T_norm = np.linalg.norm(T, axis=1, keepdims=True) + 1e-8
T = (T / T_norm).astype(np.float32)
# Normal vectors (perpendicular to tangent)
N = np.stack([-T[:, 1], T[:, 0]], axis=1).astype(np.float32)
# Curvature: |x'y'' - y'x''| / (x'^2 + y'^2)^(3/2)
curvature = np.abs(dx * ddy - dy * ddx) / (np.power(dx**2 + dy**2, 1.5) + 1e-8)
curvature = curvature.astype(np.float32)
return C, T, N, curvature
def compute_cobb_angles(spine: Spine2D) -> Tuple[float, float, float]:
"""
Compute Cobb angles from vertebra orientations.
The Cobb angle is measured as the angle between:
- The superior endplate of the most tilted vertebra at the top of a curve
- The inferior endplate of the most tilted vertebra at the bottom
This function estimates 3 Cobb angles:
- PT: Proximal Thoracic (T1-T6 region)
- MT: Main Thoracic (T6-T12 region)
- TL: Thoracolumbar/Lumbar (T12-L5 region)
Args:
spine: Spine2D object with detected vertebrae
Returns:
(pt_angle, mt_angle, tl_angle) in degrees
"""
orientations = spine.get_orientations()
n = len(orientations)
if n < 7:
spine.cobb_pt = 0.0
spine.cobb_mt = 0.0
spine.cobb_tl = 0.0
return 0.0, 0.0, 0.0
# Divide into regions (approximately)
# PT: top 1/3, MT: middle 1/3, TL: bottom 1/3
pt_end = n // 3
mt_end = 2 * n // 3
# Find max tilt difference in each region
def region_cobb(start_idx: int, end_idx: int) -> float:
if end_idx <= start_idx:
return 0.0
region_angles = orientations[start_idx:end_idx]
if len(region_angles) < 2:
return 0.0
# Cobb angle = max angle - min angle in region
return abs(float(np.max(region_angles) - np.min(region_angles)))
pt_angle = region_cobb(0, pt_end)
mt_angle = region_cobb(pt_end, mt_end)
tl_angle = region_cobb(mt_end, n)
# Store in spine object
spine.cobb_pt = pt_angle
spine.cobb_mt = mt_angle
spine.cobb_tl = tl_angle
# Determine curve type
if mt_angle > 10 and tl_angle > 10:
spine.curve_type = "S" # Double curve
elif mt_angle > 10 or tl_angle > 10:
spine.curve_type = "C" # Single curve
else:
spine.curve_type = "Normal"
return pt_angle, mt_angle, tl_angle
def find_apex_vertebrae(spine: Spine2D) -> List[int]:
"""
Find indices of apex vertebrae (most deviated from midline).
Args:
spine: Spine2D with computed curve
Returns:
List of vertebra indices that are curve apexes
"""
centroids = spine.get_centroids()
if len(centroids) < 5:
return []
# Find midline (linear fit through endpoints)
start = centroids[0]
end = centroids[-1]
# Distance from midline for each vertebra
midline_vec = end - start
midline_len = np.linalg.norm(midline_vec)
if midline_len < 1e-6:
return []
midline_unit = midline_vec / midline_len
# Calculate perpendicular distance to midline
deviations = []
for i, pt in enumerate(centroids):
v = pt - start
# Project onto midline
proj_len = np.dot(v, midline_unit)
proj = proj_len * midline_unit
# Perpendicular distance
perp = v - proj
dist = np.linalg.norm(perp)
# Sign: positive if to the right of midline
sign = np.sign(np.cross(midline_unit, v / (np.linalg.norm(v) + 1e-8)))
deviations.append(dist * sign)
deviations = np.array(deviations)
# Find local extrema (peaks and valleys)
apexes = []
for i in range(1, len(deviations) - 1):
# Local maximum
if deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]:
if abs(deviations[i]) > 5: # Minimum deviation threshold (pixels)
apexes.append(i)
# Local minimum
elif deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]:
if abs(deviations[i]) > 5:
apexes.append(i)
return apexes
def get_curve_severity(cobb_angle: float) -> str:
"""
Get clinical severity classification from Cobb angle.
Args:
cobb_angle: Cobb angle in degrees
Returns:
Severity string: "Normal", "Mild", "Moderate", or "Severe"
"""
if cobb_angle < 10:
return "Normal"
elif cobb_angle < 25:
return "Mild"
elif cobb_angle < 40:
return "Moderate"
else:
return "Severe"
def classify_rigo_type(spine: Spine2D) -> dict:
"""
Classify scoliosis according to Rigo-Chêneau brace classification.
Rigo Classification Types:
- A1, A2, A3: 3-curve patterns (thoracic major)
- B1, B2: 4-curve patterns (double major)
- C1, C2: Single thoracolumbar/lumbar
- E1, E2: Single thoracic
Args:
spine: Spine2D object with detected vertebrae and Cobb angles
Returns:
dict with 'rigo_type', 'description', 'apex_region', 'curve_pattern'
"""
# Get Cobb angles
cobb_angles = spine.get_cobb_angles()
pt = cobb_angles.get('PT', 0)
mt = cobb_angles.get('MT', 0)
tl = cobb_angles.get('TL', 0)
n_verts = len(spine.vertebrae)
# Calculate lateral deviations to determine curve direction
centroids = spine.get_centroids()
deviations = _calculate_lateral_deviations(centroids)
# Find apex positions and directions
apex_info = _find_apex_info(centroids, deviations, n_verts)
# Determine curve pattern based on number of significant curves
significant_curves = []
if pt >= 10:
significant_curves.append(('PT', pt))
if mt >= 10:
significant_curves.append(('MT', mt))
if tl >= 10:
significant_curves.append(('TL', tl))
n_curves = len(significant_curves)
# Classification logic
rigo_type = "N/A"
description = ""
curve_pattern = ""
# No significant scoliosis
if n_curves == 0 or max(pt, mt, tl) < 10:
rigo_type = "Normal"
description = "No significant scoliosis (all Cobb angles < 10°)"
curve_pattern = "None"
# Single curve patterns
elif n_curves == 1:
max_curve = significant_curves[0][0]
max_angle = significant_curves[0][1]
if max_curve == 'MT' or max_curve == 'PT':
# Thoracic single curve
if apex_info['thoracic_apex_idx'] is not None:
# Check if there's a compensatory lumbar
if tl > 5:
rigo_type = "E2"
description = f"Single thoracic curve ({max_angle:.1f}°) with lumbar compensatory ({tl:.1f}°)"
curve_pattern = "Thoracic with compensation"
else:
rigo_type = "E1"
description = f"True single thoracic curve ({max_angle:.1f}°)"
curve_pattern = "Single thoracic"
else:
rigo_type = "E1"
description = f"Single thoracic curve ({max_angle:.1f}°)"
curve_pattern = "Single thoracic"
elif max_curve == 'TL':
# Thoracolumbar/Lumbar single curve
if mt > 5 or pt > 5:
rigo_type = "C2"
description = f"Thoracolumbar curve ({tl:.1f}°) with upper compensatory"
curve_pattern = "TL/L with compensation"
else:
rigo_type = "C1"
description = f"Single thoracolumbar/lumbar curve ({tl:.1f}°)"
curve_pattern = "Single TL/L"
# Double curve patterns
elif n_curves >= 2:
# Determine which curves are primary
thoracic_total = pt + mt
lumbar_total = tl
# Check curve directions for S vs C pattern
is_s_curve = apex_info['is_s_pattern']
if is_s_curve:
# S-curve: typically 3 or 4 curve patterns
if thoracic_total > lumbar_total * 1.5:
# Thoracic dominant - Type A (3-curve)
if apex_info['lumbar_apex_low']:
rigo_type = "A1"
description = f"3-curve: Thoracic major ({mt:.1f}°), lumbar apex low"
elif apex_info['apex_at_tl_junction']:
rigo_type = "A2"
description = f"3-curve: Thoracolumbar transition ({mt:.1f}°/{tl:.1f}°)"
else:
rigo_type = "A3"
description = f"3-curve: Thoracic major ({mt:.1f}°) with structural lumbar ({tl:.1f}°)"
curve_pattern = "3-curve (thoracic major)"
elif lumbar_total > thoracic_total * 1.5:
# Lumbar dominant
rigo_type = "C2"
description = f"Lumbar major ({tl:.1f}°) with thoracic compensatory ({mt:.1f}°)"
curve_pattern = "Lumbar major"
else:
# Double major - Type B (4-curve)
if tl >= mt:
rigo_type = "B1"
description = f"4-curve: Double major, lumbar prominent ({tl:.1f}°/{mt:.1f}°)"
else:
rigo_type = "B2"
description = f"4-curve: Double major, thoracic prominent ({mt:.1f}°/{tl:.1f}°)"
curve_pattern = "4-curve (double major)"
else:
# C-curve pattern (curves in same direction)
if mt >= tl:
if tl > 5:
rigo_type = "A3"
description = f"Long thoracic curve ({mt:.1f}°) extending to lumbar ({tl:.1f}°)"
else:
rigo_type = "E2"
description = f"Thoracic curve ({mt:.1f}°) with minor lumbar ({tl:.1f}°)"
curve_pattern = "Extended thoracic"
else:
rigo_type = "C2"
description = f"TL/Lumbar curve ({tl:.1f}°) with thoracic involvement ({mt:.1f}°)"
curve_pattern = "Extended lumbar"
# Store in spine object
spine.rigo_type = rigo_type
spine.rigo_description = description
return {
'rigo_type': rigo_type,
'description': description,
'curve_pattern': curve_pattern,
'apex_info': apex_info,
'cobb_angles': cobb_angles,
'n_significant_curves': n_curves
}
def _calculate_lateral_deviations(centroids: np.ndarray) -> np.ndarray:
"""Calculate lateral deviation from midline for each vertebra."""
if len(centroids) < 2:
return np.zeros(len(centroids))
# Midline from first to last vertebra
start = centroids[0]
end = centroids[-1]
midline_vec = end - start
midline_len = np.linalg.norm(midline_vec)
if midline_len < 1e-6:
return np.zeros(len(centroids))
midline_unit = midline_vec / midline_len
deviations = []
for pt in centroids:
v = pt - start
# Project onto midline
proj_len = np.dot(v, midline_unit)
proj = proj_len * midline_unit
# Perpendicular vector
perp = v - proj
dist = np.linalg.norm(perp)
# Sign: positive = right, negative = left
sign = np.sign(np.cross(midline_unit, perp / (dist + 1e-8)))
deviations.append(dist * sign)
return np.array(deviations)
def _find_apex_info(centroids: np.ndarray, deviations: np.ndarray, n_verts: int) -> dict:
"""Find apex positions and determine curve pattern."""
info = {
'thoracic_apex_idx': None,
'lumbar_apex_idx': None,
'lumbar_apex_low': False,
'apex_at_tl_junction': False,
'is_s_pattern': False,
'apex_directions': []
}
if len(deviations) < 3:
return info
# Find local extrema (apexes)
apexes = []
apex_values = []
for i in range(1, len(deviations) - 1):
if (deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]) or \
(deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]):
if abs(deviations[i]) > 3: # Minimum threshold
apexes.append(i)
apex_values.append(deviations[i])
# Determine S-pattern (alternating signs at apexes)
if len(apex_values) >= 2:
signs = [np.sign(v) for v in apex_values]
# S-pattern if adjacent apexes have opposite signs
for i in range(len(signs) - 1):
if signs[i] != signs[i+1]:
info['is_s_pattern'] = True
break
# Classify apex regions
# Assume: top 40% = thoracic, middle 20% = TL junction, bottom 40% = lumbar
thoracic_end = int(0.4 * n_verts)
tl_junction_end = int(0.6 * n_verts)
for apex_idx in apexes:
if apex_idx < thoracic_end:
if info['thoracic_apex_idx'] is None or \
abs(deviations[apex_idx]) > abs(deviations[info['thoracic_apex_idx']]):
info['thoracic_apex_idx'] = apex_idx
elif apex_idx < tl_junction_end:
info['apex_at_tl_junction'] = True
else:
if info['lumbar_apex_idx'] is None or \
abs(deviations[apex_idx]) > abs(deviations[info['lumbar_apex_idx']]):
info['lumbar_apex_idx'] = apex_idx
# Check if lumbar apex is in lower region (bottom 30%)
if info['lumbar_apex_idx'] is not None:
if info['lumbar_apex_idx'] > int(0.7 * n_verts):
info['lumbar_apex_low'] = True
info['apex_directions'] = apex_values
return info