Initial commit - BraceIQMed platform with frontend, API, and brace generator

This commit is contained in:
2026-01-29 14:34:05 -08:00
commit 745f9f827f
187 changed files with 534688 additions and 0 deletions

View File

@@ -0,0 +1,68 @@
# ============================================
# BraceIQMed Brace Generator - FastAPI + PyTorch (CPU)
# Build context: repo root (braceiqmed/)
# ============================================
FROM python:3.10-slim
# Prevent interactive prompts
ENV DEBIAN_FRONTEND=noninteractive
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
libgl1 \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
wget \
curl \
git \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# Install PyTorch CPU version (smaller, no CUDA)
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
# Copy and install requirements (from brace-generator folder)
COPY brace-generator/requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
# Copy scoliovis-api requirements and install
COPY scoliovis-api/requirements.txt /app/requirements-scoliovis.txt
RUN pip install --no-cache-dir -r requirements-scoliovis.txt || true
# Copy brace-generator code
COPY brace-generator/ /app/brace_generator/server_DEV/
# Copy scoliovis-api
COPY scoliovis-api/ /app/scoliovis-api/
# Copy templates
COPY templates/ /app/templates/
# Set Python path
ENV PYTHONPATH=/app:/app/brace_generator/server_DEV:/app/scoliovis-api
# Environment variables
ENV HOST=0.0.0.0
ENV PORT=8002
ENV DEVICE=cpu
ENV MODEL=scoliovis
ENV TEMP_DIR=/tmp/brace_generator
ENV CORS_ORIGINS=*
# Create directories
RUN mkdir -p /tmp/brace_generator /app/data/uploads /app/data/outputs
EXPOSE 8002
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
CMD curl -f http://localhost:8002/health || exit 1
# Run the server
CMD ["python", "-m", "uvicorn", "brace_generator.server_DEV.app:app", "--host", "0.0.0.0", "--port", "8002"]

View File

@@ -0,0 +1,8 @@
"""
Brace Generator Server Package.
"""
from .app import app
from .config import config
from .services import BraceService
__all__ = ["app", "config", "BraceService"]

137
brace-generator/app.py Normal file
View File

@@ -0,0 +1,137 @@
"""
FastAPI server for Brace Generator.
Provides REST API for:
- X-ray analysis and landmark detection
- Cobb angle measurement
- Rigo-Chêneau classification
- Adaptive brace generation (STL/PLY)
Usage:
uvicorn server.app:app --host 0.0.0.0 --port 8000
Or with Docker:
docker run -p 8000:8000 --gpus all brace-generator
"""
import sys
from pathlib import Path
from contextlib import asynccontextmanager
# Add parent directories to path for imports
server_dir = Path(__file__).parent
brace_generator_dir = server_dir.parent
spine_dir = brace_generator_dir.parent
sys.path.insert(0, str(spine_dir))
sys.path.insert(0, str(brace_generator_dir))
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import torch
from .config import config
from .routes import router
from .services import BraceService
# Global service instance (initialized on startup)
brace_service: BraceService = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize and cleanup resources."""
global brace_service
print("=" * 60)
print("Brace Generator Server Starting")
print("=" * 60)
# Ensure directories exist
config.ensure_dirs()
# Initialize service with model
device = config.get_device()
print(f"Device: {device}")
if device == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"Loading model: {config.MODEL}")
brace_service = BraceService(device=device, model=config.MODEL)
print("Model loaded successfully!")
# Make service available to routes
app.state.brace_service = brace_service
print("=" * 60)
print(f"Server ready at http://{config.HOST}:{config.PORT}")
print("=" * 60)
yield
# Cleanup
print("Shutting down...")
del brace_service
# Create FastAPI app
app = FastAPI(
title="Brace Generator API",
description="""
API for generating scoliosis braces from X-ray images.
## Features
- Vertebrae landmark detection (ScolioVis model)
- Cobb angle measurement (PT, MT, TL)
- Rigo-Chêneau classification
- Adaptive brace generation with research-based deformations
## Experiments
- **standard**: Original template-based pipeline
- **experiment_3**: Research-based adaptive deformation (Guy et al. 2024)
""",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=config.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routes
app.include_router(router)
# Exception handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={"error": exc.detail}
)
@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
return JSONResponse(
status_code=500,
content={"error": "Internal server error", "detail": str(exc)}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server.app:app",
host=config.HOST,
port=config.PORT,
reload=config.DEBUG
)

View File

@@ -0,0 +1,411 @@
"""
Body scan integration for patient-specific brace fitting.
Based on EXPERIMENT_10's approach:
1. Extract body measurements from 3D scan
2. Compute body basis (coordinate frame)
3. Select template based on Rigo classification
4. Fit shell to body using basis alignment
"""
import sys
import json
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional, Tuple
from dataclasses import dataclass, asdict
try:
import trimesh
HAS_TRIMESH = True
except ImportError:
HAS_TRIMESH = False
# Add EXPERIMENT_10 to path for imports
EXPERIMENTS_DIR = Path(__file__).parent.parent / "EXPERIMENTS"
EXP10_DIR = EXPERIMENTS_DIR / "EXPERIMENT_10"
if str(EXP10_DIR) not in sys.path:
sys.path.insert(0, str(EXP10_DIR))
# Import EXPERIMENT_10 modules
try:
from body_measurements import extract_body_measurements, measurements_to_dict, BodyMeasurements
from body_basis import compute_body_basis, body_basis_to_dict, BodyBasis
from shell_fitter_v2 import (
fit_shell_to_body_v2,
compute_brace_basis_from_geometry,
brace_basis_to_dict,
RIGO_TO_VASE,
FittingFeedback
)
HAS_EXP10 = True
except ImportError as e:
print(f"Warning: Could not import EXPERIMENT_10 modules: {e}")
HAS_EXP10 = False
# Vase templates directory
VASES_DIR = Path(__file__).parent.parent.parent / "_vase" / "_vase"
def extract_measurements_from_scan(scan_path: str) -> Dict[str, Any]:
"""
Extract body measurements from a 3D body scan.
Args:
scan_path: Path to STL/OBJ/PLY body scan file
Returns:
Dictionary with measurements suitable for API response
"""
if not HAS_TRIMESH:
raise ImportError("trimesh is required for body scan processing")
# Try EXPERIMENT_10 first
if HAS_EXP10:
try:
measurements = extract_body_measurements(scan_path)
result = measurements_to_dict(measurements)
# Flatten for API-friendly format
flat = {
"total_height_mm": result["overall_dimensions"]["total_height_mm"],
"shoulder_width_mm": result["widths_mm"]["shoulder_width"],
"chest_width_mm": result["widths_mm"]["chest_width"],
"chest_depth_mm": result["depths_mm"]["chest_depth"],
"waist_width_mm": result["widths_mm"]["waist_width"],
"waist_depth_mm": result["depths_mm"]["waist_depth"],
"hip_width_mm": result["widths_mm"]["hip_width"],
"hip_depth_mm": result["depths_mm"]["hip_depth"],
"brace_coverage_height_mm": result["brace_coverage_region"]["coverage_height_mm"],
"chest_circumference_mm": result["circumferences_mm"]["chest"],
"waist_circumference_mm": result["circumferences_mm"]["waist"],
"hip_circumference_mm": result["circumferences_mm"]["hip"],
}
# Also include full detailed result
flat["detailed"] = result
return flat
except Exception as e:
print(f"EXPERIMENT_10 measurement extraction failed: {e}, using fallback")
# Fallback: Simple trimesh-based measurements
return _extract_measurements_trimesh_fallback(scan_path)
def _extract_measurements_trimesh_fallback(scan_path: str) -> Dict[str, Any]:
"""
Simple fallback for body measurements using trimesh bounding box analysis.
Less accurate than EXPERIMENT_10 but provides basic measurements.
"""
mesh = trimesh.load(scan_path)
# Get bounding box
bounds = mesh.bounds
min_pt, max_pt = bounds[0], bounds[1]
# Assuming Y is up (typical human scan orientation)
# Try to auto-detect orientation
extents = max_pt - min_pt
height_axis = np.argmax(extents) # Longest axis is usually height
if height_axis == 1: # Y-up
total_height = extents[1]
width_axis, depth_axis = 0, 2
elif height_axis == 2: # Z-up
total_height = extents[2]
width_axis, depth_axis = 0, 1
else: # X-up (unusual)
total_height = extents[0]
width_axis, depth_axis = 1, 2
width = extents[width_axis]
depth = extents[depth_axis]
# Estimate body segments using height percentages
# These are approximate ratios for human body
chest_height_ratio = 0.75 # Chest at 75% of height from bottom
waist_height_ratio = 0.60 # Waist at 60% of height
hip_height_ratio = 0.50 # Hips at 50% of height
shoulder_height_ratio = 0.82 # Shoulders at 82%
# Get cross-sections at different heights to estimate widths
def get_width_at_height(height_ratio):
if height_axis == 1:
h = min_pt[1] + total_height * height_ratio
mask = (mesh.vertices[:, 1] > h - total_height * 0.05) & \
(mesh.vertices[:, 1] < h + total_height * 0.05)
elif height_axis == 2:
h = min_pt[2] + total_height * height_ratio
mask = (mesh.vertices[:, 2] > h - total_height * 0.05) & \
(mesh.vertices[:, 2] < h + total_height * 0.05)
else:
h = min_pt[0] + total_height * height_ratio
mask = (mesh.vertices[:, 0] > h - total_height * 0.05) & \
(mesh.vertices[:, 0] < h + total_height * 0.05)
if not np.any(mask):
return width, depth
slice_verts = mesh.vertices[mask]
slice_width = np.ptp(slice_verts[:, width_axis])
slice_depth = np.ptp(slice_verts[:, depth_axis])
return slice_width, slice_depth
shoulder_w, shoulder_d = get_width_at_height(shoulder_height_ratio)
chest_w, chest_d = get_width_at_height(chest_height_ratio)
waist_w, waist_d = get_width_at_height(waist_height_ratio)
hip_w, hip_d = get_width_at_height(hip_height_ratio)
# Estimate circumferences using ellipse approximation
def estimate_circumference(w, d):
a, b = w / 2, d / 2
# Ramanujan's approximation for ellipse circumference
h = ((a - b) ** 2) / ((a + b) ** 2)
return np.pi * (a + b) * (1 + 3 * h / (10 + np.sqrt(4 - 3 * h)))
return {
"total_height_mm": float(total_height),
"shoulder_width_mm": float(shoulder_w),
"chest_width_mm": float(chest_w),
"chest_depth_mm": float(chest_d),
"waist_width_mm": float(waist_w),
"waist_depth_mm": float(waist_d),
"hip_width_mm": float(hip_w),
"hip_depth_mm": float(hip_d),
"brace_coverage_height_mm": float(total_height * 0.55), # 55% coverage
"chest_circumference_mm": float(estimate_circumference(chest_w, chest_d)),
"waist_circumference_mm": float(estimate_circumference(waist_w, waist_d)),
"hip_circumference_mm": float(estimate_circumference(hip_w, hip_d)),
"measurement_source": "trimesh_fallback"
}
def generate_fitted_brace(
body_scan_path: str,
rigo_type: str,
output_dir: str,
case_id: str,
clearance_mm: float = 8.0,
wall_thickness_mm: float = 2.4
) -> Dict[str, Any]:
"""
Generate a patient-specific brace fitted to body scan.
Args:
body_scan_path: Path to 3D body scan (STL/OBJ/PLY)
rigo_type: Rigo classification (A1, A2, B1, etc.)
output_dir: Directory to save output files
case_id: Case identifier for naming files
clearance_mm: Clearance between body and shell (default 8mm)
wall_thickness_mm: Shell wall thickness (default 2.4mm for 3D printing)
Returns:
Dictionary with output file paths and fitting info
"""
if not HAS_TRIMESH:
raise ImportError("trimesh is required for brace fitting")
if not HAS_EXP10:
raise ImportError("EXPERIMENT_10 modules not available")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Select template based on Rigo type
template_file = RIGO_TO_VASE.get(rigo_type, "A1_vase.OBJ")
template_path = VASES_DIR / template_file
if not template_path.exists():
# Try alternative paths
alt_paths = [
EXPERIMENTS_DIR / "EXPERIMENT_10" / "_vase" / template_file,
Path(__file__).parent.parent.parent / "_vase" / template_file,
]
for alt in alt_paths:
if alt.exists():
template_path = alt
break
else:
raise FileNotFoundError(f"Template not found: {template_file}")
# Fit shell to body
# Returns: (shell_mesh, body_mesh, combined_mesh, feedback)
fitted_mesh, body_mesh, combined_mesh, feedback = fit_shell_to_body_v2(
body_scan_path=body_scan_path,
template_path=str(template_path),
clearance_mm=clearance_mm
)
# Generate output files
outputs = {}
# Shell STL (for 3D printing)
shell_stl = output_path / f"{case_id}_shell.stl"
fitted_mesh.export(str(shell_stl))
outputs["shell_stl"] = str(shell_stl)
# Shell GLB (for web viewing)
shell_glb = output_path / f"{case_id}_shell.glb"
fitted_mesh.export(str(shell_glb))
outputs["shell_glb"] = str(shell_glb)
# Combined body + shell STL (for visualization)
# combined_mesh is already returned from fit_shell_to_body_v2
combined_stl = output_path / f"{case_id}_body_with_shell.stl"
combined_mesh.export(str(combined_stl))
outputs["combined_stl"] = str(combined_stl)
# Feedback JSON
feedback_json = output_path / f"{case_id}_feedback.json"
with open(feedback_json, "w") as f:
json.dump(asdict(feedback), f, indent=2, default=_json_serializer)
outputs["feedback_json"] = str(feedback_json)
# Create visualization
try:
viz_path = output_path / f"{case_id}_visualization.png"
create_fitting_visualization(body_mesh, fitted_mesh, feedback, str(viz_path))
outputs["visualization"] = str(viz_path)
except Exception as e:
print(f"Warning: Could not create visualization: {e}")
# Return result
return {
"template_used": template_file,
"rigo_type": rigo_type,
"clearance_mm": clearance_mm,
"fitting": {
"scale_right": feedback.scale_right,
"scale_up": feedback.scale_up,
"scale_forward": feedback.scale_forward,
"pelvis_distance_mm": feedback.pelvis_distance_mm,
"up_alignment_dot": feedback.up_alignment_dot,
"warnings": feedback.warnings,
},
"body_measurements": {
"max_width_mm": feedback.max_body_width_mm,
"max_depth_mm": feedback.max_body_depth_mm,
},
"shell_dimensions": {
"width_mm": feedback.target_shell_width_mm,
"depth_mm": feedback.target_shell_depth_mm,
"bounds_min": feedback.final_bounds_min,
"bounds_max": feedback.final_bounds_max,
},
"mesh_stats": {
"vertices": len(fitted_mesh.vertices),
"faces": len(fitted_mesh.faces),
},
"outputs": outputs,
}
def create_fitting_visualization(
body_mesh: 'trimesh.Trimesh',
shell_mesh: 'trimesh.Trimesh',
feedback: 'FittingFeedback',
output_path: str
):
"""Create a multi-panel visualization of the fitted brace."""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
except ImportError:
return
fig = plt.figure(figsize=(16, 10))
# Panel 1: Front view
ax1 = fig.add_subplot(2, 3, 1, projection='3d')
plot_mesh_silhouette(ax1, body_mesh, 'gray', alpha=0.3)
plot_mesh_silhouette(ax1, shell_mesh, 'blue', alpha=0.6)
ax1.set_title('Front View')
ax1.view_init(elev=0, azim=0)
# Panel 2: Side view
ax2 = fig.add_subplot(2, 3, 2, projection='3d')
plot_mesh_silhouette(ax2, body_mesh, 'gray', alpha=0.3)
plot_mesh_silhouette(ax2, shell_mesh, 'blue', alpha=0.6)
ax2.set_title('Side View')
ax2.view_init(elev=0, azim=90)
# Panel 3: Top view
ax3 = fig.add_subplot(2, 3, 3, projection='3d')
plot_mesh_silhouette(ax3, body_mesh, 'gray', alpha=0.3)
plot_mesh_silhouette(ax3, shell_mesh, 'blue', alpha=0.6)
ax3.set_title('Top View')
ax3.view_init(elev=90, azim=0)
# Panel 4: Fitting info
ax4 = fig.add_subplot(2, 3, 4)
ax4.axis('off')
info_text = f"""
Fitting Information
-------------------
Template: {feedback.template_name}
Clearance: {feedback.clearance_mm} mm
Scale Factors:
Right: {feedback.scale_right:.3f}
Up: {feedback.scale_up:.3f}
Forward: {feedback.scale_forward:.3f}
Alignment:
Pelvis Distance: {feedback.pelvis_distance_mm:.2f} mm
Up Alignment: {feedback.up_alignment_dot:.4f}
Shell vs Body:
Width Margin: {feedback.shell_minus_body_width_mm:.1f} mm
Depth Margin: {feedback.shell_minus_body_depth_mm:.1f} mm
"""
ax4.text(0.1, 0.9, info_text, transform=ax4.transAxes, fontsize=10,
verticalalignment='top', fontfamily='monospace')
# Panel 5: Warnings
ax5 = fig.add_subplot(2, 3, 5)
ax5.axis('off')
warnings_text = "Warnings:\n" + ("\n".join(feedback.warnings) if feedback.warnings else "None")
ax5.text(0.1, 0.9, warnings_text, transform=ax5.transAxes, fontsize=10,
verticalalignment='top', color='orange' if feedback.warnings else 'green')
# Panel 6: Isometric view
ax6 = fig.add_subplot(2, 3, 6, projection='3d')
plot_mesh_silhouette(ax6, body_mesh, 'gray', alpha=0.3)
plot_mesh_silhouette(ax6, shell_mesh, 'blue', alpha=0.6)
ax6.set_title('Isometric View')
ax6.view_init(elev=20, azim=45)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
def plot_mesh_silhouette(ax, mesh, color, alpha=0.5):
"""Plot a simplified mesh representation."""
# Sample vertices for plotting
verts = mesh.vertices
if len(verts) > 5000:
indices = np.random.choice(len(verts), 5000, replace=False)
verts = verts[indices]
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2],
c=color, alpha=alpha, s=1)
# Set equal aspect ratio
max_range = np.max(mesh.extents) / 2
mid = mesh.centroid
ax.set_xlim(mid[0] - max_range, mid[0] + max_range)
ax.set_ylim(mid[1] - max_range, mid[1] + max_range)
ax.set_zlim(mid[2] - max_range, mid[2] + max_range)
def _json_serializer(obj):
"""JSON serializer for numpy types."""
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.float32, np.float64)):
return float(obj)
if isinstance(obj, (np.int32, np.int64)):
return int(obj)
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")

58
brace-generator/config.py Normal file
View File

@@ -0,0 +1,58 @@
"""
Server configuration for Brace Generator API.
"""
import os
from pathlib import Path
class Config:
"""Server configuration loaded from environment variables."""
# Server settings (DEV uses port 8001)
HOST: str = os.getenv("HOST", "0.0.0.0")
PORT: int = int(os.getenv("PORT", "8001"))
DEBUG: bool = os.getenv("DEBUG", "true").lower() == "true"
# Model settings
DEVICE: str = os.getenv("DEVICE", "cuda") # 'cuda' or 'cpu'
MODEL: str = os.getenv("MODEL", "scoliovis") # 'scoliovis' or 'vertebra-landmark'
# Paths
BASE_DIR: Path = Path(__file__).parent.parent
TEMPLATES_DIR: Path = BASE_DIR / "rigoBrace(2)"
WEIGHTS_DIR: Path = BASE_DIR.parent / "scoliovis-api" / "models"
# TEMP_DIR: Use system temp on Windows, /tmp on Linux
@staticmethod
def _get_temp_dir() -> Path:
env_temp = os.getenv("TEMP_DIR")
if env_temp:
return Path(env_temp)
# Use system temp directory (works on both Windows and Linux)
import tempfile
return Path(tempfile.gettempdir()) / "brace_generator"
TEMP_DIR: Path = _get_temp_dir()
# CORS
CORS_ORIGINS: list = os.getenv("CORS_ORIGINS", "*").split(",")
# Request limits
MAX_IMAGE_SIZE_MB: int = int(os.getenv("MAX_IMAGE_SIZE_MB", "50"))
REQUEST_TIMEOUT_SECONDS: int = int(os.getenv("REQUEST_TIMEOUT_SECONDS", "120"))
@classmethod
def ensure_dirs(cls):
"""Create necessary directories."""
cls.TEMP_DIR.mkdir(parents=True, exist_ok=True)
@classmethod
def get_device(cls) -> str:
"""Get device, falling back to CPU if CUDA unavailable."""
import torch
if cls.DEVICE == "cuda" and torch.cuda.is_available():
return "cuda"
return "cpu"
config = Config()

View File

@@ -0,0 +1,906 @@
"""
GLB Brace Generator with Markers
This module generates GLB brace files with embedded markers for editing.
Supports both regular (fitted) and vase-shaped templates.
PRESSURE ZONES EXPLANATION:
===========================
The brace has 4 main pressure/expansion zones that correct spinal curvature:
1. THORACIC PAD (LM_PAD_TH) - PUSH ZONE
- Location: On the CONVEX side of the thoracic curve (the side that bulges out)
- Function: Pushes INWARD to correct the thoracic curvature
- For right thoracic curves: pad is on the RIGHT back
- Depth: 8-25mm depending on Cobb angle severity
2. THORACIC BAY (LM_BAY_TH) - EXPANSION ZONE
- Location: OPPOSITE the thoracic pad (concave side)
- Function: Creates SPACE for the body to move INTO during correction
- The ribs/body shift into this space as the pad pushes
- Clearance: 10-35mm
3. LUMBAR PAD (LM_PAD_LUM) - PUSH ZONE
- Location: On the CONVEX side of the lumbar curve
- Function: Pushes INWARD to correct lumbar curvature
- Usually on the opposite side of thoracic pad (for S-curves)
- Depth: 6-20mm
4. LUMBAR BAY (LM_BAY_LUM) - EXPANSION ZONE
- Location: OPPOSITE the lumbar pad
- Function: Creates SPACE for lumbar correction
- Clearance: 8-25mm
5. HIP ANCHORS (LM_ANCHOR_HIP_L/R) - STABILITY ZONES
- Location: Around the hip/pelvis area on both sides
- Function: Grip the pelvis to prevent brace from riding up
- Slight inward pressure to anchor the brace
The Rigo classification determines which zones are primary:
- A types (3-curve): Strong thoracic pad, minor lumbar
- B types (4-curve): Both thoracic and lumbar pads are primary
- C types (non-3-non-4): Balanced thoracic, neutral pelvis
- E types (single lumbar/TL): Strong lumbar/TL pad, counter-thoracic
"""
import json
import numpy as np
import trimesh
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, Literal
from dataclasses import dataclass, asdict
# Paths to template directories
BASE_DIR = Path(__file__).parent.parent
BRACES_DIR = BASE_DIR / "braces"
REGULAR_TEMPLATES_DIR = BRACES_DIR / "brace_templates"
VASE_TEMPLATES_DIR = BRACES_DIR / "vase_brace_templates"
# Template types
TemplateType = Literal["regular", "vase"]
@dataclass
class MarkerPositions:
"""Marker positions for a brace template."""
LM_PELVIS_CENTER: Tuple[float, float, float]
LM_TOP_CENTER: Tuple[float, float, float]
LM_PAD_TH: Tuple[float, float, float]
LM_BAY_TH: Tuple[float, float, float]
LM_PAD_LUM: Tuple[float, float, float]
LM_BAY_LUM: Tuple[float, float, float]
LM_ANCHOR_HIP_L: Tuple[float, float, float]
LM_ANCHOR_HIP_R: Tuple[float, float, float]
@dataclass
class PressureZone:
"""Describes a pressure or expansion zone on the brace."""
name: str
marker_name: str
position: Tuple[float, float, float]
zone_type: Literal["pad", "bay", "anchor"]
function: str
direction: Literal["inward", "outward", "grip"]
depth_mm: float = 0.0
radius_mm: Tuple[float, float, float] = (50.0, 80.0, 40.0)
@dataclass
class BraceGenerationResult:
"""Result of brace generation with markers."""
glb_path: str
stl_path: str
json_path: str
template_type: str
rigo_type: str
markers: Dict[str, Tuple[float, float, float]]
basis: Dict[str, Any]
pressure_zones: list
mesh_stats: Dict[str, int]
transform_applied: Optional[Dict[str, Any]] = None
def get_template_paths(rigo_type: str, template_type: TemplateType) -> Tuple[Path, Path]:
"""
Get paths to GLB template and markers JSON.
Args:
rigo_type: Rigo classification (A1, A2, A3, B1, B2, C1, C2, E1, E2)
template_type: "regular" or "vase"
Returns:
Tuple of (glb_path, markers_json_path)
"""
if template_type == "regular":
glb_path = REGULAR_TEMPLATES_DIR / f"{rigo_type}_marked_v3.glb"
json_path = REGULAR_TEMPLATES_DIR / f"{rigo_type}_marked_v3.markers.json"
else: # vase
glb_path = VASE_TEMPLATES_DIR / "glb" / f"{rigo_type}_vase_marked.glb"
json_path = VASE_TEMPLATES_DIR / "markers_json" / f"{rigo_type}_vase_marked.markers.json"
return glb_path, json_path
def load_template_markers(rigo_type: str, template_type: TemplateType) -> Dict[str, Any]:
"""Load markers from JSON file for a template."""
_, json_path = get_template_paths(rigo_type, template_type)
if not json_path.exists():
raise FileNotFoundError(f"Markers JSON not found: {json_path}")
with open(json_path, "r") as f:
return json.load(f)
def load_glb_template(rigo_type: str, template_type: TemplateType) -> trimesh.Trimesh:
"""Load GLB template as trimesh."""
glb_path, _ = get_template_paths(rigo_type, template_type)
if not glb_path.exists():
raise FileNotFoundError(f"GLB template not found: {glb_path}")
scene = trimesh.load(str(glb_path))
# If it's a scene, concatenate all meshes
if isinstance(scene, trimesh.Scene):
meshes = [g for g in scene.geometry.values() if isinstance(g, trimesh.Trimesh)]
if meshes:
mesh = trimesh.util.concatenate(meshes)
else:
raise ValueError(f"No valid meshes found in GLB: {glb_path}")
else:
mesh = scene
return mesh
def calculate_pressure_zones(
markers: Dict[str, Any],
rigo_type: str,
cobb_angles: Dict[str, float]
) -> list:
"""
Calculate pressure zone parameters based on markers and analysis.
Args:
markers: Marker positions from template
rigo_type: Rigo classification
cobb_angles: Dict with PT, MT, TL Cobb angles
Returns:
List of PressureZone objects
"""
marker_pos = markers.get("markers", markers)
# Calculate severity from Cobb angles
mt_angle = cobb_angles.get("MT", 0)
tl_angle = cobb_angles.get("TL", 0)
# Severity mapping: Cobb -> depth
def cobb_to_depth(angle: float, min_depth: float = 6.0, max_depth: float = 22.0) -> float:
severity = min(max((angle - 10) / 40, 0), 1) # 0-1 range
return min_depth + severity * (max_depth - min_depth)
th_depth = cobb_to_depth(mt_angle, 8.0, 22.0)
lum_depth = cobb_to_depth(tl_angle, 6.0, 18.0)
# Bay clearance is typically 1.2-1.5x pad depth
th_clearance = th_depth * 1.3 + 5
lum_clearance = lum_depth * 1.3 + 4
zones = [
PressureZone(
name="Thoracic Pad",
marker_name="LM_PAD_TH",
position=tuple(marker_pos.get("LM_PAD_TH", [0, 0, 0])),
zone_type="pad",
function="Pushes INWARD on thoracic curve convex side to correct curvature",
direction="inward",
depth_mm=th_depth,
radius_mm=(50.0, 90.0, 40.0)
),
PressureZone(
name="Thoracic Bay",
marker_name="LM_BAY_TH",
position=tuple(marker_pos.get("LM_BAY_TH", [0, 0, 0])),
zone_type="bay",
function="Creates SPACE on thoracic concave side for body to shift into",
direction="outward",
depth_mm=th_clearance,
radius_mm=(65.0, 110.0, 55.0)
),
PressureZone(
name="Lumbar Pad",
marker_name="LM_PAD_LUM",
position=tuple(marker_pos.get("LM_PAD_LUM", [0, 0, 0])),
zone_type="pad",
function="Pushes INWARD on lumbar curve convex side to correct curvature",
direction="inward",
depth_mm=lum_depth,
radius_mm=(55.0, 85.0, 45.0)
),
PressureZone(
name="Lumbar Bay",
marker_name="LM_BAY_LUM",
position=tuple(marker_pos.get("LM_BAY_LUM", [0, 0, 0])),
zone_type="bay",
function="Creates SPACE on lumbar concave side for body to shift into",
direction="outward",
depth_mm=lum_clearance,
radius_mm=(70.0, 100.0, 60.0)
),
PressureZone(
name="Left Hip Anchor",
marker_name="LM_ANCHOR_HIP_L",
position=tuple(marker_pos.get("LM_ANCHOR_HIP_L", [0, 0, 0])),
zone_type="anchor",
function="Grips left hip/pelvis to stabilize brace and prevent riding up",
direction="grip",
depth_mm=4.0,
radius_mm=(40.0, 60.0, 40.0)
),
PressureZone(
name="Right Hip Anchor",
marker_name="LM_ANCHOR_HIP_R",
position=tuple(marker_pos.get("LM_ANCHOR_HIP_R", [0, 0, 0])),
zone_type="anchor",
function="Grips right hip/pelvis to stabilize brace and prevent riding up",
direction="grip",
depth_mm=4.0,
radius_mm=(40.0, 60.0, 40.0)
),
]
return zones
def transform_markers(
markers: Dict[str, list],
transform_matrix: np.ndarray
) -> Dict[str, Tuple[float, float, float]]:
"""Apply transformation matrix to all marker positions."""
transformed = {}
for name, pos in markers.items():
if isinstance(pos, (list, tuple)) and len(pos) == 3:
# Convert to homogeneous coordinates
pos_h = np.array([pos[0], pos[1], pos[2], 1.0])
# Apply transform
new_pos = transform_matrix @ pos_h
transformed[name] = (float(new_pos[0]), float(new_pos[1]), float(new_pos[2]))
return transformed
def generate_glb_brace(
rigo_type: str,
template_type: TemplateType,
output_dir: Path,
case_id: str,
cobb_angles: Dict[str, float],
body_scan_path: Optional[str] = None,
clearance_mm: float = 8.0
) -> BraceGenerationResult:
"""
Generate a GLB brace with markers.
Args:
rigo_type: Rigo classification (A1, A2, etc.)
template_type: "regular" or "vase"
output_dir: Directory for output files
case_id: Case identifier
cobb_angles: Dict with PT, MT, TL angles
body_scan_path: Optional path to body scan STL for fitting
clearance_mm: Clearance between body and brace
Returns:
BraceGenerationResult with paths and marker info
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load template
mesh = load_glb_template(rigo_type, template_type)
marker_data = load_template_markers(rigo_type, template_type)
markers = marker_data.get("markers", {})
basis = marker_data.get("basis", {})
transform_matrix = np.eye(4)
transform_info = None
# If body scan provided, fit to body
if body_scan_path and Path(body_scan_path).exists():
mesh, transform_matrix, transform_info = fit_brace_to_body(
mesh, body_scan_path, clearance_mm, brace_basis=basis, template_type=template_type,
markers=markers # Pass markers for zone-aware ironing
)
# Transform markers to match the body-fitted mesh
markers = transform_markers(markers, transform_matrix)
# Calculate pressure zones
pressure_zones = calculate_pressure_zones(markers, rigo_type, cobb_angles)
# Output file names
type_suffix = "vase" if template_type == "vase" else "regular"
glb_filename = f"{case_id}_{rigo_type}_{type_suffix}.glb"
stl_filename = f"{case_id}_{rigo_type}_{type_suffix}.stl"
json_filename = f"{case_id}_{rigo_type}_{type_suffix}_markers.json"
glb_path = output_dir / glb_filename
stl_path = output_dir / stl_filename
json_path = output_dir / json_filename
# Export GLB
mesh.export(str(glb_path))
# Export STL
mesh.export(str(stl_path))
# Build output JSON with markers and zones
output_data = {
"case_id": case_id,
"rigo_type": rigo_type,
"template_type": template_type,
"cobb_angles": cobb_angles,
"markers": {k: list(v) if isinstance(v, tuple) else v for k, v in markers.items()},
"basis": basis,
"pressure_zones": [
{
"name": z.name,
"marker_name": z.marker_name,
"position": list(z.position),
"zone_type": z.zone_type,
"function": z.function,
"direction": z.direction,
"depth_mm": z.depth_mm,
"radius_mm": list(z.radius_mm)
}
for z in pressure_zones
],
"mesh_stats": {
"vertices": len(mesh.vertices),
"faces": len(mesh.faces)
},
"outputs": {
"glb": str(glb_path),
"stl": str(stl_path),
"json": str(json_path)
}
}
if transform_info:
output_data["body_fitting"] = transform_info
# Save JSON
with open(json_path, "w") as f:
json.dump(output_data, f, indent=2)
return BraceGenerationResult(
glb_path=str(glb_path),
stl_path=str(stl_path),
json_path=str(json_path),
template_type=template_type,
rigo_type=rigo_type,
markers=markers,
basis=basis,
pressure_zones=[asdict(z) for z in pressure_zones],
mesh_stats=output_data["mesh_stats"],
transform_applied=transform_info
)
def iron_brace_to_body(
brace_mesh: trimesh.Trimesh,
body_mesh: trimesh.Trimesh,
min_clearance_mm: float = 3.0,
max_clearance_mm: float = 15.0,
smoothing_iterations: int = 2,
up_axis: int = 2,
markers: Optional[Dict[str, Any]] = None
) -> trimesh.Trimesh:
"""
Iron the brace surface to conform to the body scan surface.
This ensures the brace follows the body contour without excessive gaps.
Uses zone-aware ironing:
- FRONT (belly) and BACK: Aggressive ironing for tight fit
- SIDES (where pads/bays are): Preserve correction zones, moderate ironing
Args:
brace_mesh: The brace mesh to iron
body_mesh: The body scan mesh to conform to
min_clearance_mm: Minimum distance from body surface
max_clearance_mm: Maximum distance from body surface (trigger ironing)
smoothing_iterations: Number of Laplacian smoothing passes after ironing
up_axis: Which axis is "up" (0=X, 1=Y, 2=Z)
markers: Optional dict of marker positions to preserve pressure zones
Returns:
Ironed brace mesh
"""
from scipy.spatial import cKDTree
import math
print(f"Ironing brace to body surface (clearance: {min_clearance_mm}-{max_clearance_mm}mm)")
# Create a copy to modify
ironed_mesh = brace_mesh.copy()
vertices = ironed_mesh.vertices.copy()
# Get body center and bounds
body_center = body_mesh.centroid
body_bounds = body_mesh.bounds
# Determine the torso region (process middle 80% of body height)
body_height = body_bounds[1, up_axis] - body_bounds[0, up_axis]
torso_bottom = body_bounds[0, up_axis] + body_height * 0.10
torso_top = body_bounds[0, up_axis] + body_height * 0.90
# Build KD-tree from body mesh vertices for fast nearest neighbor queries
body_tree = cKDTree(body_mesh.vertices)
# Find closest points on body for ALL brace vertices at once
distances, closest_indices = body_tree.query(vertices, k=1)
closest_points = body_mesh.vertices[closest_indices]
# Determine horizontal axes (perpendicular to up axis)
horiz_axes = [i for i in range(3) if i != up_axis]
# Calculate brace center for angle computation
brace_center = np.mean(vertices, axis=0)
# Identify marker exclusion zones (preserve correction areas)
exclusion_zones = []
if markers:
# Pad and bay markers need preservation
for marker_name in ['LM_PAD_TH', 'LM_PAD_LUM', 'LM_BAY_TH', 'LM_BAY_LUM']:
if marker_name in markers:
pos = markers[marker_name]
if isinstance(pos, (list, tuple)) and len(pos) >= 3:
exclusion_zones.append({
'center': np.array(pos),
'radius': 60.0, # 60mm exclusion radius around markers
'name': marker_name
})
# Process each brace vertex
adjusted_count = 0
pulled_in_count = 0
pushed_out_count = 0
skipped_zone_count = 0
# Height normalization
brace_min_z = vertices[:, up_axis].min()
brace_max_z = vertices[:, up_axis].max()
brace_height_range = max(brace_max_z - brace_min_z, 1.0)
for i in range(len(vertices)):
vertex = vertices[i]
closest_pt = closest_points[i]
dist = distances[i]
# Only process vertices in the torso region
if vertex[up_axis] < torso_bottom or vertex[up_axis] > torso_top:
continue
# Check if vertex is in an exclusion zone (near pad/bay markers)
in_exclusion = False
for zone in exclusion_zones:
zone_dist = np.linalg.norm(vertex - zone['center'])
if zone_dist < zone['radius']:
in_exclusion = True
skipped_zone_count += 1
break
if in_exclusion:
continue
# Calculate angular position around body center (horizontal plane)
# 0° = front (belly), 90° = right side, 180° = back, 270° = left side
rel_pos = vertex - body_center
angle = math.atan2(rel_pos[horiz_axes[1]], rel_pos[horiz_axes[0]])
angle_deg = math.degrees(angle) % 360
# Determine zone based on angle:
# FRONT (belly): 315-45° - aggressive ironing
# BACK: 135-225° - aggressive ironing
# SIDES: 45-135° and 225-315° - moderate ironing (correction zones)
is_front_back = (angle_deg < 45 or angle_deg > 315) or (135 < angle_deg < 225)
# Height-based clearance adjustment
height_norm = (vertex[up_axis] - brace_min_z) / brace_height_range
# Set clearances based on zone
if is_front_back:
# FRONT/BACK: Aggressive ironing - very tight fit
local_min = min_clearance_mm * 0.5 # Allow closer to body
local_max = max_clearance_mm * 0.6 # Trigger ironing earlier
local_target = min_clearance_mm + 2.0 # Target just above minimum
else:
# SIDES: More conservative - preserve room for correction
local_min = min_clearance_mm
local_max = max_clearance_mm * 1.2 # Allow slightly more gap
local_target = (min_clearance_mm + max_clearance_mm) / 2
# Height adjustments (tighter at hips and chest)
if height_norm < 0.25 or height_norm > 0.75:
local_max *= 0.8 # Tighter at extremes
local_target *= 0.85
# Direction from body surface to brace vertex
direction = vertex - closest_pt
dir_length = np.linalg.norm(direction)
if dir_length < 1e-6:
direction = vertex - body_center
direction[up_axis] = 0
dir_length = np.linalg.norm(direction)
if dir_length < 1e-6:
continue
direction = direction / dir_length
# Determine signed distance
vertex_dist_to_center = np.linalg.norm(vertex[:2] - body_center[:2])
closest_dist_to_center = np.linalg.norm(closest_pt[:2] - body_center[:2])
if vertex_dist_to_center >= closest_dist_to_center:
signed_distance = dist
else:
signed_distance = -dist
# Determine if adjustment is needed
needs_adjustment = False
new_position = vertex.copy()
if signed_distance > local_max:
# Gap too large - pull vertex closer to body
new_position = closest_pt + direction * local_target
new_position[up_axis] = vertex[up_axis] # Preserve height
needs_adjustment = True
pulled_in_count += 1
elif signed_distance < local_min:
# Too close or inside body - push outward
offset = local_min + 1.0
outward_dir = closest_pt - body_center
outward_dir[up_axis] = 0
outward_length = np.linalg.norm(outward_dir)
if outward_length > 1e-6:
outward_dir = outward_dir / outward_length
new_position = closest_pt + outward_dir * offset
new_position[up_axis] = vertex[up_axis]
needs_adjustment = True
pushed_out_count += 1
if needs_adjustment:
vertices[i] = new_position
adjusted_count += 1
print(f"Ironing adjusted {adjusted_count} vertices (pulled in: {pulled_in_count}, pushed out: {pushed_out_count}, skipped zones: {skipped_zone_count})")
# Apply modified vertices
ironed_mesh.vertices = vertices
# Apply Laplacian smoothing to blend changes and remove artifacts
if smoothing_iterations > 0 and adjusted_count > 0:
print(f"Applying {smoothing_iterations} smoothing iterations")
try:
ironed_mesh = trimesh.smoothing.filter_laplacian(
ironed_mesh,
lamb=0.3, # Gentler smoothing to preserve shape
iterations=smoothing_iterations,
implicit_time_integration=False
)
except Exception as e:
print(f"Smoothing failed (non-critical): {e}")
# Ensure mesh is valid
ironed_mesh.fix_normals()
return ironed_mesh
def fit_brace_to_body(
brace_mesh: trimesh.Trimesh,
body_scan_path: str,
clearance_mm: float = 8.0,
brace_basis: Optional[Dict[str, Any]] = None,
template_type: str = "regular",
enable_ironing: bool = True,
markers: Optional[Dict[str, Any]] = None
) -> Tuple[trimesh.Trimesh, np.ndarray, Dict[str, Any]]:
"""
Fit brace to body scan using basis alignment.
The brace needs to be:
1. Rotated so its UP axis aligns with body's UP axis (typically Z for body scans)
2. Scaled to fit around the body with proper clearance
3. Positioned at the torso level
4. Ironed to conform to body surface (respecting correction zones)
Returns:
Tuple of (transformed_mesh, transform_matrix, fitting_info)
"""
# Load body scan
body_mesh = trimesh.load(body_scan_path, force='mesh')
# Get body dimensions
body_bounds = body_mesh.bounds
body_extents = body_mesh.extents
body_center = body_mesh.centroid
# Determine body up axis (typically the longest dimension = height)
# For human body scans, this is usually Z (from 3D scanners) or Y
body_up_axis_idx = np.argmax(body_extents)
print(f"Body up axis: {['X', 'Y', 'Z'][body_up_axis_idx]}, extents: {body_extents}")
# Get brace dimensions
brace_bounds = brace_mesh.bounds
brace_extents = brace_mesh.extents
brace_center = brace_mesh.centroid
print(f"Brace original extents: {brace_extents}, template_type: {template_type}")
# Start building transformation
transformed_mesh = brace_mesh.copy()
transform = np.eye(4)
# Step 1: Center brace at origin
T_center = np.eye(4)
T_center[:3, 3] = -brace_center
transformed_mesh.apply_transform(T_center)
transform = T_center @ transform
# Step 2: Apply rotations based on template type and body orientation
# Regular templates have: negative Y is up (inverted), need to flip
# Vase templates have: positive Y is up
# Body scan is Z-up
if body_up_axis_idx == 2: # Body is Z-up (standard for 3D scanners)
if template_type == "regular":
# Regular brace: -Y is up (inverted)
# 1. Rotate -90° around X to bring Y-up to Z-up
R1 = trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0])
transformed_mesh.apply_transform(R1)
transform = R1 @ transform
# 2. The brace is now Z-up but inverted (pelvis at top, shoulders at bottom)
# Flip 180° around X to correct (this keeps Z as up axis)
R2 = trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])
transformed_mesh.apply_transform(R2)
transform = R2 @ transform
# 3. Rotate around Z to face forward correctly
R3 = trimesh.transformations.rotation_matrix(-np.pi/2, [0, 0, 1])
transformed_mesh.apply_transform(R3)
transform = R3 @ transform
print(f"Applied regular brace rotations: X-90°, X+180° (flip), Z-90°")
else: # vase
# Vase brace: positive Y is up
# 1. Rotate -90° around X to bring Y-up to Z-up
R1 = trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0])
transformed_mesh.apply_transform(R1)
transform = R1 @ transform
# 2. Flip 180° around Y to correct orientation (right-side up)
R2 = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
transformed_mesh.apply_transform(R2)
transform = R2 @ transform
print(f"Applied vase brace rotations: X-90°, Y+180° (flip)")
# Step 3: Get new brace dimensions after rotation
new_brace_extents = transformed_mesh.extents
new_brace_center = transformed_mesh.centroid
print(f"Brace extents after rotation: {new_brace_extents}")
# Step 4: Calculate NON-UNIFORM scaling based on body dimensions
# The brace should cover the TORSO region (~50% of body height)
# AND wrap around the body with proper girth
body_height = body_extents[body_up_axis_idx]
brace_height = new_brace_extents[body_up_axis_idx] # After rotation, this is the height
# Body horizontal dimensions (girth at torso level)
horizontal_axes = [i for i in range(3) if i != body_up_axis_idx]
body_width = body_extents[horizontal_axes[0]] # X width
body_depth = body_extents[horizontal_axes[1]] # Y depth
# Brace horizontal dimensions
brace_width = new_brace_extents[horizontal_axes[0]]
brace_depth = new_brace_extents[horizontal_axes[1]]
# Target: brace height should cover ~65% of body height (full torso coverage)
target_height = body_height * 0.65
height_scale = target_height / brace_height if brace_height > 0 else 1.0
# Target: brace width/depth should be LARGER than body to wrap AROUND it
# The brace sits OUTSIDE the body, only pressure points push inward
# Add ~25% extra + clearance so brace externals are visible outside body
target_width = body_width * 1.25 + clearance_mm * 2
target_depth = body_depth * 1.25 + clearance_mm * 2
width_scale = target_width / brace_width if brace_width > 0 else 1.0
depth_scale = target_depth / brace_depth if brace_depth > 0 else 1.0
# Apply non-uniform scaling
# Determine which axis is which after rotation
S = np.eye(4)
if body_up_axis_idx == 2: # Z is up
S[0, 0] = width_scale # X scale
S[1, 1] = depth_scale # Y scale
S[2, 2] = height_scale # Z scale (height)
elif body_up_axis_idx == 1: # Y is up
S[0, 0] = width_scale # X scale
S[1, 1] = height_scale # Y scale (height)
S[2, 2] = depth_scale # Z scale
else: # X is up (unusual)
S[0, 0] = height_scale # X scale (height)
S[1, 1] = width_scale # Y scale
S[2, 2] = depth_scale # Z scale
# Limit scales to reasonable range
S[0, 0] = max(0.5, min(S[0, 0], 50.0))
S[1, 1] = max(0.5, min(S[1, 1], 50.0))
S[2, 2] = max(0.5, min(S[2, 2], 50.0))
transformed_mesh.apply_transform(S)
transform = S @ transform
print(f"Applied non-uniform scale: width={S[0,0]:.2f}, depth={S[1,1]:.2f}, height={S[2,2]:.2f}")
print(f"Target dimensions: width={target_width:.1f}, depth={target_depth:.1f}, height={target_height:.1f}")
# For fitting_info, use average scale
scale = (S[0, 0] + S[1, 1] + S[2, 2]) / 3
# Step 6: Position brace at torso level
# Calculate where the torso is (middle portion of body height)
body_height = body_extents[body_up_axis_idx]
body_bottom = body_bounds[0, body_up_axis_idx]
body_top = body_bounds[1, body_up_axis_idx]
# Torso is roughly the middle 40% of body height (from ~30% to ~70%)
torso_center_ratio = 0.5 # Middle of body
torso_center_height = body_bottom + body_height * torso_center_ratio
# Target position: center horizontally on body, at torso height vertically
target_center = body_center.copy()
target_center[body_up_axis_idx] = torso_center_height
# Current brace center after transformations
current_center = transformed_mesh.centroid
T_position = np.eye(4)
T_position[:3, 3] = target_center - current_center
transformed_mesh.apply_transform(T_position)
transform = T_position @ transform
# Step 7: Iron brace to conform to body surface (eliminate gaps and humps)
# Transform markers so we can exclude correction zones from ironing
transformed_markers = None
if markers:
transformed_markers = transform_markers(markers, transform)
ironing_info = {}
if enable_ironing:
try:
print(f"Starting brace ironing to body surface...")
pre_iron_extents = transformed_mesh.extents.copy()
transformed_mesh = iron_brace_to_body(
brace_mesh=transformed_mesh,
body_mesh=body_mesh,
min_clearance_mm=clearance_mm * 0.4, # Allow closer for tight fit
max_clearance_mm=clearance_mm * 1.5, # Iron areas with gaps > 1.5x clearance
smoothing_iterations=3,
up_axis=body_up_axis_idx,
markers=transformed_markers
)
post_iron_extents = transformed_mesh.extents
ironing_info = {
"enabled": True,
"pre_iron_extents": pre_iron_extents.tolist(),
"post_iron_extents": post_iron_extents.tolist(),
"min_clearance_mm": clearance_mm * 0.5,
"max_clearance_mm": clearance_mm * 2.0,
}
print(f"Ironing complete. Extents changed from {pre_iron_extents} to {post_iron_extents}")
except Exception as e:
print(f"Ironing failed (non-critical): {e}")
ironing_info = {"enabled": False, "error": str(e)}
else:
ironing_info = {"enabled": False}
fitting_info = {
"scale_avg": float(scale),
"scale_x": float(S[0, 0]),
"scale_y": float(S[1, 1]),
"scale_z": float(S[2, 2]),
"template_type": template_type,
"body_extents": body_extents.tolist(),
"brace_extents_original": brace_extents.tolist(),
"brace_extents_final": transformed_mesh.extents.tolist(),
"clearance_mm": clearance_mm,
"body_center": body_center.tolist(),
"final_center": transformed_mesh.centroid.tolist(),
"body_up_axis": int(body_up_axis_idx),
"ironing": ironing_info,
}
return transformed_mesh, transform, fitting_info
def generate_both_brace_types(
rigo_type: str,
output_dir: Path,
case_id: str,
cobb_angles: Dict[str, float],
body_scan_path: Optional[str] = None,
clearance_mm: float = 8.0
) -> Dict[str, BraceGenerationResult]:
"""
Generate both regular and vase brace types for comparison.
Returns:
Dict with "regular" and "vase" results
"""
results = {}
# Generate regular brace
try:
results["regular"] = generate_glb_brace(
rigo_type=rigo_type,
template_type="regular",
output_dir=output_dir,
case_id=case_id,
cobb_angles=cobb_angles,
body_scan_path=body_scan_path,
clearance_mm=clearance_mm
)
except FileNotFoundError as e:
results["regular"] = {"error": str(e)}
# Generate vase brace
try:
results["vase"] = generate_glb_brace(
rigo_type=rigo_type,
template_type="vase",
output_dir=output_dir,
case_id=case_id,
cobb_angles=cobb_angles,
body_scan_path=body_scan_path,
clearance_mm=clearance_mm
)
except FileNotFoundError as e:
results["vase"] = {"error": str(e)}
return results
# Available templates
AVAILABLE_RIGO_TYPES = ["A1", "A2", "A3", "B1", "B2", "C1", "C2", "E1", "E2"]
def list_available_templates() -> Dict[str, list]:
"""List all available template files."""
regular = []
vase = []
for rigo_type in AVAILABLE_RIGO_TYPES:
glb_path, _ = get_template_paths(rigo_type, "regular")
if glb_path.exists():
regular.append(rigo_type)
glb_path, _ = get_template_paths(rigo_type, "vase")
if glb_path.exists():
vase.append(rigo_type)
return {
"regular": regular,
"vase": vase
}

View File

@@ -0,0 +1,26 @@
# Server dependencies
fastapi>=0.100.0
uvicorn[standard]>=0.22.0
python-multipart>=0.0.6
pydantic>=2.0.0
requests>=2.28.0
# AWS SDK
boto3>=1.26.0
# Core ML dependencies
torch>=2.0.0
torchvision>=0.15.0
# Image processing
numpy>=1.20.0
scipy>=1.7.0
pillow>=8.0.0
opencv-python-headless>=4.5.0
pydicom>=2.2.0
# 3D mesh processing
trimesh>=3.10.0
# Visualization
matplotlib>=3.4.0

990
brace-generator/routes.py Normal file
View File

@@ -0,0 +1,990 @@
"""
API routes for Brace Generator.
Note: S3 operations are handled by the Lambda function.
This server only handles ML inference and returns local file paths.
"""
import torch
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Request
from fastapi.responses import FileResponse
from typing import Optional
import json
from pathlib import Path
from .schemas import (
AnalysisResult, HealthResponse, ExperimentType, BraceConfigRequest
)
from .config import config
router = APIRouter()
@router.get("/", summary="Root endpoint")
async def root():
"""Welcome endpoint."""
return {
"service": "Brace Generator API",
"version": "1.0.0",
"docs": "/docs",
"health": "/health"
}
@router.get("/health", response_model=HealthResponse, summary="Health check")
async def health_check():
"""Check server health and GPU status."""
cuda_available = torch.cuda.is_available()
gpu_name = None
gpu_memory_mb = None
if cuda_available:
gpu_name = torch.cuda.get_device_name(0)
gpu_memory_mb = int(torch.cuda.get_device_properties(0).total_memory / (1024**2))
return HealthResponse(
status="healthy",
device=config.get_device(),
cuda_available=cuda_available,
model_loaded=True,
gpu_name=gpu_name,
gpu_memory_mb=gpu_memory_mb
)
@router.post("/analyze/upload", response_model=AnalysisResult, summary="Analyze uploaded X-ray")
async def analyze_upload(
req: Request,
file: UploadFile = File(..., description="X-ray image file"),
case_id: Optional[str] = Form(None, description="Case ID"),
experiment: str = Form("experiment_3", description="Experiment type"),
config_json: Optional[str] = Form(None, description="Brace config as JSON"),
landmarks_json: Optional[str] = Form(None, description="Pre-computed landmarks with manual edits")
):
"""
Analyze an uploaded X-ray image and generate brace.
This endpoint accepts multipart/form-data for direct file upload.
Returns analysis results with local file paths that can be downloaded
via the /download endpoint.
If landmarks_json is provided, it will use those landmarks (with manual edits)
instead of re-running automatic detection. This allows manual corrections
to be incorporated into the brace generation.
The Lambda function is responsible for:
1. Downloading the X-ray from S3
2. Calling this endpoint
3. Downloading output files via /download
4. Uploading files to S3
"""
# Validate file
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
# Check file size
contents = await file.read()
if len(contents) > config.MAX_IMAGE_SIZE_MB * 1024 * 1024:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size is {config.MAX_IMAGE_SIZE_MB}MB"
)
# Parse config if provided
brace_config = None
if config_json:
try:
config_data = json.loads(config_json)
brace_config = BraceConfigRequest(**config_data)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"Invalid config: {e}")
# Parse landmarks if provided (manual edits)
landmarks_data = None
if landmarks_json:
try:
landmarks_data = json.loads(landmarks_json)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid landmarks JSON: {e}")
# Parse experiment type
try:
exp_type = ExperimentType(experiment)
except ValueError:
exp_type = ExperimentType.EXPERIMENT_3
service = req.app.state.brace_service
try:
result = await service.analyze_from_bytes(
image_data=contents,
filename=file.filename,
experiment=exp_type,
case_id=case_id,
brace_config=brace_config,
landmarks_data=landmarks_data # Pass pre-computed landmarks
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/download/{case_id}/{filename}", summary="Download output file")
async def download_file(case_id: str, filename: str):
"""
Download a generated output file.
This endpoint is called by the Lambda function to retrieve
generated files (STL, PLY, PNG, JSON) for upload to S3.
"""
file_path = config.TEMP_DIR / case_id / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
# Determine media type
ext = file_path.suffix.lower()
media_types = {
".stl": "application/octet-stream",
".ply": "application/octet-stream",
".obj": "application/octet-stream",
".glb": "model/gltf-binary",
".gltf": "model/gltf+json",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".json": "application/json",
}
media_type = media_types.get(ext, "application/octet-stream")
return FileResponse(
path=str(file_path),
filename=filename,
media_type=media_type
)
@router.post("/extract-body-measurements", summary="Extract body measurements from 3D scan")
async def extract_body_measurements(
file: UploadFile = File(..., description="3D body scan file (STL/OBJ/PLY)")
):
"""
Extract body measurements from a 3D body scan.
Returns measurements needed for brace fitting:
- Total height
- Shoulder, chest, waist, hip widths and depths
- Circumferences
- Brace coverage region
"""
import tempfile
from pathlib import Path
try:
from server_DEV.body_integration import extract_measurements_from_scan
except ImportError as e:
raise HTTPException(status_code=500, detail=f"Body integration module not available: {e}")
# Validate file type
allowed_extensions = ['.stl', '.obj', '.ply', '.glb', '.gltf']
ext = Path(file.filename).suffix.lower() if file.filename else '.stl'
if ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}"
)
# Save to temp file
contents = await file.read()
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f:
f.write(contents)
temp_path = f.name
try:
measurements = extract_measurements_from_scan(temp_path)
return measurements
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
# Cleanup
Path(temp_path).unlink(missing_ok=True)
@router.post("/generate-with-body", summary="Generate brace with body scan fitting")
async def generate_with_body_scan(
req: Request,
xray_file: UploadFile = File(..., description="X-ray image"),
body_scan_file: UploadFile = File(..., description="3D body scan (STL/OBJ/PLY)"),
case_id: Optional[str] = Form(None, description="Case ID"),
landmarks_json: Optional[str] = Form(None, description="Pre-computed landmarks"),
clearance_mm: float = Form(8.0, description="Shell clearance in mm"),
):
"""
Generate a patient-specific brace using X-ray analysis and 3D body scan.
This endpoint:
1. Analyzes X-ray to detect spine landmarks and compute Cobb angles
2. Classifies curve type using Rigo-Cheneau system
3. Fits a shell template to the 3D body scan
4. Returns STL, GLB, and visualization files
"""
import tempfile
import uuid
from pathlib import Path
try:
from server_DEV.body_integration import generate_fitted_brace, extract_measurements_from_scan
except ImportError as e:
raise HTTPException(status_code=500, detail=f"Body integration module not available: {e}")
# Generate case ID if not provided
case_id = case_id or f"case_{uuid.uuid4().hex[:8]}"
# Save files to temp directory
temp_dir = config.TEMP_DIR / case_id
temp_dir.mkdir(parents=True, exist_ok=True)
# Save X-ray
xray_contents = await xray_file.read()
xray_ext = Path(xray_file.filename).suffix if xray_file.filename else '.jpg'
xray_path = temp_dir / f"xray{xray_ext}"
xray_path.write_bytes(xray_contents)
# Save body scan
body_contents = await body_scan_file.read()
body_ext = Path(body_scan_file.filename).suffix if body_scan_file.filename else '.stl'
body_scan_path = temp_dir / f"body_scan{body_ext}"
body_scan_path.write_bytes(body_contents)
try:
# Parse landmarks if provided
landmarks_data = None
if landmarks_json:
import json
landmarks_data = json.loads(landmarks_json)
# Step 1: Analyze X-ray to get Rigo classification (this generates the brace)
service = req.app.state.brace_service
xray_result = await service.analyze_from_bytes(
image_data=xray_contents,
filename=xray_file.filename,
experiment=ExperimentType.EXPERIMENT_3,
case_id=case_id,
landmarks_data=landmarks_data
)
rigo_type = xray_result.rigo_classification.type if xray_result.rigo_classification else "A1"
# Step 2: Try to extract body measurements (optional - EXPERIMENT_10 may not be deployed)
body_measurements = None
fitting_result = None
body_scan_error = None
try:
body_measurements = extract_measurements_from_scan(str(body_scan_path))
# Step 3: Generate fitted brace (only if measurements worked)
fitting_result = generate_fitted_brace(
body_scan_path=str(body_scan_path),
rigo_type=rigo_type,
output_dir=str(temp_dir),
case_id=case_id,
clearance_mm=clearance_mm
)
except Exception as body_err:
print(f"Warning: Body scan processing failed, using X-ray only: {body_err}")
body_scan_error = str(body_err)
# If body fitting worked, return full result
if fitting_result:
return {
"case_id": case_id,
"experiment": "experiment_10",
"model_used": xray_result.model_used,
"vertebrae_detected": xray_result.vertebrae_detected,
"cobb_angles": {
"PT": xray_result.cobb_angles.PT,
"MT": xray_result.cobb_angles.MT,
"TL": xray_result.cobb_angles.TL,
},
"curve_type": xray_result.curve_type,
"rigo_classification": {
"type": rigo_type,
"description": xray_result.rigo_classification.description if xray_result.rigo_classification else ""
},
"body_scan": {
"measurements": body_measurements,
},
"brace_fitting": fitting_result,
"outputs": {
"shell_stl": fitting_result["outputs"]["shell_stl"],
"shell_glb": fitting_result["outputs"]["shell_glb"],
"combined_stl": fitting_result["outputs"]["combined_stl"],
"visualization": fitting_result["outputs"].get("visualization"),
"feedback_json": fitting_result["outputs"]["feedback_json"],
"xray_visualization": str(xray_result.outputs.get("visualization", "")),
},
"mesh_vertices": fitting_result["mesh_stats"]["vertices"],
"mesh_faces": fitting_result["mesh_stats"]["faces"],
"processing_time_ms": xray_result.processing_time_ms,
}
# Fallback: return X-ray only result (body scan processing not available)
return {
"case_id": case_id,
"experiment": "experiment_3_fallback",
"model_used": xray_result.model_used,
"vertebrae_detected": xray_result.vertebrae_detected,
"cobb_angles": {
"PT": xray_result.cobb_angles.PT,
"MT": xray_result.cobb_angles.MT,
"TL": xray_result.cobb_angles.TL,
},
"curve_type": xray_result.curve_type,
"rigo_classification": {
"type": rigo_type,
"description": xray_result.rigo_classification.description if xray_result.rigo_classification else ""
},
"body_scan": {
"error": body_scan_error or "Body scan processing not available",
"fallback": "Using X-ray only brace generation"
},
"outputs": xray_result.outputs,
"mesh_vertices": xray_result.mesh_vertices,
"mesh_faces": xray_result.mesh_faces,
"processing_time_ms": xray_result.processing_time_ms,
}
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@router.get("/experiments", summary="List available experiments")
async def list_experiments():
"""List available brace generation experiments."""
return {
"experiments": [
{
"id": "standard",
"name": "Standard Pipeline",
"description": "Original template-based brace generation using Rigo classification"
},
{
"id": "experiment_3",
"name": "Research-Based Adaptive",
"description": "Adaptive brace generation based on Guy et al. (2024) with patch-based deformation optimization"
},
{
"id": "experiment_10",
"name": "Patient-Specific Body Fitting",
"description": "X-ray analysis + 3D body scan for precise patient-specific brace fitting"
}
],
"default": "experiment_3"
}
@router.get("/models", summary="List available detection models")
async def list_models():
"""List available landmark detection models."""
return {
"models": [
{
"id": "scoliovis",
"name": "ScolioVis",
"description": "Keypoint R-CNN model for vertebrae detection",
"supports_gpu": True
},
{
"id": "vertebra-landmark",
"name": "Vertebra-Landmark-Detection",
"description": "SpineNet-based detection (alternative)",
"supports_gpu": True
}
],
"current": config.MODEL
}
# ============================================
# NEW ENDPOINTS FOR PIPELINE DEV
# ============================================
@router.post("/detect-landmarks", summary="Detect landmarks only (Stage 1)")
async def detect_landmarks(
req: Request,
file: UploadFile = File(..., description="X-ray image file"),
case_id: Optional[str] = Form(None, description="Case ID"),
):
"""
Detect vertebrae landmarks without generating a brace.
Returns landmarks, visualization, and vertebrae_structure for manual editing.
This is Stage 1 of the pipeline - just detection, no brace generation.
"""
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
contents = await file.read()
if len(contents) > config.MAX_IMAGE_SIZE_MB * 1024 * 1024:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size is {config.MAX_IMAGE_SIZE_MB}MB"
)
service = req.app.state.brace_service
try:
result = await service.detect_landmarks_only(
image_data=contents,
filename=file.filename,
case_id=case_id
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/recalculate", summary="Recalculate Cobb/Rigo from landmarks")
async def recalculate_analysis(req: Request):
"""
Recalculate Cobb angles and Rigo classification from provided landmarks.
Use this after manual landmark editing to get updated analysis.
Request body:
{
"case_id": "case-xxx",
"landmarks": { ... vertebrae_structure from detect-landmarks ... }
}
"""
body = await req.json()
case_id = body.get("case_id")
landmarks = body.get("landmarks")
if not landmarks:
raise HTTPException(status_code=400, detail="landmarks data required")
service = req.app.state.brace_service
try:
result = await service.recalculate_from_landmarks(
landmarks_data=landmarks,
case_id=case_id
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# GLB BRACE GENERATION WITH MARKERS
# =============================================================================
from .glb_generator import (
generate_glb_brace,
generate_both_brace_types,
list_available_templates,
calculate_pressure_zones,
load_template_markers,
AVAILABLE_RIGO_TYPES
)
@router.get("/templates", summary="List available brace templates")
async def list_templates():
"""
List all available brace templates (regular and vase types).
Returns which Rigo types have templates available.
"""
return {
"available_templates": list_available_templates(),
"rigo_types": AVAILABLE_RIGO_TYPES,
"template_types": ["regular", "vase"]
}
@router.get("/templates/{rigo_type}/markers", summary="Get template markers")
async def get_template_markers(
rigo_type: str,
template_type: str = "regular"
):
"""
Get marker positions for a specific template.
Args:
rigo_type: Rigo classification (A1, A2, A3, B1, B2, C1, C2, E1, E2)
template_type: "regular" or "vase"
Returns:
Marker positions and basis vectors
"""
if rigo_type not in AVAILABLE_RIGO_TYPES:
raise HTTPException(
status_code=400,
detail=f"Invalid rigo_type. Must be one of: {AVAILABLE_RIGO_TYPES}"
)
if template_type not in ["regular", "vase"]:
raise HTTPException(
status_code=400,
detail="template_type must be 'regular' or 'vase'"
)
try:
markers = load_template_markers(rigo_type, template_type)
return {
"rigo_type": rigo_type,
"template_type": template_type,
**markers
}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/generate-glb", summary="Generate GLB brace with markers")
async def generate_glb_endpoint(
req: Request,
rigo_type: str = Form(..., description="Rigo classification (A1-E2)"),
template_type: str = Form("regular", description="Template type: 'regular' or 'vase'"),
case_id: str = Form(..., description="Case identifier"),
cobb_pt: float = Form(0.0, description="Proximal Thoracic Cobb angle"),
cobb_mt: float = Form(0.0, description="Main Thoracic Cobb angle"),
cobb_tl: float = Form(0.0, description="Thoracolumbar Cobb angle"),
body_scan: Optional[UploadFile] = File(None, description="Optional 3D body scan STL")
):
"""
Generate a GLB brace with embedded markers.
This endpoint generates a brace file that includes marker positions
for later editing. Optionally fits to a body scan.
**Pressure Zones in Output:**
- LM_PAD_TH: Thoracic pad (pushes INWARD on curve convex side)
- LM_BAY_TH: Thoracic bay (creates SPACE on curve concave side)
- LM_PAD_LUM: Lumbar pad (pushes INWARD)
- LM_BAY_LUM: Lumbar bay (creates SPACE)
- LM_ANCHOR_HIP_L/R: Hip anchors (stabilize brace)
Returns:
GLB and STL file paths, marker positions, pressure zone info
"""
if rigo_type not in AVAILABLE_RIGO_TYPES:
raise HTTPException(
status_code=400,
detail=f"Invalid rigo_type. Must be one of: {AVAILABLE_RIGO_TYPES}"
)
if template_type not in ["regular", "vase"]:
raise HTTPException(
status_code=400,
detail="template_type must be 'regular' or 'vase'"
)
import tempfile
from pathlib import Path
output_dir = Path(tempfile.gettempdir()) / "brace_generator" / case_id
output_dir.mkdir(parents=True, exist_ok=True)
body_scan_path = None
# Save body scan if provided
if body_scan:
body_ext = Path(body_scan.filename).suffix if body_scan.filename else ".stl"
body_scan_path = str(output_dir / f"body_scan{body_ext}")
with open(body_scan_path, "wb") as f:
content = await body_scan.read()
f.write(content)
cobb_angles = {
"PT": cobb_pt,
"MT": cobb_mt,
"TL": cobb_tl
}
try:
result = generate_glb_brace(
rigo_type=rigo_type,
template_type=template_type,
output_dir=output_dir,
case_id=case_id,
cobb_angles=cobb_angles,
body_scan_path=body_scan_path,
clearance_mm=8.0
)
return {
"success": True,
"case_id": case_id,
"rigo_type": rigo_type,
"template_type": template_type,
"outputs": {
"glb": result.glb_path,
"stl": result.stl_path,
"json": result.json_path
},
"markers": result.markers,
"basis": result.basis,
"pressure_zones": result.pressure_zones,
"mesh_stats": result.mesh_stats,
"body_fitting": result.transform_applied
}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-both-braces", summary="Generate both brace types for comparison")
async def generate_both_braces_endpoint(
req: Request,
rigo_type: str = Form(..., description="Rigo classification (A1-E2)"),
case_id: str = Form(..., description="Case identifier"),
cobb_pt: float = Form(0.0, description="Proximal Thoracic Cobb angle"),
cobb_mt: float = Form(0.0, description="Main Thoracic Cobb angle"),
cobb_tl: float = Form(0.0, description="Thoracolumbar Cobb angle"),
body_scan: Optional[UploadFile] = File(None, description="Optional 3D body scan STL"),
body_scan_path: Optional[str] = Form(None, description="Optional path to existing body scan file"),
clearance_mm: float = Form(8.0, description="Brace clearance from body in mm")
):
"""
Generate BOTH regular and vase brace types for side-by-side comparison.
This allows the user to compare the two brace shapes and choose
the preferred design.
Returns:
Both brace files with markers and pressure zones
"""
if rigo_type not in AVAILABLE_RIGO_TYPES:
raise HTTPException(
status_code=400,
detail=f"Invalid rigo_type. Must be one of: {AVAILABLE_RIGO_TYPES}"
)
import tempfile
from pathlib import Path
output_dir = Path(tempfile.gettempdir()) / "brace_generator" / case_id
output_dir.mkdir(parents=True, exist_ok=True)
final_body_scan_path = None
# Save body scan if uploaded as file
if body_scan:
body_ext = Path(body_scan.filename).suffix if body_scan.filename else ".stl"
final_body_scan_path = str(output_dir / f"body_scan{body_ext}")
with open(final_body_scan_path, "wb") as f:
content = await body_scan.read()
f.write(content)
# Or use provided path if it exists
elif body_scan_path and Path(body_scan_path).exists():
final_body_scan_path = body_scan_path
print(f"Using existing body scan at: {body_scan_path}")
cobb_angles = {
"PT": cobb_pt,
"MT": cobb_mt,
"TL": cobb_tl
}
try:
results = generate_both_brace_types(
rigo_type=rigo_type,
output_dir=output_dir,
case_id=case_id,
cobb_angles=cobb_angles,
body_scan_path=final_body_scan_path,
clearance_mm=clearance_mm
)
response = {
"success": True,
"case_id": case_id,
"rigo_type": rigo_type,
"cobb_angles": cobb_angles,
"body_scan_used": final_body_scan_path is not None,
"braces": {}
}
for brace_type, result in results.items():
if isinstance(result, dict) and "error" in result:
response["braces"][brace_type] = result
else:
response["braces"][brace_type] = {
"outputs": {
"glb": result.glb_path,
"stl": result.stl_path,
"json": result.json_path
},
"markers": result.markers,
"pressure_zones": result.pressure_zones,
"mesh_stats": result.mesh_stats
}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/pressure-zones/{rigo_type}", summary="Get pressure zone information")
async def get_pressure_zones(
rigo_type: str,
template_type: str = "regular",
cobb_mt: float = 25.0,
cobb_tl: float = 15.0
):
"""
Get detailed pressure zone information for a Rigo type.
This explains WHERE and HOW MUCH pressure is applied based on
the Cobb angles.
**Pressure Zone Types:**
- **PAD (Push Zone)**: Pushes INWARD on the convex side of the curve
to apply corrective force. Depth increases with Cobb angle severity.
- **BAY (Expansion Zone)**: Creates SPACE on the concave side for the
body to shift into during correction. Clearance is ~1.3x pad depth.
- **ANCHOR (Stability Zone)**: Grips the pelvis to prevent the brace
from riding up. Light inward pressure.
Returns:
Detailed pressure zone descriptions with depths in mm
"""
if rigo_type not in AVAILABLE_RIGO_TYPES:
raise HTTPException(
status_code=400,
detail=f"Invalid rigo_type. Must be one of: {AVAILABLE_RIGO_TYPES}"
)
try:
markers = load_template_markers(rigo_type, template_type)
zones = calculate_pressure_zones(
markers,
rigo_type,
{"PT": 0, "MT": cobb_mt, "TL": cobb_tl}
)
return {
"rigo_type": rigo_type,
"template_type": template_type,
"cobb_angles": {"MT": cobb_mt, "TL": cobb_tl},
"pressure_zones": [
{
"name": z.name,
"marker": z.marker_name,
"position": list(z.position),
"type": z.zone_type,
"direction": z.direction,
"function": z.function,
"depth_mm": round(z.depth_mm, 1),
"radius_mm": list(z.radius_mm)
}
for z in zones
],
"explanation": {
"pad_depth": f"Based on Cobb angle severity: {cobb_mt}° MT → {round(8 + min(max((cobb_mt - 10) / 40, 0), 1) * 14, 1)}mm thoracic pad",
"bay_clearance": "Bay clearance = 1.3 × pad depth + 4-5mm to allow body movement",
"hip_anchors": "4mm inward pressure to grip pelvis and stabilize brace"
}
}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
# =============================================================================
# DEV MODE: LOCAL FILE STORAGE AND SERVING
# =============================================================================
# Local storage directory for DEV mode
DEV_STORAGE_DIR = config.TEMP_DIR / "dev_storage"
DEV_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
@router.post("/cases", summary="Create a new case (DEV)")
async def create_case():
"""Create a new case with a generated ID (DEV mode)."""
import uuid
from datetime import datetime
case_id = f"case-{datetime.now().strftime('%Y%m%d')}-{uuid.uuid4().hex[:8]}"
case_dir = DEV_STORAGE_DIR / case_id
(case_dir / "uploads").mkdir(parents=True, exist_ok=True)
(case_dir / "outputs").mkdir(parents=True, exist_ok=True)
# Save case metadata
metadata = {
"case_id": case_id,
"created_at": datetime.now().isoformat(),
"status": "created"
}
(case_dir / "case.json").write_text(json.dumps(metadata, indent=2))
return {"caseId": case_id, "status": "created"}
@router.get("/cases/{case_id}", summary="Get case details (DEV)")
async def get_case(case_id: str):
"""Get case details (DEV mode)."""
case_dir = DEV_STORAGE_DIR / case_id
if not case_dir.exists():
raise HTTPException(status_code=404, detail=f"Case not found: {case_id}")
metadata_file = case_dir / "case.json"
if metadata_file.exists():
metadata = json.loads(metadata_file.read_text())
else:
metadata = {"case_id": case_id, "status": "unknown"}
return metadata
@router.post("/cases/{case_id}/upload", summary="Upload X-ray for case (DEV)")
async def upload_xray(
case_id: str,
file: UploadFile = File(..., description="X-ray image file")
):
"""Upload X-ray image for a case (DEV mode - saves locally)."""
case_dir = DEV_STORAGE_DIR / case_id
uploads_dir = case_dir / "uploads"
uploads_dir.mkdir(parents=True, exist_ok=True)
# Determine extension from filename
ext = Path(file.filename).suffix.lower() if file.filename else ".jpg"
if ext not in [".jpg", ".jpeg", ".png", ".webp"]:
ext = ".jpg"
# Save as xray.{ext}
xray_path = uploads_dir / f"xray{ext}"
contents = await file.read()
xray_path.write_bytes(contents)
# Update case metadata
metadata_file = case_dir / "case.json"
if metadata_file.exists():
metadata = json.loads(metadata_file.read_text())
else:
metadata = {"case_id": case_id}
metadata["xray_uploaded"] = True
metadata["xray_filename"] = f"xray{ext}"
metadata_file.write_text(json.dumps(metadata, indent=2))
return {
"filename": f"xray{ext}",
"path": f"/files/uploads/{case_id}/xray{ext}"
}
@router.get("/cases/{case_id}/assets", summary="Get case assets (DEV)")
async def get_case_assets(case_id: str):
"""List all uploaded and output files for a case (DEV mode)."""
case_dir = DEV_STORAGE_DIR / case_id
if not case_dir.exists():
raise HTTPException(status_code=404, detail=f"Case not found: {case_id}")
uploads = []
outputs = []
# List uploads
uploads_dir = case_dir / "uploads"
if uploads_dir.exists():
for f in uploads_dir.iterdir():
if f.is_file():
uploads.append({
"filename": f.name,
"url": f"/files/uploads/{case_id}/{f.name}"
})
# List outputs
outputs_dir = case_dir / "outputs"
if outputs_dir.exists():
for f in outputs_dir.iterdir():
if f.is_file():
outputs.append({
"filename": f.name,
"url": f"/files/outputs/{case_id}/{f.name}"
})
return {
"caseId": case_id,
"assets": {
"uploads": uploads,
"outputs": outputs
}
}
@router.get("/files/uploads/{case_id}/{filename}", summary="Serve uploaded file (DEV)")
async def serve_upload_file(case_id: str, filename: str):
"""Serve an uploaded file (DEV mode)."""
file_path = DEV_STORAGE_DIR / case_id / "uploads" / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
# Determine media type
ext = file_path.suffix.lower()
media_types = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
".stl": "application/octet-stream",
".glb": "model/gltf-binary",
".json": "application/json",
}
media_type = media_types.get(ext, "application/octet-stream")
return FileResponse(
path=str(file_path),
filename=filename,
media_type=media_type
)
@router.get("/files/outputs/{case_id}/{filename}", summary="Serve output file (DEV)")
async def serve_output_file(case_id: str, filename: str):
"""Serve an output file (DEV mode)."""
file_path = DEV_STORAGE_DIR / case_id / "outputs" / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
# Determine media type
ext = file_path.suffix.lower()
media_types = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
".stl": "application/octet-stream",
".ply": "application/octet-stream",
".obj": "application/octet-stream",
".glb": "model/gltf-binary",
".gltf": "model/gltf+json",
".json": "application/json",
}
media_type = media_types.get(ext, "application/octet-stream")
return FileResponse(
path=str(file_path),
filename=filename,
media_type=media_type
)

125
brace-generator/schemas.py Normal file
View File

@@ -0,0 +1,125 @@
"""
Pydantic schemas for API request/response validation.
"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from enum import Enum
class ExperimentType(str, Enum):
"""Available brace generation experiments."""
STANDARD = "standard" # Original pipeline
EXPERIMENT_3 = "experiment_3" # Research-based adaptive
class BraceConfigRequest(BaseModel):
"""Brace configuration parameters."""
brace_height_mm: float = Field(default=400.0, ge=200, le=600)
torso_width_mm: float = Field(default=280.0, ge=150, le=400)
torso_depth_mm: float = Field(default=200.0, ge=100, le=350)
wall_thickness_mm: float = Field(default=4.0, ge=2, le=10)
pressure_strength_mm: float = Field(default=15.0, ge=0, le=30)
class AnalyzeRequest(BaseModel):
"""Request to analyze X-ray and generate brace."""
s3_key: Optional[str] = Field(None, description="S3 key of uploaded X-ray image")
case_id: Optional[str] = Field(None, description="Case ID for organizing outputs")
experiment: ExperimentType = Field(default=ExperimentType.EXPERIMENT_3)
config: Optional[BraceConfigRequest] = None
# Output options
save_visualization: bool = Field(default=True)
save_landmarks: bool = Field(default=True)
output_format: str = Field(default="stl", description="stl, ply, or both")
class AnalyzeFromUrlRequest(BaseModel):
"""Request with direct image URL."""
image_url: str = Field(..., description="URL to download X-ray image from")
case_id: Optional[str] = Field(None)
experiment: ExperimentType = Field(default=ExperimentType.EXPERIMENT_3)
config: Optional[BraceConfigRequest] = None
save_visualization: bool = True
save_landmarks: bool = True
output_format: str = "stl"
class Vertebra(BaseModel):
"""Single vertebra data."""
level: str
centroid_px: List[float]
orientation_deg: Optional[float] = None
confidence: Optional[float] = None
corners_px: Optional[List[List[float]]] = None
class CobbAngles(BaseModel):
"""Cobb angle measurements."""
PT: float = Field(..., description="Proximal Thoracic angle")
MT: float = Field(..., description="Main Thoracic angle")
TL: float = Field(..., description="Thoracolumbar angle")
class RigoClassification(BaseModel):
"""Rigo-Chêneau classification result."""
type: str
description: str
curve_pattern: Optional[str] = None
class DeformationReport(BaseModel):
"""Patch-based deformation report (Experiment 3)."""
patch_grid: str
deformations: Optional[List[List[float]]] = None
zones: Optional[List[Dict[str, Any]]] = None
class AnalysisResult(BaseModel):
"""Complete analysis result."""
case_id: Optional[str] = None
experiment: str
# Input
input_image: str
# Detection results
model_used: str
vertebrae_detected: int
vertebrae: Optional[List[Vertebra]] = None
# Measurements
cobb_angles: CobbAngles
curve_type: str
# Classification
rigo_classification: RigoClassification
# Brace mesh info
mesh_vertices: int
mesh_faces: int
# Deformation (Experiment 3)
deformation_report: Optional[DeformationReport] = None
# Output URLs/paths
outputs: Dict[str, str] = Field(default_factory=dict)
# Timing
processing_time_ms: float
class HealthResponse(BaseModel):
"""Health check response."""
status: str
device: str
cuda_available: bool
model_loaded: bool
gpu_name: Optional[str] = None
gpu_memory_mb: Optional[int] = None
class ErrorResponse(BaseModel):
"""Error response."""
error: str
detail: Optional[str] = None

884
brace-generator/services.py Normal file
View File

@@ -0,0 +1,884 @@
"""
Business logic service for brace generation.
This service handles ML inference and file generation.
S3 operations are handled by the Lambda function, not here.
"""
import time
import uuid
import tempfile
import numpy as np
import trimesh
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from io import BytesIO
from .config import config
from .schemas import (
AnalyzeRequest, AnalyzeFromUrlRequest, BraceConfigRequest,
AnalysisResult, CobbAngles, RigoClassification, Vertebra,
DeformationReport, ExperimentType
)
class BraceService:
"""
Service for X-ray analysis and brace generation.
Handles:
- Model loading and inference
- Pipeline orchestration
- Local file management
Note: S3 operations are handled by Lambda, not here.
"""
def __init__(self, device: str = "cuda", model: str = "scoliovis"):
self.device = device
self.model_name = model
# Initialize pipelines
self._init_pipelines()
def _init_pipelines(self):
"""Initialize brace generation pipelines."""
from brace_generator.data_models import BraceConfig
from brace_generator.pipeline import BracePipeline
# Standard pipeline
self.standard_pipeline = BracePipeline(
model=self.model_name,
device=self.device
)
# Experiment 3 pipeline (lazy load)
self._exp3_pipeline = None
def _get_exp3_pipeline(self):
"""Return standard pipeline (EXPERIMENT_3 not deployed)."""
return self.standard_pipeline
@property
def model_loaded(self) -> bool:
"""Check if model is loaded."""
return self.standard_pipeline is not None
async def analyze_from_bytes(
self,
image_data: bytes,
filename: str,
experiment: ExperimentType = ExperimentType.EXPERIMENT_3,
case_id: Optional[str] = None,
brace_config: Optional[BraceConfigRequest] = None,
landmarks_data: Optional[Dict[str, Any]] = None
) -> AnalysisResult:
"""
Analyze X-ray from raw bytes.
If landmarks_data is provided, it will use those landmarks (with manual edits)
instead of re-running automatic detection.
"""
start_time = time.time()
# Generate case ID if not provided
case_id = case_id or str(uuid.uuid4())[:8]
# Save image to temp file
suffix = Path(filename).suffix or ".jpg"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
f.write(image_data)
input_path = f.name
try:
# Prepare output directory
output_dir = config.TEMP_DIR / case_id
output_dir.mkdir(parents=True, exist_ok=True)
output_base = output_dir / f"brace_{case_id}"
# Select pipeline based on experiment
if experiment == ExperimentType.EXPERIMENT_3:
result = await self._run_experiment_3(
input_path, output_base, brace_config, landmarks_data
)
else:
result = await self._run_standard(input_path, output_base, brace_config)
# Add timing and case ID
result.processing_time_ms = (time.time() - start_time) * 1000
result.case_id = case_id
return result
finally:
# Cleanup temp input
Path(input_path).unlink(missing_ok=True)
async def _run_standard(
self,
input_path: str,
output_base: Path,
brace_config: Optional[BraceConfigRequest]
) -> AnalysisResult:
"""Run standard pipeline."""
from brace_generator.data_models import BraceConfig
# Configure
if brace_config:
self.standard_pipeline.config = BraceConfig(
brace_height_mm=brace_config.brace_height_mm,
torso_width_mm=brace_config.torso_width_mm,
torso_depth_mm=brace_config.torso_depth_mm,
wall_thickness_mm=brace_config.wall_thickness_mm,
pressure_strength_mm=brace_config.pressure_strength_mm,
)
# Run pipeline
results = self.standard_pipeline.process(
input_path,
str(output_base) + ".stl",
visualize=True,
save_landmarks=True
)
# Build response with local file paths
outputs = {
"stl": str(output_base) + ".stl",
}
vis_path = str(output_base) + ".png"
json_path = str(output_base) + ".json"
if Path(vis_path).exists():
outputs["visualization"] = vis_path
if Path(json_path).exists():
outputs["landmarks"] = json_path
return AnalysisResult(
experiment="standard",
input_image=input_path,
model_used=results["model"],
vertebrae_detected=results["vertebrae_detected"],
cobb_angles=CobbAngles(
PT=results["cobb_angles"]["PT"],
MT=results["cobb_angles"]["MT"],
TL=results["cobb_angles"]["TL"],
),
curve_type=results["curve_type"],
rigo_classification=RigoClassification(
type=results["rigo_type"],
description=results.get("rigo_description", "")
),
mesh_vertices=results["mesh_vertices"],
mesh_faces=results["mesh_faces"],
outputs=outputs,
processing_time_ms=0 # Will be set by caller
)
async def _run_experiment_3(
self,
input_path: str,
output_base: Path,
brace_config: Optional[BraceConfigRequest],
landmarks_data: Optional[Dict[str, Any]] = None
) -> AnalysisResult:
"""
Run Experiment 3 (research-based adaptive) pipeline.
If landmarks_data is provided, it uses those landmarks (with manual edits)
instead of running automatic detection.
"""
import sys
from brace_generator.data_models import BraceConfig
pipeline = self._get_exp3_pipeline()
# Configure
if brace_config:
pipeline.config = BraceConfig(
brace_height_mm=brace_config.brace_height_mm,
torso_width_mm=brace_config.torso_width_mm,
torso_depth_mm=brace_config.torso_depth_mm,
wall_thickness_mm=brace_config.wall_thickness_mm,
pressure_strength_mm=brace_config.pressure_strength_mm,
)
# If landmarks_data is provided, use it instead of running detection
if landmarks_data:
results = await self._run_experiment_3_with_landmarks(
input_path, output_base, pipeline, landmarks_data
)
else:
# Run full pipeline with automatic detection
results = pipeline.process(
input_path,
str(output_base),
visualize=True,
save_landmarks=True
)
# Build deformation report
deformation_report = None
if results.get("deformation_report"):
dr = results["deformation_report"]
deformation_report = DeformationReport(
patch_grid=dr.get("patch_grid", "6x8"),
deformations=dr.get("deformations"),
zones=dr.get("zones")
)
# Collect output file paths
outputs = {}
if results.get("output_stl"):
outputs["stl"] = results["output_stl"]
if results.get("output_ply"):
outputs["ply"] = results["output_ply"]
# Check for visualization and landmarks files
vis_path = str(output_base) + ".png"
json_path = str(output_base) + ".json"
if Path(vis_path).exists():
outputs["visualization"] = vis_path
if Path(json_path).exists():
outputs["landmarks"] = json_path
return AnalysisResult(
experiment="experiment_3",
input_image=input_path,
model_used=results.get("model", "manual_landmarks"),
vertebrae_detected=results.get("vertebrae_detected", 0),
cobb_angles=CobbAngles(
PT=results["cobb_angles"]["PT"],
MT=results["cobb_angles"]["MT"],
TL=results["cobb_angles"]["TL"],
),
curve_type=results["curve_type"],
rigo_classification=RigoClassification(
type=results["rigo_type"],
description=results.get("rigo_description", "")
),
mesh_vertices=results.get("mesh_vertices", 0),
mesh_faces=results.get("mesh_faces", 0),
deformation_report=deformation_report,
outputs=outputs,
processing_time_ms=0
)
async def _run_experiment_3_with_landmarks(
self,
input_path: str,
output_base: Path,
pipeline,
landmarks_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Run experiment 3 brace generation using pre-computed landmarks.
Uses final_values from landmarks_data (which may include manual edits).
"""
import sys
import json
# Load analysis modules from brace_generator root
from brace_generator.data_models import Spine2D, VertebraLandmark
from brace_generator.spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
from image_loader import load_xray_rgb
# Load the image for visualization
image_rgb, pixel_spacing = load_xray_rgb(input_path)
# Build Spine2D from landmarks_data final_values
vertebrae_structure = landmarks_data.get("vertebrae_structure", landmarks_data)
vertebrae_list = vertebrae_structure.get("vertebrae", [])
spine = Spine2D()
for vdata in vertebrae_list:
final = vdata.get("final_values", {})
centroid = final.get("centroid_px")
if centroid is None:
continue
v = VertebraLandmark(
level=vdata.get("level"),
centroid_px=np.array(centroid, dtype=np.float32),
confidence=float(final.get("confidence", 0.5))
)
corners = final.get("corners_px")
if corners:
v.corners_px = np.array(corners, dtype=np.float32)
spine.vertebrae.append(v)
if len(spine.vertebrae) < 3:
raise ValueError("Need at least 3 vertebrae for brace generation")
spine.pixel_spacing_mm = pixel_spacing
spine.image_shape = image_rgb.shape[:2]
spine.sort_vertebrae()
# Compute Cobb angles and classification
compute_cobb_angles(spine)
apex_indices = find_apex_vertebrae(spine)
rigo_result = classify_rigo_type(spine)
# Generate adaptive brace using the pipeline's brace generator directly
# This uses our manually-edited spine instead of re-detecting
brace_mesh = pipeline.brace_generator.generate(spine)
mesh_vertices = 0
mesh_faces = 0
output_stl = None
output_ply = None
deformation_report = None
if brace_mesh is not None:
mesh_vertices = len(brace_mesh.vertices)
mesh_faces = len(brace_mesh.faces)
# Get deformation report if available
if hasattr(pipeline.brace_generator, 'get_deformation_report'):
deformation_report = pipeline.brace_generator.get_deformation_report()
# Export STL and PLY
output_base_path = Path(output_base)
output_stl = str(output_base_path.with_suffix('.stl'))
output_ply = str(output_base_path.with_suffix('.ply'))
brace_mesh.export(output_stl)
# Export PLY if method available
if hasattr(pipeline.brace_generator, 'export_ply'):
pipeline.brace_generator.export_ply(brace_mesh, output_ply)
else:
brace_mesh.export(output_ply)
print(f" Exported: {output_stl}, {output_ply}")
# Save visualization with the manual/combined landmarks and deformation heatmap
vis_path = str(output_base) + ".png"
self._save_landmarks_visualization_with_spine(
image_rgb, spine, rigo_result, deformation_report, vis_path
)
# Build result dict
result = {
"model": "manual_landmarks",
"vertebrae_detected": len(spine.vertebrae),
"cobb_angles": {
"PT": float(spine.cobb_pt or 0),
"MT": float(spine.cobb_mt or 0),
"TL": float(spine.cobb_tl or 0),
},
"curve_type": spine.curve_type or "Unknown",
"rigo_type": rigo_result["rigo_type"],
"rigo_description": rigo_result.get("description", ""),
"mesh_vertices": mesh_vertices,
"mesh_faces": mesh_faces,
"output_stl": output_stl,
"output_ply": output_ply,
"deformation_report": deformation_report,
}
# Save landmarks JSON
json_path = str(output_base) + ".json"
with open(json_path, "w") as f:
json.dump({
"source": "manual_landmarks",
"vertebrae_count": len(spine.vertebrae),
"cobb_angles": result["cobb_angles"],
"rigo_type": result["rigo_type"],
"curve_type": result["curve_type"],
"deformation_report": deformation_report,
}, f, indent=2, default=lambda x: x.tolist() if hasattr(x, 'tolist') else x)
return result
def _save_landmarks_visualization_with_spine(self, image, spine, rigo_result, deformation_report, path):
"""Save visualization using a pre-built Spine2D object with deformation heatmap."""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
except ImportError:
return
# 3-panel layout like the original pipeline
fig, axes = plt.subplots(1, 3, figsize=(18, 10))
# Left: landmarks with X-shaped markers
ax1 = axes[0]
ax1.imshow(image)
# Draw green X-shaped vertebra markers and red centroids
for v in spine.vertebrae:
if v.corners_px is not None:
corners = v.corners_px
for i in range(4):
j = (i + 1) % 4
ax1.plot([corners[i, 0], corners[j, 0]],
[corners[i, 1], corners[j, 1]],
'g-', linewidth=1.5, zorder=4)
if v.centroid_px is not None:
ax1.scatter(v.centroid_px[0], v.centroid_px[1], c='red', s=40, zorder=5)
# Add labels
for v in spine.vertebrae:
if v.centroid_px is not None:
label = v.level or "?"
ax1.annotate(
label, (v.centroid_px[0] + 8, v.centroid_px[1]),
fontsize=7, color='yellow', fontweight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.6)
)
ax1.set_title(f"Landmarks ({len(spine.vertebrae)} vertebrae)")
ax1.axis('off')
# Middle: analysis with spine curve
ax2 = axes[1]
ax2.imshow(image, alpha=0.5)
# Draw spine curve line through centroids
centroids = [v.centroid_px for v in spine.vertebrae if v.centroid_px is not None]
if len(centroids) > 1:
centroids_arr = np.array(centroids)
ax2.plot(centroids_arr[:, 0], centroids_arr[:, 1], 'b-', linewidth=2, alpha=0.8)
text = f"Cobb Angles:\n"
text += f"PT: {spine.cobb_pt:.1f}°\n"
text += f"MT: {spine.cobb_mt:.1f}°\n"
text += f"TL: {spine.cobb_tl:.1f}°\n\n"
text += f"Curve: {spine.curve_type}\n"
text += f"Rigo: {rigo_result['rigo_type']}"
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))
ax2.set_title("Spine Analysis")
ax2.axis('off')
# Right: deformation heatmap
ax3 = axes[2]
if deformation_report and deformation_report.get('deformations'):
deform_array = np.array(deformation_report['deformations'])
# Create heatmap with diverging colormap
vmax = max(abs(deform_array.min()), abs(deform_array.max()), 1)
norm = TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
im = ax3.imshow(deform_array, cmap='RdBu_r', aspect='auto',
norm=norm, origin='upper')
# Add colorbar
cbar = plt.colorbar(im, ax=ax3, shrink=0.8)
cbar.set_label('Radial deformation (mm)')
# Labels
ax3.set_xlabel('Angular Position (patches)')
ax3.set_ylabel('Height (patches)')
ax3.set_title('Patch Deformations (mm)\nBlue=Relief, Red=Pressure')
# Add zone labels on y-axis
height_labels = ['Pelvis', 'Low Lumb', 'Up Lumb', 'Low Thor', 'Up Thor', 'Shoulder']
if deform_array.shape[0] <= len(height_labels):
ax3.set_yticks(range(deform_array.shape[0]))
ax3.set_yticklabels(height_labels[:deform_array.shape[0]])
# Angular position labels
angle_labels = ['BR', 'R', 'FR', 'F', 'FL', 'L', 'BL', 'B']
if deform_array.shape[1] <= len(angle_labels):
ax3.set_xticks(range(deform_array.shape[1]))
ax3.set_xticklabels(angle_labels[:deform_array.shape[1]])
else:
ax3.text(0.5, 0.5, 'No deformation data', ha='center', va='center',
transform=ax3.transAxes, fontsize=14, color='gray')
ax3.set_title('Patch Deformations')
ax3.axis('off')
plt.tight_layout()
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
# ============================================
# NEW METHODS FOR PIPELINE DEV
# ============================================
async def detect_landmarks_only(
self,
image_data: bytes,
filename: str,
case_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Detect landmarks only, without generating a brace.
Returns full vertebrae_structure with manual_override support.
"""
import sys
from pathlib import Path
start_time = time.time()
case_id = case_id or str(uuid.uuid4())[:8]
# Save image to temp file
suffix = Path(filename).suffix or ".jpg"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
f.write(image_data)
input_path = f.name
try:
# Import from brace_generator root (modules are already in PYTHONPATH)
# Note: In Docker, PYTHONPATH includes /app/brace_generator
from image_loader import load_xray_rgb
from adapters import ScolioVisAdapter
from spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
from data_models import Spine2D
# Load full image
image_rgb_full, pixel_spacing = load_xray_rgb(input_path)
image_h, image_w = image_rgb_full.shape[:2]
# Smart cropping: Only crop to middle 1/3 if image is wide enough
# Wide images (e.g., full chest X-rays) benefit from cropping to spine area
# Narrow images (e.g., already cropped to spine) should not be cropped further
MIN_WIDTH_FOR_CROPPING = 500 # Only crop if wider than 500 pixels
CROPPED_MIN_WIDTH = 200 # Ensure cropped width is at least 200 pixels
left_margin = 0 # Initialize offset for coordinate mapping
if image_w >= MIN_WIDTH_FOR_CROPPING:
# Image is wide - crop to middle 1/3 for better spine detection
left_margin = image_w // 3
right_margin = 2 * image_w // 3
cropped_width = right_margin - left_margin
# Ensure minimum width after cropping
if cropped_width >= CROPPED_MIN_WIDTH:
image_rgb_for_detection = image_rgb_full[:, left_margin:right_margin]
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Cropped to middle 1/3: {cropped_width}x{image_h}")
else:
# Cropped would be too narrow, use full image
image_rgb_for_detection = image_rgb_full
left_margin = 0
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Cropped would be too narrow ({cropped_width}px), using full image")
else:
# Image is already narrow - use full image
image_rgb_for_detection = image_rgb_full
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Already narrow (< {MIN_WIDTH_FOR_CROPPING}px), using full image")
# Detect landmarks (on cropped or full image depending on width)
adapter = ScolioVisAdapter(device=self.device)
spine = adapter.predict(image_rgb_for_detection)
spine.pixel_spacing_mm = pixel_spacing
# Offset all detected coordinates back to full image space if cropping was applied
if left_margin > 0:
for v in spine.vertebrae:
if v.centroid_px is not None:
# Offset centroid X coordinate
v.centroid_px[0] += left_margin
if v.corners_px is not None:
# Offset all corner X coordinates
v.corners_px[:, 0] += left_margin
# Keep reference to full image for visualization
image_rgb = image_rgb_full
# Compute analysis
compute_cobb_angles(spine)
apex_indices = find_apex_vertebrae(spine)
rigo_result = classify_rigo_type(spine)
# Prepare output directory
output_dir = config.TEMP_DIR / case_id
output_dir.mkdir(parents=True, exist_ok=True)
# Save visualization
vis_path = output_dir / "visualization.png"
self._save_landmarks_visualization(image_rgb, spine, rigo_result, str(vis_path))
# Build full vertebrae structure (all T1-L5)
ALL_LEVELS = ["T1", "T2", "T3", "T4", "T5", "T6", "T7", "T8", "T9", "T10", "T11", "T12", "L1", "L2", "L3", "L4", "L5"]
# ScolioVis doesn't assign levels - assign based on Y position (top to bottom)
# Sort detected vertebrae by Y coordinate (centroid)
detected_verts = sorted(
[v for v in spine.vertebrae if v.centroid_px is not None],
key=lambda v: v.centroid_px[1] # Sort by Y (top to bottom)
)
# Assign levels based on count
# If we detect 17 vertebrae, assign T1-L5
# If fewer, we need to figure out which ones are missing
num_detected = len(detected_verts)
if num_detected >= 17:
# All vertebrae detected - assign directly
for i, v in enumerate(detected_verts[:17]):
v.level = ALL_LEVELS[i]
elif num_detected > 0:
# Fewer than 17 - assign from T1 onwards (assuming top vertebrae visible)
# This is a simplification - ideally we'd use anatomical features
for i, v in enumerate(detected_verts):
if i < len(ALL_LEVELS):
v.level = ALL_LEVELS[i]
# Build detected_map with assigned levels
detected_map = {v.level: v for v in detected_verts if v.level}
vertebrae_list = []
for level in ALL_LEVELS:
if level in detected_map:
v = detected_map[level]
centroid = v.centroid_px.tolist() if v.centroid_px is not None else None
corners = v.corners_px.tolist() if v.corners_px is not None else None
orientation = float(v.compute_orientation()) if centroid else None
vertebrae_list.append({
"level": level,
"detected": True,
"scoliovis_data": {
"centroid_px": centroid,
"corners_px": corners,
"orientation_deg": orientation,
"confidence": float(v.confidence),
},
"manual_override": {
"enabled": False,
"centroid_px": None,
"corners_px": None,
"orientation_deg": None,
"confidence": None,
"notes": None,
},
"final_values": {
"centroid_px": centroid,
"corners_px": corners,
"orientation_deg": orientation,
"confidence": float(v.confidence),
"source": "scoliovis",
},
})
else:
vertebrae_list.append({
"level": level,
"detected": False,
"scoliovis_data": {
"centroid_px": None,
"corners_px": None,
"orientation_deg": None,
"confidence": 0.0,
},
"manual_override": {
"enabled": False,
"centroid_px": None,
"corners_px": None,
"orientation_deg": None,
"confidence": None,
"notes": None,
},
"final_values": {
"centroid_px": None,
"corners_px": None,
"orientation_deg": None,
"confidence": 0.0,
"source": "undetected",
},
})
# Build result
result = {
"case_id": case_id,
"status": "landmarks_detected",
"input": {
"image_dimensions": {"width": image_w, "height": image_h},
"pixel_spacing_mm": pixel_spacing,
},
"detection_quality": {
"vertebrae_count": len(spine.vertebrae),
"average_confidence": float(np.mean([v.confidence for v in spine.vertebrae])) if spine.vertebrae else 0.0,
},
"cobb_angles": {
"PT": float(spine.cobb_pt),
"MT": float(spine.cobb_mt),
"TL": float(spine.cobb_tl),
"max": float(max(spine.cobb_pt, spine.cobb_mt, spine.cobb_tl)),
"PT_severity": get_curve_severity(spine.cobb_pt),
"MT_severity": get_curve_severity(spine.cobb_mt),
"TL_severity": get_curve_severity(spine.cobb_tl),
},
"rigo_classification": {
"type": rigo_result["rigo_type"],
"description": rigo_result["description"],
},
"curve_type": spine.curve_type,
"vertebrae_structure": {
"all_levels": ALL_LEVELS,
"detected_count": len(spine.vertebrae),
"total_count": len(ALL_LEVELS),
"vertebrae": vertebrae_list,
"manual_edit_instructions": {
"to_override": "Set manual_override.enabled=true and fill manual_override fields",
"final_values_rule": "When manual_override.enabled=true, final_values uses manual values",
},
},
"visualization_path": str(vis_path),
"processing_time_ms": (time.time() - start_time) * 1000,
}
# Save JSON
json_path = output_dir / "landmarks.json"
import json
with open(json_path, "w") as f:
json.dump(result, f, indent=2)
result["json_path"] = str(json_path)
return result
finally:
Path(input_path).unlink(missing_ok=True)
def _save_landmarks_visualization(self, image, spine, rigo_result, path):
"""Save visualization with landmarks and green quadrilateral boxes."""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
except ImportError:
return
fig, axes = plt.subplots(1, 2, figsize=(14, 10))
# Left: landmarks with green boxes
ax1 = axes[0]
ax1.imshow(image)
# Draw green X-shaped vertebra markers and red centroids
for v in spine.vertebrae:
# Draw green X-shape if corners exist
# Corner order: [0]=top_left, [1]=top_right, [2]=bottom_left, [3]=bottom_right
# Drawing 0→1→2→3→0 creates the X pattern showing endplate orientations
if v.corners_px is not None:
corners = v.corners_px
for i in range(4):
j = (i + 1) % 4
ax1.plot([corners[i, 0], corners[j, 0]],
[corners[i, 1], corners[j, 1]],
'g-', linewidth=1.5, zorder=4)
# Draw red centroid dot
if v.centroid_px is not None:
ax1.scatter(v.centroid_px[0], v.centroid_px[1], c='red', s=40, zorder=5)
# Add labels
for i, v in enumerate(spine.vertebrae):
if v.centroid_px is not None:
label = v.level or str(i)
ax1.annotate(
label, (v.centroid_px[0] + 8, v.centroid_px[1]),
fontsize=7, color='yellow', fontweight='bold',
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.6)
)
ax1.set_title(f"Automatic Detection ({len(spine.vertebrae)} vertebrae)")
ax1.axis('off')
# Right: analysis
ax2 = axes[1]
ax2.imshow(image, alpha=0.5)
text = f"Cobb Angles:\n"
text += f"PT: {spine.cobb_pt:.1f}°\n"
text += f"MT: {spine.cobb_mt:.1f}°\n"
text += f"TL: {spine.cobb_tl:.1f}°\n\n"
text += f"Curve: {spine.curve_type}\n"
text += f"Rigo: {rigo_result['rigo_type']}"
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))
ax2.set_title("Spine Analysis")
ax2.axis('off')
plt.tight_layout()
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
async def recalculate_from_landmarks(
self,
landmarks_data: Dict[str, Any],
case_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Recalculate Cobb angles and Rigo classification from landmarks data.
Uses final_values from each vertebra (which may be manual overrides).
"""
import sys
start_time = time.time()
case_id = case_id or str(uuid.uuid4())[:8]
# Load analysis modules from brace_generator root
from brace_generator.data_models import Spine2D, VertebraLandmark
from brace_generator.spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
# Reconstruct spine from landmarks data
vertebrae_structure = landmarks_data.get("vertebrae_structure", landmarks_data)
vertebrae_list = vertebrae_structure.get("vertebrae", [])
# Build Spine2D from final_values
spine = Spine2D()
for vdata in vertebrae_list:
final = vdata.get("final_values", {})
centroid = final.get("centroid_px")
if centroid is None:
continue # Skip undetected/empty vertebrae
v = VertebraLandmark(
level=vdata.get("level"),
centroid_px=np.array(centroid, dtype=np.float32),
confidence=float(final.get("confidence", 0.5))
)
corners = final.get("corners_px")
if corners:
v.corners_px = np.array(corners, dtype=np.float32)
spine.vertebrae.append(v)
if len(spine.vertebrae) < 3:
raise ValueError("Need at least 3 vertebrae for analysis")
# Sort by Y position (top to bottom)
spine.sort_vertebrae()
# Compute Cobb angles and Rigo
compute_cobb_angles(spine)
apex_indices = find_apex_vertebrae(spine)
rigo_result = classify_rigo_type(spine)
result = {
"case_id": case_id,
"status": "analysis_recalculated",
"cobb_angles": {
"PT": float(spine.cobb_pt),
"MT": float(spine.cobb_mt),
"TL": float(spine.cobb_tl),
"max": float(max(spine.cobb_pt, spine.cobb_mt, spine.cobb_tl)),
"PT_severity": get_curve_severity(spine.cobb_pt),
"MT_severity": get_curve_severity(spine.cobb_mt),
"TL_severity": get_curve_severity(spine.cobb_tl),
},
"rigo_classification": {
"type": rigo_result["rigo_type"],
"description": rigo_result["description"],
},
"curve_type": spine.curve_type,
"apex_indices": apex_indices,
"vertebrae_used": len(spine.vertebrae),
"processing_time_ms": (time.time() - start_time) * 1000,
}
return result

View File

@@ -0,0 +1,456 @@
"""
Simple FastAPI server for brace generation - CPU optimized.
Designed to work standalone with minimal dependencies.
"""
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # Force CPU
import sys
import time
import json
import shutil
import tempfile
from pathlib import Path
from typing import Optional
from contextlib import asynccontextmanager
import numpy as np
import cv2
import torch
import trimesh
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
# Paths
BASE_DIR = Path(__file__).parent
MODELS_DIR = BASE_DIR / "models"
OUTPUTS_DIR = BASE_DIR / "outputs"
TEMPLATES_DIR = BASE_DIR / "templates"
OUTPUTS_DIR.mkdir(exist_ok=True)
# Global model (loaded once)
model = None
model_loaded = False
def get_kprcnn_model():
"""Load Keypoint RCNN model."""
from torchvision.models.detection.rpn import AnchorGenerator
import torchvision
model_path = MODELS_DIR / "keypointsrcnn_weights.pt"
if not model_path.exists():
raise FileNotFoundError(f"Model not found: {model_path}")
num_keypoints = 4
anchor_generator = AnchorGenerator(
sizes=(32, 64, 128, 256, 512),
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
)
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
weights=None,
weights_backbone='IMAGENET1K_V1',
num_keypoints=num_keypoints,
num_classes=2,
rpn_anchor_generator=anchor_generator
)
state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(state_dict)
model.eval()
return model
def predict_keypoints(model, image_rgb):
"""Run keypoint detection."""
from torchvision.transforms import functional as F
import torchvision
# Convert to tensor
image_tensor = F.to_tensor(image_rgb).unsqueeze(0)
# Inference
with torch.no_grad():
outputs = model(image_tensor)
output = outputs[0]
# Filter results
scores = output['scores'].cpu().numpy()
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
if not high_scores_idxs:
return [], [], []
post_nms_idxs = torchvision.ops.nms(
output['boxes'][high_scores_idxs],
output['scores'][high_scores_idxs],
0.3
).cpu().numpy()
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].cpu().numpy()
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].cpu().numpy()
np_scores = scores[high_scores_idxs][post_nms_idxs]
# Sort by score, take top 18
sorted_idxs = np.argsort(-np_scores)[:18]
np_keypoints = np_keypoints[sorted_idxs]
np_bboxes = np_bboxes[sorted_idxs]
np_scores = np_scores[sorted_idxs]
# Sort by y position
ymins = np.array([kps[0][1] for kps in np_keypoints])
sorted_ymin_idxs = np.argsort(ymins)
np_keypoints = np_keypoints[sorted_ymin_idxs]
np_bboxes = np_bboxes[sorted_ymin_idxs]
np_scores = np_scores[sorted_ymin_idxs]
# Convert to list format
keypoints_list = [[list(map(float, kp[:2])) for kp in kps] for kps in np_keypoints]
bboxes_list = [list(map(int, bbox.tolist())) for bbox in np_bboxes]
scores_list = np_scores.tolist()
return bboxes_list, keypoints_list, scores_list
def compute_cobb_angles(keypoints):
"""Compute Cobb angles from keypoints."""
if len(keypoints) < 5:
return {"pt": 0, "mt": 0, "tl": 0}
# Calculate midpoints and angles
midpoints = []
angles = []
for kps in keypoints:
# kps is list of [x, y] for 4 corners
corners = np.array(kps)
# Top midpoint (average of corners 0 and 1)
top = (corners[0] + corners[1]) / 2
# Bottom midpoint (average of corners 2 and 3)
bottom = (corners[2] + corners[3]) / 2
midpoints.append((top, bottom))
# Vertebra angle
dx = bottom[0] - top[0]
dy = bottom[1] - top[1]
angle = np.degrees(np.arctan2(dx, dy))
angles.append(angle)
angles = np.array(angles)
# Find inflection points for curve regions
n = len(angles)
# Simple approach: divide into 3 regions
third = n // 3
pt_region = angles[:third] if third > 0 else angles[:2]
mt_region = angles[third:2*third] if 2*third > third else angles[2:4]
tl_region = angles[2*third:] if n > 2*third else angles[-2:]
# Cobb angle = difference between max and min tilt in region
pt_angle = float(np.max(pt_region) - np.min(pt_region)) if len(pt_region) > 1 else 0
mt_angle = float(np.max(mt_region) - np.min(mt_region)) if len(mt_region) > 1 else 0
tl_angle = float(np.max(tl_region) - np.min(tl_region)) if len(tl_region) > 1 else 0
return {"pt": pt_angle, "mt": mt_angle, "tl": tl_angle}
def classify_rigo_type(cobb_angles):
"""Classify Rigo-Chêneau type based on Cobb angles."""
pt = abs(cobb_angles['pt'])
mt = abs(cobb_angles['mt'])
tl = abs(cobb_angles['tl'])
max_angle = max(pt, mt, tl)
if max_angle < 10:
return "Normal"
elif mt >= tl and mt >= pt:
if mt < 25:
return "A1"
elif mt < 40:
return "A2"
else:
return "A3"
elif tl >= mt:
if tl < 30:
return "C1"
else:
return "C2"
else:
return "B1"
def generate_brace(rigo_type: str, case_id: str):
"""Load brace template and export."""
if rigo_type == "Normal":
return None, None
template_path = TEMPLATES_DIR / f"{rigo_type}.obj"
if not template_path.exists():
# Try fallback templates
fallback = {"A1": "A2", "A2": "A1", "A3": "A2", "C1": "C2", "C2": "C1", "B1": "A1", "B2": "A1"}
if rigo_type in fallback:
template_path = TEMPLATES_DIR / f"{fallback[rigo_type]}.obj"
if not template_path.exists():
return None, None
mesh = trimesh.load(str(template_path))
if isinstance(mesh, trimesh.Scene):
meshes = [g for g in mesh.geometry.values() if isinstance(g, trimesh.Trimesh)]
if meshes:
mesh = trimesh.util.concatenate(meshes)
else:
return None, None
# Export
case_dir = OUTPUTS_DIR / case_id
case_dir.mkdir(exist_ok=True)
stl_path = case_dir / f"{case_id}_brace.stl"
mesh.export(str(stl_path))
return str(stl_path), mesh
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup."""
global model, model_loaded
print("Loading ScolioVis model...")
start = time.time()
try:
model = get_kprcnn_model()
model_loaded = True
print(f"Model loaded in {time.time() - start:.1f}s")
except Exception as e:
print(f"Failed to load model: {e}")
model_loaded = False
yield
app = FastAPI(
title="Brace Generator API",
description="CPU-based scoliosis brace generation",
lifespan=lifespan
)
class AnalysisResult(BaseModel):
case_id: str
experiment: str = "experiment_4"
model_used: str = "ScolioVis"
vertebrae_detected: int
cobb_angles: dict
curve_type: str = "Unknown"
rigo_classification: dict
mesh_vertices: int = 0
mesh_faces: int = 0
timing_ms: float
outputs: dict
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": model_loaded,
"device": "CPU"
}
@app.post("/analyze/upload", response_model=AnalysisResult)
async def analyze_upload(
file: UploadFile = File(...),
case_id: str = Form(None)
):
"""Analyze X-ray and generate brace."""
global model
if not model_loaded:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = time.time()
# Generate case ID if not provided
if not case_id:
case_id = f"case_{int(time.time())}"
# Save uploaded file
case_dir = OUTPUTS_DIR / case_id
case_dir.mkdir(exist_ok=True)
input_path = case_dir / file.filename
with open(input_path, "wb") as f:
content = await file.read()
f.write(content)
# Load image
img = cv2.imread(str(input_path))
if img is None:
raise HTTPException(status_code=400, detail="Could not read image")
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Detect keypoints
bboxes, keypoints, scores = predict_keypoints(model, img_rgb)
n_vertebrae = len(keypoints)
if n_vertebrae < 3:
raise HTTPException(status_code=400, detail=f"Insufficient vertebrae detected: {n_vertebrae}")
# Compute Cobb angles
cobb_angles = compute_cobb_angles(keypoints)
# Classify Rigo type
rigo_type = classify_rigo_type(cobb_angles)
# Generate brace
stl_path, mesh = generate_brace(rigo_type, case_id)
outputs = {}
mesh_vertices = 0
mesh_faces = 0
if stl_path:
outputs["stl"] = stl_path
if mesh is not None:
mesh_vertices = len(mesh.vertices)
mesh_faces = len(mesh.faces)
# Determine curve type
curve_type = "Normal"
if cobb_angles['pt'] > 10 or cobb_angles['mt'] > 10 or cobb_angles['tl'] > 10:
if (cobb_angles['mt'] > 10 and cobb_angles['tl'] > 10) or (cobb_angles['pt'] > 10 and cobb_angles['tl'] > 10):
curve_type = "S-shaped"
else:
curve_type = "C-shaped"
# Rigo classification details
rigo_descriptions = {
"Normal": "No significant curve - normal spine",
"A1": "Main thoracic curve (mild) - 3C pattern",
"A2": "Main thoracic curve (moderate) - 3C pattern",
"A3": "Main thoracic curve (severe) - 3C pattern",
"B1": "Double curve (thoracic dominant) - 4C pattern",
"B2": "Double curve (lumbar dominant) - 4C pattern",
"C1": "Main thoracolumbar/lumbar (mild) - 4C pattern",
"C2": "Main thoracolumbar/lumbar (moderate-severe) - 4C pattern",
"E1": "Not structural - functional curve",
"E2": "Not structural - compensatory curve",
}
rigo_classification = {
"type": rigo_type,
"description": rigo_descriptions.get(rigo_type, "Unknown classification"),
}
# Generate visualization
vis_path = None
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(10, 12))
ax.imshow(img_rgb)
# Draw keypoints
for kps in keypoints:
corners = np.array(kps) # List of [x, y]
centroid = corners.mean(axis=0)
ax.scatter(centroid[0], centroid[1], c='red', s=50, zorder=5)
# Draw corners (vertebra outline)
for i in range(4):
j = (i + 1) % 4
ax.plot([corners[i][0], corners[j][0]],
[corners[i][1], corners[j][1]], 'g-', linewidth=1)
# Add text overlay
text = f"ScolioVis Analysis\n"
text += f"-" * 20 + "\n"
text += f"Vertebrae: {n_vertebrae}\n"
text += f"Cobb Angles:\n"
text += f" PT: {cobb_angles['pt']:.1f}°\n"
text += f" MT: {cobb_angles['mt']:.1f}°\n"
text += f" TL: {cobb_angles['tl']:.1f}°\n"
text += f"Curve: {curve_type}\n"
text += f"Rigo: {rigo_type}\n"
text += f"{rigo_classification['description']}"
ax.text(0.02, 0.98, text, transform=ax.transAxes, fontsize=10,
verticalalignment='top', fontfamily='monospace',
bbox=dict(facecolor='white', alpha=0.9))
ax.set_title(f"Case: {case_id}")
ax.axis('off')
vis_path = case_dir / f"{case_id}_visualization.png"
plt.savefig(str(vis_path), dpi=150, bbox_inches='tight')
plt.close()
outputs["visualization"] = str(vis_path)
except Exception as e:
print(f"Visualization failed: {e}")
# Save analysis JSON
analysis = {
"case_id": case_id,
"input_file": file.filename,
"experiment": "experiment_4",
"model_used": "ScolioVis",
"vertebrae_detected": n_vertebrae,
"cobb_angles": cobb_angles,
"curve_type": curve_type,
"rigo_type": rigo_type,
"rigo_classification": rigo_classification,
"keypoints": keypoints,
"bboxes": bboxes,
"scores": scores,
"mesh_vertices": mesh_vertices,
"mesh_faces": mesh_faces,
}
analysis_path = case_dir / f"{case_id}_analysis.json"
with open(analysis_path, "w") as f:
json.dump(analysis, f, indent=2)
outputs["analysis"] = str(analysis_path)
elapsed_ms = (time.time() - start_time) * 1000
return AnalysisResult(
case_id=case_id,
experiment="experiment_4",
model_used="ScolioVis",
vertebrae_detected=n_vertebrae,
cobb_angles=cobb_angles,
curve_type=curve_type,
rigo_classification=rigo_classification,
mesh_vertices=mesh_vertices,
mesh_faces=mesh_faces,
timing_ms=elapsed_ms,
outputs=outputs
)
@app.get("/download/{case_id}/{filename}")
async def download_file(case_id: str, filename: str):
"""Download generated file."""
file_path = OUTPUTS_DIR / case_id / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(str(file_path), filename=filename)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)