Add patient management, deployment scripts, and Docker fixes
This commit is contained in:
@@ -27,7 +27,7 @@ WORKDIR /app
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Copy and install requirements (from brace-generator folder)
|
||||
# Copy and install requirements
|
||||
COPY brace-generator/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
@@ -35,8 +35,16 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY scoliovis-api/requirements.txt /app/requirements-scoliovis.txt
|
||||
RUN pip install --no-cache-dir -r requirements-scoliovis.txt || true
|
||||
|
||||
# Copy brace-generator code
|
||||
COPY brace-generator/ /app/brace_generator/server_DEV/
|
||||
# Create brace_generator package structure
|
||||
RUN mkdir -p /app/brace_generator
|
||||
|
||||
# Copy brace-generator code as a package
|
||||
COPY brace-generator/*.py /app/brace_generator/
|
||||
COPY brace-generator/__init__.py /app/brace_generator/__init__.py
|
||||
|
||||
# Also keep server_DEV structure for compatibility
|
||||
RUN mkdir -p /app/brace_generator/server_DEV
|
||||
COPY brace-generator/*.py /app/brace_generator/server_DEV/
|
||||
|
||||
# Copy scoliovis-api
|
||||
COPY scoliovis-api/ /app/scoliovis-api/
|
||||
@@ -44,8 +52,8 @@ COPY scoliovis-api/ /app/scoliovis-api/
|
||||
# Copy templates
|
||||
COPY templates/ /app/templates/
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app:/app/brace_generator/server_DEV:/app/scoliovis-api
|
||||
# Set Python path - include both locations
|
||||
ENV PYTHONPATH=/app:/app/scoliovis-api
|
||||
|
||||
# Environment variables
|
||||
ENV HOST=0.0.0.0
|
||||
@@ -61,8 +69,8 @@ RUN mkdir -p /tmp/brace_generator /app/data/uploads /app/data/outputs
|
||||
EXPOSE 8002
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD curl -f http://localhost:8002/health || exit 1
|
||||
|
||||
# Run the server
|
||||
# Run the server from the brace_generator package
|
||||
CMD ["python", "-m", "uvicorn", "brace_generator.server_DEV.app:app", "--host", "0.0.0.0", "--port", "8002"]
|
||||
|
||||
508
brace-generator/adapters.py
Normal file
508
brace-generator/adapters.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""
|
||||
Model adapters that convert different model outputs to unified Spine2D format.
|
||||
Each adapter wraps a specific model and produces consistent output.
|
||||
"""
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from data_models import VertebraLandmark, Spine2D
|
||||
|
||||
|
||||
class BaseLandmarkAdapter(ABC):
|
||||
"""Base class for landmark detection model adapters."""
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, image: np.ndarray) -> Spine2D:
|
||||
"""
|
||||
Run inference on an image and return unified spine landmarks.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (grayscale or RGB)
|
||||
|
||||
Returns:
|
||||
Spine2D object with detected landmarks
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Model name for identification."""
|
||||
pass
|
||||
|
||||
|
||||
class ScolioVisAdapter(BaseLandmarkAdapter):
|
||||
"""
|
||||
Adapter for ScolioVis-API (Keypoint R-CNN model).
|
||||
|
||||
Uses the original ScolioVis inference code for best accuracy.
|
||||
Outputs: 4 keypoints per vertebra + Cobb angles (PT, MT, TL) + curve type (S/C)
|
||||
"""
|
||||
|
||||
def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'):
|
||||
"""
|
||||
Initialize ScolioVis model.
|
||||
|
||||
Args:
|
||||
weights_path: Path to keypointsrcnn_weights.pt (auto-detects if None)
|
||||
device: 'cpu' or 'cuda'
|
||||
"""
|
||||
self.device = device
|
||||
self.model = None
|
||||
self.weights_path = weights_path
|
||||
self._scoliovis_path = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the Keypoint R-CNN model."""
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision.models.detection import keypointrcnn_resnet50_fpn
|
||||
from torchvision.models.detection.rpn import AnchorGenerator
|
||||
|
||||
# Find weights and scoliovis module
|
||||
scoliovis_api_path = Path(__file__).parent.parent / 'scoliovis-api'
|
||||
if self.weights_path is None:
|
||||
possible_paths = [
|
||||
scoliovis_api_path / 'models' / 'keypointsrcnn_weights.pt',
|
||||
scoliovis_api_path / 'keypointsrcnn_weights.pt',
|
||||
scoliovis_api_path / 'weights' / 'keypointsrcnn_weights.pt',
|
||||
]
|
||||
for p in possible_paths:
|
||||
if p.exists():
|
||||
self.weights_path = str(p)
|
||||
break
|
||||
|
||||
if self.weights_path is None or not Path(self.weights_path).exists():
|
||||
raise FileNotFoundError(
|
||||
"ScolioVis weights not found. Please provide weights_path or ensure "
|
||||
"scoliovis-api/models/keypointsrcnn_weights.pt exists."
|
||||
)
|
||||
|
||||
# Store path to scoliovis module for Cobb angle calculation
|
||||
self._scoliovis_path = scoliovis_api_path
|
||||
|
||||
# Create model with same anchor generator as original training
|
||||
anchor_generator = AnchorGenerator(
|
||||
sizes=(32, 64, 128, 256, 512),
|
||||
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
|
||||
)
|
||||
|
||||
self.model = keypointrcnn_resnet50_fpn(
|
||||
weights=None,
|
||||
weights_backbone=None,
|
||||
num_classes=2, # background + vertebra
|
||||
num_keypoints=4, # 4 corners per vertebra
|
||||
rpn_anchor_generator=anchor_generator
|
||||
)
|
||||
|
||||
# Load weights
|
||||
checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False)
|
||||
self.model.load_state_dict(checkpoint)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
print(f"ScolioVis model loaded from {self.weights_path}")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "ScolioVis-API"
|
||||
|
||||
def _filter_output(self, output, max_verts: int = 17):
|
||||
"""
|
||||
Filter model output using NMS and score threshold.
|
||||
Matches the original ScolioVis filtering logic.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
scores = output['scores'].detach().cpu().numpy()
|
||||
|
||||
# Get indices of scores over threshold (0.5)
|
||||
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
|
||||
if len(high_scores_idxs) == 0:
|
||||
return [], [], []
|
||||
|
||||
# Apply NMS with IoU threshold 0.3
|
||||
post_nms_idxs = torchvision.ops.nms(
|
||||
output['boxes'][high_scores_idxs],
|
||||
output['scores'][high_scores_idxs],
|
||||
0.3
|
||||
).cpu().numpy()
|
||||
|
||||
# Get filtered results
|
||||
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
|
||||
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
|
||||
np_scores = output['scores'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
|
||||
|
||||
# Take top N by score (usually 17 for full spine)
|
||||
sorted_scores_idxs = np.argsort(-1 * np_scores)
|
||||
np_scores = np_scores[sorted_scores_idxs][:max_verts]
|
||||
np_keypoints = np.array([np_keypoints[idx] for idx in sorted_scores_idxs])[:max_verts]
|
||||
np_bboxes = np.array([np_bboxes[idx] for idx in sorted_scores_idxs])[:max_verts]
|
||||
|
||||
# Sort by ymin (top to bottom)
|
||||
if len(np_keypoints) > 0:
|
||||
ymins = np.array([kps[0][1] for kps in np_keypoints])
|
||||
sorted_ymin_idxs = np.argsort(ymins)
|
||||
np_scores = np.array([np_scores[idx] for idx in sorted_ymin_idxs])
|
||||
np_keypoints = np.array([np_keypoints[idx] for idx in sorted_ymin_idxs])
|
||||
np_bboxes = np.array([np_bboxes[idx] for idx in sorted_ymin_idxs])
|
||||
|
||||
# Convert to lists
|
||||
keypoints_list = []
|
||||
for kps in np_keypoints:
|
||||
keypoints_list.append([list(map(float, kp[:2])) for kp in kps])
|
||||
|
||||
bboxes_list = []
|
||||
for bbox in np_bboxes:
|
||||
bboxes_list.append(list(map(int, bbox.tolist())))
|
||||
|
||||
scores_list = np_scores.tolist()
|
||||
|
||||
return bboxes_list, keypoints_list, scores_list
|
||||
|
||||
def predict(self, image: np.ndarray) -> Spine2D:
|
||||
"""Run inference and return unified landmarks with ScolioVis Cobb angles."""
|
||||
import torch
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
# Ensure RGB
|
||||
if len(image.shape) == 2:
|
||||
image_rgb = np.stack([image, image, image], axis=-1)
|
||||
else:
|
||||
image_rgb = image
|
||||
|
||||
image_shape = image_rgb.shape # (H, W, C)
|
||||
|
||||
# Convert to tensor (ScolioVis uses torchvision's to_tensor)
|
||||
img_tensor = F.to_tensor(image_rgb).to(self.device)
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
outputs = self.model([img_tensor])
|
||||
|
||||
# Filter output using original ScolioVis logic
|
||||
bboxes, keypoints, scores = self._filter_output(outputs[0])
|
||||
|
||||
if len(keypoints) == 0:
|
||||
return Spine2D(
|
||||
vertebrae=[],
|
||||
image_shape=image_shape[:2],
|
||||
source_model=self.name
|
||||
)
|
||||
|
||||
# Convert to unified format
|
||||
vertebrae = []
|
||||
for i in range(len(bboxes)):
|
||||
kps = np.array(keypoints[i], dtype=np.float32) # (4, 2)
|
||||
|
||||
# Corners order from ScolioVis: [top_left, top_right, bottom_left, bottom_right]
|
||||
corners = kps
|
||||
centroid = np.mean(corners, axis=0)
|
||||
|
||||
# Compute orientation from top edge (kps[0] to kps[1])
|
||||
top_left, top_right = corners[0], corners[1]
|
||||
dx = top_right[0] - top_left[0]
|
||||
dy = top_right[1] - top_left[1]
|
||||
orientation = np.degrees(np.arctan2(dy, dx))
|
||||
|
||||
vert = VertebraLandmark(
|
||||
level=None, # ScolioVis doesn't assign levels
|
||||
centroid_px=centroid,
|
||||
corners_px=corners,
|
||||
endplate_upper_px=corners[:2], # top-left, top-right
|
||||
endplate_lower_px=corners[2:], # bottom-left, bottom-right
|
||||
orientation_deg=orientation,
|
||||
confidence=float(scores[i]),
|
||||
meta={'box': bboxes[i]}
|
||||
)
|
||||
vertebrae.append(vert)
|
||||
|
||||
# Create Spine2D
|
||||
spine = Spine2D(
|
||||
vertebrae=vertebrae,
|
||||
image_shape=image_shape[:2],
|
||||
source_model=self.name
|
||||
)
|
||||
|
||||
# Use original ScolioVis Cobb angle calculation if available
|
||||
if len(keypoints) >= 5:
|
||||
try:
|
||||
# Import original ScolioVis cobb_angle_cal
|
||||
if str(self._scoliovis_path) not in sys.path:
|
||||
sys.path.insert(0, str(self._scoliovis_path))
|
||||
|
||||
from scoliovis.cobb_angle_cal import cobb_angle_cal, keypoints_to_landmark_xy
|
||||
|
||||
landmark_xy = keypoints_to_landmark_xy(keypoints)
|
||||
cobb_angles_list, angles_with_pos, curve_type, midpoint_lines = cobb_angle_cal(
|
||||
landmark_xy, image_shape
|
||||
)
|
||||
|
||||
# Store Cobb angles in spine object
|
||||
spine.cobb_angles = {
|
||||
'PT': cobb_angles_list[0],
|
||||
'MT': cobb_angles_list[1],
|
||||
'TL': cobb_angles_list[2]
|
||||
}
|
||||
spine.curve_type = curve_type
|
||||
spine.meta = {
|
||||
'angles_with_pos': angles_with_pos,
|
||||
'midpoint_lines': midpoint_lines
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not use ScolioVis Cobb calculation: {e}")
|
||||
# Fallback to our own calculation
|
||||
from spine_analysis import compute_cobb_angles
|
||||
compute_cobb_angles(spine)
|
||||
|
||||
return spine
|
||||
|
||||
|
||||
class VertLandmarkAdapter(BaseLandmarkAdapter):
|
||||
"""
|
||||
Adapter for Vertebra-Landmark-Detection (SpineNet model).
|
||||
|
||||
Outputs: 68 landmarks (4 corners × 17 vertebrae)
|
||||
"""
|
||||
|
||||
def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'):
|
||||
"""
|
||||
Initialize SpineNet model.
|
||||
|
||||
Args:
|
||||
weights_path: Path to model_last.pth (auto-detects if None)
|
||||
device: 'cpu' or 'cuda'
|
||||
"""
|
||||
self.device = device
|
||||
self.model = None
|
||||
self.weights_path = weights_path
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the SpineNet model."""
|
||||
import torch
|
||||
|
||||
# Find weights
|
||||
if self.weights_path is None:
|
||||
possible_paths = [
|
||||
Path(__file__).parent.parent / 'Vertebra-Landmark-Detection' / 'weights_spinal' / 'model_last.pth',
|
||||
]
|
||||
for p in possible_paths:
|
||||
if p.exists():
|
||||
self.weights_path = str(p)
|
||||
break
|
||||
|
||||
if self.weights_path is None or not Path(self.weights_path).exists():
|
||||
raise FileNotFoundError(
|
||||
"Vertebra-Landmark-Detection weights not found. "
|
||||
"Download from Google Drive and place in weights_spinal/model_last.pth"
|
||||
)
|
||||
|
||||
# Add repo to path to import model
|
||||
repo_path = Path(__file__).parent.parent / 'Vertebra-Landmark-Detection'
|
||||
if str(repo_path) not in sys.path:
|
||||
sys.path.insert(0, str(repo_path))
|
||||
|
||||
from models import spinal_net
|
||||
|
||||
# Create model
|
||||
heads = {'hm': 1, 'reg': 2, 'wh': 8}
|
||||
self.model = spinal_net.SpineNet(
|
||||
heads=heads,
|
||||
pretrained=False,
|
||||
down_ratio=4,
|
||||
final_kernel=1,
|
||||
head_conv=256
|
||||
)
|
||||
|
||||
# Load weights
|
||||
checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False)
|
||||
self.model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
print(f"Vertebra-Landmark-Detection model loaded from {self.weights_path}")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Vertebra-Landmark-Detection"
|
||||
|
||||
def _nms(self, heat, kernel=3):
|
||||
"""Apply NMS using max pooling."""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
|
||||
keep = (hmax == heat).float()
|
||||
return heat * keep
|
||||
|
||||
def _gather_feat(self, feat, ind):
|
||||
"""Gather features by index."""
|
||||
dim = feat.size(2)
|
||||
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
|
||||
feat = feat.gather(1, ind)
|
||||
return feat
|
||||
|
||||
def _tranpose_and_gather_feat(self, feat, ind):
|
||||
"""Transpose and gather features - matches original decoder."""
|
||||
feat = feat.permute(0, 2, 3, 1).contiguous()
|
||||
feat = feat.view(feat.size(0), -1, feat.size(3))
|
||||
feat = self._gather_feat(feat, ind)
|
||||
return feat
|
||||
|
||||
def _decode_predictions(self, output: Dict, down_ratio: int = 4, K: int = 17):
|
||||
"""Decode model output using original decoder logic."""
|
||||
import torch
|
||||
|
||||
hm = output['hm'].sigmoid()
|
||||
reg = output['reg']
|
||||
wh = output['wh']
|
||||
|
||||
batch, cat, height, width = hm.size()
|
||||
|
||||
# Apply NMS
|
||||
hm = self._nms(hm)
|
||||
|
||||
# Get top K from heatmap
|
||||
topk_scores, topk_inds = torch.topk(hm.view(batch, cat, -1), K)
|
||||
topk_inds = topk_inds % (height * width)
|
||||
topk_ys = (topk_inds // width).float()
|
||||
topk_xs = (topk_inds % width).float()
|
||||
|
||||
# Get overall top K
|
||||
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
|
||||
topk_inds = self._gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
|
||||
topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
|
||||
topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
|
||||
|
||||
scores = topk_score.view(batch, K, 1)
|
||||
|
||||
# Get regression offset and apply
|
||||
reg = self._tranpose_and_gather_feat(reg, topk_inds)
|
||||
reg = reg.view(batch, K, 2)
|
||||
xs = topk_xs.view(batch, K, 1) + reg[:, :, 0:1]
|
||||
ys = topk_ys.view(batch, K, 1) + reg[:, :, 1:2]
|
||||
|
||||
# Get corner offsets
|
||||
wh = self._tranpose_and_gather_feat(wh, topk_inds)
|
||||
wh = wh.view(batch, K, 8)
|
||||
|
||||
# Calculate corners by SUBTRACTING offsets (original decoder logic)
|
||||
tl_x = xs - wh[:, :, 0:1]
|
||||
tl_y = ys - wh[:, :, 1:2]
|
||||
tr_x = xs - wh[:, :, 2:3]
|
||||
tr_y = ys - wh[:, :, 3:4]
|
||||
bl_x = xs - wh[:, :, 4:5]
|
||||
bl_y = ys - wh[:, :, 5:6]
|
||||
br_x = xs - wh[:, :, 6:7]
|
||||
br_y = ys - wh[:, :, 7:8]
|
||||
|
||||
# Combine into output format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score]
|
||||
pts = torch.cat([xs, ys, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, scores], dim=2)
|
||||
|
||||
# Scale to image coordinates
|
||||
pts[:, :, :10] *= down_ratio
|
||||
|
||||
return pts[0].cpu().numpy() # (K, 11)
|
||||
|
||||
def predict(self, image: np.ndarray) -> Spine2D:
|
||||
"""Run inference and return unified landmarks."""
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
# Ensure RGB
|
||||
if len(image.shape) == 2:
|
||||
image_rgb = np.stack([image, image, image], axis=-1)
|
||||
else:
|
||||
image_rgb = image
|
||||
|
||||
orig_h, orig_w = image_rgb.shape[:2]
|
||||
|
||||
# Resize to model input size (1024x512)
|
||||
input_h, input_w = 1024, 512
|
||||
img_resized = cv2.resize(image_rgb, (input_w, input_h))
|
||||
|
||||
# Normalize and convert to tensor - use original preprocessing!
|
||||
# Original: out_image = image / 255. - 0.5 (NOT ImageNet stats)
|
||||
img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0 - 0.5
|
||||
img_tensor = img_tensor.unsqueeze(0).to(self.device)
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
output = self.model(img_tensor)
|
||||
|
||||
# Decode predictions - returns (K, 11) array
|
||||
# Format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score]
|
||||
pts = self._decode_predictions(output, down_ratio=4, K=17)
|
||||
|
||||
# Scale coordinates back to original image size
|
||||
scale_x = orig_w / input_w
|
||||
scale_y = orig_h / input_h
|
||||
|
||||
# Convert to unified format
|
||||
vertebrae = []
|
||||
threshold = 0.3
|
||||
|
||||
for i in range(len(pts)):
|
||||
score = pts[i, 10]
|
||||
if score < threshold:
|
||||
continue
|
||||
|
||||
# Get center and corners (already scaled by down_ratio in decoder)
|
||||
cx = pts[i, 0] * scale_x
|
||||
cy = pts[i, 1] * scale_y
|
||||
|
||||
# Corners: tl, tr, bl, br
|
||||
tl = np.array([pts[i, 2] * scale_x, pts[i, 3] * scale_y])
|
||||
tr = np.array([pts[i, 4] * scale_x, pts[i, 5] * scale_y])
|
||||
bl = np.array([pts[i, 6] * scale_x, pts[i, 7] * scale_y])
|
||||
br = np.array([pts[i, 8] * scale_x, pts[i, 9] * scale_y])
|
||||
|
||||
# Reorder to [tl, tr, br, bl] for consistency with ScolioVis
|
||||
corners = np.array([tl, tr, br, bl], dtype=np.float32)
|
||||
centroid = np.array([cx, cy], dtype=np.float32)
|
||||
|
||||
# Compute orientation from top edge
|
||||
dx = tr[0] - tl[0]
|
||||
dy = tr[1] - tl[1]
|
||||
orientation = np.degrees(np.arctan2(dy, dx))
|
||||
|
||||
vert = VertebraLandmark(
|
||||
level=None,
|
||||
centroid_px=centroid,
|
||||
corners_px=corners,
|
||||
endplate_upper_px=np.array([tl, tr]), # top edge
|
||||
endplate_lower_px=np.array([bl, br]), # bottom edge
|
||||
orientation_deg=orientation,
|
||||
confidence=float(score),
|
||||
meta={'raw_pts': pts[i].tolist()}
|
||||
)
|
||||
vertebrae.append(vert)
|
||||
|
||||
# Sort by y-coordinate (top to bottom)
|
||||
vertebrae.sort(key=lambda v: float(v.centroid_px[1]))
|
||||
|
||||
# Assign vertebra levels (T1-L5 = 17 vertebrae typically)
|
||||
level_names = ['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4', 'L5']
|
||||
for i, vert in enumerate(vertebrae):
|
||||
if i < len(level_names):
|
||||
vert.level = level_names[i]
|
||||
|
||||
# Create Spine2D
|
||||
spine = Spine2D(
|
||||
vertebrae=vertebrae,
|
||||
image_shape=(orig_h, orig_w),
|
||||
source_model=self.name
|
||||
)
|
||||
|
||||
# Compute Cobb angles
|
||||
if len(vertebrae) >= 7:
|
||||
from spine_analysis import compute_cobb_angles
|
||||
compute_cobb_angles(spine)
|
||||
|
||||
return spine
|
||||
354
brace-generator/brace_surface.py
Normal file
354
brace-generator/brace_surface.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Brace surface generation from spine landmarks.
|
||||
|
||||
Two modes:
|
||||
- Version A: Generic/average body shape (parametric torso)
|
||||
- Version B: Uses actual 3D body scan mesh
|
||||
"""
|
||||
import numpy as np
|
||||
from typing import Tuple, Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
from .data_models import Spine2D, BraceConfig
|
||||
from .spine_analysis import compute_spine_curve, find_apex_vertebrae
|
||||
|
||||
try:
|
||||
import trimesh
|
||||
HAS_TRIMESH = True
|
||||
except ImportError:
|
||||
HAS_TRIMESH = False
|
||||
|
||||
|
||||
class BraceGenerator:
|
||||
"""
|
||||
Generates 3D brace shell from spine landmarks.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BraceConfig] = None):
|
||||
"""
|
||||
Initialize brace generator.
|
||||
|
||||
Args:
|
||||
config: Brace configuration parameters
|
||||
"""
|
||||
if not HAS_TRIMESH:
|
||||
raise ImportError("trimesh is required for brace generation. Install with: pip install trimesh")
|
||||
|
||||
self.config = config or BraceConfig()
|
||||
|
||||
def generate(self, spine: Spine2D) -> 'trimesh.Trimesh':
|
||||
"""
|
||||
Generate brace mesh from spine landmarks.
|
||||
|
||||
Args:
|
||||
spine: Spine2D object with detected vertebrae
|
||||
|
||||
Returns:
|
||||
trimesh.Trimesh object representing the brace shell
|
||||
"""
|
||||
if self.config.use_body_scan and self.config.body_scan_path:
|
||||
return self._generate_from_body_scan(spine)
|
||||
else:
|
||||
return self._generate_from_average_body(spine)
|
||||
|
||||
def _torso_profile(self, z01: float) -> Tuple[float, float]:
|
||||
"""
|
||||
Get torso cross-section radii at a given height.
|
||||
|
||||
Args:
|
||||
z01: Normalized height (0=top, 1=bottom)
|
||||
|
||||
Returns:
|
||||
(a_mm, b_mm): Radii in left-right and front-back directions
|
||||
"""
|
||||
# Torso shape varies with height
|
||||
# Wider at chest (z~0.3) and hips (z~0.8), narrower at waist (z~0.5)
|
||||
|
||||
# Base radii from config
|
||||
base_a = self.config.torso_width_mm / 2
|
||||
base_b = self.config.torso_depth_mm / 2
|
||||
|
||||
# Shape modulation
|
||||
# Chest region (z ~ 0.2-0.4): wider
|
||||
# Waist region (z ~ 0.5): narrower
|
||||
# Hip region (z ~ 0.8-1.0): wider
|
||||
|
||||
if z01 < 0.3:
|
||||
# Upper chest - moderate width
|
||||
mod = 1.0
|
||||
elif z01 < 0.5:
|
||||
# Transition to waist
|
||||
t = (z01 - 0.3) / 0.2
|
||||
mod = 1.0 - 0.15 * t # Decrease by 15%
|
||||
elif z01 < 0.7:
|
||||
# Waist region - narrowest
|
||||
mod = 0.85
|
||||
else:
|
||||
# Hips - widen again
|
||||
t = (z01 - 0.7) / 0.3
|
||||
mod = 0.85 + 0.2 * t # Increase by 20%
|
||||
|
||||
return base_a * mod, base_b * mod
|
||||
|
||||
def _generate_from_average_body(self, spine: Spine2D) -> 'trimesh.Trimesh':
|
||||
"""
|
||||
Generate brace using parametric average body shape.
|
||||
|
||||
The brace follows the spine curve laterally and applies
|
||||
pressure zones at curve apexes.
|
||||
"""
|
||||
cfg = self.config
|
||||
|
||||
# 1) Compute spine curve
|
||||
try:
|
||||
C_px, T_px, N_px, curvature = compute_spine_curve(spine, smooth=5.0, n_samples=cfg.n_vertical_slices)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Cannot generate brace: {e}")
|
||||
|
||||
# 2) Convert to mm
|
||||
if spine.pixel_spacing_mm is not None:
|
||||
sx, sy = spine.pixel_spacing_mm
|
||||
elif cfg.pixel_spacing_mm is not None:
|
||||
sx, sy = cfg.pixel_spacing_mm
|
||||
else:
|
||||
sx = sy = 0.25 # Default assumption
|
||||
|
||||
C_mm = np.zeros_like(C_px)
|
||||
C_mm[:, 0] = C_px[:, 0] * sx
|
||||
C_mm[:, 1] = C_px[:, 1] * sy
|
||||
|
||||
# 3) Determine brace vertical extent
|
||||
y_mm = C_mm[:, 1]
|
||||
y_min, y_max = y_mm.min(), y_mm.max()
|
||||
spine_height = y_max - y_min
|
||||
|
||||
# Brace height (might extend beyond detected vertebrae)
|
||||
brace_height = min(cfg.brace_height_mm, spine_height * 1.1)
|
||||
|
||||
# 4) Normalize curvature for pressure zones
|
||||
curv_norm = (curvature - curvature.min()) / (curvature.max() - curvature.min() + 1e-8)
|
||||
|
||||
# 5) Build vertices
|
||||
n_z = cfg.n_vertical_slices
|
||||
n_theta = cfg.n_circumference_points
|
||||
|
||||
# Opening angle (front of brace might be open)
|
||||
opening_half = np.radians(cfg.front_opening_deg / 2)
|
||||
|
||||
vertices = []
|
||||
|
||||
for i in range(n_z):
|
||||
z01 = i / (n_z - 1) # 0 to 1
|
||||
|
||||
# Z coordinate (vertical position in 3D)
|
||||
z_mm = y_min + z01 * spine_height
|
||||
|
||||
# Get torso profile at this height
|
||||
a_mm, b_mm = self._torso_profile(z01)
|
||||
|
||||
# Lateral offset from spine curve
|
||||
x_offset = C_mm[i, 0] - (C_mm[0, 0] + C_mm[-1, 0]) / 2 # Deviation from midline
|
||||
|
||||
# Pressure modulation based on curvature
|
||||
pressure = cfg.pressure_strength_mm * curv_norm[i]
|
||||
|
||||
for j in range(n_theta):
|
||||
theta = 2 * np.pi * (j / n_theta)
|
||||
|
||||
# Skip vertices in the opening region (front = theta around 0)
|
||||
# Actually, we'll still create them but can mark them for later removal
|
||||
|
||||
# Base ellipse point
|
||||
x = a_mm * np.cos(theta)
|
||||
y = b_mm * np.sin(theta)
|
||||
|
||||
# Apply lateral offset (brace follows spine curve)
|
||||
x += x_offset
|
||||
|
||||
# Apply pressure zones
|
||||
# Pressure on sides (theta near π/2 or 3π/2 = sides)
|
||||
# The side that's convex gets pushed in
|
||||
side_factor = abs(np.cos(theta)) # Max at sides (theta=0 or π)
|
||||
|
||||
# Determine which side based on spine deviation
|
||||
if x_offset > 0:
|
||||
# Spine deviated right, push on right side
|
||||
if np.cos(theta) > 0: # Right side
|
||||
x -= pressure * side_factor
|
||||
else:
|
||||
# Spine deviated left, push on left side
|
||||
if np.cos(theta) < 0: # Left side
|
||||
x -= pressure * side_factor * np.sign(np.cos(theta))
|
||||
|
||||
# Vertex position: x=left/right, y=front/back, z=vertical
|
||||
vertices.append([x, y, z_mm])
|
||||
|
||||
vertices = np.array(vertices, dtype=np.float32)
|
||||
|
||||
# 6) Build faces (quad strips between adjacent rings)
|
||||
faces = []
|
||||
|
||||
def vid(i, j):
|
||||
return i * n_theta + (j % n_theta)
|
||||
|
||||
for i in range(n_z - 1):
|
||||
for j in range(n_theta):
|
||||
j2 = (j + 1) % n_theta
|
||||
|
||||
# Two triangles per quad
|
||||
a = vid(i, j)
|
||||
b = vid(i, j2)
|
||||
c = vid(i + 1, j2)
|
||||
d = vid(i + 1, j)
|
||||
|
||||
faces.append([a, b, c])
|
||||
faces.append([a, c, d])
|
||||
|
||||
faces = np.array(faces, dtype=np.int32)
|
||||
|
||||
# 7) Create outer shell mesh
|
||||
outer_shell = trimesh.Trimesh(vertices=vertices, faces=faces, process=True)
|
||||
|
||||
# 8) Create inner shell (offset inward by wall thickness)
|
||||
outer_shell.fix_normals()
|
||||
vn = outer_shell.vertex_normals
|
||||
inner_vertices = vertices - cfg.wall_thickness_mm * vn
|
||||
|
||||
# Inner faces need reversed winding
|
||||
inner_faces = faces[:, ::-1]
|
||||
|
||||
# 9) Combine into solid shell
|
||||
all_vertices = np.vstack([vertices, inner_vertices])
|
||||
inner_faces_offset = inner_faces + len(vertices)
|
||||
all_faces = np.vstack([faces, inner_faces_offset])
|
||||
|
||||
# 10) Add end caps (top and bottom rings)
|
||||
# Top cap (connect outer to inner at z=0)
|
||||
top_faces = []
|
||||
for j in range(n_theta):
|
||||
j2 = (j + 1) % n_theta
|
||||
outer_j = vid(0, j)
|
||||
outer_j2 = vid(0, j2)
|
||||
inner_j = outer_j + len(vertices)
|
||||
inner_j2 = outer_j2 + len(vertices)
|
||||
top_faces.append([outer_j, inner_j, inner_j2])
|
||||
top_faces.append([outer_j, inner_j2, outer_j2])
|
||||
|
||||
# Bottom cap
|
||||
bottom_faces = []
|
||||
for j in range(n_theta):
|
||||
j2 = (j + 1) % n_theta
|
||||
outer_j = vid(n_z - 1, j)
|
||||
outer_j2 = vid(n_z - 1, j2)
|
||||
inner_j = outer_j + len(vertices)
|
||||
inner_j2 = outer_j2 + len(vertices)
|
||||
bottom_faces.append([outer_j, outer_j2, inner_j2])
|
||||
bottom_faces.append([outer_j, inner_j2, inner_j])
|
||||
|
||||
all_faces = np.vstack([all_faces, top_faces, bottom_faces])
|
||||
|
||||
# Create final mesh
|
||||
brace = trimesh.Trimesh(vertices=all_vertices, faces=all_faces, process=True)
|
||||
brace.merge_vertices()
|
||||
# Remove degenerate faces
|
||||
valid_faces = brace.nondegenerate_faces()
|
||||
brace.update_faces(valid_faces)
|
||||
brace.fix_normals()
|
||||
|
||||
return brace
|
||||
|
||||
def _generate_from_body_scan(self, spine: Spine2D) -> 'trimesh.Trimesh':
|
||||
"""
|
||||
Generate brace by offsetting from a 3D body scan mesh.
|
||||
|
||||
The body scan provides the actual torso shape, and we:
|
||||
1. Offset outward for clearance
|
||||
2. Apply pressure zones based on spine curvature
|
||||
3. Thicken for wall thickness
|
||||
"""
|
||||
cfg = self.config
|
||||
|
||||
if not cfg.body_scan_path or not Path(cfg.body_scan_path).exists():
|
||||
raise FileNotFoundError(f"Body scan not found: {cfg.body_scan_path}")
|
||||
|
||||
# Load body scan
|
||||
body = trimesh.load(cfg.body_scan_path, force='mesh')
|
||||
body.remove_unreferenced_vertices()
|
||||
body.fix_normals()
|
||||
|
||||
# Compute spine curve for pressure mapping
|
||||
try:
|
||||
C_px, T_px, N_px, curvature = compute_spine_curve(spine, smooth=5.0, n_samples=200)
|
||||
except ValueError:
|
||||
curvature = np.zeros(200)
|
||||
|
||||
# Convert spine coordinates to mm
|
||||
if spine.pixel_spacing_mm is not None:
|
||||
sx, sy = spine.pixel_spacing_mm
|
||||
else:
|
||||
sx = sy = 0.25
|
||||
|
||||
y_mm = C_px[:, 1] * sy
|
||||
y_min, y_max = y_mm.min(), y_mm.max()
|
||||
H = y_max - y_min + 1e-6
|
||||
|
||||
# Normalize curvature
|
||||
curv_norm = (curvature - curvature.min()) / (curvature.max() - curvature.min() + 1e-8)
|
||||
|
||||
# 1) Offset body surface outward for clearance (inner brace surface)
|
||||
clearance_mm = 6.0 # Gap between body and brace
|
||||
vn = body.vertex_normals
|
||||
inner_surface = trimesh.Trimesh(
|
||||
vertices=body.vertices + clearance_mm * vn,
|
||||
faces=body.faces.copy(),
|
||||
process=True
|
||||
)
|
||||
|
||||
# 2) Apply pressure deformation
|
||||
# Map each vertex's Z coordinate to spine curvature
|
||||
z_coords = inner_surface.vertices[:, 2] # Assuming Z is vertical
|
||||
z_min, z_max = z_coords.min(), z_coords.max()
|
||||
z01 = (z_coords - z_min) / (z_max - z_min + 1e-6)
|
||||
|
||||
# Sample curvature at each vertex height
|
||||
curv_idx = np.clip((z01 * (len(curv_norm) - 1)).astype(int), 0, len(curv_norm) - 1)
|
||||
pressure_per_vertex = cfg.pressure_strength_mm * curv_norm[curv_idx]
|
||||
|
||||
# Apply pressure on sides (based on X coordinate)
|
||||
x_coords = inner_surface.vertices[:, 0]
|
||||
x_range = np.abs(x_coords).max() + 1e-6
|
||||
side_factor = np.abs(x_coords) / x_range # 0 at center, 1 at sides
|
||||
|
||||
deformation = (pressure_per_vertex * side_factor)[:, np.newaxis] * inner_surface.vertex_normals
|
||||
inner_surface.vertices = inner_surface.vertices - deformation
|
||||
|
||||
# 3) Create outer surface (offset by wall thickness)
|
||||
inner_surface.fix_normals()
|
||||
outer_surface = trimesh.Trimesh(
|
||||
vertices=inner_surface.vertices + cfg.wall_thickness_mm * inner_surface.vertex_normals,
|
||||
faces=inner_surface.faces.copy(),
|
||||
process=True
|
||||
)
|
||||
|
||||
# 4) Combine surfaces
|
||||
# For a true solid, we'd need to stitch edges - simplified here
|
||||
brace = trimesh.util.concatenate([inner_surface, outer_surface])
|
||||
brace.merge_vertices()
|
||||
valid_faces = brace.nondegenerate_faces()
|
||||
brace.update_faces(valid_faces)
|
||||
brace.fix_normals()
|
||||
|
||||
return brace
|
||||
|
||||
def export_stl(self, mesh: 'trimesh.Trimesh', output_path: str):
|
||||
"""
|
||||
Export mesh to STL file.
|
||||
|
||||
Args:
|
||||
mesh: trimesh.Trimesh object
|
||||
output_path: Path for output STL file
|
||||
"""
|
||||
mesh.export(output_path)
|
||||
print(f"Exported brace to {output_path}")
|
||||
print(f" Vertices: {len(mesh.vertices)}")
|
||||
print(f" Faces: {len(mesh.faces)}")
|
||||
177
brace-generator/data_models.py
Normal file
177
brace-generator/data_models.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Data models for unified spine landmark representation.
|
||||
This is the "glue" that connects different model outputs to the brace generator.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Any
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertebraLandmark:
|
||||
"""
|
||||
Unified representation of a single vertebra's landmarks.
|
||||
All coordinates are in pixels (can be converted to mm with pixel_spacing).
|
||||
"""
|
||||
# Vertebra level identifier (e.g., "T1", "T4", "L1", etc.) - None if unknown
|
||||
level: Optional[str] = None
|
||||
|
||||
# Center point of vertebra [x, y] in pixels
|
||||
centroid_px: np.ndarray = field(default_factory=lambda: np.zeros(2))
|
||||
|
||||
# Four corner points [top_left, top_right, bottom_right, bottom_left] shape (4, 2)
|
||||
corners_px: Optional[np.ndarray] = None
|
||||
|
||||
# Upper endplate points [left, right] shape (2, 2)
|
||||
endplate_upper_px: Optional[np.ndarray] = None
|
||||
|
||||
# Lower endplate points [left, right] shape (2, 2)
|
||||
endplate_lower_px: Optional[np.ndarray] = None
|
||||
|
||||
# Orientation angle of vertebra in degrees (tilt in coronal plane)
|
||||
orientation_deg: Optional[float] = None
|
||||
|
||||
# Detection confidence (0-1)
|
||||
confidence: float = 1.0
|
||||
|
||||
# Additional metadata from source model
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
def compute_orientation(self) -> float:
|
||||
"""Compute vertebra orientation from corners or endplates."""
|
||||
if self.orientation_deg is not None:
|
||||
return self.orientation_deg
|
||||
|
||||
# Try to compute from upper endplate
|
||||
if self.endplate_upper_px is not None:
|
||||
left, right = self.endplate_upper_px[0], self.endplate_upper_px[1]
|
||||
dx = right[0] - left[0]
|
||||
dy = right[1] - left[1]
|
||||
angle = np.degrees(np.arctan2(dy, dx))
|
||||
self.orientation_deg = angle
|
||||
return angle
|
||||
|
||||
# Try to compute from corners (top-left to top-right)
|
||||
if self.corners_px is not None:
|
||||
top_left, top_right = self.corners_px[0], self.corners_px[1]
|
||||
dx = top_right[0] - top_left[0]
|
||||
dy = top_right[1] - top_left[1]
|
||||
angle = np.degrees(np.arctan2(dy, dx))
|
||||
self.orientation_deg = angle
|
||||
return angle
|
||||
|
||||
return 0.0
|
||||
|
||||
def compute_centroid(self) -> np.ndarray:
|
||||
"""Compute centroid from corners if not set."""
|
||||
if self.corners_px is not None and np.all(self.centroid_px == 0):
|
||||
self.centroid_px = np.mean(self.corners_px, axis=0)
|
||||
return self.centroid_px
|
||||
|
||||
|
||||
@dataclass
|
||||
class Spine2D:
|
||||
"""
|
||||
Complete 2D spine representation from an X-ray.
|
||||
Contains all detected vertebrae and computed angles.
|
||||
"""
|
||||
# List of vertebrae, ordered from top (C7/T1) to bottom (L5/S1)
|
||||
vertebrae: List[VertebraLandmark] = field(default_factory=list)
|
||||
|
||||
# Pixel spacing in mm [sx, sy] - from DICOM if available
|
||||
pixel_spacing_mm: Optional[np.ndarray] = None
|
||||
|
||||
# Original image shape (height, width)
|
||||
image_shape: Optional[tuple] = None
|
||||
|
||||
# Computed Cobb angles in degrees (individual fields)
|
||||
cobb_pt: Optional[float] = None # Proximal Thoracic
|
||||
cobb_mt: Optional[float] = None # Main Thoracic
|
||||
cobb_tl: Optional[float] = None # Thoracolumbar/Lumbar
|
||||
|
||||
# Cobb angles as dictionary (alternative format)
|
||||
cobb_angles: Optional[Dict[str, float]] = None # {'PT': angle, 'MT': angle, 'TL': angle}
|
||||
|
||||
# Curve type: "S" (double curve) or "C" (single curve) or "Normal"
|
||||
curve_type: Optional[str] = None
|
||||
|
||||
# Rigo-Chêneau classification
|
||||
rigo_type: Optional[str] = None # A1, A2, A3, B1, B2, C1, C2, E1, E2, Normal
|
||||
rigo_description: Optional[str] = None # Detailed description
|
||||
|
||||
# Source model that generated this data
|
||||
source_model: Optional[str] = None
|
||||
|
||||
# Additional metadata
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
def get_cobb_angles(self) -> Dict[str, float]:
|
||||
"""Get Cobb angles as dictionary, preferring computed individual fields."""
|
||||
# Prefer individual fields (set by compute_cobb_angles) over dictionary
|
||||
# This ensures consistency between displayed values and classification
|
||||
if self.cobb_pt is not None or self.cobb_mt is not None or self.cobb_tl is not None:
|
||||
return {
|
||||
'PT': self.cobb_pt or 0.0,
|
||||
'MT': self.cobb_mt or 0.0,
|
||||
'TL': self.cobb_tl or 0.0
|
||||
}
|
||||
if self.cobb_angles is not None:
|
||||
return self.cobb_angles
|
||||
return {'PT': 0.0, 'MT': 0.0, 'TL': 0.0}
|
||||
|
||||
def get_centroids(self) -> np.ndarray:
|
||||
"""Get array of all vertebra centroids, shape (N, 2)."""
|
||||
centroids = []
|
||||
for v in self.vertebrae:
|
||||
v.compute_centroid()
|
||||
centroids.append(v.centroid_px)
|
||||
return np.array(centroids, dtype=np.float32)
|
||||
|
||||
def get_orientations(self) -> np.ndarray:
|
||||
"""Get array of all vertebra orientations in degrees, shape (N,)."""
|
||||
return np.array([v.compute_orientation() for v in self.vertebrae], dtype=np.float32)
|
||||
|
||||
def to_mm(self, coords_px: np.ndarray) -> np.ndarray:
|
||||
"""Convert pixel coordinates to millimeters."""
|
||||
if self.pixel_spacing_mm is None:
|
||||
# Default assumption: 0.25 mm/pixel (typical for spine X-rays)
|
||||
spacing = np.array([0.25, 0.25])
|
||||
else:
|
||||
spacing = self.pixel_spacing_mm
|
||||
return coords_px * spacing
|
||||
|
||||
def sort_vertebrae(self):
|
||||
"""Sort vertebrae by vertical position (top to bottom)."""
|
||||
self.vertebrae.sort(key=lambda v: float(v.centroid_px[1]))
|
||||
|
||||
|
||||
@dataclass
|
||||
class BraceConfig:
|
||||
"""
|
||||
Configuration parameters for brace generation.
|
||||
"""
|
||||
# Brace dimensions
|
||||
brace_height_mm: float = 400.0 # Total height of brace
|
||||
wall_thickness_mm: float = 4.0 # Shell thickness
|
||||
|
||||
# Torso shape parameters (for average body mode)
|
||||
torso_width_mm: float = 280.0 # Left-right diameter at widest
|
||||
torso_depth_mm: float = 200.0 # Front-back diameter at widest
|
||||
|
||||
# Correction parameters
|
||||
pressure_strength_mm: float = 15.0 # Max indentation at apex
|
||||
pressure_spread_deg: float = 45.0 # Angular spread of pressure zone
|
||||
|
||||
# Mesh resolution
|
||||
n_vertical_slices: int = 100 # Number of cross-sections
|
||||
n_circumference_points: int = 72 # Points per cross-section (every 5°)
|
||||
|
||||
# Opening (for brace accessibility)
|
||||
front_opening_deg: float = 60.0 # Angular width of front opening (0 = closed)
|
||||
|
||||
# Mode
|
||||
use_body_scan: bool = False # True = use 3D body scan, False = average body
|
||||
body_scan_path: Optional[str] = None # Path to body scan mesh
|
||||
|
||||
# Scale
|
||||
pixel_spacing_mm: Optional[np.ndarray] = None # Override pixel spacing
|
||||
115
brace-generator/image_loader.py
Normal file
115
brace-generator/image_loader.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Image loader supporting JPEG, PNG, and DICOM formats.
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
try:
|
||||
import pydicom
|
||||
HAS_PYDICOM = True
|
||||
except ImportError:
|
||||
HAS_PYDICOM = False
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
HAS_PIL = True
|
||||
except ImportError:
|
||||
HAS_PIL = False
|
||||
|
||||
try:
|
||||
import cv2
|
||||
HAS_CV2 = True
|
||||
except ImportError:
|
||||
HAS_CV2 = False
|
||||
|
||||
|
||||
def load_xray(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""
|
||||
Load an X-ray image from file.
|
||||
|
||||
Supports: JPEG, PNG, BMP, DICOM (.dcm)
|
||||
|
||||
Args:
|
||||
path: Path to the image file
|
||||
|
||||
Returns:
|
||||
img_u8: Grayscale image as uint8 array (H, W)
|
||||
spacing_mm: Pixel spacing [sx, sy] in mm, or None if not available
|
||||
"""
|
||||
path = Path(path)
|
||||
suffix = path.suffix.lower()
|
||||
|
||||
# DICOM
|
||||
if suffix in ['.dcm', '.dicom']:
|
||||
if not HAS_PYDICOM:
|
||||
raise ImportError("pydicom is required for DICOM files. Install with: pip install pydicom")
|
||||
return _load_dicom(str(path))
|
||||
|
||||
# Standard image formats
|
||||
if suffix in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
||||
return _load_standard_image(str(path))
|
||||
|
||||
# Try to load as standard image anyway
|
||||
try:
|
||||
return _load_standard_image(str(path))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not load image: {path}. Error: {e}")
|
||||
|
||||
|
||||
def _load_dicom(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""Load DICOM file."""
|
||||
ds = pydicom.dcmread(path)
|
||||
arr = ds.pixel_array.astype(np.float32)
|
||||
|
||||
# Apply modality LUT if present
|
||||
if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
|
||||
arr = arr * ds.RescaleSlope + ds.RescaleIntercept
|
||||
|
||||
# Normalize to 0-255
|
||||
arr = arr - arr.min()
|
||||
if arr.max() > 0:
|
||||
arr = arr / arr.max()
|
||||
img_u8 = (arr * 255).astype(np.uint8)
|
||||
|
||||
# Get pixel spacing
|
||||
spacing_mm = None
|
||||
if hasattr(ds, 'PixelSpacing'):
|
||||
# PixelSpacing is [row_spacing, col_spacing] in mm
|
||||
sy, sx = [float(x) for x in ds.PixelSpacing]
|
||||
spacing_mm = np.array([sx, sy], dtype=np.float32)
|
||||
elif hasattr(ds, 'ImagerPixelSpacing'):
|
||||
sy, sx = [float(x) for x in ds.ImagerPixelSpacing]
|
||||
spacing_mm = np.array([sx, sy], dtype=np.float32)
|
||||
|
||||
return img_u8, spacing_mm
|
||||
|
||||
|
||||
def _load_standard_image(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""Load standard image format (JPEG, PNG, etc.)."""
|
||||
if HAS_CV2:
|
||||
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
||||
if img is None:
|
||||
raise ValueError(f"Could not read image: {path}")
|
||||
return img.astype(np.uint8), None
|
||||
elif HAS_PIL:
|
||||
img = Image.open(path).convert('L') # Convert to grayscale
|
||||
return np.array(img, dtype=np.uint8), None
|
||||
else:
|
||||
raise ImportError("Either opencv-python or Pillow is required. Install with: pip install opencv-python")
|
||||
|
||||
|
||||
def load_xray_rgb(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""
|
||||
Load X-ray as RGB (for models that expect 3-channel input).
|
||||
|
||||
Returns:
|
||||
img_rgb: RGB image as uint8 array (H, W, 3)
|
||||
spacing_mm: Pixel spacing or None
|
||||
"""
|
||||
img_gray, spacing = load_xray(path)
|
||||
|
||||
# Convert grayscale to RGB by stacking
|
||||
img_rgb = np.stack([img_gray, img_gray, img_gray], axis=-1)
|
||||
|
||||
return img_rgb, spacing
|
||||
346
brace-generator/pipeline.py
Normal file
346
brace-generator/pipeline.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Complete pipeline: X-ray → Landmarks → Brace STL
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Union
|
||||
import json
|
||||
|
||||
from brace_generator.data_models import Spine2D, BraceConfig, VertebraLandmark
|
||||
from brace_generator.image_loader import load_xray, load_xray_rgb
|
||||
from brace_generator.adapters import BaseLandmarkAdapter, ScolioVisAdapter, VertLandmarkAdapter
|
||||
from brace_generator.spine_analysis import (
|
||||
compute_spine_curve, compute_cobb_angles, find_apex_vertebrae,
|
||||
get_curve_severity, classify_rigo_type
|
||||
)
|
||||
from brace_generator.brace_surface import BraceGenerator
|
||||
|
||||
|
||||
class BracePipeline:
|
||||
"""
|
||||
End-to-end pipeline for generating scoliosis braces from X-rays.
|
||||
|
||||
Usage:
|
||||
# Basic usage with default model
|
||||
pipeline = BracePipeline()
|
||||
pipeline.process("xray.png", "brace.stl")
|
||||
|
||||
# With specific model
|
||||
pipeline = BracePipeline(model="vertebra-landmark")
|
||||
pipeline.process("xray.dcm", "brace.stl")
|
||||
|
||||
# With body scan
|
||||
config = BraceConfig(use_body_scan=True, body_scan_path="body.obj")
|
||||
pipeline = BracePipeline(config=config)
|
||||
pipeline.process("xray.png", "brace.stl")
|
||||
"""
|
||||
|
||||
AVAILABLE_MODELS = {
|
||||
'scoliovis': ScolioVisAdapter,
|
||||
'vertebra-landmark': VertLandmarkAdapter,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = 'scoliovis',
|
||||
config: Optional[BraceConfig] = None,
|
||||
device: str = 'cpu'
|
||||
):
|
||||
"""
|
||||
Initialize pipeline.
|
||||
|
||||
Args:
|
||||
model: Model to use ('scoliovis' or 'vertebra-landmark')
|
||||
config: Brace configuration
|
||||
device: 'cpu' or 'cuda'
|
||||
"""
|
||||
self.device = device
|
||||
self.config = config or BraceConfig()
|
||||
self.model_name = model.lower()
|
||||
|
||||
# Initialize model adapter
|
||||
if self.model_name not in self.AVAILABLE_MODELS:
|
||||
raise ValueError(f"Unknown model: {model}. Available: {list(self.AVAILABLE_MODELS.keys())}")
|
||||
|
||||
self.adapter: BaseLandmarkAdapter = self.AVAILABLE_MODELS[self.model_name](device=device)
|
||||
self.brace_generator = BraceGenerator(self.config)
|
||||
|
||||
# Store last results for inspection
|
||||
self.last_spine: Optional[Spine2D] = None
|
||||
self.last_image: Optional[np.ndarray] = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
xray_path: str,
|
||||
output_stl_path: str,
|
||||
visualize: bool = False,
|
||||
save_landmarks: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process X-ray and generate brace STL.
|
||||
|
||||
Args:
|
||||
xray_path: Path to input X-ray (JPEG, PNG, or DICOM)
|
||||
output_stl_path: Path for output STL file
|
||||
visualize: If True, also save visualization image
|
||||
save_landmarks: If True, also save landmarks JSON
|
||||
|
||||
Returns:
|
||||
Dictionary with analysis results
|
||||
"""
|
||||
print(f"=" * 60)
|
||||
print(f"Brace Generation Pipeline")
|
||||
print(f"Model: {self.adapter.name}")
|
||||
print(f"=" * 60)
|
||||
|
||||
# 1) Load X-ray
|
||||
print(f"\n1. Loading X-ray: {xray_path}")
|
||||
image_rgb, pixel_spacing = load_xray_rgb(xray_path)
|
||||
self.last_image = image_rgb
|
||||
print(f" Image size: {image_rgb.shape[:2]}")
|
||||
if pixel_spacing is not None:
|
||||
print(f" Pixel spacing: {pixel_spacing} mm")
|
||||
|
||||
# 2) Detect landmarks
|
||||
print(f"\n2. Detecting landmarks...")
|
||||
spine = self.adapter.predict(image_rgb)
|
||||
spine.pixel_spacing_mm = pixel_spacing
|
||||
self.last_spine = spine
|
||||
|
||||
print(f" Detected {len(spine.vertebrae)} vertebrae")
|
||||
|
||||
if len(spine.vertebrae) < 5:
|
||||
raise ValueError(f"Insufficient vertebrae detected ({len(spine.vertebrae)}). Need at least 5.")
|
||||
|
||||
# 3) Compute spine analysis
|
||||
print(f"\n3. Analyzing spine curvature...")
|
||||
compute_cobb_angles(spine)
|
||||
apexes = find_apex_vertebrae(spine)
|
||||
|
||||
# Classify Rigo type
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
print(f" Cobb Angles:")
|
||||
print(f" PT (Proximal Thoracic): {spine.cobb_pt:.1f}° - {get_curve_severity(spine.cobb_pt)}")
|
||||
print(f" MT (Main Thoracic): {spine.cobb_mt:.1f}° - {get_curve_severity(spine.cobb_mt)}")
|
||||
print(f" TL (Thoracolumbar): {spine.cobb_tl:.1f}° - {get_curve_severity(spine.cobb_tl)}")
|
||||
print(f" Curve type: {spine.curve_type}")
|
||||
print(f" Rigo Classification: {rigo_result['rigo_type']}")
|
||||
print(f" - {rigo_result['description']}")
|
||||
print(f" Apex vertebrae indices: {apexes}")
|
||||
|
||||
# 4) Generate brace
|
||||
print(f"\n4. Generating brace mesh...")
|
||||
if self.config.use_body_scan:
|
||||
print(f" Mode: Using body scan ({self.config.body_scan_path})")
|
||||
else:
|
||||
print(f" Mode: Average body shape")
|
||||
|
||||
brace_mesh = self.brace_generator.generate(spine)
|
||||
print(f" Mesh: {len(brace_mesh.vertices)} vertices, {len(brace_mesh.faces)} faces")
|
||||
|
||||
# 5) Export STL
|
||||
print(f"\n5. Exporting STL: {output_stl_path}")
|
||||
self.brace_generator.export_stl(brace_mesh, output_stl_path)
|
||||
|
||||
# 6) Optional: Save visualization
|
||||
if visualize:
|
||||
vis_path = str(Path(output_stl_path).with_suffix('.png'))
|
||||
self._save_visualization(vis_path, spine, image_rgb)
|
||||
print(f" Visualization saved: {vis_path}")
|
||||
|
||||
# 7) Optional: Save landmarks JSON
|
||||
if save_landmarks:
|
||||
json_path = str(Path(output_stl_path).with_suffix('.json'))
|
||||
self._save_landmarks_json(json_path, spine)
|
||||
print(f" Landmarks saved: {json_path}")
|
||||
|
||||
# Prepare results
|
||||
results = {
|
||||
'input_image': xray_path,
|
||||
'output_stl': output_stl_path,
|
||||
'model': self.adapter.name,
|
||||
'vertebrae_detected': len(spine.vertebrae),
|
||||
'cobb_angles': {
|
||||
'PT': spine.cobb_pt,
|
||||
'MT': spine.cobb_mt,
|
||||
'TL': spine.cobb_tl,
|
||||
},
|
||||
'curve_type': spine.curve_type,
|
||||
'rigo_type': rigo_result['rigo_type'],
|
||||
'rigo_description': rigo_result['description'],
|
||||
'apex_indices': apexes,
|
||||
'mesh_vertices': len(brace_mesh.vertices),
|
||||
'mesh_faces': len(brace_mesh.faces),
|
||||
}
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Pipeline complete!")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
return results
|
||||
|
||||
def _save_visualization(self, path: str, spine: Spine2D, image: np.ndarray):
|
||||
"""Save visualization of detected landmarks and spine curve."""
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
print(" Warning: matplotlib not available for visualization")
|
||||
return
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 10))
|
||||
|
||||
# Left: Original with landmarks
|
||||
ax1 = axes[0]
|
||||
ax1.imshow(image)
|
||||
|
||||
# Draw vertebra centers
|
||||
centroids = spine.get_centroids()
|
||||
ax1.scatter(centroids[:, 0], centroids[:, 1], c='red', s=30, zorder=5)
|
||||
|
||||
# Draw corners if available
|
||||
for vert in spine.vertebrae:
|
||||
if vert.corners_px is not None:
|
||||
corners = vert.corners_px
|
||||
# Draw quadrilateral
|
||||
for i in range(4):
|
||||
j = (i + 1) % 4
|
||||
ax1.plot([corners[i, 0], corners[j, 0]],
|
||||
[corners[i, 1], corners[j, 1]], 'g-', linewidth=1)
|
||||
|
||||
ax1.set_title(f"Detected Landmarks ({len(spine.vertebrae)} vertebrae)")
|
||||
ax1.axis('off')
|
||||
|
||||
# Right: Spine curve analysis
|
||||
ax2 = axes[1]
|
||||
ax2.imshow(image, alpha=0.5)
|
||||
|
||||
# Draw spine curve
|
||||
try:
|
||||
C, T, N, curv = compute_spine_curve(spine)
|
||||
ax2.plot(C[:, 0], C[:, 1], 'b-', linewidth=2, label='Spine curve')
|
||||
|
||||
# Highlight high curvature regions
|
||||
high_curv_mask = curv > curv.mean() + curv.std()
|
||||
ax2.scatter(C[high_curv_mask, 0], C[high_curv_mask, 1],
|
||||
c='orange', s=20, label='High curvature')
|
||||
except:
|
||||
pass
|
||||
|
||||
# Get Rigo classification for display
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
# Add Cobb angles and Rigo type text
|
||||
text = f"Cobb Angles:\n"
|
||||
text += f"PT: {spine.cobb_pt:.1f}°\n"
|
||||
text += f"MT: {spine.cobb_mt:.1f}°\n"
|
||||
text += f"TL: {spine.cobb_tl:.1f}°\n"
|
||||
text += f"Curve: {spine.curve_type}\n"
|
||||
text += f"-----------\n"
|
||||
text += f"Rigo: {rigo_result['rigo_type']}"
|
||||
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
|
||||
|
||||
ax2.set_title("Spine Analysis")
|
||||
ax2.axis('off')
|
||||
ax2.legend(loc='lower right')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def _save_landmarks_json(self, path: str, spine: Spine2D):
|
||||
"""Save landmarks to JSON file with Rigo classification."""
|
||||
def to_native(val):
|
||||
"""Convert numpy types to native Python types."""
|
||||
if isinstance(val, np.ndarray):
|
||||
return val.tolist()
|
||||
elif isinstance(val, (np.float32, np.float64)):
|
||||
return float(val)
|
||||
elif isinstance(val, (np.int32, np.int64)):
|
||||
return int(val)
|
||||
return val
|
||||
|
||||
# Get Rigo classification
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
data = {
|
||||
'source_model': spine.source_model,
|
||||
'image_shape': list(spine.image_shape) if spine.image_shape else None,
|
||||
'pixel_spacing_mm': spine.pixel_spacing_mm.tolist() if spine.pixel_spacing_mm is not None else None,
|
||||
'cobb_angles': {
|
||||
'PT': to_native(spine.cobb_pt),
|
||||
'MT': to_native(spine.cobb_mt),
|
||||
'TL': to_native(spine.cobb_tl),
|
||||
},
|
||||
'curve_type': spine.curve_type,
|
||||
'rigo_classification': {
|
||||
'type': rigo_result['rigo_type'],
|
||||
'description': rigo_result['description'],
|
||||
'curve_pattern': rigo_result['curve_pattern'],
|
||||
'n_significant_curves': rigo_result['n_significant_curves'],
|
||||
},
|
||||
'vertebrae': []
|
||||
}
|
||||
|
||||
for vert in spine.vertebrae:
|
||||
vert_data = {
|
||||
'level': vert.level,
|
||||
'centroid_px': vert.centroid_px.tolist(),
|
||||
'orientation_deg': to_native(vert.orientation_deg),
|
||||
'confidence': to_native(vert.confidence),
|
||||
}
|
||||
if vert.corners_px is not None:
|
||||
vert_data['corners_px'] = vert.corners_px.tolist()
|
||||
data['vertebrae'].append(vert_data)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
"""Command-line interface for brace generation."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate scoliosis brace from X-ray')
|
||||
parser.add_argument('input', help='Input X-ray image (JPEG, PNG, or DICOM)')
|
||||
parser.add_argument('output', help='Output STL file path')
|
||||
parser.add_argument('--model', choices=['scoliovis', 'vertebra-landmark'],
|
||||
default='scoliovis', help='Landmark detection model')
|
||||
parser.add_argument('--device', default='cpu', help='Device (cpu or cuda)')
|
||||
parser.add_argument('--body-scan', help='Path to 3D body scan mesh (optional)')
|
||||
parser.add_argument('--visualize', action='store_true', help='Save visualization')
|
||||
parser.add_argument('--save-landmarks', action='store_true', help='Save landmarks JSON')
|
||||
parser.add_argument('--pressure', type=float, default=15.0,
|
||||
help='Pressure strength in mm (default: 15)')
|
||||
parser.add_argument('--thickness', type=float, default=4.0,
|
||||
help='Wall thickness in mm (default: 4)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Build config
|
||||
config = BraceConfig(
|
||||
pressure_strength_mm=args.pressure,
|
||||
wall_thickness_mm=args.thickness,
|
||||
)
|
||||
|
||||
if args.body_scan:
|
||||
config.use_body_scan = True
|
||||
config.body_scan_path = args.body_scan
|
||||
|
||||
# Run pipeline
|
||||
pipeline = BracePipeline(model=args.model, config=config, device=args.device)
|
||||
results = pipeline.process(
|
||||
args.input,
|
||||
args.output,
|
||||
visualize=args.visualize,
|
||||
save_landmarks=args.save_landmarks
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
464
brace-generator/spine_analysis.py
Normal file
464
brace-generator/spine_analysis.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Spine analysis functions for computing curves, Cobb angles, and identifying apex vertebrae.
|
||||
"""
|
||||
import numpy as np
|
||||
from scipy.interpolate import splprep, splev
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
from data_models import Spine2D, VertebraLandmark
|
||||
|
||||
|
||||
def compute_spine_curve(
|
||||
spine: Spine2D,
|
||||
smooth: float = 1.0,
|
||||
n_samples: int = 200
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute smooth spine centerline from vertebra centroids.
|
||||
|
||||
Args:
|
||||
spine: Spine2D object with detected vertebrae
|
||||
smooth: Smoothing factor for spline (higher = smoother)
|
||||
n_samples: Number of points to sample along the curve
|
||||
|
||||
Returns:
|
||||
C: Curve points, shape (n_samples, 2)
|
||||
T: Tangent vectors, shape (n_samples, 2)
|
||||
N: Normal vectors, shape (n_samples, 2)
|
||||
curvature: Curvature at each point, shape (n_samples,)
|
||||
"""
|
||||
pts = spine.get_centroids()
|
||||
|
||||
if len(pts) < 4:
|
||||
raise ValueError(f"Need at least 4 vertebrae for spline, got {len(pts)}")
|
||||
|
||||
# Fit parametric spline through centroids
|
||||
x = pts[:, 0]
|
||||
y = pts[:, 1]
|
||||
|
||||
try:
|
||||
tck, u = splprep([x, y], s=smooth, k=min(3, len(pts)-1))
|
||||
except Exception as e:
|
||||
# Fallback: simple linear interpolation
|
||||
t = np.linspace(0, 1, n_samples)
|
||||
xs = np.interp(t, np.linspace(0, 1, len(x)), x)
|
||||
ys = np.interp(t, np.linspace(0, 1, len(y)), y)
|
||||
C = np.stack([xs, ys], axis=1).astype(np.float32)
|
||||
T = np.gradient(C, axis=0)
|
||||
T = T / (np.linalg.norm(T, axis=1, keepdims=True) + 1e-8)
|
||||
N = np.stack([-T[:, 1], T[:, 0]], axis=1)
|
||||
curvature = np.zeros(n_samples, dtype=np.float32)
|
||||
return C, T, N, curvature
|
||||
|
||||
# Sample the spline
|
||||
u_new = np.linspace(0, 1, n_samples)
|
||||
xs, ys = splev(u_new, tck)
|
||||
|
||||
# First and second derivatives
|
||||
dx, dy = splev(u_new, tck, der=1)
|
||||
ddx, ddy = splev(u_new, tck, der=2)
|
||||
|
||||
# Curve points
|
||||
C = np.stack([xs, ys], axis=1).astype(np.float32)
|
||||
|
||||
# Tangent vectors (normalized)
|
||||
T = np.stack([dx, dy], axis=1)
|
||||
T_norm = np.linalg.norm(T, axis=1, keepdims=True) + 1e-8
|
||||
T = (T / T_norm).astype(np.float32)
|
||||
|
||||
# Normal vectors (perpendicular to tangent)
|
||||
N = np.stack([-T[:, 1], T[:, 0]], axis=1).astype(np.float32)
|
||||
|
||||
# Curvature: |x'y'' - y'x''| / (x'^2 + y'^2)^(3/2)
|
||||
curvature = np.abs(dx * ddy - dy * ddx) / (np.power(dx**2 + dy**2, 1.5) + 1e-8)
|
||||
curvature = curvature.astype(np.float32)
|
||||
|
||||
return C, T, N, curvature
|
||||
|
||||
|
||||
def compute_cobb_angles(spine: Spine2D) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Compute Cobb angles from vertebra orientations.
|
||||
|
||||
The Cobb angle is measured as the angle between:
|
||||
- The superior endplate of the most tilted vertebra at the top of a curve
|
||||
- The inferior endplate of the most tilted vertebra at the bottom
|
||||
|
||||
This function estimates 3 Cobb angles:
|
||||
- PT: Proximal Thoracic (T1-T6 region)
|
||||
- MT: Main Thoracic (T6-T12 region)
|
||||
- TL: Thoracolumbar/Lumbar (T12-L5 region)
|
||||
|
||||
Args:
|
||||
spine: Spine2D object with detected vertebrae
|
||||
|
||||
Returns:
|
||||
(pt_angle, mt_angle, tl_angle) in degrees
|
||||
"""
|
||||
orientations = spine.get_orientations()
|
||||
n = len(orientations)
|
||||
|
||||
if n < 7:
|
||||
spine.cobb_pt = 0.0
|
||||
spine.cobb_mt = 0.0
|
||||
spine.cobb_tl = 0.0
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
# Divide into regions (approximately)
|
||||
# PT: top 1/3, MT: middle 1/3, TL: bottom 1/3
|
||||
pt_end = n // 3
|
||||
mt_end = 2 * n // 3
|
||||
|
||||
# Find max tilt difference in each region
|
||||
def region_cobb(start_idx: int, end_idx: int) -> float:
|
||||
if end_idx <= start_idx:
|
||||
return 0.0
|
||||
region_angles = orientations[start_idx:end_idx]
|
||||
if len(region_angles) < 2:
|
||||
return 0.0
|
||||
# Cobb angle = max angle - min angle in region
|
||||
return abs(float(np.max(region_angles) - np.min(region_angles)))
|
||||
|
||||
pt_angle = region_cobb(0, pt_end)
|
||||
mt_angle = region_cobb(pt_end, mt_end)
|
||||
tl_angle = region_cobb(mt_end, n)
|
||||
|
||||
# Store in spine object
|
||||
spine.cobb_pt = pt_angle
|
||||
spine.cobb_mt = mt_angle
|
||||
spine.cobb_tl = tl_angle
|
||||
|
||||
# Determine curve type
|
||||
if mt_angle > 10 and tl_angle > 10:
|
||||
spine.curve_type = "S" # Double curve
|
||||
elif mt_angle > 10 or tl_angle > 10:
|
||||
spine.curve_type = "C" # Single curve
|
||||
else:
|
||||
spine.curve_type = "Normal"
|
||||
|
||||
return pt_angle, mt_angle, tl_angle
|
||||
|
||||
|
||||
def find_apex_vertebrae(spine: Spine2D) -> List[int]:
|
||||
"""
|
||||
Find indices of apex vertebrae (most deviated from midline).
|
||||
|
||||
Args:
|
||||
spine: Spine2D with computed curve
|
||||
|
||||
Returns:
|
||||
List of vertebra indices that are curve apexes
|
||||
"""
|
||||
centroids = spine.get_centroids()
|
||||
|
||||
if len(centroids) < 5:
|
||||
return []
|
||||
|
||||
# Find midline (linear fit through endpoints)
|
||||
start = centroids[0]
|
||||
end = centroids[-1]
|
||||
|
||||
# Distance from midline for each vertebra
|
||||
midline_vec = end - start
|
||||
midline_len = np.linalg.norm(midline_vec)
|
||||
|
||||
if midline_len < 1e-6:
|
||||
return []
|
||||
|
||||
midline_unit = midline_vec / midline_len
|
||||
|
||||
# Calculate perpendicular distance to midline
|
||||
deviations = []
|
||||
for i, pt in enumerate(centroids):
|
||||
v = pt - start
|
||||
# Project onto midline
|
||||
proj_len = np.dot(v, midline_unit)
|
||||
proj = proj_len * midline_unit
|
||||
# Perpendicular distance
|
||||
perp = v - proj
|
||||
dist = np.linalg.norm(perp)
|
||||
# Sign: positive if to the right of midline
|
||||
sign = np.sign(np.cross(midline_unit, v / (np.linalg.norm(v) + 1e-8)))
|
||||
deviations.append(dist * sign)
|
||||
|
||||
deviations = np.array(deviations)
|
||||
|
||||
# Find local extrema (peaks and valleys)
|
||||
apexes = []
|
||||
for i in range(1, len(deviations) - 1):
|
||||
# Local maximum
|
||||
if deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]:
|
||||
if abs(deviations[i]) > 5: # Minimum deviation threshold (pixels)
|
||||
apexes.append(i)
|
||||
# Local minimum
|
||||
elif deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]:
|
||||
if abs(deviations[i]) > 5:
|
||||
apexes.append(i)
|
||||
|
||||
return apexes
|
||||
|
||||
|
||||
def get_curve_severity(cobb_angle: float) -> str:
|
||||
"""
|
||||
Get clinical severity classification from Cobb angle.
|
||||
|
||||
Args:
|
||||
cobb_angle: Cobb angle in degrees
|
||||
|
||||
Returns:
|
||||
Severity string: "Normal", "Mild", "Moderate", or "Severe"
|
||||
"""
|
||||
if cobb_angle < 10:
|
||||
return "Normal"
|
||||
elif cobb_angle < 25:
|
||||
return "Mild"
|
||||
elif cobb_angle < 40:
|
||||
return "Moderate"
|
||||
else:
|
||||
return "Severe"
|
||||
|
||||
|
||||
def classify_rigo_type(spine: Spine2D) -> dict:
|
||||
"""
|
||||
Classify scoliosis according to Rigo-Chêneau brace classification.
|
||||
|
||||
Rigo Classification Types:
|
||||
- A1, A2, A3: 3-curve patterns (thoracic major)
|
||||
- B1, B2: 4-curve patterns (double major)
|
||||
- C1, C2: Single thoracolumbar/lumbar
|
||||
- E1, E2: Single thoracic
|
||||
|
||||
Args:
|
||||
spine: Spine2D object with detected vertebrae and Cobb angles
|
||||
|
||||
Returns:
|
||||
dict with 'rigo_type', 'description', 'apex_region', 'curve_pattern'
|
||||
"""
|
||||
# Get Cobb angles
|
||||
cobb_angles = spine.get_cobb_angles()
|
||||
pt = cobb_angles.get('PT', 0)
|
||||
mt = cobb_angles.get('MT', 0)
|
||||
tl = cobb_angles.get('TL', 0)
|
||||
|
||||
n_verts = len(spine.vertebrae)
|
||||
|
||||
# Calculate lateral deviations to determine curve direction
|
||||
centroids = spine.get_centroids()
|
||||
deviations = _calculate_lateral_deviations(centroids)
|
||||
|
||||
# Find apex positions and directions
|
||||
apex_info = _find_apex_info(centroids, deviations, n_verts)
|
||||
|
||||
# Determine curve pattern based on number of significant curves
|
||||
significant_curves = []
|
||||
if pt >= 10:
|
||||
significant_curves.append(('PT', pt))
|
||||
if mt >= 10:
|
||||
significant_curves.append(('MT', mt))
|
||||
if tl >= 10:
|
||||
significant_curves.append(('TL', tl))
|
||||
|
||||
n_curves = len(significant_curves)
|
||||
|
||||
# Classification logic
|
||||
rigo_type = "N/A"
|
||||
description = ""
|
||||
curve_pattern = ""
|
||||
|
||||
# No significant scoliosis
|
||||
if n_curves == 0 or max(pt, mt, tl) < 10:
|
||||
rigo_type = "Normal"
|
||||
description = "No significant scoliosis (all Cobb angles < 10°)"
|
||||
curve_pattern = "None"
|
||||
|
||||
# Single curve patterns
|
||||
elif n_curves == 1:
|
||||
max_curve = significant_curves[0][0]
|
||||
max_angle = significant_curves[0][1]
|
||||
|
||||
if max_curve == 'MT' or max_curve == 'PT':
|
||||
# Thoracic single curve
|
||||
if apex_info['thoracic_apex_idx'] is not None:
|
||||
# Check if there's a compensatory lumbar
|
||||
if tl > 5:
|
||||
rigo_type = "E2"
|
||||
description = f"Single thoracic curve ({max_angle:.1f}°) with lumbar compensatory ({tl:.1f}°)"
|
||||
curve_pattern = "Thoracic with compensation"
|
||||
else:
|
||||
rigo_type = "E1"
|
||||
description = f"True single thoracic curve ({max_angle:.1f}°)"
|
||||
curve_pattern = "Single thoracic"
|
||||
else:
|
||||
rigo_type = "E1"
|
||||
description = f"Single thoracic curve ({max_angle:.1f}°)"
|
||||
curve_pattern = "Single thoracic"
|
||||
|
||||
elif max_curve == 'TL':
|
||||
# Thoracolumbar/Lumbar single curve
|
||||
if mt > 5 or pt > 5:
|
||||
rigo_type = "C2"
|
||||
description = f"Thoracolumbar curve ({tl:.1f}°) with upper compensatory"
|
||||
curve_pattern = "TL/L with compensation"
|
||||
else:
|
||||
rigo_type = "C1"
|
||||
description = f"Single thoracolumbar/lumbar curve ({tl:.1f}°)"
|
||||
curve_pattern = "Single TL/L"
|
||||
|
||||
# Double curve patterns
|
||||
elif n_curves >= 2:
|
||||
# Determine which curves are primary
|
||||
thoracic_total = pt + mt
|
||||
lumbar_total = tl
|
||||
|
||||
# Check curve directions for S vs C pattern
|
||||
is_s_curve = apex_info['is_s_pattern']
|
||||
|
||||
if is_s_curve:
|
||||
# S-curve: typically 3 or 4 curve patterns
|
||||
if thoracic_total > lumbar_total * 1.5:
|
||||
# Thoracic dominant - Type A (3-curve)
|
||||
if apex_info['lumbar_apex_low']:
|
||||
rigo_type = "A1"
|
||||
description = f"3-curve: Thoracic major ({mt:.1f}°), lumbar apex low"
|
||||
elif apex_info['apex_at_tl_junction']:
|
||||
rigo_type = "A2"
|
||||
description = f"3-curve: Thoracolumbar transition ({mt:.1f}°/{tl:.1f}°)"
|
||||
else:
|
||||
rigo_type = "A3"
|
||||
description = f"3-curve: Thoracic major ({mt:.1f}°) with structural lumbar ({tl:.1f}°)"
|
||||
curve_pattern = "3-curve (thoracic major)"
|
||||
|
||||
elif lumbar_total > thoracic_total * 1.5:
|
||||
# Lumbar dominant
|
||||
rigo_type = "C2"
|
||||
description = f"Lumbar major ({tl:.1f}°) with thoracic compensatory ({mt:.1f}°)"
|
||||
curve_pattern = "Lumbar major"
|
||||
|
||||
else:
|
||||
# Double major - Type B (4-curve)
|
||||
if tl >= mt:
|
||||
rigo_type = "B1"
|
||||
description = f"4-curve: Double major, lumbar prominent ({tl:.1f}°/{mt:.1f}°)"
|
||||
else:
|
||||
rigo_type = "B2"
|
||||
description = f"4-curve: Double major, thoracic prominent ({mt:.1f}°/{tl:.1f}°)"
|
||||
curve_pattern = "4-curve (double major)"
|
||||
else:
|
||||
# C-curve pattern (curves in same direction)
|
||||
if mt >= tl:
|
||||
if tl > 5:
|
||||
rigo_type = "A3"
|
||||
description = f"Long thoracic curve ({mt:.1f}°) extending to lumbar ({tl:.1f}°)"
|
||||
else:
|
||||
rigo_type = "E2"
|
||||
description = f"Thoracic curve ({mt:.1f}°) with minor lumbar ({tl:.1f}°)"
|
||||
curve_pattern = "Extended thoracic"
|
||||
else:
|
||||
rigo_type = "C2"
|
||||
description = f"TL/Lumbar curve ({tl:.1f}°) with thoracic involvement ({mt:.1f}°)"
|
||||
curve_pattern = "Extended lumbar"
|
||||
|
||||
# Store in spine object
|
||||
spine.rigo_type = rigo_type
|
||||
spine.rigo_description = description
|
||||
|
||||
return {
|
||||
'rigo_type': rigo_type,
|
||||
'description': description,
|
||||
'curve_pattern': curve_pattern,
|
||||
'apex_info': apex_info,
|
||||
'cobb_angles': cobb_angles,
|
||||
'n_significant_curves': n_curves
|
||||
}
|
||||
|
||||
|
||||
def _calculate_lateral_deviations(centroids: np.ndarray) -> np.ndarray:
|
||||
"""Calculate lateral deviation from midline for each vertebra."""
|
||||
if len(centroids) < 2:
|
||||
return np.zeros(len(centroids))
|
||||
|
||||
# Midline from first to last vertebra
|
||||
start = centroids[0]
|
||||
end = centroids[-1]
|
||||
midline_vec = end - start
|
||||
midline_len = np.linalg.norm(midline_vec)
|
||||
|
||||
if midline_len < 1e-6:
|
||||
return np.zeros(len(centroids))
|
||||
|
||||
midline_unit = midline_vec / midline_len
|
||||
|
||||
deviations = []
|
||||
for pt in centroids:
|
||||
v = pt - start
|
||||
# Project onto midline
|
||||
proj_len = np.dot(v, midline_unit)
|
||||
proj = proj_len * midline_unit
|
||||
# Perpendicular vector
|
||||
perp = v - proj
|
||||
dist = np.linalg.norm(perp)
|
||||
# Sign: positive = right, negative = left
|
||||
sign = np.sign(np.cross(midline_unit, perp / (dist + 1e-8)))
|
||||
deviations.append(dist * sign)
|
||||
|
||||
return np.array(deviations)
|
||||
|
||||
|
||||
def _find_apex_info(centroids: np.ndarray, deviations: np.ndarray, n_verts: int) -> dict:
|
||||
"""Find apex positions and determine curve pattern."""
|
||||
info = {
|
||||
'thoracic_apex_idx': None,
|
||||
'lumbar_apex_idx': None,
|
||||
'lumbar_apex_low': False,
|
||||
'apex_at_tl_junction': False,
|
||||
'is_s_pattern': False,
|
||||
'apex_directions': []
|
||||
}
|
||||
|
||||
if len(deviations) < 3:
|
||||
return info
|
||||
|
||||
# Find local extrema (apexes)
|
||||
apexes = []
|
||||
apex_values = []
|
||||
for i in range(1, len(deviations) - 1):
|
||||
if (deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]) or \
|
||||
(deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]):
|
||||
if abs(deviations[i]) > 3: # Minimum threshold
|
||||
apexes.append(i)
|
||||
apex_values.append(deviations[i])
|
||||
|
||||
# Determine S-pattern (alternating signs at apexes)
|
||||
if len(apex_values) >= 2:
|
||||
signs = [np.sign(v) for v in apex_values]
|
||||
# S-pattern if adjacent apexes have opposite signs
|
||||
for i in range(len(signs) - 1):
|
||||
if signs[i] != signs[i+1]:
|
||||
info['is_s_pattern'] = True
|
||||
break
|
||||
|
||||
# Classify apex regions
|
||||
# Assume: top 40% = thoracic, middle 20% = TL junction, bottom 40% = lumbar
|
||||
thoracic_end = int(0.4 * n_verts)
|
||||
tl_junction_end = int(0.6 * n_verts)
|
||||
|
||||
for apex_idx in apexes:
|
||||
if apex_idx < thoracic_end:
|
||||
if info['thoracic_apex_idx'] is None or \
|
||||
abs(deviations[apex_idx]) > abs(deviations[info['thoracic_apex_idx']]):
|
||||
info['thoracic_apex_idx'] = apex_idx
|
||||
elif apex_idx < tl_junction_end:
|
||||
info['apex_at_tl_junction'] = True
|
||||
else:
|
||||
if info['lumbar_apex_idx'] is None or \
|
||||
abs(deviations[apex_idx]) > abs(deviations[info['lumbar_apex_idx']]):
|
||||
info['lumbar_apex_idx'] = apex_idx
|
||||
|
||||
# Check if lumbar apex is in lower region (bottom 30%)
|
||||
if info['lumbar_apex_idx'] is not None:
|
||||
if info['lumbar_apex_idx'] > int(0.7 * n_verts):
|
||||
info['lumbar_apex_low'] = True
|
||||
|
||||
info['apex_directions'] = apex_values
|
||||
|
||||
return info
|
||||
Reference in New Issue
Block a user