""" Spine analysis functions for computing curves, Cobb angles, and identifying apex vertebrae. """ import numpy as np from scipy.interpolate import splprep, splev from typing import Tuple, List, Optional from data_models import Spine2D, VertebraLandmark def compute_spine_curve( spine: Spine2D, smooth: float = 1.0, n_samples: int = 200 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Compute smooth spine centerline from vertebra centroids. Args: spine: Spine2D object with detected vertebrae smooth: Smoothing factor for spline (higher = smoother) n_samples: Number of points to sample along the curve Returns: C: Curve points, shape (n_samples, 2) T: Tangent vectors, shape (n_samples, 2) N: Normal vectors, shape (n_samples, 2) curvature: Curvature at each point, shape (n_samples,) """ pts = spine.get_centroids() if len(pts) < 4: raise ValueError(f"Need at least 4 vertebrae for spline, got {len(pts)}") # Fit parametric spline through centroids x = pts[:, 0] y = pts[:, 1] try: tck, u = splprep([x, y], s=smooth, k=min(3, len(pts)-1)) except Exception as e: # Fallback: simple linear interpolation t = np.linspace(0, 1, n_samples) xs = np.interp(t, np.linspace(0, 1, len(x)), x) ys = np.interp(t, np.linspace(0, 1, len(y)), y) C = np.stack([xs, ys], axis=1).astype(np.float32) T = np.gradient(C, axis=0) T = T / (np.linalg.norm(T, axis=1, keepdims=True) + 1e-8) N = np.stack([-T[:, 1], T[:, 0]], axis=1) curvature = np.zeros(n_samples, dtype=np.float32) return C, T, N, curvature # Sample the spline u_new = np.linspace(0, 1, n_samples) xs, ys = splev(u_new, tck) # First and second derivatives dx, dy = splev(u_new, tck, der=1) ddx, ddy = splev(u_new, tck, der=2) # Curve points C = np.stack([xs, ys], axis=1).astype(np.float32) # Tangent vectors (normalized) T = np.stack([dx, dy], axis=1) T_norm = np.linalg.norm(T, axis=1, keepdims=True) + 1e-8 T = (T / T_norm).astype(np.float32) # Normal vectors (perpendicular to tangent) N = np.stack([-T[:, 1], T[:, 0]], axis=1).astype(np.float32) # Curvature: |x'y'' - y'x''| / (x'^2 + y'^2)^(3/2) curvature = np.abs(dx * ddy - dy * ddx) / (np.power(dx**2 + dy**2, 1.5) + 1e-8) curvature = curvature.astype(np.float32) return C, T, N, curvature def compute_cobb_angles(spine: Spine2D) -> Tuple[float, float, float]: """ Compute Cobb angles from vertebra orientations. The Cobb angle is measured as the angle between: - The superior endplate of the most tilted vertebra at the top of a curve - The inferior endplate of the most tilted vertebra at the bottom This function estimates 3 Cobb angles: - PT: Proximal Thoracic (T1-T6 region) - MT: Main Thoracic (T6-T12 region) - TL: Thoracolumbar/Lumbar (T12-L5 region) Args: spine: Spine2D object with detected vertebrae Returns: (pt_angle, mt_angle, tl_angle) in degrees """ orientations = spine.get_orientations() n = len(orientations) if n < 7: spine.cobb_pt = 0.0 spine.cobb_mt = 0.0 spine.cobb_tl = 0.0 return 0.0, 0.0, 0.0 # Divide into regions (approximately) # PT: top 1/3, MT: middle 1/3, TL: bottom 1/3 pt_end = n // 3 mt_end = 2 * n // 3 # Find max tilt difference in each region def region_cobb(start_idx: int, end_idx: int) -> float: if end_idx <= start_idx: return 0.0 region_angles = orientations[start_idx:end_idx] if len(region_angles) < 2: return 0.0 # Cobb angle = max angle - min angle in region return abs(float(np.max(region_angles) - np.min(region_angles))) pt_angle = region_cobb(0, pt_end) mt_angle = region_cobb(pt_end, mt_end) tl_angle = region_cobb(mt_end, n) # Store in spine object spine.cobb_pt = pt_angle spine.cobb_mt = mt_angle spine.cobb_tl = tl_angle # Determine curve type if mt_angle > 10 and tl_angle > 10: spine.curve_type = "S" # Double curve elif mt_angle > 10 or tl_angle > 10: spine.curve_type = "C" # Single curve else: spine.curve_type = "Normal" return pt_angle, mt_angle, tl_angle def find_apex_vertebrae(spine: Spine2D) -> List[int]: """ Find indices of apex vertebrae (most deviated from midline). Args: spine: Spine2D with computed curve Returns: List of vertebra indices that are curve apexes """ centroids = spine.get_centroids() if len(centroids) < 5: return [] # Find midline (linear fit through endpoints) start = centroids[0] end = centroids[-1] # Distance from midline for each vertebra midline_vec = end - start midline_len = np.linalg.norm(midline_vec) if midline_len < 1e-6: return [] midline_unit = midline_vec / midline_len # Calculate perpendicular distance to midline deviations = [] for i, pt in enumerate(centroids): v = pt - start # Project onto midline proj_len = np.dot(v, midline_unit) proj = proj_len * midline_unit # Perpendicular distance perp = v - proj dist = np.linalg.norm(perp) # Sign: positive if to the right of midline sign = np.sign(np.cross(midline_unit, v / (np.linalg.norm(v) + 1e-8))) deviations.append(dist * sign) deviations = np.array(deviations) # Find local extrema (peaks and valleys) apexes = [] for i in range(1, len(deviations) - 1): # Local maximum if deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]: if abs(deviations[i]) > 5: # Minimum deviation threshold (pixels) apexes.append(i) # Local minimum elif deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]: if abs(deviations[i]) > 5: apexes.append(i) return apexes def get_curve_severity(cobb_angle: float) -> str: """ Get clinical severity classification from Cobb angle. Args: cobb_angle: Cobb angle in degrees Returns: Severity string: "Normal", "Mild", "Moderate", or "Severe" """ if cobb_angle < 10: return "Normal" elif cobb_angle < 25: return "Mild" elif cobb_angle < 40: return "Moderate" else: return "Severe" def classify_rigo_type(spine: Spine2D) -> dict: """ Classify scoliosis according to Rigo-Chêneau brace classification. Rigo Classification Types: - A1, A2, A3: 3-curve patterns (thoracic major) - B1, B2: 4-curve patterns (double major) - C1, C2: Single thoracolumbar/lumbar - E1, E2: Single thoracic Args: spine: Spine2D object with detected vertebrae and Cobb angles Returns: dict with 'rigo_type', 'description', 'apex_region', 'curve_pattern' """ # Get Cobb angles cobb_angles = spine.get_cobb_angles() pt = cobb_angles.get('PT', 0) mt = cobb_angles.get('MT', 0) tl = cobb_angles.get('TL', 0) n_verts = len(spine.vertebrae) # Calculate lateral deviations to determine curve direction centroids = spine.get_centroids() deviations = _calculate_lateral_deviations(centroids) # Find apex positions and directions apex_info = _find_apex_info(centroids, deviations, n_verts) # Determine curve pattern based on number of significant curves significant_curves = [] if pt >= 10: significant_curves.append(('PT', pt)) if mt >= 10: significant_curves.append(('MT', mt)) if tl >= 10: significant_curves.append(('TL', tl)) n_curves = len(significant_curves) # Classification logic rigo_type = "N/A" description = "" curve_pattern = "" # No significant scoliosis if n_curves == 0 or max(pt, mt, tl) < 10: rigo_type = "Normal" description = "No significant scoliosis (all Cobb angles < 10°)" curve_pattern = "None" # Single curve patterns elif n_curves == 1: max_curve = significant_curves[0][0] max_angle = significant_curves[0][1] if max_curve == 'MT' or max_curve == 'PT': # Thoracic single curve if apex_info['thoracic_apex_idx'] is not None: # Check if there's a compensatory lumbar if tl > 5: rigo_type = "E2" description = f"Single thoracic curve ({max_angle:.1f}°) with lumbar compensatory ({tl:.1f}°)" curve_pattern = "Thoracic with compensation" else: rigo_type = "E1" description = f"True single thoracic curve ({max_angle:.1f}°)" curve_pattern = "Single thoracic" else: rigo_type = "E1" description = f"Single thoracic curve ({max_angle:.1f}°)" curve_pattern = "Single thoracic" elif max_curve == 'TL': # Thoracolumbar/Lumbar single curve if mt > 5 or pt > 5: rigo_type = "C2" description = f"Thoracolumbar curve ({tl:.1f}°) with upper compensatory" curve_pattern = "TL/L with compensation" else: rigo_type = "C1" description = f"Single thoracolumbar/lumbar curve ({tl:.1f}°)" curve_pattern = "Single TL/L" # Double curve patterns elif n_curves >= 2: # Determine which curves are primary thoracic_total = pt + mt lumbar_total = tl # Check curve directions for S vs C pattern is_s_curve = apex_info['is_s_pattern'] if is_s_curve: # S-curve: typically 3 or 4 curve patterns if thoracic_total > lumbar_total * 1.5: # Thoracic dominant - Type A (3-curve) if apex_info['lumbar_apex_low']: rigo_type = "A1" description = f"3-curve: Thoracic major ({mt:.1f}°), lumbar apex low" elif apex_info['apex_at_tl_junction']: rigo_type = "A2" description = f"3-curve: Thoracolumbar transition ({mt:.1f}°/{tl:.1f}°)" else: rigo_type = "A3" description = f"3-curve: Thoracic major ({mt:.1f}°) with structural lumbar ({tl:.1f}°)" curve_pattern = "3-curve (thoracic major)" elif lumbar_total > thoracic_total * 1.5: # Lumbar dominant rigo_type = "C2" description = f"Lumbar major ({tl:.1f}°) with thoracic compensatory ({mt:.1f}°)" curve_pattern = "Lumbar major" else: # Double major - Type B (4-curve) if tl >= mt: rigo_type = "B1" description = f"4-curve: Double major, lumbar prominent ({tl:.1f}°/{mt:.1f}°)" else: rigo_type = "B2" description = f"4-curve: Double major, thoracic prominent ({mt:.1f}°/{tl:.1f}°)" curve_pattern = "4-curve (double major)" else: # C-curve pattern (curves in same direction) if mt >= tl: if tl > 5: rigo_type = "A3" description = f"Long thoracic curve ({mt:.1f}°) extending to lumbar ({tl:.1f}°)" else: rigo_type = "E2" description = f"Thoracic curve ({mt:.1f}°) with minor lumbar ({tl:.1f}°)" curve_pattern = "Extended thoracic" else: rigo_type = "C2" description = f"TL/Lumbar curve ({tl:.1f}°) with thoracic involvement ({mt:.1f}°)" curve_pattern = "Extended lumbar" # Store in spine object spine.rigo_type = rigo_type spine.rigo_description = description return { 'rigo_type': rigo_type, 'description': description, 'curve_pattern': curve_pattern, 'apex_info': apex_info, 'cobb_angles': cobb_angles, 'n_significant_curves': n_curves } def _calculate_lateral_deviations(centroids: np.ndarray) -> np.ndarray: """Calculate lateral deviation from midline for each vertebra.""" if len(centroids) < 2: return np.zeros(len(centroids)) # Midline from first to last vertebra start = centroids[0] end = centroids[-1] midline_vec = end - start midline_len = np.linalg.norm(midline_vec) if midline_len < 1e-6: return np.zeros(len(centroids)) midline_unit = midline_vec / midline_len deviations = [] for pt in centroids: v = pt - start # Project onto midline proj_len = np.dot(v, midline_unit) proj = proj_len * midline_unit # Perpendicular vector perp = v - proj dist = np.linalg.norm(perp) # Sign: positive = right, negative = left sign = np.sign(np.cross(midline_unit, perp / (dist + 1e-8))) deviations.append(dist * sign) return np.array(deviations) def _find_apex_info(centroids: np.ndarray, deviations: np.ndarray, n_verts: int) -> dict: """Find apex positions and determine curve pattern.""" info = { 'thoracic_apex_idx': None, 'lumbar_apex_idx': None, 'lumbar_apex_low': False, 'apex_at_tl_junction': False, 'is_s_pattern': False, 'apex_directions': [] } if len(deviations) < 3: return info # Find local extrema (apexes) apexes = [] apex_values = [] for i in range(1, len(deviations) - 1): if (deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]) or \ (deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]): if abs(deviations[i]) > 3: # Minimum threshold apexes.append(i) apex_values.append(deviations[i]) # Determine S-pattern (alternating signs at apexes) if len(apex_values) >= 2: signs = [np.sign(v) for v in apex_values] # S-pattern if adjacent apexes have opposite signs for i in range(len(signs) - 1): if signs[i] != signs[i+1]: info['is_s_pattern'] = True break # Classify apex regions # Assume: top 40% = thoracic, middle 20% = TL junction, bottom 40% = lumbar thoracic_end = int(0.4 * n_verts) tl_junction_end = int(0.6 * n_verts) for apex_idx in apexes: if apex_idx < thoracic_end: if info['thoracic_apex_idx'] is None or \ abs(deviations[apex_idx]) > abs(deviations[info['thoracic_apex_idx']]): info['thoracic_apex_idx'] = apex_idx elif apex_idx < tl_junction_end: info['apex_at_tl_junction'] = True else: if info['lumbar_apex_idx'] is None or \ abs(deviations[apex_idx]) > abs(deviations[info['lumbar_apex_idx']]): info['lumbar_apex_idx'] = apex_idx # Check if lumbar apex is in lower region (bottom 30%) if info['lumbar_apex_idx'] is not None: if info['lumbar_apex_idx'] > int(0.7 * n_verts): info['lumbar_apex_low'] = True info['apex_directions'] = apex_values return info