Add patient management, deployment scripts, and Docker fixes

This commit is contained in:
2026-01-30 01:51:33 -08:00
parent 745f9f827f
commit d28d2f20c6
33 changed files with 7496 additions and 284 deletions

View File

@@ -27,7 +27,7 @@ WORKDIR /app
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
# Copy and install requirements (from brace-generator folder)
# Copy and install requirements
COPY brace-generator/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
@@ -35,8 +35,16 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY scoliovis-api/requirements.txt /app/requirements-scoliovis.txt
RUN pip install --no-cache-dir -r requirements-scoliovis.txt || true
# Copy brace-generator code
COPY brace-generator/ /app/brace_generator/server_DEV/
# Create brace_generator package structure
RUN mkdir -p /app/brace_generator
# Copy brace-generator code as a package
COPY brace-generator/*.py /app/brace_generator/
COPY brace-generator/__init__.py /app/brace_generator/__init__.py
# Also keep server_DEV structure for compatibility
RUN mkdir -p /app/brace_generator/server_DEV
COPY brace-generator/*.py /app/brace_generator/server_DEV/
# Copy scoliovis-api
COPY scoliovis-api/ /app/scoliovis-api/
@@ -44,8 +52,8 @@ COPY scoliovis-api/ /app/scoliovis-api/
# Copy templates
COPY templates/ /app/templates/
# Set Python path
ENV PYTHONPATH=/app:/app/brace_generator/server_DEV:/app/scoliovis-api
# Set Python path - include both locations
ENV PYTHONPATH=/app:/app/scoliovis-api
# Environment variables
ENV HOST=0.0.0.0
@@ -61,8 +69,8 @@ RUN mkdir -p /tmp/brace_generator /app/data/uploads /app/data/outputs
EXPOSE 8002
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
CMD curl -f http://localhost:8002/health || exit 1
# Run the server
# Run the server from the brace_generator package
CMD ["python", "-m", "uvicorn", "brace_generator.server_DEV.app:app", "--host", "0.0.0.0", "--port", "8002"]

508
brace-generator/adapters.py Normal file
View File

@@ -0,0 +1,508 @@
"""
Model adapters that convert different model outputs to unified Spine2D format.
Each adapter wraps a specific model and produces consistent output.
"""
import sys
import numpy as np
from pathlib import Path
from typing import Optional, Dict, Any
from abc import ABC, abstractmethod
from data_models import VertebraLandmark, Spine2D
class BaseLandmarkAdapter(ABC):
"""Base class for landmark detection model adapters."""
@abstractmethod
def predict(self, image: np.ndarray) -> Spine2D:
"""
Run inference on an image and return unified spine landmarks.
Args:
image: Input image as numpy array (grayscale or RGB)
Returns:
Spine2D object with detected landmarks
"""
pass
@property
@abstractmethod
def name(self) -> str:
"""Model name for identification."""
pass
class ScolioVisAdapter(BaseLandmarkAdapter):
"""
Adapter for ScolioVis-API (Keypoint R-CNN model).
Uses the original ScolioVis inference code for best accuracy.
Outputs: 4 keypoints per vertebra + Cobb angles (PT, MT, TL) + curve type (S/C)
"""
def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'):
"""
Initialize ScolioVis model.
Args:
weights_path: Path to keypointsrcnn_weights.pt (auto-detects if None)
device: 'cpu' or 'cuda'
"""
self.device = device
self.model = None
self.weights_path = weights_path
self._scoliovis_path = None
self._load_model()
def _load_model(self):
"""Load the Keypoint R-CNN model."""
import torch
import torchvision
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection.rpn import AnchorGenerator
# Find weights and scoliovis module
scoliovis_api_path = Path(__file__).parent.parent / 'scoliovis-api'
if self.weights_path is None:
possible_paths = [
scoliovis_api_path / 'models' / 'keypointsrcnn_weights.pt',
scoliovis_api_path / 'keypointsrcnn_weights.pt',
scoliovis_api_path / 'weights' / 'keypointsrcnn_weights.pt',
]
for p in possible_paths:
if p.exists():
self.weights_path = str(p)
break
if self.weights_path is None or not Path(self.weights_path).exists():
raise FileNotFoundError(
"ScolioVis weights not found. Please provide weights_path or ensure "
"scoliovis-api/models/keypointsrcnn_weights.pt exists."
)
# Store path to scoliovis module for Cobb angle calculation
self._scoliovis_path = scoliovis_api_path
# Create model with same anchor generator as original training
anchor_generator = AnchorGenerator(
sizes=(32, 64, 128, 256, 512),
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
)
self.model = keypointrcnn_resnet50_fpn(
weights=None,
weights_backbone=None,
num_classes=2, # background + vertebra
num_keypoints=4, # 4 corners per vertebra
rpn_anchor_generator=anchor_generator
)
# Load weights
checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint)
self.model.to(self.device)
self.model.eval()
print(f"ScolioVis model loaded from {self.weights_path}")
@property
def name(self) -> str:
return "ScolioVis-API"
def _filter_output(self, output, max_verts: int = 17):
"""
Filter model output using NMS and score threshold.
Matches the original ScolioVis filtering logic.
"""
import torch
import torchvision
scores = output['scores'].detach().cpu().numpy()
# Get indices of scores over threshold (0.5)
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
if len(high_scores_idxs) == 0:
return [], [], []
# Apply NMS with IoU threshold 0.3
post_nms_idxs = torchvision.ops.nms(
output['boxes'][high_scores_idxs],
output['scores'][high_scores_idxs],
0.3
).cpu().numpy()
# Get filtered results
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
np_scores = output['scores'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
# Take top N by score (usually 17 for full spine)
sorted_scores_idxs = np.argsort(-1 * np_scores)
np_scores = np_scores[sorted_scores_idxs][:max_verts]
np_keypoints = np.array([np_keypoints[idx] for idx in sorted_scores_idxs])[:max_verts]
np_bboxes = np.array([np_bboxes[idx] for idx in sorted_scores_idxs])[:max_verts]
# Sort by ymin (top to bottom)
if len(np_keypoints) > 0:
ymins = np.array([kps[0][1] for kps in np_keypoints])
sorted_ymin_idxs = np.argsort(ymins)
np_scores = np.array([np_scores[idx] for idx in sorted_ymin_idxs])
np_keypoints = np.array([np_keypoints[idx] for idx in sorted_ymin_idxs])
np_bboxes = np.array([np_bboxes[idx] for idx in sorted_ymin_idxs])
# Convert to lists
keypoints_list = []
for kps in np_keypoints:
keypoints_list.append([list(map(float, kp[:2])) for kp in kps])
bboxes_list = []
for bbox in np_bboxes:
bboxes_list.append(list(map(int, bbox.tolist())))
scores_list = np_scores.tolist()
return bboxes_list, keypoints_list, scores_list
def predict(self, image: np.ndarray) -> Spine2D:
"""Run inference and return unified landmarks with ScolioVis Cobb angles."""
import torch
from torchvision.transforms import functional as F
# Ensure RGB
if len(image.shape) == 2:
image_rgb = np.stack([image, image, image], axis=-1)
else:
image_rgb = image
image_shape = image_rgb.shape # (H, W, C)
# Convert to tensor (ScolioVis uses torchvision's to_tensor)
img_tensor = F.to_tensor(image_rgb).to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model([img_tensor])
# Filter output using original ScolioVis logic
bboxes, keypoints, scores = self._filter_output(outputs[0])
if len(keypoints) == 0:
return Spine2D(
vertebrae=[],
image_shape=image_shape[:2],
source_model=self.name
)
# Convert to unified format
vertebrae = []
for i in range(len(bboxes)):
kps = np.array(keypoints[i], dtype=np.float32) # (4, 2)
# Corners order from ScolioVis: [top_left, top_right, bottom_left, bottom_right]
corners = kps
centroid = np.mean(corners, axis=0)
# Compute orientation from top edge (kps[0] to kps[1])
top_left, top_right = corners[0], corners[1]
dx = top_right[0] - top_left[0]
dy = top_right[1] - top_left[1]
orientation = np.degrees(np.arctan2(dy, dx))
vert = VertebraLandmark(
level=None, # ScolioVis doesn't assign levels
centroid_px=centroid,
corners_px=corners,
endplate_upper_px=corners[:2], # top-left, top-right
endplate_lower_px=corners[2:], # bottom-left, bottom-right
orientation_deg=orientation,
confidence=float(scores[i]),
meta={'box': bboxes[i]}
)
vertebrae.append(vert)
# Create Spine2D
spine = Spine2D(
vertebrae=vertebrae,
image_shape=image_shape[:2],
source_model=self.name
)
# Use original ScolioVis Cobb angle calculation if available
if len(keypoints) >= 5:
try:
# Import original ScolioVis cobb_angle_cal
if str(self._scoliovis_path) not in sys.path:
sys.path.insert(0, str(self._scoliovis_path))
from scoliovis.cobb_angle_cal import cobb_angle_cal, keypoints_to_landmark_xy
landmark_xy = keypoints_to_landmark_xy(keypoints)
cobb_angles_list, angles_with_pos, curve_type, midpoint_lines = cobb_angle_cal(
landmark_xy, image_shape
)
# Store Cobb angles in spine object
spine.cobb_angles = {
'PT': cobb_angles_list[0],
'MT': cobb_angles_list[1],
'TL': cobb_angles_list[2]
}
spine.curve_type = curve_type
spine.meta = {
'angles_with_pos': angles_with_pos,
'midpoint_lines': midpoint_lines
}
except Exception as e:
print(f"Warning: Could not use ScolioVis Cobb calculation: {e}")
# Fallback to our own calculation
from spine_analysis import compute_cobb_angles
compute_cobb_angles(spine)
return spine
class VertLandmarkAdapter(BaseLandmarkAdapter):
"""
Adapter for Vertebra-Landmark-Detection (SpineNet model).
Outputs: 68 landmarks (4 corners × 17 vertebrae)
"""
def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'):
"""
Initialize SpineNet model.
Args:
weights_path: Path to model_last.pth (auto-detects if None)
device: 'cpu' or 'cuda'
"""
self.device = device
self.model = None
self.weights_path = weights_path
self._load_model()
def _load_model(self):
"""Load the SpineNet model."""
import torch
# Find weights
if self.weights_path is None:
possible_paths = [
Path(__file__).parent.parent / 'Vertebra-Landmark-Detection' / 'weights_spinal' / 'model_last.pth',
]
for p in possible_paths:
if p.exists():
self.weights_path = str(p)
break
if self.weights_path is None or not Path(self.weights_path).exists():
raise FileNotFoundError(
"Vertebra-Landmark-Detection weights not found. "
"Download from Google Drive and place in weights_spinal/model_last.pth"
)
# Add repo to path to import model
repo_path = Path(__file__).parent.parent / 'Vertebra-Landmark-Detection'
if str(repo_path) not in sys.path:
sys.path.insert(0, str(repo_path))
from models import spinal_net
# Create model
heads = {'hm': 1, 'reg': 2, 'wh': 8}
self.model = spinal_net.SpineNet(
heads=heads,
pretrained=False,
down_ratio=4,
final_kernel=1,
head_conv=256
)
# Load weights
checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint['state_dict'], strict=False)
self.model.to(self.device)
self.model.eval()
print(f"Vertebra-Landmark-Detection model loaded from {self.weights_path}")
@property
def name(self) -> str:
return "Vertebra-Landmark-Detection"
def _nms(self, heat, kernel=3):
"""Apply NMS using max pooling."""
import torch
import torch.nn.functional as F
hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
keep = (hmax == heat).float()
return heat * keep
def _gather_feat(self, feat, ind):
"""Gather features by index."""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
return feat
def _tranpose_and_gather_feat(self, feat, ind):
"""Transpose and gather features - matches original decoder."""
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = self._gather_feat(feat, ind)
return feat
def _decode_predictions(self, output: Dict, down_ratio: int = 4, K: int = 17):
"""Decode model output using original decoder logic."""
import torch
hm = output['hm'].sigmoid()
reg = output['reg']
wh = output['wh']
batch, cat, height, width = hm.size()
# Apply NMS
hm = self._nms(hm)
# Get top K from heatmap
topk_scores, topk_inds = torch.topk(hm.view(batch, cat, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds // width).float()
topk_xs = (topk_inds % width).float()
# Get overall top K
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_inds = self._gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
scores = topk_score.view(batch, K, 1)
# Get regression offset and apply
reg = self._tranpose_and_gather_feat(reg, topk_inds)
reg = reg.view(batch, K, 2)
xs = topk_xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = topk_ys.view(batch, K, 1) + reg[:, :, 1:2]
# Get corner offsets
wh = self._tranpose_and_gather_feat(wh, topk_inds)
wh = wh.view(batch, K, 8)
# Calculate corners by SUBTRACTING offsets (original decoder logic)
tl_x = xs - wh[:, :, 0:1]
tl_y = ys - wh[:, :, 1:2]
tr_x = xs - wh[:, :, 2:3]
tr_y = ys - wh[:, :, 3:4]
bl_x = xs - wh[:, :, 4:5]
bl_y = ys - wh[:, :, 5:6]
br_x = xs - wh[:, :, 6:7]
br_y = ys - wh[:, :, 7:8]
# Combine into output format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score]
pts = torch.cat([xs, ys, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, scores], dim=2)
# Scale to image coordinates
pts[:, :, :10] *= down_ratio
return pts[0].cpu().numpy() # (K, 11)
def predict(self, image: np.ndarray) -> Spine2D:
"""Run inference and return unified landmarks."""
import torch
import cv2
# Ensure RGB
if len(image.shape) == 2:
image_rgb = np.stack([image, image, image], axis=-1)
else:
image_rgb = image
orig_h, orig_w = image_rgb.shape[:2]
# Resize to model input size (1024x512)
input_h, input_w = 1024, 512
img_resized = cv2.resize(image_rgb, (input_w, input_h))
# Normalize and convert to tensor - use original preprocessing!
# Original: out_image = image / 255. - 0.5 (NOT ImageNet stats)
img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0 - 0.5
img_tensor = img_tensor.unsqueeze(0).to(self.device)
# Run inference
with torch.no_grad():
output = self.model(img_tensor)
# Decode predictions - returns (K, 11) array
# Format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score]
pts = self._decode_predictions(output, down_ratio=4, K=17)
# Scale coordinates back to original image size
scale_x = orig_w / input_w
scale_y = orig_h / input_h
# Convert to unified format
vertebrae = []
threshold = 0.3
for i in range(len(pts)):
score = pts[i, 10]
if score < threshold:
continue
# Get center and corners (already scaled by down_ratio in decoder)
cx = pts[i, 0] * scale_x
cy = pts[i, 1] * scale_y
# Corners: tl, tr, bl, br
tl = np.array([pts[i, 2] * scale_x, pts[i, 3] * scale_y])
tr = np.array([pts[i, 4] * scale_x, pts[i, 5] * scale_y])
bl = np.array([pts[i, 6] * scale_x, pts[i, 7] * scale_y])
br = np.array([pts[i, 8] * scale_x, pts[i, 9] * scale_y])
# Reorder to [tl, tr, br, bl] for consistency with ScolioVis
corners = np.array([tl, tr, br, bl], dtype=np.float32)
centroid = np.array([cx, cy], dtype=np.float32)
# Compute orientation from top edge
dx = tr[0] - tl[0]
dy = tr[1] - tl[1]
orientation = np.degrees(np.arctan2(dy, dx))
vert = VertebraLandmark(
level=None,
centroid_px=centroid,
corners_px=corners,
endplate_upper_px=np.array([tl, tr]), # top edge
endplate_lower_px=np.array([bl, br]), # bottom edge
orientation_deg=orientation,
confidence=float(score),
meta={'raw_pts': pts[i].tolist()}
)
vertebrae.append(vert)
# Sort by y-coordinate (top to bottom)
vertebrae.sort(key=lambda v: float(v.centroid_px[1]))
# Assign vertebra levels (T1-L5 = 17 vertebrae typically)
level_names = ['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4', 'L5']
for i, vert in enumerate(vertebrae):
if i < len(level_names):
vert.level = level_names[i]
# Create Spine2D
spine = Spine2D(
vertebrae=vertebrae,
image_shape=(orig_h, orig_w),
source_model=self.name
)
# Compute Cobb angles
if len(vertebrae) >= 7:
from spine_analysis import compute_cobb_angles
compute_cobb_angles(spine)
return spine

View File

@@ -0,0 +1,354 @@
"""
Brace surface generation from spine landmarks.
Two modes:
- Version A: Generic/average body shape (parametric torso)
- Version B: Uses actual 3D body scan mesh
"""
import numpy as np
from typing import Tuple, Optional, List
from pathlib import Path
from .data_models import Spine2D, BraceConfig
from .spine_analysis import compute_spine_curve, find_apex_vertebrae
try:
import trimesh
HAS_TRIMESH = True
except ImportError:
HAS_TRIMESH = False
class BraceGenerator:
"""
Generates 3D brace shell from spine landmarks.
"""
def __init__(self, config: Optional[BraceConfig] = None):
"""
Initialize brace generator.
Args:
config: Brace configuration parameters
"""
if not HAS_TRIMESH:
raise ImportError("trimesh is required for brace generation. Install with: pip install trimesh")
self.config = config or BraceConfig()
def generate(self, spine: Spine2D) -> 'trimesh.Trimesh':
"""
Generate brace mesh from spine landmarks.
Args:
spine: Spine2D object with detected vertebrae
Returns:
trimesh.Trimesh object representing the brace shell
"""
if self.config.use_body_scan and self.config.body_scan_path:
return self._generate_from_body_scan(spine)
else:
return self._generate_from_average_body(spine)
def _torso_profile(self, z01: float) -> Tuple[float, float]:
"""
Get torso cross-section radii at a given height.
Args:
z01: Normalized height (0=top, 1=bottom)
Returns:
(a_mm, b_mm): Radii in left-right and front-back directions
"""
# Torso shape varies with height
# Wider at chest (z~0.3) and hips (z~0.8), narrower at waist (z~0.5)
# Base radii from config
base_a = self.config.torso_width_mm / 2
base_b = self.config.torso_depth_mm / 2
# Shape modulation
# Chest region (z ~ 0.2-0.4): wider
# Waist region (z ~ 0.5): narrower
# Hip region (z ~ 0.8-1.0): wider
if z01 < 0.3:
# Upper chest - moderate width
mod = 1.0
elif z01 < 0.5:
# Transition to waist
t = (z01 - 0.3) / 0.2
mod = 1.0 - 0.15 * t # Decrease by 15%
elif z01 < 0.7:
# Waist region - narrowest
mod = 0.85
else:
# Hips - widen again
t = (z01 - 0.7) / 0.3
mod = 0.85 + 0.2 * t # Increase by 20%
return base_a * mod, base_b * mod
def _generate_from_average_body(self, spine: Spine2D) -> 'trimesh.Trimesh':
"""
Generate brace using parametric average body shape.
The brace follows the spine curve laterally and applies
pressure zones at curve apexes.
"""
cfg = self.config
# 1) Compute spine curve
try:
C_px, T_px, N_px, curvature = compute_spine_curve(spine, smooth=5.0, n_samples=cfg.n_vertical_slices)
except ValueError as e:
raise ValueError(f"Cannot generate brace: {e}")
# 2) Convert to mm
if spine.pixel_spacing_mm is not None:
sx, sy = spine.pixel_spacing_mm
elif cfg.pixel_spacing_mm is not None:
sx, sy = cfg.pixel_spacing_mm
else:
sx = sy = 0.25 # Default assumption
C_mm = np.zeros_like(C_px)
C_mm[:, 0] = C_px[:, 0] * sx
C_mm[:, 1] = C_px[:, 1] * sy
# 3) Determine brace vertical extent
y_mm = C_mm[:, 1]
y_min, y_max = y_mm.min(), y_mm.max()
spine_height = y_max - y_min
# Brace height (might extend beyond detected vertebrae)
brace_height = min(cfg.brace_height_mm, spine_height * 1.1)
# 4) Normalize curvature for pressure zones
curv_norm = (curvature - curvature.min()) / (curvature.max() - curvature.min() + 1e-8)
# 5) Build vertices
n_z = cfg.n_vertical_slices
n_theta = cfg.n_circumference_points
# Opening angle (front of brace might be open)
opening_half = np.radians(cfg.front_opening_deg / 2)
vertices = []
for i in range(n_z):
z01 = i / (n_z - 1) # 0 to 1
# Z coordinate (vertical position in 3D)
z_mm = y_min + z01 * spine_height
# Get torso profile at this height
a_mm, b_mm = self._torso_profile(z01)
# Lateral offset from spine curve
x_offset = C_mm[i, 0] - (C_mm[0, 0] + C_mm[-1, 0]) / 2 # Deviation from midline
# Pressure modulation based on curvature
pressure = cfg.pressure_strength_mm * curv_norm[i]
for j in range(n_theta):
theta = 2 * np.pi * (j / n_theta)
# Skip vertices in the opening region (front = theta around 0)
# Actually, we'll still create them but can mark them for later removal
# Base ellipse point
x = a_mm * np.cos(theta)
y = b_mm * np.sin(theta)
# Apply lateral offset (brace follows spine curve)
x += x_offset
# Apply pressure zones
# Pressure on sides (theta near π/2 or 3π/2 = sides)
# The side that's convex gets pushed in
side_factor = abs(np.cos(theta)) # Max at sides (theta=0 or π)
# Determine which side based on spine deviation
if x_offset > 0:
# Spine deviated right, push on right side
if np.cos(theta) > 0: # Right side
x -= pressure * side_factor
else:
# Spine deviated left, push on left side
if np.cos(theta) < 0: # Left side
x -= pressure * side_factor * np.sign(np.cos(theta))
# Vertex position: x=left/right, y=front/back, z=vertical
vertices.append([x, y, z_mm])
vertices = np.array(vertices, dtype=np.float32)
# 6) Build faces (quad strips between adjacent rings)
faces = []
def vid(i, j):
return i * n_theta + (j % n_theta)
for i in range(n_z - 1):
for j in range(n_theta):
j2 = (j + 1) % n_theta
# Two triangles per quad
a = vid(i, j)
b = vid(i, j2)
c = vid(i + 1, j2)
d = vid(i + 1, j)
faces.append([a, b, c])
faces.append([a, c, d])
faces = np.array(faces, dtype=np.int32)
# 7) Create outer shell mesh
outer_shell = trimesh.Trimesh(vertices=vertices, faces=faces, process=True)
# 8) Create inner shell (offset inward by wall thickness)
outer_shell.fix_normals()
vn = outer_shell.vertex_normals
inner_vertices = vertices - cfg.wall_thickness_mm * vn
# Inner faces need reversed winding
inner_faces = faces[:, ::-1]
# 9) Combine into solid shell
all_vertices = np.vstack([vertices, inner_vertices])
inner_faces_offset = inner_faces + len(vertices)
all_faces = np.vstack([faces, inner_faces_offset])
# 10) Add end caps (top and bottom rings)
# Top cap (connect outer to inner at z=0)
top_faces = []
for j in range(n_theta):
j2 = (j + 1) % n_theta
outer_j = vid(0, j)
outer_j2 = vid(0, j2)
inner_j = outer_j + len(vertices)
inner_j2 = outer_j2 + len(vertices)
top_faces.append([outer_j, inner_j, inner_j2])
top_faces.append([outer_j, inner_j2, outer_j2])
# Bottom cap
bottom_faces = []
for j in range(n_theta):
j2 = (j + 1) % n_theta
outer_j = vid(n_z - 1, j)
outer_j2 = vid(n_z - 1, j2)
inner_j = outer_j + len(vertices)
inner_j2 = outer_j2 + len(vertices)
bottom_faces.append([outer_j, outer_j2, inner_j2])
bottom_faces.append([outer_j, inner_j2, inner_j])
all_faces = np.vstack([all_faces, top_faces, bottom_faces])
# Create final mesh
brace = trimesh.Trimesh(vertices=all_vertices, faces=all_faces, process=True)
brace.merge_vertices()
# Remove degenerate faces
valid_faces = brace.nondegenerate_faces()
brace.update_faces(valid_faces)
brace.fix_normals()
return brace
def _generate_from_body_scan(self, spine: Spine2D) -> 'trimesh.Trimesh':
"""
Generate brace by offsetting from a 3D body scan mesh.
The body scan provides the actual torso shape, and we:
1. Offset outward for clearance
2. Apply pressure zones based on spine curvature
3. Thicken for wall thickness
"""
cfg = self.config
if not cfg.body_scan_path or not Path(cfg.body_scan_path).exists():
raise FileNotFoundError(f"Body scan not found: {cfg.body_scan_path}")
# Load body scan
body = trimesh.load(cfg.body_scan_path, force='mesh')
body.remove_unreferenced_vertices()
body.fix_normals()
# Compute spine curve for pressure mapping
try:
C_px, T_px, N_px, curvature = compute_spine_curve(spine, smooth=5.0, n_samples=200)
except ValueError:
curvature = np.zeros(200)
# Convert spine coordinates to mm
if spine.pixel_spacing_mm is not None:
sx, sy = spine.pixel_spacing_mm
else:
sx = sy = 0.25
y_mm = C_px[:, 1] * sy
y_min, y_max = y_mm.min(), y_mm.max()
H = y_max - y_min + 1e-6
# Normalize curvature
curv_norm = (curvature - curvature.min()) / (curvature.max() - curvature.min() + 1e-8)
# 1) Offset body surface outward for clearance (inner brace surface)
clearance_mm = 6.0 # Gap between body and brace
vn = body.vertex_normals
inner_surface = trimesh.Trimesh(
vertices=body.vertices + clearance_mm * vn,
faces=body.faces.copy(),
process=True
)
# 2) Apply pressure deformation
# Map each vertex's Z coordinate to spine curvature
z_coords = inner_surface.vertices[:, 2] # Assuming Z is vertical
z_min, z_max = z_coords.min(), z_coords.max()
z01 = (z_coords - z_min) / (z_max - z_min + 1e-6)
# Sample curvature at each vertex height
curv_idx = np.clip((z01 * (len(curv_norm) - 1)).astype(int), 0, len(curv_norm) - 1)
pressure_per_vertex = cfg.pressure_strength_mm * curv_norm[curv_idx]
# Apply pressure on sides (based on X coordinate)
x_coords = inner_surface.vertices[:, 0]
x_range = np.abs(x_coords).max() + 1e-6
side_factor = np.abs(x_coords) / x_range # 0 at center, 1 at sides
deformation = (pressure_per_vertex * side_factor)[:, np.newaxis] * inner_surface.vertex_normals
inner_surface.vertices = inner_surface.vertices - deformation
# 3) Create outer surface (offset by wall thickness)
inner_surface.fix_normals()
outer_surface = trimesh.Trimesh(
vertices=inner_surface.vertices + cfg.wall_thickness_mm * inner_surface.vertex_normals,
faces=inner_surface.faces.copy(),
process=True
)
# 4) Combine surfaces
# For a true solid, we'd need to stitch edges - simplified here
brace = trimesh.util.concatenate([inner_surface, outer_surface])
brace.merge_vertices()
valid_faces = brace.nondegenerate_faces()
brace.update_faces(valid_faces)
brace.fix_normals()
return brace
def export_stl(self, mesh: 'trimesh.Trimesh', output_path: str):
"""
Export mesh to STL file.
Args:
mesh: trimesh.Trimesh object
output_path: Path for output STL file
"""
mesh.export(output_path)
print(f"Exported brace to {output_path}")
print(f" Vertices: {len(mesh.vertices)}")
print(f" Faces: {len(mesh.faces)}")

View File

@@ -0,0 +1,177 @@
"""
Data models for unified spine landmark representation.
This is the "glue" that connects different model outputs to the brace generator.
"""
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
import numpy as np
@dataclass
class VertebraLandmark:
"""
Unified representation of a single vertebra's landmarks.
All coordinates are in pixels (can be converted to mm with pixel_spacing).
"""
# Vertebra level identifier (e.g., "T1", "T4", "L1", etc.) - None if unknown
level: Optional[str] = None
# Center point of vertebra [x, y] in pixels
centroid_px: np.ndarray = field(default_factory=lambda: np.zeros(2))
# Four corner points [top_left, top_right, bottom_right, bottom_left] shape (4, 2)
corners_px: Optional[np.ndarray] = None
# Upper endplate points [left, right] shape (2, 2)
endplate_upper_px: Optional[np.ndarray] = None
# Lower endplate points [left, right] shape (2, 2)
endplate_lower_px: Optional[np.ndarray] = None
# Orientation angle of vertebra in degrees (tilt in coronal plane)
orientation_deg: Optional[float] = None
# Detection confidence (0-1)
confidence: float = 1.0
# Additional metadata from source model
meta: Optional[Dict[str, Any]] = None
def compute_orientation(self) -> float:
"""Compute vertebra orientation from corners or endplates."""
if self.orientation_deg is not None:
return self.orientation_deg
# Try to compute from upper endplate
if self.endplate_upper_px is not None:
left, right = self.endplate_upper_px[0], self.endplate_upper_px[1]
dx = right[0] - left[0]
dy = right[1] - left[1]
angle = np.degrees(np.arctan2(dy, dx))
self.orientation_deg = angle
return angle
# Try to compute from corners (top-left to top-right)
if self.corners_px is not None:
top_left, top_right = self.corners_px[0], self.corners_px[1]
dx = top_right[0] - top_left[0]
dy = top_right[1] - top_left[1]
angle = np.degrees(np.arctan2(dy, dx))
self.orientation_deg = angle
return angle
return 0.0
def compute_centroid(self) -> np.ndarray:
"""Compute centroid from corners if not set."""
if self.corners_px is not None and np.all(self.centroid_px == 0):
self.centroid_px = np.mean(self.corners_px, axis=0)
return self.centroid_px
@dataclass
class Spine2D:
"""
Complete 2D spine representation from an X-ray.
Contains all detected vertebrae and computed angles.
"""
# List of vertebrae, ordered from top (C7/T1) to bottom (L5/S1)
vertebrae: List[VertebraLandmark] = field(default_factory=list)
# Pixel spacing in mm [sx, sy] - from DICOM if available
pixel_spacing_mm: Optional[np.ndarray] = None
# Original image shape (height, width)
image_shape: Optional[tuple] = None
# Computed Cobb angles in degrees (individual fields)
cobb_pt: Optional[float] = None # Proximal Thoracic
cobb_mt: Optional[float] = None # Main Thoracic
cobb_tl: Optional[float] = None # Thoracolumbar/Lumbar
# Cobb angles as dictionary (alternative format)
cobb_angles: Optional[Dict[str, float]] = None # {'PT': angle, 'MT': angle, 'TL': angle}
# Curve type: "S" (double curve) or "C" (single curve) or "Normal"
curve_type: Optional[str] = None
# Rigo-Chêneau classification
rigo_type: Optional[str] = None # A1, A2, A3, B1, B2, C1, C2, E1, E2, Normal
rigo_description: Optional[str] = None # Detailed description
# Source model that generated this data
source_model: Optional[str] = None
# Additional metadata
meta: Optional[Dict[str, Any]] = None
def get_cobb_angles(self) -> Dict[str, float]:
"""Get Cobb angles as dictionary, preferring computed individual fields."""
# Prefer individual fields (set by compute_cobb_angles) over dictionary
# This ensures consistency between displayed values and classification
if self.cobb_pt is not None or self.cobb_mt is not None or self.cobb_tl is not None:
return {
'PT': self.cobb_pt or 0.0,
'MT': self.cobb_mt or 0.0,
'TL': self.cobb_tl or 0.0
}
if self.cobb_angles is not None:
return self.cobb_angles
return {'PT': 0.0, 'MT': 0.0, 'TL': 0.0}
def get_centroids(self) -> np.ndarray:
"""Get array of all vertebra centroids, shape (N, 2)."""
centroids = []
for v in self.vertebrae:
v.compute_centroid()
centroids.append(v.centroid_px)
return np.array(centroids, dtype=np.float32)
def get_orientations(self) -> np.ndarray:
"""Get array of all vertebra orientations in degrees, shape (N,)."""
return np.array([v.compute_orientation() for v in self.vertebrae], dtype=np.float32)
def to_mm(self, coords_px: np.ndarray) -> np.ndarray:
"""Convert pixel coordinates to millimeters."""
if self.pixel_spacing_mm is None:
# Default assumption: 0.25 mm/pixel (typical for spine X-rays)
spacing = np.array([0.25, 0.25])
else:
spacing = self.pixel_spacing_mm
return coords_px * spacing
def sort_vertebrae(self):
"""Sort vertebrae by vertical position (top to bottom)."""
self.vertebrae.sort(key=lambda v: float(v.centroid_px[1]))
@dataclass
class BraceConfig:
"""
Configuration parameters for brace generation.
"""
# Brace dimensions
brace_height_mm: float = 400.0 # Total height of brace
wall_thickness_mm: float = 4.0 # Shell thickness
# Torso shape parameters (for average body mode)
torso_width_mm: float = 280.0 # Left-right diameter at widest
torso_depth_mm: float = 200.0 # Front-back diameter at widest
# Correction parameters
pressure_strength_mm: float = 15.0 # Max indentation at apex
pressure_spread_deg: float = 45.0 # Angular spread of pressure zone
# Mesh resolution
n_vertical_slices: int = 100 # Number of cross-sections
n_circumference_points: int = 72 # Points per cross-section (every 5°)
# Opening (for brace accessibility)
front_opening_deg: float = 60.0 # Angular width of front opening (0 = closed)
# Mode
use_body_scan: bool = False # True = use 3D body scan, False = average body
body_scan_path: Optional[str] = None # Path to body scan mesh
# Scale
pixel_spacing_mm: Optional[np.ndarray] = None # Override pixel spacing

View File

@@ -0,0 +1,115 @@
"""
Image loader supporting JPEG, PNG, and DICOM formats.
"""
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
try:
import pydicom
HAS_PYDICOM = True
except ImportError:
HAS_PYDICOM = False
try:
from PIL import Image
HAS_PIL = True
except ImportError:
HAS_PIL = False
try:
import cv2
HAS_CV2 = True
except ImportError:
HAS_CV2 = False
def load_xray(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Load an X-ray image from file.
Supports: JPEG, PNG, BMP, DICOM (.dcm)
Args:
path: Path to the image file
Returns:
img_u8: Grayscale image as uint8 array (H, W)
spacing_mm: Pixel spacing [sx, sy] in mm, or None if not available
"""
path = Path(path)
suffix = path.suffix.lower()
# DICOM
if suffix in ['.dcm', '.dicom']:
if not HAS_PYDICOM:
raise ImportError("pydicom is required for DICOM files. Install with: pip install pydicom")
return _load_dicom(str(path))
# Standard image formats
if suffix in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
return _load_standard_image(str(path))
# Try to load as standard image anyway
try:
return _load_standard_image(str(path))
except Exception as e:
raise ValueError(f"Could not load image: {path}. Error: {e}")
def _load_dicom(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""Load DICOM file."""
ds = pydicom.dcmread(path)
arr = ds.pixel_array.astype(np.float32)
# Apply modality LUT if present
if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
arr = arr * ds.RescaleSlope + ds.RescaleIntercept
# Normalize to 0-255
arr = arr - arr.min()
if arr.max() > 0:
arr = arr / arr.max()
img_u8 = (arr * 255).astype(np.uint8)
# Get pixel spacing
spacing_mm = None
if hasattr(ds, 'PixelSpacing'):
# PixelSpacing is [row_spacing, col_spacing] in mm
sy, sx = [float(x) for x in ds.PixelSpacing]
spacing_mm = np.array([sx, sy], dtype=np.float32)
elif hasattr(ds, 'ImagerPixelSpacing'):
sy, sx = [float(x) for x in ds.ImagerPixelSpacing]
spacing_mm = np.array([sx, sy], dtype=np.float32)
return img_u8, spacing_mm
def _load_standard_image(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""Load standard image format (JPEG, PNG, etc.)."""
if HAS_CV2:
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if img is None:
raise ValueError(f"Could not read image: {path}")
return img.astype(np.uint8), None
elif HAS_PIL:
img = Image.open(path).convert('L') # Convert to grayscale
return np.array(img, dtype=np.uint8), None
else:
raise ImportError("Either opencv-python or Pillow is required. Install with: pip install opencv-python")
def load_xray_rgb(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Load X-ray as RGB (for models that expect 3-channel input).
Returns:
img_rgb: RGB image as uint8 array (H, W, 3)
spacing_mm: Pixel spacing or None
"""
img_gray, spacing = load_xray(path)
# Convert grayscale to RGB by stacking
img_rgb = np.stack([img_gray, img_gray, img_gray], axis=-1)
return img_rgb, spacing

346
brace-generator/pipeline.py Normal file
View File

@@ -0,0 +1,346 @@
"""
Complete pipeline: X-ray → Landmarks → Brace STL
"""
import numpy as np
from pathlib import Path
from typing import Optional, Dict, Any, Union
import json
from brace_generator.data_models import Spine2D, BraceConfig, VertebraLandmark
from brace_generator.image_loader import load_xray, load_xray_rgb
from brace_generator.adapters import BaseLandmarkAdapter, ScolioVisAdapter, VertLandmarkAdapter
from brace_generator.spine_analysis import (
compute_spine_curve, compute_cobb_angles, find_apex_vertebrae,
get_curve_severity, classify_rigo_type
)
from brace_generator.brace_surface import BraceGenerator
class BracePipeline:
"""
End-to-end pipeline for generating scoliosis braces from X-rays.
Usage:
# Basic usage with default model
pipeline = BracePipeline()
pipeline.process("xray.png", "brace.stl")
# With specific model
pipeline = BracePipeline(model="vertebra-landmark")
pipeline.process("xray.dcm", "brace.stl")
# With body scan
config = BraceConfig(use_body_scan=True, body_scan_path="body.obj")
pipeline = BracePipeline(config=config)
pipeline.process("xray.png", "brace.stl")
"""
AVAILABLE_MODELS = {
'scoliovis': ScolioVisAdapter,
'vertebra-landmark': VertLandmarkAdapter,
}
def __init__(
self,
model: str = 'scoliovis',
config: Optional[BraceConfig] = None,
device: str = 'cpu'
):
"""
Initialize pipeline.
Args:
model: Model to use ('scoliovis' or 'vertebra-landmark')
config: Brace configuration
device: 'cpu' or 'cuda'
"""
self.device = device
self.config = config or BraceConfig()
self.model_name = model.lower()
# Initialize model adapter
if self.model_name not in self.AVAILABLE_MODELS:
raise ValueError(f"Unknown model: {model}. Available: {list(self.AVAILABLE_MODELS.keys())}")
self.adapter: BaseLandmarkAdapter = self.AVAILABLE_MODELS[self.model_name](device=device)
self.brace_generator = BraceGenerator(self.config)
# Store last results for inspection
self.last_spine: Optional[Spine2D] = None
self.last_image: Optional[np.ndarray] = None
def process(
self,
xray_path: str,
output_stl_path: str,
visualize: bool = False,
save_landmarks: bool = False
) -> Dict[str, Any]:
"""
Process X-ray and generate brace STL.
Args:
xray_path: Path to input X-ray (JPEG, PNG, or DICOM)
output_stl_path: Path for output STL file
visualize: If True, also save visualization image
save_landmarks: If True, also save landmarks JSON
Returns:
Dictionary with analysis results
"""
print(f"=" * 60)
print(f"Brace Generation Pipeline")
print(f"Model: {self.adapter.name}")
print(f"=" * 60)
# 1) Load X-ray
print(f"\n1. Loading X-ray: {xray_path}")
image_rgb, pixel_spacing = load_xray_rgb(xray_path)
self.last_image = image_rgb
print(f" Image size: {image_rgb.shape[:2]}")
if pixel_spacing is not None:
print(f" Pixel spacing: {pixel_spacing} mm")
# 2) Detect landmarks
print(f"\n2. Detecting landmarks...")
spine = self.adapter.predict(image_rgb)
spine.pixel_spacing_mm = pixel_spacing
self.last_spine = spine
print(f" Detected {len(spine.vertebrae)} vertebrae")
if len(spine.vertebrae) < 5:
raise ValueError(f"Insufficient vertebrae detected ({len(spine.vertebrae)}). Need at least 5.")
# 3) Compute spine analysis
print(f"\n3. Analyzing spine curvature...")
compute_cobb_angles(spine)
apexes = find_apex_vertebrae(spine)
# Classify Rigo type
rigo_result = classify_rigo_type(spine)
print(f" Cobb Angles:")
print(f" PT (Proximal Thoracic): {spine.cobb_pt:.1f}° - {get_curve_severity(spine.cobb_pt)}")
print(f" MT (Main Thoracic): {spine.cobb_mt:.1f}° - {get_curve_severity(spine.cobb_mt)}")
print(f" TL (Thoracolumbar): {spine.cobb_tl:.1f}° - {get_curve_severity(spine.cobb_tl)}")
print(f" Curve type: {spine.curve_type}")
print(f" Rigo Classification: {rigo_result['rigo_type']}")
print(f" - {rigo_result['description']}")
print(f" Apex vertebrae indices: {apexes}")
# 4) Generate brace
print(f"\n4. Generating brace mesh...")
if self.config.use_body_scan:
print(f" Mode: Using body scan ({self.config.body_scan_path})")
else:
print(f" Mode: Average body shape")
brace_mesh = self.brace_generator.generate(spine)
print(f" Mesh: {len(brace_mesh.vertices)} vertices, {len(brace_mesh.faces)} faces")
# 5) Export STL
print(f"\n5. Exporting STL: {output_stl_path}")
self.brace_generator.export_stl(brace_mesh, output_stl_path)
# 6) Optional: Save visualization
if visualize:
vis_path = str(Path(output_stl_path).with_suffix('.png'))
self._save_visualization(vis_path, spine, image_rgb)
print(f" Visualization saved: {vis_path}")
# 7) Optional: Save landmarks JSON
if save_landmarks:
json_path = str(Path(output_stl_path).with_suffix('.json'))
self._save_landmarks_json(json_path, spine)
print(f" Landmarks saved: {json_path}")
# Prepare results
results = {
'input_image': xray_path,
'output_stl': output_stl_path,
'model': self.adapter.name,
'vertebrae_detected': len(spine.vertebrae),
'cobb_angles': {
'PT': spine.cobb_pt,
'MT': spine.cobb_mt,
'TL': spine.cobb_tl,
},
'curve_type': spine.curve_type,
'rigo_type': rigo_result['rigo_type'],
'rigo_description': rigo_result['description'],
'apex_indices': apexes,
'mesh_vertices': len(brace_mesh.vertices),
'mesh_faces': len(brace_mesh.faces),
}
print(f"\n{'=' * 60}")
print(f"Pipeline complete!")
print(f"{'=' * 60}")
return results
def _save_visualization(self, path: str, spine: Spine2D, image: np.ndarray):
"""Save visualization of detected landmarks and spine curve."""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
except ImportError:
print(" Warning: matplotlib not available for visualization")
return
fig, axes = plt.subplots(1, 2, figsize=(14, 10))
# Left: Original with landmarks
ax1 = axes[0]
ax1.imshow(image)
# Draw vertebra centers
centroids = spine.get_centroids()
ax1.scatter(centroids[:, 0], centroids[:, 1], c='red', s=30, zorder=5)
# Draw corners if available
for vert in spine.vertebrae:
if vert.corners_px is not None:
corners = vert.corners_px
# Draw quadrilateral
for i in range(4):
j = (i + 1) % 4
ax1.plot([corners[i, 0], corners[j, 0]],
[corners[i, 1], corners[j, 1]], 'g-', linewidth=1)
ax1.set_title(f"Detected Landmarks ({len(spine.vertebrae)} vertebrae)")
ax1.axis('off')
# Right: Spine curve analysis
ax2 = axes[1]
ax2.imshow(image, alpha=0.5)
# Draw spine curve
try:
C, T, N, curv = compute_spine_curve(spine)
ax2.plot(C[:, 0], C[:, 1], 'b-', linewidth=2, label='Spine curve')
# Highlight high curvature regions
high_curv_mask = curv > curv.mean() + curv.std()
ax2.scatter(C[high_curv_mask, 0], C[high_curv_mask, 1],
c='orange', s=20, label='High curvature')
except:
pass
# Get Rigo classification for display
rigo_result = classify_rigo_type(spine)
# Add Cobb angles and Rigo type text
text = f"Cobb Angles:\n"
text += f"PT: {spine.cobb_pt:.1f}°\n"
text += f"MT: {spine.cobb_mt:.1f}°\n"
text += f"TL: {spine.cobb_tl:.1f}°\n"
text += f"Curve: {spine.curve_type}\n"
text += f"-----------\n"
text += f"Rigo: {rigo_result['rigo_type']}"
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
ax2.set_title("Spine Analysis")
ax2.axis('off')
ax2.legend(loc='lower right')
plt.tight_layout()
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
def _save_landmarks_json(self, path: str, spine: Spine2D):
"""Save landmarks to JSON file with Rigo classification."""
def to_native(val):
"""Convert numpy types to native Python types."""
if isinstance(val, np.ndarray):
return val.tolist()
elif isinstance(val, (np.float32, np.float64)):
return float(val)
elif isinstance(val, (np.int32, np.int64)):
return int(val)
return val
# Get Rigo classification
rigo_result = classify_rigo_type(spine)
data = {
'source_model': spine.source_model,
'image_shape': list(spine.image_shape) if spine.image_shape else None,
'pixel_spacing_mm': spine.pixel_spacing_mm.tolist() if spine.pixel_spacing_mm is not None else None,
'cobb_angles': {
'PT': to_native(spine.cobb_pt),
'MT': to_native(spine.cobb_mt),
'TL': to_native(spine.cobb_tl),
},
'curve_type': spine.curve_type,
'rigo_classification': {
'type': rigo_result['rigo_type'],
'description': rigo_result['description'],
'curve_pattern': rigo_result['curve_pattern'],
'n_significant_curves': rigo_result['n_significant_curves'],
},
'vertebrae': []
}
for vert in spine.vertebrae:
vert_data = {
'level': vert.level,
'centroid_px': vert.centroid_px.tolist(),
'orientation_deg': to_native(vert.orientation_deg),
'confidence': to_native(vert.confidence),
}
if vert.corners_px is not None:
vert_data['corners_px'] = vert.corners_px.tolist()
data['vertebrae'].append(vert_data)
with open(path, 'w') as f:
json.dump(data, f, indent=2)
def main():
"""Command-line interface for brace generation."""
import argparse
parser = argparse.ArgumentParser(description='Generate scoliosis brace from X-ray')
parser.add_argument('input', help='Input X-ray image (JPEG, PNG, or DICOM)')
parser.add_argument('output', help='Output STL file path')
parser.add_argument('--model', choices=['scoliovis', 'vertebra-landmark'],
default='scoliovis', help='Landmark detection model')
parser.add_argument('--device', default='cpu', help='Device (cpu or cuda)')
parser.add_argument('--body-scan', help='Path to 3D body scan mesh (optional)')
parser.add_argument('--visualize', action='store_true', help='Save visualization')
parser.add_argument('--save-landmarks', action='store_true', help='Save landmarks JSON')
parser.add_argument('--pressure', type=float, default=15.0,
help='Pressure strength in mm (default: 15)')
parser.add_argument('--thickness', type=float, default=4.0,
help='Wall thickness in mm (default: 4)')
args = parser.parse_args()
# Build config
config = BraceConfig(
pressure_strength_mm=args.pressure,
wall_thickness_mm=args.thickness,
)
if args.body_scan:
config.use_body_scan = True
config.body_scan_path = args.body_scan
# Run pipeline
pipeline = BracePipeline(model=args.model, config=config, device=args.device)
results = pipeline.process(
args.input,
args.output,
visualize=args.visualize,
save_landmarks=args.save_landmarks
)
return results
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,464 @@
"""
Spine analysis functions for computing curves, Cobb angles, and identifying apex vertebrae.
"""
import numpy as np
from scipy.interpolate import splprep, splev
from typing import Tuple, List, Optional
from data_models import Spine2D, VertebraLandmark
def compute_spine_curve(
spine: Spine2D,
smooth: float = 1.0,
n_samples: int = 200
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Compute smooth spine centerline from vertebra centroids.
Args:
spine: Spine2D object with detected vertebrae
smooth: Smoothing factor for spline (higher = smoother)
n_samples: Number of points to sample along the curve
Returns:
C: Curve points, shape (n_samples, 2)
T: Tangent vectors, shape (n_samples, 2)
N: Normal vectors, shape (n_samples, 2)
curvature: Curvature at each point, shape (n_samples,)
"""
pts = spine.get_centroids()
if len(pts) < 4:
raise ValueError(f"Need at least 4 vertebrae for spline, got {len(pts)}")
# Fit parametric spline through centroids
x = pts[:, 0]
y = pts[:, 1]
try:
tck, u = splprep([x, y], s=smooth, k=min(3, len(pts)-1))
except Exception as e:
# Fallback: simple linear interpolation
t = np.linspace(0, 1, n_samples)
xs = np.interp(t, np.linspace(0, 1, len(x)), x)
ys = np.interp(t, np.linspace(0, 1, len(y)), y)
C = np.stack([xs, ys], axis=1).astype(np.float32)
T = np.gradient(C, axis=0)
T = T / (np.linalg.norm(T, axis=1, keepdims=True) + 1e-8)
N = np.stack([-T[:, 1], T[:, 0]], axis=1)
curvature = np.zeros(n_samples, dtype=np.float32)
return C, T, N, curvature
# Sample the spline
u_new = np.linspace(0, 1, n_samples)
xs, ys = splev(u_new, tck)
# First and second derivatives
dx, dy = splev(u_new, tck, der=1)
ddx, ddy = splev(u_new, tck, der=2)
# Curve points
C = np.stack([xs, ys], axis=1).astype(np.float32)
# Tangent vectors (normalized)
T = np.stack([dx, dy], axis=1)
T_norm = np.linalg.norm(T, axis=1, keepdims=True) + 1e-8
T = (T / T_norm).astype(np.float32)
# Normal vectors (perpendicular to tangent)
N = np.stack([-T[:, 1], T[:, 0]], axis=1).astype(np.float32)
# Curvature: |x'y'' - y'x''| / (x'^2 + y'^2)^(3/2)
curvature = np.abs(dx * ddy - dy * ddx) / (np.power(dx**2 + dy**2, 1.5) + 1e-8)
curvature = curvature.astype(np.float32)
return C, T, N, curvature
def compute_cobb_angles(spine: Spine2D) -> Tuple[float, float, float]:
"""
Compute Cobb angles from vertebra orientations.
The Cobb angle is measured as the angle between:
- The superior endplate of the most tilted vertebra at the top of a curve
- The inferior endplate of the most tilted vertebra at the bottom
This function estimates 3 Cobb angles:
- PT: Proximal Thoracic (T1-T6 region)
- MT: Main Thoracic (T6-T12 region)
- TL: Thoracolumbar/Lumbar (T12-L5 region)
Args:
spine: Spine2D object with detected vertebrae
Returns:
(pt_angle, mt_angle, tl_angle) in degrees
"""
orientations = spine.get_orientations()
n = len(orientations)
if n < 7:
spine.cobb_pt = 0.0
spine.cobb_mt = 0.0
spine.cobb_tl = 0.0
return 0.0, 0.0, 0.0
# Divide into regions (approximately)
# PT: top 1/3, MT: middle 1/3, TL: bottom 1/3
pt_end = n // 3
mt_end = 2 * n // 3
# Find max tilt difference in each region
def region_cobb(start_idx: int, end_idx: int) -> float:
if end_idx <= start_idx:
return 0.0
region_angles = orientations[start_idx:end_idx]
if len(region_angles) < 2:
return 0.0
# Cobb angle = max angle - min angle in region
return abs(float(np.max(region_angles) - np.min(region_angles)))
pt_angle = region_cobb(0, pt_end)
mt_angle = region_cobb(pt_end, mt_end)
tl_angle = region_cobb(mt_end, n)
# Store in spine object
spine.cobb_pt = pt_angle
spine.cobb_mt = mt_angle
spine.cobb_tl = tl_angle
# Determine curve type
if mt_angle > 10 and tl_angle > 10:
spine.curve_type = "S" # Double curve
elif mt_angle > 10 or tl_angle > 10:
spine.curve_type = "C" # Single curve
else:
spine.curve_type = "Normal"
return pt_angle, mt_angle, tl_angle
def find_apex_vertebrae(spine: Spine2D) -> List[int]:
"""
Find indices of apex vertebrae (most deviated from midline).
Args:
spine: Spine2D with computed curve
Returns:
List of vertebra indices that are curve apexes
"""
centroids = spine.get_centroids()
if len(centroids) < 5:
return []
# Find midline (linear fit through endpoints)
start = centroids[0]
end = centroids[-1]
# Distance from midline for each vertebra
midline_vec = end - start
midline_len = np.linalg.norm(midline_vec)
if midline_len < 1e-6:
return []
midline_unit = midline_vec / midline_len
# Calculate perpendicular distance to midline
deviations = []
for i, pt in enumerate(centroids):
v = pt - start
# Project onto midline
proj_len = np.dot(v, midline_unit)
proj = proj_len * midline_unit
# Perpendicular distance
perp = v - proj
dist = np.linalg.norm(perp)
# Sign: positive if to the right of midline
sign = np.sign(np.cross(midline_unit, v / (np.linalg.norm(v) + 1e-8)))
deviations.append(dist * sign)
deviations = np.array(deviations)
# Find local extrema (peaks and valleys)
apexes = []
for i in range(1, len(deviations) - 1):
# Local maximum
if deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]:
if abs(deviations[i]) > 5: # Minimum deviation threshold (pixels)
apexes.append(i)
# Local minimum
elif deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]:
if abs(deviations[i]) > 5:
apexes.append(i)
return apexes
def get_curve_severity(cobb_angle: float) -> str:
"""
Get clinical severity classification from Cobb angle.
Args:
cobb_angle: Cobb angle in degrees
Returns:
Severity string: "Normal", "Mild", "Moderate", or "Severe"
"""
if cobb_angle < 10:
return "Normal"
elif cobb_angle < 25:
return "Mild"
elif cobb_angle < 40:
return "Moderate"
else:
return "Severe"
def classify_rigo_type(spine: Spine2D) -> dict:
"""
Classify scoliosis according to Rigo-Chêneau brace classification.
Rigo Classification Types:
- A1, A2, A3: 3-curve patterns (thoracic major)
- B1, B2: 4-curve patterns (double major)
- C1, C2: Single thoracolumbar/lumbar
- E1, E2: Single thoracic
Args:
spine: Spine2D object with detected vertebrae and Cobb angles
Returns:
dict with 'rigo_type', 'description', 'apex_region', 'curve_pattern'
"""
# Get Cobb angles
cobb_angles = spine.get_cobb_angles()
pt = cobb_angles.get('PT', 0)
mt = cobb_angles.get('MT', 0)
tl = cobb_angles.get('TL', 0)
n_verts = len(spine.vertebrae)
# Calculate lateral deviations to determine curve direction
centroids = spine.get_centroids()
deviations = _calculate_lateral_deviations(centroids)
# Find apex positions and directions
apex_info = _find_apex_info(centroids, deviations, n_verts)
# Determine curve pattern based on number of significant curves
significant_curves = []
if pt >= 10:
significant_curves.append(('PT', pt))
if mt >= 10:
significant_curves.append(('MT', mt))
if tl >= 10:
significant_curves.append(('TL', tl))
n_curves = len(significant_curves)
# Classification logic
rigo_type = "N/A"
description = ""
curve_pattern = ""
# No significant scoliosis
if n_curves == 0 or max(pt, mt, tl) < 10:
rigo_type = "Normal"
description = "No significant scoliosis (all Cobb angles < 10°)"
curve_pattern = "None"
# Single curve patterns
elif n_curves == 1:
max_curve = significant_curves[0][0]
max_angle = significant_curves[0][1]
if max_curve == 'MT' or max_curve == 'PT':
# Thoracic single curve
if apex_info['thoracic_apex_idx'] is not None:
# Check if there's a compensatory lumbar
if tl > 5:
rigo_type = "E2"
description = f"Single thoracic curve ({max_angle:.1f}°) with lumbar compensatory ({tl:.1f}°)"
curve_pattern = "Thoracic with compensation"
else:
rigo_type = "E1"
description = f"True single thoracic curve ({max_angle:.1f}°)"
curve_pattern = "Single thoracic"
else:
rigo_type = "E1"
description = f"Single thoracic curve ({max_angle:.1f}°)"
curve_pattern = "Single thoracic"
elif max_curve == 'TL':
# Thoracolumbar/Lumbar single curve
if mt > 5 or pt > 5:
rigo_type = "C2"
description = f"Thoracolumbar curve ({tl:.1f}°) with upper compensatory"
curve_pattern = "TL/L with compensation"
else:
rigo_type = "C1"
description = f"Single thoracolumbar/lumbar curve ({tl:.1f}°)"
curve_pattern = "Single TL/L"
# Double curve patterns
elif n_curves >= 2:
# Determine which curves are primary
thoracic_total = pt + mt
lumbar_total = tl
# Check curve directions for S vs C pattern
is_s_curve = apex_info['is_s_pattern']
if is_s_curve:
# S-curve: typically 3 or 4 curve patterns
if thoracic_total > lumbar_total * 1.5:
# Thoracic dominant - Type A (3-curve)
if apex_info['lumbar_apex_low']:
rigo_type = "A1"
description = f"3-curve: Thoracic major ({mt:.1f}°), lumbar apex low"
elif apex_info['apex_at_tl_junction']:
rigo_type = "A2"
description = f"3-curve: Thoracolumbar transition ({mt:.1f}°/{tl:.1f}°)"
else:
rigo_type = "A3"
description = f"3-curve: Thoracic major ({mt:.1f}°) with structural lumbar ({tl:.1f}°)"
curve_pattern = "3-curve (thoracic major)"
elif lumbar_total > thoracic_total * 1.5:
# Lumbar dominant
rigo_type = "C2"
description = f"Lumbar major ({tl:.1f}°) with thoracic compensatory ({mt:.1f}°)"
curve_pattern = "Lumbar major"
else:
# Double major - Type B (4-curve)
if tl >= mt:
rigo_type = "B1"
description = f"4-curve: Double major, lumbar prominent ({tl:.1f}°/{mt:.1f}°)"
else:
rigo_type = "B2"
description = f"4-curve: Double major, thoracic prominent ({mt:.1f}°/{tl:.1f}°)"
curve_pattern = "4-curve (double major)"
else:
# C-curve pattern (curves in same direction)
if mt >= tl:
if tl > 5:
rigo_type = "A3"
description = f"Long thoracic curve ({mt:.1f}°) extending to lumbar ({tl:.1f}°)"
else:
rigo_type = "E2"
description = f"Thoracic curve ({mt:.1f}°) with minor lumbar ({tl:.1f}°)"
curve_pattern = "Extended thoracic"
else:
rigo_type = "C2"
description = f"TL/Lumbar curve ({tl:.1f}°) with thoracic involvement ({mt:.1f}°)"
curve_pattern = "Extended lumbar"
# Store in spine object
spine.rigo_type = rigo_type
spine.rigo_description = description
return {
'rigo_type': rigo_type,
'description': description,
'curve_pattern': curve_pattern,
'apex_info': apex_info,
'cobb_angles': cobb_angles,
'n_significant_curves': n_curves
}
def _calculate_lateral_deviations(centroids: np.ndarray) -> np.ndarray:
"""Calculate lateral deviation from midline for each vertebra."""
if len(centroids) < 2:
return np.zeros(len(centroids))
# Midline from first to last vertebra
start = centroids[0]
end = centroids[-1]
midline_vec = end - start
midline_len = np.linalg.norm(midline_vec)
if midline_len < 1e-6:
return np.zeros(len(centroids))
midline_unit = midline_vec / midline_len
deviations = []
for pt in centroids:
v = pt - start
# Project onto midline
proj_len = np.dot(v, midline_unit)
proj = proj_len * midline_unit
# Perpendicular vector
perp = v - proj
dist = np.linalg.norm(perp)
# Sign: positive = right, negative = left
sign = np.sign(np.cross(midline_unit, perp / (dist + 1e-8)))
deviations.append(dist * sign)
return np.array(deviations)
def _find_apex_info(centroids: np.ndarray, deviations: np.ndarray, n_verts: int) -> dict:
"""Find apex positions and determine curve pattern."""
info = {
'thoracic_apex_idx': None,
'lumbar_apex_idx': None,
'lumbar_apex_low': False,
'apex_at_tl_junction': False,
'is_s_pattern': False,
'apex_directions': []
}
if len(deviations) < 3:
return info
# Find local extrema (apexes)
apexes = []
apex_values = []
for i in range(1, len(deviations) - 1):
if (deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]) or \
(deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]):
if abs(deviations[i]) > 3: # Minimum threshold
apexes.append(i)
apex_values.append(deviations[i])
# Determine S-pattern (alternating signs at apexes)
if len(apex_values) >= 2:
signs = [np.sign(v) for v in apex_values]
# S-pattern if adjacent apexes have opposite signs
for i in range(len(signs) - 1):
if signs[i] != signs[i+1]:
info['is_s_pattern'] = True
break
# Classify apex regions
# Assume: top 40% = thoracic, middle 20% = TL junction, bottom 40% = lumbar
thoracic_end = int(0.4 * n_verts)
tl_junction_end = int(0.6 * n_verts)
for apex_idx in apexes:
if apex_idx < thoracic_end:
if info['thoracic_apex_idx'] is None or \
abs(deviations[apex_idx]) > abs(deviations[info['thoracic_apex_idx']]):
info['thoracic_apex_idx'] = apex_idx
elif apex_idx < tl_junction_end:
info['apex_at_tl_junction'] = True
else:
if info['lumbar_apex_idx'] is None or \
abs(deviations[apex_idx]) > abs(deviations[info['lumbar_apex_idx']]):
info['lumbar_apex_idx'] = apex_idx
# Check if lumbar apex is in lower region (bottom 30%)
if info['lumbar_apex_idx'] is not None:
if info['lumbar_apex_idx'] > int(0.7 * n_verts):
info['lumbar_apex_low'] = True
info['apex_directions'] = apex_values
return info