Initial commit - BraceIQMed platform with frontend, API, and brace generator
This commit is contained in:
68
brace-generator/Dockerfile
Normal file
68
brace-generator/Dockerfile
Normal file
@@ -0,0 +1,68 @@
|
||||
# ============================================
|
||||
# BraceIQMed Brace Generator - FastAPI + PyTorch (CPU)
|
||||
# Build context: repo root (braceiqmed/)
|
||||
# ============================================
|
||||
|
||||
FROM python:3.10-slim
|
||||
|
||||
# Prevent interactive prompts
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender-dev \
|
||||
wget \
|
||||
curl \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install PyTorch CPU version (smaller, no CUDA)
|
||||
RUN pip install --no-cache-dir --upgrade pip && \
|
||||
pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Copy and install requirements (from brace-generator folder)
|
||||
COPY brace-generator/requirements.txt /app/requirements.txt
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy scoliovis-api requirements and install
|
||||
COPY scoliovis-api/requirements.txt /app/requirements-scoliovis.txt
|
||||
RUN pip install --no-cache-dir -r requirements-scoliovis.txt || true
|
||||
|
||||
# Copy brace-generator code
|
||||
COPY brace-generator/ /app/brace_generator/server_DEV/
|
||||
|
||||
# Copy scoliovis-api
|
||||
COPY scoliovis-api/ /app/scoliovis-api/
|
||||
|
||||
# Copy templates
|
||||
COPY templates/ /app/templates/
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app:/app/brace_generator/server_DEV:/app/scoliovis-api
|
||||
|
||||
# Environment variables
|
||||
ENV HOST=0.0.0.0
|
||||
ENV PORT=8002
|
||||
ENV DEVICE=cpu
|
||||
ENV MODEL=scoliovis
|
||||
ENV TEMP_DIR=/tmp/brace_generator
|
||||
ENV CORS_ORIGINS=*
|
||||
|
||||
# Create directories
|
||||
RUN mkdir -p /tmp/brace_generator /app/data/uploads /app/data/outputs
|
||||
|
||||
EXPOSE 8002
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
|
||||
CMD curl -f http://localhost:8002/health || exit 1
|
||||
|
||||
# Run the server
|
||||
CMD ["python", "-m", "uvicorn", "brace_generator.server_DEV.app:app", "--host", "0.0.0.0", "--port", "8002"]
|
||||
8
brace-generator/__init__.py
Normal file
8
brace-generator/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Brace Generator Server Package.
|
||||
"""
|
||||
from .app import app
|
||||
from .config import config
|
||||
from .services import BraceService
|
||||
|
||||
__all__ = ["app", "config", "BraceService"]
|
||||
137
brace-generator/app.py
Normal file
137
brace-generator/app.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
FastAPI server for Brace Generator.
|
||||
|
||||
Provides REST API for:
|
||||
- X-ray analysis and landmark detection
|
||||
- Cobb angle measurement
|
||||
- Rigo-Chêneau classification
|
||||
- Adaptive brace generation (STL/PLY)
|
||||
|
||||
Usage:
|
||||
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
||||
|
||||
Or with Docker:
|
||||
docker run -p 8000:8000 --gpus all brace-generator
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# Add parent directories to path for imports
|
||||
server_dir = Path(__file__).parent
|
||||
brace_generator_dir = server_dir.parent
|
||||
spine_dir = brace_generator_dir.parent
|
||||
|
||||
sys.path.insert(0, str(spine_dir))
|
||||
sys.path.insert(0, str(brace_generator_dir))
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import torch
|
||||
|
||||
from .config import config
|
||||
from .routes import router
|
||||
from .services import BraceService
|
||||
|
||||
|
||||
# Global service instance (initialized on startup)
|
||||
brace_service: BraceService = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize and cleanup resources."""
|
||||
global brace_service
|
||||
|
||||
print("=" * 60)
|
||||
print("Brace Generator Server Starting")
|
||||
print("=" * 60)
|
||||
|
||||
# Ensure directories exist
|
||||
config.ensure_dirs()
|
||||
|
||||
# Initialize service with model
|
||||
device = config.get_device()
|
||||
print(f"Device: {device}")
|
||||
|
||||
if device == "cuda":
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
||||
|
||||
print(f"Loading model: {config.MODEL}")
|
||||
brace_service = BraceService(device=device, model=config.MODEL)
|
||||
print("Model loaded successfully!")
|
||||
|
||||
# Make service available to routes
|
||||
app.state.brace_service = brace_service
|
||||
|
||||
print("=" * 60)
|
||||
print(f"Server ready at http://{config.HOST}:{config.PORT}")
|
||||
print("=" * 60)
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
print("Shutting down...")
|
||||
del brace_service
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Brace Generator API",
|
||||
description="""
|
||||
API for generating scoliosis braces from X-ray images.
|
||||
|
||||
## Features
|
||||
- Vertebrae landmark detection (ScolioVis model)
|
||||
- Cobb angle measurement (PT, MT, TL)
|
||||
- Rigo-Chêneau classification
|
||||
- Adaptive brace generation with research-based deformations
|
||||
|
||||
## Experiments
|
||||
- **standard**: Original template-based pipeline
|
||||
- **experiment_3**: Research-based adaptive deformation (Guy et al. 2024)
|
||||
""",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routes
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": exc.detail}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "Internal server error", "detail": str(exc)}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"server.app:app",
|
||||
host=config.HOST,
|
||||
port=config.PORT,
|
||||
reload=config.DEBUG
|
||||
)
|
||||
411
brace-generator/body_integration.py
Normal file
411
brace-generator/body_integration.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Body scan integration for patient-specific brace fitting.
|
||||
|
||||
Based on EXPERIMENT_10's approach:
|
||||
1. Extract body measurements from 3D scan
|
||||
2. Compute body basis (coordinate frame)
|
||||
3. Select template based on Rigo classification
|
||||
4. Fit shell to body using basis alignment
|
||||
"""
|
||||
import sys
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
try:
|
||||
import trimesh
|
||||
HAS_TRIMESH = True
|
||||
except ImportError:
|
||||
HAS_TRIMESH = False
|
||||
|
||||
# Add EXPERIMENT_10 to path for imports
|
||||
EXPERIMENTS_DIR = Path(__file__).parent.parent / "EXPERIMENTS"
|
||||
EXP10_DIR = EXPERIMENTS_DIR / "EXPERIMENT_10"
|
||||
if str(EXP10_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(EXP10_DIR))
|
||||
|
||||
# Import EXPERIMENT_10 modules
|
||||
try:
|
||||
from body_measurements import extract_body_measurements, measurements_to_dict, BodyMeasurements
|
||||
from body_basis import compute_body_basis, body_basis_to_dict, BodyBasis
|
||||
from shell_fitter_v2 import (
|
||||
fit_shell_to_body_v2,
|
||||
compute_brace_basis_from_geometry,
|
||||
brace_basis_to_dict,
|
||||
RIGO_TO_VASE,
|
||||
FittingFeedback
|
||||
)
|
||||
HAS_EXP10 = True
|
||||
except ImportError as e:
|
||||
print(f"Warning: Could not import EXPERIMENT_10 modules: {e}")
|
||||
HAS_EXP10 = False
|
||||
|
||||
# Vase templates directory
|
||||
VASES_DIR = Path(__file__).parent.parent.parent / "_vase" / "_vase"
|
||||
|
||||
|
||||
def extract_measurements_from_scan(scan_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract body measurements from a 3D body scan.
|
||||
|
||||
Args:
|
||||
scan_path: Path to STL/OBJ/PLY body scan file
|
||||
|
||||
Returns:
|
||||
Dictionary with measurements suitable for API response
|
||||
"""
|
||||
if not HAS_TRIMESH:
|
||||
raise ImportError("trimesh is required for body scan processing")
|
||||
|
||||
# Try EXPERIMENT_10 first
|
||||
if HAS_EXP10:
|
||||
try:
|
||||
measurements = extract_body_measurements(scan_path)
|
||||
result = measurements_to_dict(measurements)
|
||||
|
||||
# Flatten for API-friendly format
|
||||
flat = {
|
||||
"total_height_mm": result["overall_dimensions"]["total_height_mm"],
|
||||
"shoulder_width_mm": result["widths_mm"]["shoulder_width"],
|
||||
"chest_width_mm": result["widths_mm"]["chest_width"],
|
||||
"chest_depth_mm": result["depths_mm"]["chest_depth"],
|
||||
"waist_width_mm": result["widths_mm"]["waist_width"],
|
||||
"waist_depth_mm": result["depths_mm"]["waist_depth"],
|
||||
"hip_width_mm": result["widths_mm"]["hip_width"],
|
||||
"hip_depth_mm": result["depths_mm"]["hip_depth"],
|
||||
"brace_coverage_height_mm": result["brace_coverage_region"]["coverage_height_mm"],
|
||||
"chest_circumference_mm": result["circumferences_mm"]["chest"],
|
||||
"waist_circumference_mm": result["circumferences_mm"]["waist"],
|
||||
"hip_circumference_mm": result["circumferences_mm"]["hip"],
|
||||
}
|
||||
|
||||
# Also include full detailed result
|
||||
flat["detailed"] = result
|
||||
return flat
|
||||
except Exception as e:
|
||||
print(f"EXPERIMENT_10 measurement extraction failed: {e}, using fallback")
|
||||
|
||||
# Fallback: Simple trimesh-based measurements
|
||||
return _extract_measurements_trimesh_fallback(scan_path)
|
||||
|
||||
|
||||
def _extract_measurements_trimesh_fallback(scan_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Simple fallback for body measurements using trimesh bounding box analysis.
|
||||
Less accurate than EXPERIMENT_10 but provides basic measurements.
|
||||
"""
|
||||
mesh = trimesh.load(scan_path)
|
||||
|
||||
# Get bounding box
|
||||
bounds = mesh.bounds
|
||||
min_pt, max_pt = bounds[0], bounds[1]
|
||||
|
||||
# Assuming Y is up (typical human scan orientation)
|
||||
# Try to auto-detect orientation
|
||||
extents = max_pt - min_pt
|
||||
height_axis = np.argmax(extents) # Longest axis is usually height
|
||||
|
||||
if height_axis == 1: # Y-up
|
||||
total_height = extents[1]
|
||||
width_axis, depth_axis = 0, 2
|
||||
elif height_axis == 2: # Z-up
|
||||
total_height = extents[2]
|
||||
width_axis, depth_axis = 0, 1
|
||||
else: # X-up (unusual)
|
||||
total_height = extents[0]
|
||||
width_axis, depth_axis = 1, 2
|
||||
|
||||
width = extents[width_axis]
|
||||
depth = extents[depth_axis]
|
||||
|
||||
# Estimate body segments using height percentages
|
||||
# These are approximate ratios for human body
|
||||
chest_height_ratio = 0.75 # Chest at 75% of height from bottom
|
||||
waist_height_ratio = 0.60 # Waist at 60% of height
|
||||
hip_height_ratio = 0.50 # Hips at 50% of height
|
||||
shoulder_height_ratio = 0.82 # Shoulders at 82%
|
||||
|
||||
# Get cross-sections at different heights to estimate widths
|
||||
def get_width_at_height(height_ratio):
|
||||
if height_axis == 1:
|
||||
h = min_pt[1] + total_height * height_ratio
|
||||
mask = (mesh.vertices[:, 1] > h - total_height * 0.05) & \
|
||||
(mesh.vertices[:, 1] < h + total_height * 0.05)
|
||||
elif height_axis == 2:
|
||||
h = min_pt[2] + total_height * height_ratio
|
||||
mask = (mesh.vertices[:, 2] > h - total_height * 0.05) & \
|
||||
(mesh.vertices[:, 2] < h + total_height * 0.05)
|
||||
else:
|
||||
h = min_pt[0] + total_height * height_ratio
|
||||
mask = (mesh.vertices[:, 0] > h - total_height * 0.05) & \
|
||||
(mesh.vertices[:, 0] < h + total_height * 0.05)
|
||||
|
||||
if not np.any(mask):
|
||||
return width, depth
|
||||
|
||||
slice_verts = mesh.vertices[mask]
|
||||
slice_width = np.ptp(slice_verts[:, width_axis])
|
||||
slice_depth = np.ptp(slice_verts[:, depth_axis])
|
||||
return slice_width, slice_depth
|
||||
|
||||
shoulder_w, shoulder_d = get_width_at_height(shoulder_height_ratio)
|
||||
chest_w, chest_d = get_width_at_height(chest_height_ratio)
|
||||
waist_w, waist_d = get_width_at_height(waist_height_ratio)
|
||||
hip_w, hip_d = get_width_at_height(hip_height_ratio)
|
||||
|
||||
# Estimate circumferences using ellipse approximation
|
||||
def estimate_circumference(w, d):
|
||||
a, b = w / 2, d / 2
|
||||
# Ramanujan's approximation for ellipse circumference
|
||||
h = ((a - b) ** 2) / ((a + b) ** 2)
|
||||
return np.pi * (a + b) * (1 + 3 * h / (10 + np.sqrt(4 - 3 * h)))
|
||||
|
||||
return {
|
||||
"total_height_mm": float(total_height),
|
||||
"shoulder_width_mm": float(shoulder_w),
|
||||
"chest_width_mm": float(chest_w),
|
||||
"chest_depth_mm": float(chest_d),
|
||||
"waist_width_mm": float(waist_w),
|
||||
"waist_depth_mm": float(waist_d),
|
||||
"hip_width_mm": float(hip_w),
|
||||
"hip_depth_mm": float(hip_d),
|
||||
"brace_coverage_height_mm": float(total_height * 0.55), # 55% coverage
|
||||
"chest_circumference_mm": float(estimate_circumference(chest_w, chest_d)),
|
||||
"waist_circumference_mm": float(estimate_circumference(waist_w, waist_d)),
|
||||
"hip_circumference_mm": float(estimate_circumference(hip_w, hip_d)),
|
||||
"measurement_source": "trimesh_fallback"
|
||||
}
|
||||
|
||||
|
||||
def generate_fitted_brace(
|
||||
body_scan_path: str,
|
||||
rigo_type: str,
|
||||
output_dir: str,
|
||||
case_id: str,
|
||||
clearance_mm: float = 8.0,
|
||||
wall_thickness_mm: float = 2.4
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a patient-specific brace fitted to body scan.
|
||||
|
||||
Args:
|
||||
body_scan_path: Path to 3D body scan (STL/OBJ/PLY)
|
||||
rigo_type: Rigo classification (A1, A2, B1, etc.)
|
||||
output_dir: Directory to save output files
|
||||
case_id: Case identifier for naming files
|
||||
clearance_mm: Clearance between body and shell (default 8mm)
|
||||
wall_thickness_mm: Shell wall thickness (default 2.4mm for 3D printing)
|
||||
|
||||
Returns:
|
||||
Dictionary with output file paths and fitting info
|
||||
"""
|
||||
if not HAS_TRIMESH:
|
||||
raise ImportError("trimesh is required for brace fitting")
|
||||
|
||||
if not HAS_EXP10:
|
||||
raise ImportError("EXPERIMENT_10 modules not available")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select template based on Rigo type
|
||||
template_file = RIGO_TO_VASE.get(rigo_type, "A1_vase.OBJ")
|
||||
template_path = VASES_DIR / template_file
|
||||
|
||||
if not template_path.exists():
|
||||
# Try alternative paths
|
||||
alt_paths = [
|
||||
EXPERIMENTS_DIR / "EXPERIMENT_10" / "_vase" / template_file,
|
||||
Path(__file__).parent.parent.parent / "_vase" / template_file,
|
||||
]
|
||||
for alt in alt_paths:
|
||||
if alt.exists():
|
||||
template_path = alt
|
||||
break
|
||||
else:
|
||||
raise FileNotFoundError(f"Template not found: {template_file}")
|
||||
|
||||
# Fit shell to body
|
||||
# Returns: (shell_mesh, body_mesh, combined_mesh, feedback)
|
||||
fitted_mesh, body_mesh, combined_mesh, feedback = fit_shell_to_body_v2(
|
||||
body_scan_path=body_scan_path,
|
||||
template_path=str(template_path),
|
||||
clearance_mm=clearance_mm
|
||||
)
|
||||
|
||||
# Generate output files
|
||||
outputs = {}
|
||||
|
||||
# Shell STL (for 3D printing)
|
||||
shell_stl = output_path / f"{case_id}_shell.stl"
|
||||
fitted_mesh.export(str(shell_stl))
|
||||
outputs["shell_stl"] = str(shell_stl)
|
||||
|
||||
# Shell GLB (for web viewing)
|
||||
shell_glb = output_path / f"{case_id}_shell.glb"
|
||||
fitted_mesh.export(str(shell_glb))
|
||||
outputs["shell_glb"] = str(shell_glb)
|
||||
|
||||
# Combined body + shell STL (for visualization)
|
||||
# combined_mesh is already returned from fit_shell_to_body_v2
|
||||
combined_stl = output_path / f"{case_id}_body_with_shell.stl"
|
||||
combined_mesh.export(str(combined_stl))
|
||||
outputs["combined_stl"] = str(combined_stl)
|
||||
|
||||
# Feedback JSON
|
||||
feedback_json = output_path / f"{case_id}_feedback.json"
|
||||
with open(feedback_json, "w") as f:
|
||||
json.dump(asdict(feedback), f, indent=2, default=_json_serializer)
|
||||
outputs["feedback_json"] = str(feedback_json)
|
||||
|
||||
# Create visualization
|
||||
try:
|
||||
viz_path = output_path / f"{case_id}_visualization.png"
|
||||
create_fitting_visualization(body_mesh, fitted_mesh, feedback, str(viz_path))
|
||||
outputs["visualization"] = str(viz_path)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not create visualization: {e}")
|
||||
|
||||
# Return result
|
||||
return {
|
||||
"template_used": template_file,
|
||||
"rigo_type": rigo_type,
|
||||
"clearance_mm": clearance_mm,
|
||||
"fitting": {
|
||||
"scale_right": feedback.scale_right,
|
||||
"scale_up": feedback.scale_up,
|
||||
"scale_forward": feedback.scale_forward,
|
||||
"pelvis_distance_mm": feedback.pelvis_distance_mm,
|
||||
"up_alignment_dot": feedback.up_alignment_dot,
|
||||
"warnings": feedback.warnings,
|
||||
},
|
||||
"body_measurements": {
|
||||
"max_width_mm": feedback.max_body_width_mm,
|
||||
"max_depth_mm": feedback.max_body_depth_mm,
|
||||
},
|
||||
"shell_dimensions": {
|
||||
"width_mm": feedback.target_shell_width_mm,
|
||||
"depth_mm": feedback.target_shell_depth_mm,
|
||||
"bounds_min": feedback.final_bounds_min,
|
||||
"bounds_max": feedback.final_bounds_max,
|
||||
},
|
||||
"mesh_stats": {
|
||||
"vertices": len(fitted_mesh.vertices),
|
||||
"faces": len(fitted_mesh.faces),
|
||||
},
|
||||
"outputs": outputs,
|
||||
}
|
||||
|
||||
|
||||
def create_fitting_visualization(
|
||||
body_mesh: 'trimesh.Trimesh',
|
||||
shell_mesh: 'trimesh.Trimesh',
|
||||
feedback: 'FittingFeedback',
|
||||
output_path: str
|
||||
):
|
||||
"""Create a multi-panel visualization of the fitted brace."""
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
|
||||
# Panel 1: Front view
|
||||
ax1 = fig.add_subplot(2, 3, 1, projection='3d')
|
||||
plot_mesh_silhouette(ax1, body_mesh, 'gray', alpha=0.3)
|
||||
plot_mesh_silhouette(ax1, shell_mesh, 'blue', alpha=0.6)
|
||||
ax1.set_title('Front View')
|
||||
ax1.view_init(elev=0, azim=0)
|
||||
|
||||
# Panel 2: Side view
|
||||
ax2 = fig.add_subplot(2, 3, 2, projection='3d')
|
||||
plot_mesh_silhouette(ax2, body_mesh, 'gray', alpha=0.3)
|
||||
plot_mesh_silhouette(ax2, shell_mesh, 'blue', alpha=0.6)
|
||||
ax2.set_title('Side View')
|
||||
ax2.view_init(elev=0, azim=90)
|
||||
|
||||
# Panel 3: Top view
|
||||
ax3 = fig.add_subplot(2, 3, 3, projection='3d')
|
||||
plot_mesh_silhouette(ax3, body_mesh, 'gray', alpha=0.3)
|
||||
plot_mesh_silhouette(ax3, shell_mesh, 'blue', alpha=0.6)
|
||||
ax3.set_title('Top View')
|
||||
ax3.view_init(elev=90, azim=0)
|
||||
|
||||
# Panel 4: Fitting info
|
||||
ax4 = fig.add_subplot(2, 3, 4)
|
||||
ax4.axis('off')
|
||||
info_text = f"""
|
||||
Fitting Information
|
||||
-------------------
|
||||
Template: {feedback.template_name}
|
||||
Clearance: {feedback.clearance_mm} mm
|
||||
|
||||
Scale Factors:
|
||||
Right: {feedback.scale_right:.3f}
|
||||
Up: {feedback.scale_up:.3f}
|
||||
Forward: {feedback.scale_forward:.3f}
|
||||
|
||||
Alignment:
|
||||
Pelvis Distance: {feedback.pelvis_distance_mm:.2f} mm
|
||||
Up Alignment: {feedback.up_alignment_dot:.4f}
|
||||
|
||||
Shell vs Body:
|
||||
Width Margin: {feedback.shell_minus_body_width_mm:.1f} mm
|
||||
Depth Margin: {feedback.shell_minus_body_depth_mm:.1f} mm
|
||||
"""
|
||||
ax4.text(0.1, 0.9, info_text, transform=ax4.transAxes, fontsize=10,
|
||||
verticalalignment='top', fontfamily='monospace')
|
||||
|
||||
# Panel 5: Warnings
|
||||
ax5 = fig.add_subplot(2, 3, 5)
|
||||
ax5.axis('off')
|
||||
warnings_text = "Warnings:\n" + ("\n".join(feedback.warnings) if feedback.warnings else "None")
|
||||
ax5.text(0.1, 0.9, warnings_text, transform=ax5.transAxes, fontsize=10,
|
||||
verticalalignment='top', color='orange' if feedback.warnings else 'green')
|
||||
|
||||
# Panel 6: Isometric view
|
||||
ax6 = fig.add_subplot(2, 3, 6, projection='3d')
|
||||
plot_mesh_silhouette(ax6, body_mesh, 'gray', alpha=0.3)
|
||||
plot_mesh_silhouette(ax6, shell_mesh, 'blue', alpha=0.6)
|
||||
ax6.set_title('Isometric View')
|
||||
ax6.view_init(elev=20, azim=45)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_mesh_silhouette(ax, mesh, color, alpha=0.5):
|
||||
"""Plot a simplified mesh representation."""
|
||||
# Sample vertices for plotting
|
||||
verts = mesh.vertices
|
||||
if len(verts) > 5000:
|
||||
indices = np.random.choice(len(verts), 5000, replace=False)
|
||||
verts = verts[indices]
|
||||
|
||||
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2],
|
||||
c=color, alpha=alpha, s=1)
|
||||
|
||||
# Set equal aspect ratio
|
||||
max_range = np.max(mesh.extents) / 2
|
||||
mid = mesh.centroid
|
||||
ax.set_xlim(mid[0] - max_range, mid[0] + max_range)
|
||||
ax.set_ylim(mid[1] - max_range, mid[1] + max_range)
|
||||
ax.set_zlim(mid[2] - max_range, mid[2] + max_range)
|
||||
|
||||
|
||||
def _json_serializer(obj):
|
||||
"""JSON serializer for numpy types."""
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
if isinstance(obj, (np.float32, np.float64)):
|
||||
return float(obj)
|
||||
if isinstance(obj, (np.int32, np.int64)):
|
||||
return int(obj)
|
||||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||
58
brace-generator/config.py
Normal file
58
brace-generator/config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Server configuration for Brace Generator API.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Config:
|
||||
"""Server configuration loaded from environment variables."""
|
||||
|
||||
# Server settings (DEV uses port 8001)
|
||||
HOST: str = os.getenv("HOST", "0.0.0.0")
|
||||
PORT: int = int(os.getenv("PORT", "8001"))
|
||||
DEBUG: bool = os.getenv("DEBUG", "true").lower() == "true"
|
||||
|
||||
# Model settings
|
||||
DEVICE: str = os.getenv("DEVICE", "cuda") # 'cuda' or 'cpu'
|
||||
MODEL: str = os.getenv("MODEL", "scoliovis") # 'scoliovis' or 'vertebra-landmark'
|
||||
|
||||
# Paths
|
||||
BASE_DIR: Path = Path(__file__).parent.parent
|
||||
TEMPLATES_DIR: Path = BASE_DIR / "rigoBrace(2)"
|
||||
WEIGHTS_DIR: Path = BASE_DIR.parent / "scoliovis-api" / "models"
|
||||
|
||||
# TEMP_DIR: Use system temp on Windows, /tmp on Linux
|
||||
@staticmethod
|
||||
def _get_temp_dir() -> Path:
|
||||
env_temp = os.getenv("TEMP_DIR")
|
||||
if env_temp:
|
||||
return Path(env_temp)
|
||||
# Use system temp directory (works on both Windows and Linux)
|
||||
import tempfile
|
||||
return Path(tempfile.gettempdir()) / "brace_generator"
|
||||
|
||||
TEMP_DIR: Path = _get_temp_dir()
|
||||
|
||||
# CORS
|
||||
CORS_ORIGINS: list = os.getenv("CORS_ORIGINS", "*").split(",")
|
||||
|
||||
# Request limits
|
||||
MAX_IMAGE_SIZE_MB: int = int(os.getenv("MAX_IMAGE_SIZE_MB", "50"))
|
||||
REQUEST_TIMEOUT_SECONDS: int = int(os.getenv("REQUEST_TIMEOUT_SECONDS", "120"))
|
||||
|
||||
@classmethod
|
||||
def ensure_dirs(cls):
|
||||
"""Create necessary directories."""
|
||||
cls.TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
"""Get device, falling back to CPU if CUDA unavailable."""
|
||||
import torch
|
||||
if cls.DEVICE == "cuda" and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
config = Config()
|
||||
906
brace-generator/glb_generator.py
Normal file
906
brace-generator/glb_generator.py
Normal file
@@ -0,0 +1,906 @@
|
||||
"""
|
||||
GLB Brace Generator with Markers
|
||||
|
||||
This module generates GLB brace files with embedded markers for editing.
|
||||
Supports both regular (fitted) and vase-shaped templates.
|
||||
|
||||
PRESSURE ZONES EXPLANATION:
|
||||
===========================
|
||||
|
||||
The brace has 4 main pressure/expansion zones that correct spinal curvature:
|
||||
|
||||
1. THORACIC PAD (LM_PAD_TH) - PUSH ZONE
|
||||
- Location: On the CONVEX side of the thoracic curve (the side that bulges out)
|
||||
- Function: Pushes INWARD to correct the thoracic curvature
|
||||
- For right thoracic curves: pad is on the RIGHT back
|
||||
- Depth: 8-25mm depending on Cobb angle severity
|
||||
|
||||
2. THORACIC BAY (LM_BAY_TH) - EXPANSION ZONE
|
||||
- Location: OPPOSITE the thoracic pad (concave side)
|
||||
- Function: Creates SPACE for the body to move INTO during correction
|
||||
- The ribs/body shift into this space as the pad pushes
|
||||
- Clearance: 10-35mm
|
||||
|
||||
3. LUMBAR PAD (LM_PAD_LUM) - PUSH ZONE
|
||||
- Location: On the CONVEX side of the lumbar curve
|
||||
- Function: Pushes INWARD to correct lumbar curvature
|
||||
- Usually on the opposite side of thoracic pad (for S-curves)
|
||||
- Depth: 6-20mm
|
||||
|
||||
4. LUMBAR BAY (LM_BAY_LUM) - EXPANSION ZONE
|
||||
- Location: OPPOSITE the lumbar pad
|
||||
- Function: Creates SPACE for lumbar correction
|
||||
- Clearance: 8-25mm
|
||||
|
||||
5. HIP ANCHORS (LM_ANCHOR_HIP_L/R) - STABILITY ZONES
|
||||
- Location: Around the hip/pelvis area on both sides
|
||||
- Function: Grip the pelvis to prevent brace from riding up
|
||||
- Slight inward pressure to anchor the brace
|
||||
|
||||
The Rigo classification determines which zones are primary:
|
||||
- A types (3-curve): Strong thoracic pad, minor lumbar
|
||||
- B types (4-curve): Both thoracic and lumbar pads are primary
|
||||
- C types (non-3-non-4): Balanced thoracic, neutral pelvis
|
||||
- E types (single lumbar/TL): Strong lumbar/TL pad, counter-thoracic
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import trimesh
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple, Literal
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
# Paths to template directories
|
||||
BASE_DIR = Path(__file__).parent.parent
|
||||
BRACES_DIR = BASE_DIR / "braces"
|
||||
REGULAR_TEMPLATES_DIR = BRACES_DIR / "brace_templates"
|
||||
VASE_TEMPLATES_DIR = BRACES_DIR / "vase_brace_templates"
|
||||
|
||||
# Template types
|
||||
TemplateType = Literal["regular", "vase"]
|
||||
|
||||
@dataclass
|
||||
class MarkerPositions:
|
||||
"""Marker positions for a brace template."""
|
||||
LM_PELVIS_CENTER: Tuple[float, float, float]
|
||||
LM_TOP_CENTER: Tuple[float, float, float]
|
||||
LM_PAD_TH: Tuple[float, float, float]
|
||||
LM_BAY_TH: Tuple[float, float, float]
|
||||
LM_PAD_LUM: Tuple[float, float, float]
|
||||
LM_BAY_LUM: Tuple[float, float, float]
|
||||
LM_ANCHOR_HIP_L: Tuple[float, float, float]
|
||||
LM_ANCHOR_HIP_R: Tuple[float, float, float]
|
||||
|
||||
@dataclass
|
||||
class PressureZone:
|
||||
"""Describes a pressure or expansion zone on the brace."""
|
||||
name: str
|
||||
marker_name: str
|
||||
position: Tuple[float, float, float]
|
||||
zone_type: Literal["pad", "bay", "anchor"]
|
||||
function: str
|
||||
direction: Literal["inward", "outward", "grip"]
|
||||
depth_mm: float = 0.0
|
||||
radius_mm: Tuple[float, float, float] = (50.0, 80.0, 40.0)
|
||||
|
||||
@dataclass
|
||||
class BraceGenerationResult:
|
||||
"""Result of brace generation with markers."""
|
||||
glb_path: str
|
||||
stl_path: str
|
||||
json_path: str
|
||||
template_type: str
|
||||
rigo_type: str
|
||||
markers: Dict[str, Tuple[float, float, float]]
|
||||
basis: Dict[str, Any]
|
||||
pressure_zones: list
|
||||
mesh_stats: Dict[str, int]
|
||||
transform_applied: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def get_template_paths(rigo_type: str, template_type: TemplateType) -> Tuple[Path, Path]:
|
||||
"""
|
||||
Get paths to GLB template and markers JSON.
|
||||
|
||||
Args:
|
||||
rigo_type: Rigo classification (A1, A2, A3, B1, B2, C1, C2, E1, E2)
|
||||
template_type: "regular" or "vase"
|
||||
|
||||
Returns:
|
||||
Tuple of (glb_path, markers_json_path)
|
||||
"""
|
||||
if template_type == "regular":
|
||||
glb_path = REGULAR_TEMPLATES_DIR / f"{rigo_type}_marked_v3.glb"
|
||||
json_path = REGULAR_TEMPLATES_DIR / f"{rigo_type}_marked_v3.markers.json"
|
||||
else: # vase
|
||||
glb_path = VASE_TEMPLATES_DIR / "glb" / f"{rigo_type}_vase_marked.glb"
|
||||
json_path = VASE_TEMPLATES_DIR / "markers_json" / f"{rigo_type}_vase_marked.markers.json"
|
||||
|
||||
return glb_path, json_path
|
||||
|
||||
|
||||
def load_template_markers(rigo_type: str, template_type: TemplateType) -> Dict[str, Any]:
|
||||
"""Load markers from JSON file for a template."""
|
||||
_, json_path = get_template_paths(rigo_type, template_type)
|
||||
|
||||
if not json_path.exists():
|
||||
raise FileNotFoundError(f"Markers JSON not found: {json_path}")
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_glb_template(rigo_type: str, template_type: TemplateType) -> trimesh.Trimesh:
|
||||
"""Load GLB template as trimesh."""
|
||||
glb_path, _ = get_template_paths(rigo_type, template_type)
|
||||
|
||||
if not glb_path.exists():
|
||||
raise FileNotFoundError(f"GLB template not found: {glb_path}")
|
||||
|
||||
scene = trimesh.load(str(glb_path))
|
||||
|
||||
# If it's a scene, concatenate all meshes
|
||||
if isinstance(scene, trimesh.Scene):
|
||||
meshes = [g for g in scene.geometry.values() if isinstance(g, trimesh.Trimesh)]
|
||||
if meshes:
|
||||
mesh = trimesh.util.concatenate(meshes)
|
||||
else:
|
||||
raise ValueError(f"No valid meshes found in GLB: {glb_path}")
|
||||
else:
|
||||
mesh = scene
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
def calculate_pressure_zones(
|
||||
markers: Dict[str, Any],
|
||||
rigo_type: str,
|
||||
cobb_angles: Dict[str, float]
|
||||
) -> list:
|
||||
"""
|
||||
Calculate pressure zone parameters based on markers and analysis.
|
||||
|
||||
Args:
|
||||
markers: Marker positions from template
|
||||
rigo_type: Rigo classification
|
||||
cobb_angles: Dict with PT, MT, TL Cobb angles
|
||||
|
||||
Returns:
|
||||
List of PressureZone objects
|
||||
"""
|
||||
marker_pos = markers.get("markers", markers)
|
||||
|
||||
# Calculate severity from Cobb angles
|
||||
mt_angle = cobb_angles.get("MT", 0)
|
||||
tl_angle = cobb_angles.get("TL", 0)
|
||||
|
||||
# Severity mapping: Cobb -> depth
|
||||
def cobb_to_depth(angle: float, min_depth: float = 6.0, max_depth: float = 22.0) -> float:
|
||||
severity = min(max((angle - 10) / 40, 0), 1) # 0-1 range
|
||||
return min_depth + severity * (max_depth - min_depth)
|
||||
|
||||
th_depth = cobb_to_depth(mt_angle, 8.0, 22.0)
|
||||
lum_depth = cobb_to_depth(tl_angle, 6.0, 18.0)
|
||||
|
||||
# Bay clearance is typically 1.2-1.5x pad depth
|
||||
th_clearance = th_depth * 1.3 + 5
|
||||
lum_clearance = lum_depth * 1.3 + 4
|
||||
|
||||
zones = [
|
||||
PressureZone(
|
||||
name="Thoracic Pad",
|
||||
marker_name="LM_PAD_TH",
|
||||
position=tuple(marker_pos.get("LM_PAD_TH", [0, 0, 0])),
|
||||
zone_type="pad",
|
||||
function="Pushes INWARD on thoracic curve convex side to correct curvature",
|
||||
direction="inward",
|
||||
depth_mm=th_depth,
|
||||
radius_mm=(50.0, 90.0, 40.0)
|
||||
),
|
||||
PressureZone(
|
||||
name="Thoracic Bay",
|
||||
marker_name="LM_BAY_TH",
|
||||
position=tuple(marker_pos.get("LM_BAY_TH", [0, 0, 0])),
|
||||
zone_type="bay",
|
||||
function="Creates SPACE on thoracic concave side for body to shift into",
|
||||
direction="outward",
|
||||
depth_mm=th_clearance,
|
||||
radius_mm=(65.0, 110.0, 55.0)
|
||||
),
|
||||
PressureZone(
|
||||
name="Lumbar Pad",
|
||||
marker_name="LM_PAD_LUM",
|
||||
position=tuple(marker_pos.get("LM_PAD_LUM", [0, 0, 0])),
|
||||
zone_type="pad",
|
||||
function="Pushes INWARD on lumbar curve convex side to correct curvature",
|
||||
direction="inward",
|
||||
depth_mm=lum_depth,
|
||||
radius_mm=(55.0, 85.0, 45.0)
|
||||
),
|
||||
PressureZone(
|
||||
name="Lumbar Bay",
|
||||
marker_name="LM_BAY_LUM",
|
||||
position=tuple(marker_pos.get("LM_BAY_LUM", [0, 0, 0])),
|
||||
zone_type="bay",
|
||||
function="Creates SPACE on lumbar concave side for body to shift into",
|
||||
direction="outward",
|
||||
depth_mm=lum_clearance,
|
||||
radius_mm=(70.0, 100.0, 60.0)
|
||||
),
|
||||
PressureZone(
|
||||
name="Left Hip Anchor",
|
||||
marker_name="LM_ANCHOR_HIP_L",
|
||||
position=tuple(marker_pos.get("LM_ANCHOR_HIP_L", [0, 0, 0])),
|
||||
zone_type="anchor",
|
||||
function="Grips left hip/pelvis to stabilize brace and prevent riding up",
|
||||
direction="grip",
|
||||
depth_mm=4.0,
|
||||
radius_mm=(40.0, 60.0, 40.0)
|
||||
),
|
||||
PressureZone(
|
||||
name="Right Hip Anchor",
|
||||
marker_name="LM_ANCHOR_HIP_R",
|
||||
position=tuple(marker_pos.get("LM_ANCHOR_HIP_R", [0, 0, 0])),
|
||||
zone_type="anchor",
|
||||
function="Grips right hip/pelvis to stabilize brace and prevent riding up",
|
||||
direction="grip",
|
||||
depth_mm=4.0,
|
||||
radius_mm=(40.0, 60.0, 40.0)
|
||||
),
|
||||
]
|
||||
|
||||
return zones
|
||||
|
||||
|
||||
def transform_markers(
|
||||
markers: Dict[str, list],
|
||||
transform_matrix: np.ndarray
|
||||
) -> Dict[str, Tuple[float, float, float]]:
|
||||
"""Apply transformation matrix to all marker positions."""
|
||||
transformed = {}
|
||||
|
||||
for name, pos in markers.items():
|
||||
if isinstance(pos, (list, tuple)) and len(pos) == 3:
|
||||
# Convert to homogeneous coordinates
|
||||
pos_h = np.array([pos[0], pos[1], pos[2], 1.0])
|
||||
# Apply transform
|
||||
new_pos = transform_matrix @ pos_h
|
||||
transformed[name] = (float(new_pos[0]), float(new_pos[1]), float(new_pos[2]))
|
||||
|
||||
return transformed
|
||||
|
||||
|
||||
def generate_glb_brace(
|
||||
rigo_type: str,
|
||||
template_type: TemplateType,
|
||||
output_dir: Path,
|
||||
case_id: str,
|
||||
cobb_angles: Dict[str, float],
|
||||
body_scan_path: Optional[str] = None,
|
||||
clearance_mm: float = 8.0
|
||||
) -> BraceGenerationResult:
|
||||
"""
|
||||
Generate a GLB brace with markers.
|
||||
|
||||
Args:
|
||||
rigo_type: Rigo classification (A1, A2, etc.)
|
||||
template_type: "regular" or "vase"
|
||||
output_dir: Directory for output files
|
||||
case_id: Case identifier
|
||||
cobb_angles: Dict with PT, MT, TL angles
|
||||
body_scan_path: Optional path to body scan STL for fitting
|
||||
clearance_mm: Clearance between body and brace
|
||||
|
||||
Returns:
|
||||
BraceGenerationResult with paths and marker info
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load template
|
||||
mesh = load_glb_template(rigo_type, template_type)
|
||||
marker_data = load_template_markers(rigo_type, template_type)
|
||||
|
||||
markers = marker_data.get("markers", {})
|
||||
basis = marker_data.get("basis", {})
|
||||
|
||||
transform_matrix = np.eye(4)
|
||||
transform_info = None
|
||||
|
||||
# If body scan provided, fit to body
|
||||
if body_scan_path and Path(body_scan_path).exists():
|
||||
mesh, transform_matrix, transform_info = fit_brace_to_body(
|
||||
mesh, body_scan_path, clearance_mm, brace_basis=basis, template_type=template_type,
|
||||
markers=markers # Pass markers for zone-aware ironing
|
||||
)
|
||||
# Transform markers to match the body-fitted mesh
|
||||
markers = transform_markers(markers, transform_matrix)
|
||||
|
||||
# Calculate pressure zones
|
||||
pressure_zones = calculate_pressure_zones(markers, rigo_type, cobb_angles)
|
||||
|
||||
# Output file names
|
||||
type_suffix = "vase" if template_type == "vase" else "regular"
|
||||
glb_filename = f"{case_id}_{rigo_type}_{type_suffix}.glb"
|
||||
stl_filename = f"{case_id}_{rigo_type}_{type_suffix}.stl"
|
||||
json_filename = f"{case_id}_{rigo_type}_{type_suffix}_markers.json"
|
||||
|
||||
glb_path = output_dir / glb_filename
|
||||
stl_path = output_dir / stl_filename
|
||||
json_path = output_dir / json_filename
|
||||
|
||||
# Export GLB
|
||||
mesh.export(str(glb_path))
|
||||
|
||||
# Export STL
|
||||
mesh.export(str(stl_path))
|
||||
|
||||
# Build output JSON with markers and zones
|
||||
output_data = {
|
||||
"case_id": case_id,
|
||||
"rigo_type": rigo_type,
|
||||
"template_type": template_type,
|
||||
"cobb_angles": cobb_angles,
|
||||
"markers": {k: list(v) if isinstance(v, tuple) else v for k, v in markers.items()},
|
||||
"basis": basis,
|
||||
"pressure_zones": [
|
||||
{
|
||||
"name": z.name,
|
||||
"marker_name": z.marker_name,
|
||||
"position": list(z.position),
|
||||
"zone_type": z.zone_type,
|
||||
"function": z.function,
|
||||
"direction": z.direction,
|
||||
"depth_mm": z.depth_mm,
|
||||
"radius_mm": list(z.radius_mm)
|
||||
}
|
||||
for z in pressure_zones
|
||||
],
|
||||
"mesh_stats": {
|
||||
"vertices": len(mesh.vertices),
|
||||
"faces": len(mesh.faces)
|
||||
},
|
||||
"outputs": {
|
||||
"glb": str(glb_path),
|
||||
"stl": str(stl_path),
|
||||
"json": str(json_path)
|
||||
}
|
||||
}
|
||||
|
||||
if transform_info:
|
||||
output_data["body_fitting"] = transform_info
|
||||
|
||||
# Save JSON
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
return BraceGenerationResult(
|
||||
glb_path=str(glb_path),
|
||||
stl_path=str(stl_path),
|
||||
json_path=str(json_path),
|
||||
template_type=template_type,
|
||||
rigo_type=rigo_type,
|
||||
markers=markers,
|
||||
basis=basis,
|
||||
pressure_zones=[asdict(z) for z in pressure_zones],
|
||||
mesh_stats=output_data["mesh_stats"],
|
||||
transform_applied=transform_info
|
||||
)
|
||||
|
||||
|
||||
def iron_brace_to_body(
|
||||
brace_mesh: trimesh.Trimesh,
|
||||
body_mesh: trimesh.Trimesh,
|
||||
min_clearance_mm: float = 3.0,
|
||||
max_clearance_mm: float = 15.0,
|
||||
smoothing_iterations: int = 2,
|
||||
up_axis: int = 2,
|
||||
markers: Optional[Dict[str, Any]] = None
|
||||
) -> trimesh.Trimesh:
|
||||
"""
|
||||
Iron the brace surface to conform to the body scan surface.
|
||||
|
||||
This ensures the brace follows the body contour without excessive gaps.
|
||||
Uses zone-aware ironing:
|
||||
- FRONT (belly) and BACK: Aggressive ironing for tight fit
|
||||
- SIDES (where pads/bays are): Preserve correction zones, moderate ironing
|
||||
|
||||
Args:
|
||||
brace_mesh: The brace mesh to iron
|
||||
body_mesh: The body scan mesh to conform to
|
||||
min_clearance_mm: Minimum distance from body surface
|
||||
max_clearance_mm: Maximum distance from body surface (trigger ironing)
|
||||
smoothing_iterations: Number of Laplacian smoothing passes after ironing
|
||||
up_axis: Which axis is "up" (0=X, 1=Y, 2=Z)
|
||||
markers: Optional dict of marker positions to preserve pressure zones
|
||||
|
||||
Returns:
|
||||
Ironed brace mesh
|
||||
"""
|
||||
from scipy.spatial import cKDTree
|
||||
import math
|
||||
|
||||
print(f"Ironing brace to body surface (clearance: {min_clearance_mm}-{max_clearance_mm}mm)")
|
||||
|
||||
# Create a copy to modify
|
||||
ironed_mesh = brace_mesh.copy()
|
||||
vertices = ironed_mesh.vertices.copy()
|
||||
|
||||
# Get body center and bounds
|
||||
body_center = body_mesh.centroid
|
||||
body_bounds = body_mesh.bounds
|
||||
|
||||
# Determine the torso region (process middle 80% of body height)
|
||||
body_height = body_bounds[1, up_axis] - body_bounds[0, up_axis]
|
||||
torso_bottom = body_bounds[0, up_axis] + body_height * 0.10
|
||||
torso_top = body_bounds[0, up_axis] + body_height * 0.90
|
||||
|
||||
# Build KD-tree from body mesh vertices for fast nearest neighbor queries
|
||||
body_tree = cKDTree(body_mesh.vertices)
|
||||
|
||||
# Find closest points on body for ALL brace vertices at once
|
||||
distances, closest_indices = body_tree.query(vertices, k=1)
|
||||
closest_points = body_mesh.vertices[closest_indices]
|
||||
|
||||
# Determine horizontal axes (perpendicular to up axis)
|
||||
horiz_axes = [i for i in range(3) if i != up_axis]
|
||||
|
||||
# Calculate brace center for angle computation
|
||||
brace_center = np.mean(vertices, axis=0)
|
||||
|
||||
# Identify marker exclusion zones (preserve correction areas)
|
||||
exclusion_zones = []
|
||||
if markers:
|
||||
# Pad and bay markers need preservation
|
||||
for marker_name in ['LM_PAD_TH', 'LM_PAD_LUM', 'LM_BAY_TH', 'LM_BAY_LUM']:
|
||||
if marker_name in markers:
|
||||
pos = markers[marker_name]
|
||||
if isinstance(pos, (list, tuple)) and len(pos) >= 3:
|
||||
exclusion_zones.append({
|
||||
'center': np.array(pos),
|
||||
'radius': 60.0, # 60mm exclusion radius around markers
|
||||
'name': marker_name
|
||||
})
|
||||
|
||||
# Process each brace vertex
|
||||
adjusted_count = 0
|
||||
pulled_in_count = 0
|
||||
pushed_out_count = 0
|
||||
skipped_zone_count = 0
|
||||
|
||||
# Height normalization
|
||||
brace_min_z = vertices[:, up_axis].min()
|
||||
brace_max_z = vertices[:, up_axis].max()
|
||||
brace_height_range = max(brace_max_z - brace_min_z, 1.0)
|
||||
|
||||
for i in range(len(vertices)):
|
||||
vertex = vertices[i]
|
||||
closest_pt = closest_points[i]
|
||||
dist = distances[i]
|
||||
|
||||
# Only process vertices in the torso region
|
||||
if vertex[up_axis] < torso_bottom or vertex[up_axis] > torso_top:
|
||||
continue
|
||||
|
||||
# Check if vertex is in an exclusion zone (near pad/bay markers)
|
||||
in_exclusion = False
|
||||
for zone in exclusion_zones:
|
||||
zone_dist = np.linalg.norm(vertex - zone['center'])
|
||||
if zone_dist < zone['radius']:
|
||||
in_exclusion = True
|
||||
skipped_zone_count += 1
|
||||
break
|
||||
|
||||
if in_exclusion:
|
||||
continue
|
||||
|
||||
# Calculate angular position around body center (horizontal plane)
|
||||
# 0° = front (belly), 90° = right side, 180° = back, 270° = left side
|
||||
rel_pos = vertex - body_center
|
||||
angle = math.atan2(rel_pos[horiz_axes[1]], rel_pos[horiz_axes[0]])
|
||||
angle_deg = math.degrees(angle) % 360
|
||||
|
||||
# Determine zone based on angle:
|
||||
# FRONT (belly): 315-45° - aggressive ironing
|
||||
# BACK: 135-225° - aggressive ironing
|
||||
# SIDES: 45-135° and 225-315° - moderate ironing (correction zones)
|
||||
is_front_back = (angle_deg < 45 or angle_deg > 315) or (135 < angle_deg < 225)
|
||||
|
||||
# Height-based clearance adjustment
|
||||
height_norm = (vertex[up_axis] - brace_min_z) / brace_height_range
|
||||
|
||||
# Set clearances based on zone
|
||||
if is_front_back:
|
||||
# FRONT/BACK: Aggressive ironing - very tight fit
|
||||
local_min = min_clearance_mm * 0.5 # Allow closer to body
|
||||
local_max = max_clearance_mm * 0.6 # Trigger ironing earlier
|
||||
local_target = min_clearance_mm + 2.0 # Target just above minimum
|
||||
else:
|
||||
# SIDES: More conservative - preserve room for correction
|
||||
local_min = min_clearance_mm
|
||||
local_max = max_clearance_mm * 1.2 # Allow slightly more gap
|
||||
local_target = (min_clearance_mm + max_clearance_mm) / 2
|
||||
|
||||
# Height adjustments (tighter at hips and chest)
|
||||
if height_norm < 0.25 or height_norm > 0.75:
|
||||
local_max *= 0.8 # Tighter at extremes
|
||||
local_target *= 0.85
|
||||
|
||||
# Direction from body surface to brace vertex
|
||||
direction = vertex - closest_pt
|
||||
dir_length = np.linalg.norm(direction)
|
||||
|
||||
if dir_length < 1e-6:
|
||||
direction = vertex - body_center
|
||||
direction[up_axis] = 0
|
||||
dir_length = np.linalg.norm(direction)
|
||||
if dir_length < 1e-6:
|
||||
continue
|
||||
|
||||
direction = direction / dir_length
|
||||
|
||||
# Determine signed distance
|
||||
vertex_dist_to_center = np.linalg.norm(vertex[:2] - body_center[:2])
|
||||
closest_dist_to_center = np.linalg.norm(closest_pt[:2] - body_center[:2])
|
||||
|
||||
if vertex_dist_to_center >= closest_dist_to_center:
|
||||
signed_distance = dist
|
||||
else:
|
||||
signed_distance = -dist
|
||||
|
||||
# Determine if adjustment is needed
|
||||
needs_adjustment = False
|
||||
new_position = vertex.copy()
|
||||
|
||||
if signed_distance > local_max:
|
||||
# Gap too large - pull vertex closer to body
|
||||
new_position = closest_pt + direction * local_target
|
||||
new_position[up_axis] = vertex[up_axis] # Preserve height
|
||||
needs_adjustment = True
|
||||
pulled_in_count += 1
|
||||
|
||||
elif signed_distance < local_min:
|
||||
# Too close or inside body - push outward
|
||||
offset = local_min + 1.0
|
||||
outward_dir = closest_pt - body_center
|
||||
outward_dir[up_axis] = 0
|
||||
outward_length = np.linalg.norm(outward_dir)
|
||||
if outward_length > 1e-6:
|
||||
outward_dir = outward_dir / outward_length
|
||||
new_position = closest_pt + outward_dir * offset
|
||||
new_position[up_axis] = vertex[up_axis]
|
||||
needs_adjustment = True
|
||||
pushed_out_count += 1
|
||||
|
||||
if needs_adjustment:
|
||||
vertices[i] = new_position
|
||||
adjusted_count += 1
|
||||
|
||||
print(f"Ironing adjusted {adjusted_count} vertices (pulled in: {pulled_in_count}, pushed out: {pushed_out_count}, skipped zones: {skipped_zone_count})")
|
||||
|
||||
# Apply modified vertices
|
||||
ironed_mesh.vertices = vertices
|
||||
|
||||
# Apply Laplacian smoothing to blend changes and remove artifacts
|
||||
if smoothing_iterations > 0 and adjusted_count > 0:
|
||||
print(f"Applying {smoothing_iterations} smoothing iterations")
|
||||
try:
|
||||
ironed_mesh = trimesh.smoothing.filter_laplacian(
|
||||
ironed_mesh,
|
||||
lamb=0.3, # Gentler smoothing to preserve shape
|
||||
iterations=smoothing_iterations,
|
||||
implicit_time_integration=False
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Smoothing failed (non-critical): {e}")
|
||||
|
||||
# Ensure mesh is valid
|
||||
ironed_mesh.fix_normals()
|
||||
|
||||
return ironed_mesh
|
||||
|
||||
|
||||
def fit_brace_to_body(
|
||||
brace_mesh: trimesh.Trimesh,
|
||||
body_scan_path: str,
|
||||
clearance_mm: float = 8.0,
|
||||
brace_basis: Optional[Dict[str, Any]] = None,
|
||||
template_type: str = "regular",
|
||||
enable_ironing: bool = True,
|
||||
markers: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[trimesh.Trimesh, np.ndarray, Dict[str, Any]]:
|
||||
"""
|
||||
Fit brace to body scan using basis alignment.
|
||||
|
||||
The brace needs to be:
|
||||
1. Rotated so its UP axis aligns with body's UP axis (typically Z for body scans)
|
||||
2. Scaled to fit around the body with proper clearance
|
||||
3. Positioned at the torso level
|
||||
4. Ironed to conform to body surface (respecting correction zones)
|
||||
|
||||
Returns:
|
||||
Tuple of (transformed_mesh, transform_matrix, fitting_info)
|
||||
"""
|
||||
# Load body scan
|
||||
body_mesh = trimesh.load(body_scan_path, force='mesh')
|
||||
|
||||
# Get body dimensions
|
||||
body_bounds = body_mesh.bounds
|
||||
body_extents = body_mesh.extents
|
||||
body_center = body_mesh.centroid
|
||||
|
||||
# Determine body up axis (typically the longest dimension = height)
|
||||
# For human body scans, this is usually Z (from 3D scanners) or Y
|
||||
body_up_axis_idx = np.argmax(body_extents)
|
||||
print(f"Body up axis: {['X', 'Y', 'Z'][body_up_axis_idx]}, extents: {body_extents}")
|
||||
|
||||
# Get brace dimensions
|
||||
brace_bounds = brace_mesh.bounds
|
||||
brace_extents = brace_mesh.extents
|
||||
brace_center = brace_mesh.centroid
|
||||
|
||||
print(f"Brace original extents: {brace_extents}, template_type: {template_type}")
|
||||
|
||||
# Start building transformation
|
||||
transformed_mesh = brace_mesh.copy()
|
||||
transform = np.eye(4)
|
||||
|
||||
# Step 1: Center brace at origin
|
||||
T_center = np.eye(4)
|
||||
T_center[:3, 3] = -brace_center
|
||||
transformed_mesh.apply_transform(T_center)
|
||||
transform = T_center @ transform
|
||||
|
||||
# Step 2: Apply rotations based on template type and body orientation
|
||||
# Regular templates have: negative Y is up (inverted), need to flip
|
||||
# Vase templates have: positive Y is up
|
||||
# Body scan is Z-up
|
||||
|
||||
if body_up_axis_idx == 2: # Body is Z-up (standard for 3D scanners)
|
||||
if template_type == "regular":
|
||||
# Regular brace: -Y is up (inverted)
|
||||
# 1. Rotate -90° around X to bring Y-up to Z-up
|
||||
R1 = trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0])
|
||||
transformed_mesh.apply_transform(R1)
|
||||
transform = R1 @ transform
|
||||
|
||||
# 2. The brace is now Z-up but inverted (pelvis at top, shoulders at bottom)
|
||||
# Flip 180° around X to correct (this keeps Z as up axis)
|
||||
R2 = trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0])
|
||||
transformed_mesh.apply_transform(R2)
|
||||
transform = R2 @ transform
|
||||
|
||||
# 3. Rotate around Z to face forward correctly
|
||||
R3 = trimesh.transformations.rotation_matrix(-np.pi/2, [0, 0, 1])
|
||||
transformed_mesh.apply_transform(R3)
|
||||
transform = R3 @ transform
|
||||
|
||||
print(f"Applied regular brace rotations: X-90°, X+180° (flip), Z-90°")
|
||||
|
||||
else: # vase
|
||||
# Vase brace: positive Y is up
|
||||
# 1. Rotate -90° around X to bring Y-up to Z-up
|
||||
R1 = trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0])
|
||||
transformed_mesh.apply_transform(R1)
|
||||
transform = R1 @ transform
|
||||
|
||||
# 2. Flip 180° around Y to correct orientation (right-side up)
|
||||
R2 = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
|
||||
transformed_mesh.apply_transform(R2)
|
||||
transform = R2 @ transform
|
||||
|
||||
print(f"Applied vase brace rotations: X-90°, Y+180° (flip)")
|
||||
|
||||
# Step 3: Get new brace dimensions after rotation
|
||||
new_brace_extents = transformed_mesh.extents
|
||||
new_brace_center = transformed_mesh.centroid
|
||||
print(f"Brace extents after rotation: {new_brace_extents}")
|
||||
|
||||
# Step 4: Calculate NON-UNIFORM scaling based on body dimensions
|
||||
# The brace should cover the TORSO region (~50% of body height)
|
||||
# AND wrap around the body with proper girth
|
||||
|
||||
body_height = body_extents[body_up_axis_idx]
|
||||
brace_height = new_brace_extents[body_up_axis_idx] # After rotation, this is the height
|
||||
|
||||
# Body horizontal dimensions (girth at torso level)
|
||||
horizontal_axes = [i for i in range(3) if i != body_up_axis_idx]
|
||||
body_width = body_extents[horizontal_axes[0]] # X width
|
||||
body_depth = body_extents[horizontal_axes[1]] # Y depth
|
||||
|
||||
# Brace horizontal dimensions
|
||||
brace_width = new_brace_extents[horizontal_axes[0]]
|
||||
brace_depth = new_brace_extents[horizontal_axes[1]]
|
||||
|
||||
# Target: brace height should cover ~65% of body height (full torso coverage)
|
||||
target_height = body_height * 0.65
|
||||
height_scale = target_height / brace_height if brace_height > 0 else 1.0
|
||||
|
||||
# Target: brace width/depth should be LARGER than body to wrap AROUND it
|
||||
# The brace sits OUTSIDE the body, only pressure points push inward
|
||||
# Add ~25% extra + clearance so brace externals are visible outside body
|
||||
target_width = body_width * 1.25 + clearance_mm * 2
|
||||
target_depth = body_depth * 1.25 + clearance_mm * 2
|
||||
|
||||
width_scale = target_width / brace_width if brace_width > 0 else 1.0
|
||||
depth_scale = target_depth / brace_depth if brace_depth > 0 else 1.0
|
||||
|
||||
# Apply non-uniform scaling
|
||||
# Determine which axis is which after rotation
|
||||
S = np.eye(4)
|
||||
if body_up_axis_idx == 2: # Z is up
|
||||
S[0, 0] = width_scale # X scale
|
||||
S[1, 1] = depth_scale # Y scale
|
||||
S[2, 2] = height_scale # Z scale (height)
|
||||
elif body_up_axis_idx == 1: # Y is up
|
||||
S[0, 0] = width_scale # X scale
|
||||
S[1, 1] = height_scale # Y scale (height)
|
||||
S[2, 2] = depth_scale # Z scale
|
||||
else: # X is up (unusual)
|
||||
S[0, 0] = height_scale # X scale (height)
|
||||
S[1, 1] = width_scale # Y scale
|
||||
S[2, 2] = depth_scale # Z scale
|
||||
|
||||
# Limit scales to reasonable range
|
||||
S[0, 0] = max(0.5, min(S[0, 0], 50.0))
|
||||
S[1, 1] = max(0.5, min(S[1, 1], 50.0))
|
||||
S[2, 2] = max(0.5, min(S[2, 2], 50.0))
|
||||
|
||||
transformed_mesh.apply_transform(S)
|
||||
transform = S @ transform
|
||||
|
||||
print(f"Applied non-uniform scale: width={S[0,0]:.2f}, depth={S[1,1]:.2f}, height={S[2,2]:.2f}")
|
||||
print(f"Target dimensions: width={target_width:.1f}, depth={target_depth:.1f}, height={target_height:.1f}")
|
||||
|
||||
# For fitting_info, use average scale
|
||||
scale = (S[0, 0] + S[1, 1] + S[2, 2]) / 3
|
||||
|
||||
# Step 6: Position brace at torso level
|
||||
# Calculate where the torso is (middle portion of body height)
|
||||
body_height = body_extents[body_up_axis_idx]
|
||||
body_bottom = body_bounds[0, body_up_axis_idx]
|
||||
body_top = body_bounds[1, body_up_axis_idx]
|
||||
|
||||
# Torso is roughly the middle 40% of body height (from ~30% to ~70%)
|
||||
torso_center_ratio = 0.5 # Middle of body
|
||||
torso_center_height = body_bottom + body_height * torso_center_ratio
|
||||
|
||||
# Target position: center horizontally on body, at torso height vertically
|
||||
target_center = body_center.copy()
|
||||
target_center[body_up_axis_idx] = torso_center_height
|
||||
|
||||
# Current brace center after transformations
|
||||
current_center = transformed_mesh.centroid
|
||||
|
||||
T_position = np.eye(4)
|
||||
T_position[:3, 3] = target_center - current_center
|
||||
transformed_mesh.apply_transform(T_position)
|
||||
transform = T_position @ transform
|
||||
|
||||
# Step 7: Iron brace to conform to body surface (eliminate gaps and humps)
|
||||
# Transform markers so we can exclude correction zones from ironing
|
||||
transformed_markers = None
|
||||
if markers:
|
||||
transformed_markers = transform_markers(markers, transform)
|
||||
|
||||
ironing_info = {}
|
||||
if enable_ironing:
|
||||
try:
|
||||
print(f"Starting brace ironing to body surface...")
|
||||
pre_iron_extents = transformed_mesh.extents.copy()
|
||||
|
||||
transformed_mesh = iron_brace_to_body(
|
||||
brace_mesh=transformed_mesh,
|
||||
body_mesh=body_mesh,
|
||||
min_clearance_mm=clearance_mm * 0.4, # Allow closer for tight fit
|
||||
max_clearance_mm=clearance_mm * 1.5, # Iron areas with gaps > 1.5x clearance
|
||||
smoothing_iterations=3,
|
||||
up_axis=body_up_axis_idx,
|
||||
markers=transformed_markers
|
||||
)
|
||||
|
||||
post_iron_extents = transformed_mesh.extents
|
||||
ironing_info = {
|
||||
"enabled": True,
|
||||
"pre_iron_extents": pre_iron_extents.tolist(),
|
||||
"post_iron_extents": post_iron_extents.tolist(),
|
||||
"min_clearance_mm": clearance_mm * 0.5,
|
||||
"max_clearance_mm": clearance_mm * 2.0,
|
||||
}
|
||||
print(f"Ironing complete. Extents changed from {pre_iron_extents} to {post_iron_extents}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Ironing failed (non-critical): {e}")
|
||||
ironing_info = {"enabled": False, "error": str(e)}
|
||||
else:
|
||||
ironing_info = {"enabled": False}
|
||||
|
||||
fitting_info = {
|
||||
"scale_avg": float(scale),
|
||||
"scale_x": float(S[0, 0]),
|
||||
"scale_y": float(S[1, 1]),
|
||||
"scale_z": float(S[2, 2]),
|
||||
"template_type": template_type,
|
||||
"body_extents": body_extents.tolist(),
|
||||
"brace_extents_original": brace_extents.tolist(),
|
||||
"brace_extents_final": transformed_mesh.extents.tolist(),
|
||||
"clearance_mm": clearance_mm,
|
||||
"body_center": body_center.tolist(),
|
||||
"final_center": transformed_mesh.centroid.tolist(),
|
||||
"body_up_axis": int(body_up_axis_idx),
|
||||
"ironing": ironing_info,
|
||||
}
|
||||
|
||||
return transformed_mesh, transform, fitting_info
|
||||
|
||||
|
||||
def generate_both_brace_types(
|
||||
rigo_type: str,
|
||||
output_dir: Path,
|
||||
case_id: str,
|
||||
cobb_angles: Dict[str, float],
|
||||
body_scan_path: Optional[str] = None,
|
||||
clearance_mm: float = 8.0
|
||||
) -> Dict[str, BraceGenerationResult]:
|
||||
"""
|
||||
Generate both regular and vase brace types for comparison.
|
||||
|
||||
Returns:
|
||||
Dict with "regular" and "vase" results
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Generate regular brace
|
||||
try:
|
||||
results["regular"] = generate_glb_brace(
|
||||
rigo_type=rigo_type,
|
||||
template_type="regular",
|
||||
output_dir=output_dir,
|
||||
case_id=case_id,
|
||||
cobb_angles=cobb_angles,
|
||||
body_scan_path=body_scan_path,
|
||||
clearance_mm=clearance_mm
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
results["regular"] = {"error": str(e)}
|
||||
|
||||
# Generate vase brace
|
||||
try:
|
||||
results["vase"] = generate_glb_brace(
|
||||
rigo_type=rigo_type,
|
||||
template_type="vase",
|
||||
output_dir=output_dir,
|
||||
case_id=case_id,
|
||||
cobb_angles=cobb_angles,
|
||||
body_scan_path=body_scan_path,
|
||||
clearance_mm=clearance_mm
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
results["vase"] = {"error": str(e)}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Available templates
|
||||
AVAILABLE_RIGO_TYPES = ["A1", "A2", "A3", "B1", "B2", "C1", "C2", "E1", "E2"]
|
||||
|
||||
def list_available_templates() -> Dict[str, list]:
|
||||
"""List all available template files."""
|
||||
regular = []
|
||||
vase = []
|
||||
|
||||
for rigo_type in AVAILABLE_RIGO_TYPES:
|
||||
glb_path, _ = get_template_paths(rigo_type, "regular")
|
||||
if glb_path.exists():
|
||||
regular.append(rigo_type)
|
||||
|
||||
glb_path, _ = get_template_paths(rigo_type, "vase")
|
||||
if glb_path.exists():
|
||||
vase.append(rigo_type)
|
||||
|
||||
return {
|
||||
"regular": regular,
|
||||
"vase": vase
|
||||
}
|
||||
26
brace-generator/requirements.txt
Normal file
26
brace-generator/requirements.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
# Server dependencies
|
||||
fastapi>=0.100.0
|
||||
uvicorn[standard]>=0.22.0
|
||||
python-multipart>=0.0.6
|
||||
pydantic>=2.0.0
|
||||
requests>=2.28.0
|
||||
|
||||
# AWS SDK
|
||||
boto3>=1.26.0
|
||||
|
||||
# Core ML dependencies
|
||||
torch>=2.0.0
|
||||
torchvision>=0.15.0
|
||||
|
||||
# Image processing
|
||||
numpy>=1.20.0
|
||||
scipy>=1.7.0
|
||||
pillow>=8.0.0
|
||||
opencv-python-headless>=4.5.0
|
||||
pydicom>=2.2.0
|
||||
|
||||
# 3D mesh processing
|
||||
trimesh>=3.10.0
|
||||
|
||||
# Visualization
|
||||
matplotlib>=3.4.0
|
||||
990
brace-generator/routes.py
Normal file
990
brace-generator/routes.py
Normal file
@@ -0,0 +1,990 @@
|
||||
"""
|
||||
API routes for Brace Generator.
|
||||
|
||||
Note: S3 operations are handled by the Lambda function.
|
||||
This server only handles ML inference and returns local file paths.
|
||||
"""
|
||||
import torch
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from typing import Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from .schemas import (
|
||||
AnalysisResult, HealthResponse, ExperimentType, BraceConfigRequest
|
||||
)
|
||||
from .config import config
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", summary="Root endpoint")
|
||||
async def root():
|
||||
"""Welcome endpoint."""
|
||||
return {
|
||||
"service": "Brace Generator API",
|
||||
"version": "1.0.0",
|
||||
"docs": "/docs",
|
||||
"health": "/health"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse, summary="Health check")
|
||||
async def health_check():
|
||||
"""Check server health and GPU status."""
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
gpu_name = None
|
||||
gpu_memory_mb = None
|
||||
if cuda_available:
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
gpu_memory_mb = int(torch.cuda.get_device_properties(0).total_memory / (1024**2))
|
||||
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
device=config.get_device(),
|
||||
cuda_available=cuda_available,
|
||||
model_loaded=True,
|
||||
gpu_name=gpu_name,
|
||||
gpu_memory_mb=gpu_memory_mb
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze/upload", response_model=AnalysisResult, summary="Analyze uploaded X-ray")
|
||||
async def analyze_upload(
|
||||
req: Request,
|
||||
file: UploadFile = File(..., description="X-ray image file"),
|
||||
case_id: Optional[str] = Form(None, description="Case ID"),
|
||||
experiment: str = Form("experiment_3", description="Experiment type"),
|
||||
config_json: Optional[str] = Form(None, description="Brace config as JSON"),
|
||||
landmarks_json: Optional[str] = Form(None, description="Pre-computed landmarks with manual edits")
|
||||
):
|
||||
"""
|
||||
Analyze an uploaded X-ray image and generate brace.
|
||||
|
||||
This endpoint accepts multipart/form-data for direct file upload.
|
||||
Returns analysis results with local file paths that can be downloaded
|
||||
via the /download endpoint.
|
||||
|
||||
If landmarks_json is provided, it will use those landmarks (with manual edits)
|
||||
instead of re-running automatic detection. This allows manual corrections
|
||||
to be incorporated into the brace generation.
|
||||
|
||||
The Lambda function is responsible for:
|
||||
1. Downloading the X-ray from S3
|
||||
2. Calling this endpoint
|
||||
3. Downloading output files via /download
|
||||
4. Uploading files to S3
|
||||
"""
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
# Check file size
|
||||
contents = await file.read()
|
||||
if len(contents) > config.MAX_IMAGE_SIZE_MB * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File too large. Maximum size is {config.MAX_IMAGE_SIZE_MB}MB"
|
||||
)
|
||||
|
||||
# Parse config if provided
|
||||
brace_config = None
|
||||
if config_json:
|
||||
try:
|
||||
config_data = json.loads(config_json)
|
||||
brace_config = BraceConfigRequest(**config_data)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid config: {e}")
|
||||
|
||||
# Parse landmarks if provided (manual edits)
|
||||
landmarks_data = None
|
||||
if landmarks_json:
|
||||
try:
|
||||
landmarks_data = json.loads(landmarks_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid landmarks JSON: {e}")
|
||||
|
||||
# Parse experiment type
|
||||
try:
|
||||
exp_type = ExperimentType(experiment)
|
||||
except ValueError:
|
||||
exp_type = ExperimentType.EXPERIMENT_3
|
||||
|
||||
service = req.app.state.brace_service
|
||||
|
||||
try:
|
||||
result = await service.analyze_from_bytes(
|
||||
image_data=contents,
|
||||
filename=file.filename,
|
||||
experiment=exp_type,
|
||||
case_id=case_id,
|
||||
brace_config=brace_config,
|
||||
landmarks_data=landmarks_data # Pass pre-computed landmarks
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/download/{case_id}/{filename}", summary="Download output file")
|
||||
async def download_file(case_id: str, filename: str):
|
||||
"""
|
||||
Download a generated output file.
|
||||
|
||||
This endpoint is called by the Lambda function to retrieve
|
||||
generated files (STL, PLY, PNG, JSON) for upload to S3.
|
||||
"""
|
||||
file_path = config.TEMP_DIR / case_id / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
# Determine media type
|
||||
ext = file_path.suffix.lower()
|
||||
media_types = {
|
||||
".stl": "application/octet-stream",
|
||||
".ply": "application/octet-stream",
|
||||
".obj": "application/octet-stream",
|
||||
".glb": "model/gltf-binary",
|
||||
".gltf": "model/gltf+json",
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".json": "application/json",
|
||||
}
|
||||
media_type = media_types.get(ext, "application/octet-stream")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=filename,
|
||||
media_type=media_type
|
||||
)
|
||||
|
||||
|
||||
@router.post("/extract-body-measurements", summary="Extract body measurements from 3D scan")
|
||||
async def extract_body_measurements(
|
||||
file: UploadFile = File(..., description="3D body scan file (STL/OBJ/PLY)")
|
||||
):
|
||||
"""
|
||||
Extract body measurements from a 3D body scan.
|
||||
|
||||
Returns measurements needed for brace fitting:
|
||||
- Total height
|
||||
- Shoulder, chest, waist, hip widths and depths
|
||||
- Circumferences
|
||||
- Brace coverage region
|
||||
"""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from server_DEV.body_integration import extract_measurements_from_scan
|
||||
except ImportError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Body integration module not available: {e}")
|
||||
|
||||
# Validate file type
|
||||
allowed_extensions = ['.stl', '.obj', '.ply', '.glb', '.gltf']
|
||||
ext = Path(file.filename).suffix.lower() if file.filename else '.stl'
|
||||
if ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}"
|
||||
)
|
||||
|
||||
# Save to temp file
|
||||
contents = await file.read()
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f:
|
||||
f.write(contents)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
measurements = extract_measurements_from_scan(temp_path)
|
||||
return measurements
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
# Cleanup
|
||||
Path(temp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@router.post("/generate-with-body", summary="Generate brace with body scan fitting")
|
||||
async def generate_with_body_scan(
|
||||
req: Request,
|
||||
xray_file: UploadFile = File(..., description="X-ray image"),
|
||||
body_scan_file: UploadFile = File(..., description="3D body scan (STL/OBJ/PLY)"),
|
||||
case_id: Optional[str] = Form(None, description="Case ID"),
|
||||
landmarks_json: Optional[str] = Form(None, description="Pre-computed landmarks"),
|
||||
clearance_mm: float = Form(8.0, description="Shell clearance in mm"),
|
||||
):
|
||||
"""
|
||||
Generate a patient-specific brace using X-ray analysis and 3D body scan.
|
||||
|
||||
This endpoint:
|
||||
1. Analyzes X-ray to detect spine landmarks and compute Cobb angles
|
||||
2. Classifies curve type using Rigo-Cheneau system
|
||||
3. Fits a shell template to the 3D body scan
|
||||
4. Returns STL, GLB, and visualization files
|
||||
"""
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from server_DEV.body_integration import generate_fitted_brace, extract_measurements_from_scan
|
||||
except ImportError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Body integration module not available: {e}")
|
||||
|
||||
# Generate case ID if not provided
|
||||
case_id = case_id or f"case_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Save files to temp directory
|
||||
temp_dir = config.TEMP_DIR / case_id
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save X-ray
|
||||
xray_contents = await xray_file.read()
|
||||
xray_ext = Path(xray_file.filename).suffix if xray_file.filename else '.jpg'
|
||||
xray_path = temp_dir / f"xray{xray_ext}"
|
||||
xray_path.write_bytes(xray_contents)
|
||||
|
||||
# Save body scan
|
||||
body_contents = await body_scan_file.read()
|
||||
body_ext = Path(body_scan_file.filename).suffix if body_scan_file.filename else '.stl'
|
||||
body_scan_path = temp_dir / f"body_scan{body_ext}"
|
||||
body_scan_path.write_bytes(body_contents)
|
||||
|
||||
try:
|
||||
# Parse landmarks if provided
|
||||
landmarks_data = None
|
||||
if landmarks_json:
|
||||
import json
|
||||
landmarks_data = json.loads(landmarks_json)
|
||||
|
||||
# Step 1: Analyze X-ray to get Rigo classification (this generates the brace)
|
||||
service = req.app.state.brace_service
|
||||
|
||||
xray_result = await service.analyze_from_bytes(
|
||||
image_data=xray_contents,
|
||||
filename=xray_file.filename,
|
||||
experiment=ExperimentType.EXPERIMENT_3,
|
||||
case_id=case_id,
|
||||
landmarks_data=landmarks_data
|
||||
)
|
||||
|
||||
rigo_type = xray_result.rigo_classification.type if xray_result.rigo_classification else "A1"
|
||||
|
||||
# Step 2: Try to extract body measurements (optional - EXPERIMENT_10 may not be deployed)
|
||||
body_measurements = None
|
||||
fitting_result = None
|
||||
body_scan_error = None
|
||||
|
||||
try:
|
||||
body_measurements = extract_measurements_from_scan(str(body_scan_path))
|
||||
|
||||
# Step 3: Generate fitted brace (only if measurements worked)
|
||||
fitting_result = generate_fitted_brace(
|
||||
body_scan_path=str(body_scan_path),
|
||||
rigo_type=rigo_type,
|
||||
output_dir=str(temp_dir),
|
||||
case_id=case_id,
|
||||
clearance_mm=clearance_mm
|
||||
)
|
||||
except Exception as body_err:
|
||||
print(f"Warning: Body scan processing failed, using X-ray only: {body_err}")
|
||||
body_scan_error = str(body_err)
|
||||
|
||||
# If body fitting worked, return full result
|
||||
if fitting_result:
|
||||
return {
|
||||
"case_id": case_id,
|
||||
"experiment": "experiment_10",
|
||||
"model_used": xray_result.model_used,
|
||||
"vertebrae_detected": xray_result.vertebrae_detected,
|
||||
"cobb_angles": {
|
||||
"PT": xray_result.cobb_angles.PT,
|
||||
"MT": xray_result.cobb_angles.MT,
|
||||
"TL": xray_result.cobb_angles.TL,
|
||||
},
|
||||
"curve_type": xray_result.curve_type,
|
||||
"rigo_classification": {
|
||||
"type": rigo_type,
|
||||
"description": xray_result.rigo_classification.description if xray_result.rigo_classification else ""
|
||||
},
|
||||
"body_scan": {
|
||||
"measurements": body_measurements,
|
||||
},
|
||||
"brace_fitting": fitting_result,
|
||||
"outputs": {
|
||||
"shell_stl": fitting_result["outputs"]["shell_stl"],
|
||||
"shell_glb": fitting_result["outputs"]["shell_glb"],
|
||||
"combined_stl": fitting_result["outputs"]["combined_stl"],
|
||||
"visualization": fitting_result["outputs"].get("visualization"),
|
||||
"feedback_json": fitting_result["outputs"]["feedback_json"],
|
||||
"xray_visualization": str(xray_result.outputs.get("visualization", "")),
|
||||
},
|
||||
"mesh_vertices": fitting_result["mesh_stats"]["vertices"],
|
||||
"mesh_faces": fitting_result["mesh_stats"]["faces"],
|
||||
"processing_time_ms": xray_result.processing_time_ms,
|
||||
}
|
||||
|
||||
# Fallback: return X-ray only result (body scan processing not available)
|
||||
return {
|
||||
"case_id": case_id,
|
||||
"experiment": "experiment_3_fallback",
|
||||
"model_used": xray_result.model_used,
|
||||
"vertebrae_detected": xray_result.vertebrae_detected,
|
||||
"cobb_angles": {
|
||||
"PT": xray_result.cobb_angles.PT,
|
||||
"MT": xray_result.cobb_angles.MT,
|
||||
"TL": xray_result.cobb_angles.TL,
|
||||
},
|
||||
"curve_type": xray_result.curve_type,
|
||||
"rigo_classification": {
|
||||
"type": rigo_type,
|
||||
"description": xray_result.rigo_classification.description if xray_result.rigo_classification else ""
|
||||
},
|
||||
"body_scan": {
|
||||
"error": body_scan_error or "Body scan processing not available",
|
||||
"fallback": "Using X-ray only brace generation"
|
||||
},
|
||||
"outputs": xray_result.outputs,
|
||||
"mesh_vertices": xray_result.mesh_vertices,
|
||||
"mesh_faces": xray_result.mesh_faces,
|
||||
"processing_time_ms": xray_result.processing_time_ms,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments", summary="List available experiments")
|
||||
async def list_experiments():
|
||||
"""List available brace generation experiments."""
|
||||
return {
|
||||
"experiments": [
|
||||
{
|
||||
"id": "standard",
|
||||
"name": "Standard Pipeline",
|
||||
"description": "Original template-based brace generation using Rigo classification"
|
||||
},
|
||||
{
|
||||
"id": "experiment_3",
|
||||
"name": "Research-Based Adaptive",
|
||||
"description": "Adaptive brace generation based on Guy et al. (2024) with patch-based deformation optimization"
|
||||
},
|
||||
{
|
||||
"id": "experiment_10",
|
||||
"name": "Patient-Specific Body Fitting",
|
||||
"description": "X-ray analysis + 3D body scan for precise patient-specific brace fitting"
|
||||
}
|
||||
],
|
||||
"default": "experiment_3"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models", summary="List available detection models")
|
||||
async def list_models():
|
||||
"""List available landmark detection models."""
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"id": "scoliovis",
|
||||
"name": "ScolioVis",
|
||||
"description": "Keypoint R-CNN model for vertebrae detection",
|
||||
"supports_gpu": True
|
||||
},
|
||||
{
|
||||
"id": "vertebra-landmark",
|
||||
"name": "Vertebra-Landmark-Detection",
|
||||
"description": "SpineNet-based detection (alternative)",
|
||||
"supports_gpu": True
|
||||
}
|
||||
],
|
||||
"current": config.MODEL
|
||||
}
|
||||
|
||||
|
||||
# ============================================
|
||||
# NEW ENDPOINTS FOR PIPELINE DEV
|
||||
# ============================================
|
||||
|
||||
@router.post("/detect-landmarks", summary="Detect landmarks only (Stage 1)")
|
||||
async def detect_landmarks(
|
||||
req: Request,
|
||||
file: UploadFile = File(..., description="X-ray image file"),
|
||||
case_id: Optional[str] = Form(None, description="Case ID"),
|
||||
):
|
||||
"""
|
||||
Detect vertebrae landmarks without generating a brace.
|
||||
Returns landmarks, visualization, and vertebrae_structure for manual editing.
|
||||
|
||||
This is Stage 1 of the pipeline - just detection, no brace generation.
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
contents = await file.read()
|
||||
if len(contents) > config.MAX_IMAGE_SIZE_MB * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File too large. Maximum size is {config.MAX_IMAGE_SIZE_MB}MB"
|
||||
)
|
||||
|
||||
service = req.app.state.brace_service
|
||||
|
||||
try:
|
||||
result = await service.detect_landmarks_only(
|
||||
image_data=contents,
|
||||
filename=file.filename,
|
||||
case_id=case_id
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/recalculate", summary="Recalculate Cobb/Rigo from landmarks")
|
||||
async def recalculate_analysis(req: Request):
|
||||
"""
|
||||
Recalculate Cobb angles and Rigo classification from provided landmarks.
|
||||
|
||||
Use this after manual landmark editing to get updated analysis.
|
||||
|
||||
Request body:
|
||||
{
|
||||
"case_id": "case-xxx",
|
||||
"landmarks": { ... vertebrae_structure from detect-landmarks ... }
|
||||
}
|
||||
"""
|
||||
body = await req.json()
|
||||
case_id = body.get("case_id")
|
||||
landmarks = body.get("landmarks")
|
||||
|
||||
if not landmarks:
|
||||
raise HTTPException(status_code=400, detail="landmarks data required")
|
||||
|
||||
service = req.app.state.brace_service
|
||||
|
||||
try:
|
||||
result = await service.recalculate_from_landmarks(
|
||||
landmarks_data=landmarks,
|
||||
case_id=case_id
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GLB BRACE GENERATION WITH MARKERS
|
||||
# =============================================================================
|
||||
|
||||
from .glb_generator import (
|
||||
generate_glb_brace,
|
||||
generate_both_brace_types,
|
||||
list_available_templates,
|
||||
calculate_pressure_zones,
|
||||
load_template_markers,
|
||||
AVAILABLE_RIGO_TYPES
|
||||
)
|
||||
|
||||
|
||||
@router.get("/templates", summary="List available brace templates")
|
||||
async def list_templates():
|
||||
"""
|
||||
List all available brace templates (regular and vase types).
|
||||
|
||||
Returns which Rigo types have templates available.
|
||||
"""
|
||||
return {
|
||||
"available_templates": list_available_templates(),
|
||||
"rigo_types": AVAILABLE_RIGO_TYPES,
|
||||
"template_types": ["regular", "vase"]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/templates/{rigo_type}/markers", summary="Get template markers")
|
||||
async def get_template_markers(
|
||||
rigo_type: str,
|
||||
template_type: str = "regular"
|
||||
):
|
||||
"""
|
||||
Get marker positions for a specific template.
|
||||
|
||||
Args:
|
||||
rigo_type: Rigo classification (A1, A2, A3, B1, B2, C1, C2, E1, E2)
|
||||
template_type: "regular" or "vase"
|
||||
|
||||
Returns:
|
||||
Marker positions and basis vectors
|
||||
"""
|
||||
if rigo_type not in AVAILABLE_RIGO_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid rigo_type. Must be one of: {AVAILABLE_RIGO_TYPES}"
|
||||
)
|
||||
|
||||
if template_type not in ["regular", "vase"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="template_type must be 'regular' or 'vase'"
|
||||
)
|
||||
|
||||
try:
|
||||
markers = load_template_markers(rigo_type, template_type)
|
||||
return {
|
||||
"rigo_type": rigo_type,
|
||||
"template_type": template_type,
|
||||
**markers
|
||||
}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/generate-glb", summary="Generate GLB brace with markers")
|
||||
async def generate_glb_endpoint(
|
||||
req: Request,
|
||||
rigo_type: str = Form(..., description="Rigo classification (A1-E2)"),
|
||||
template_type: str = Form("regular", description="Template type: 'regular' or 'vase'"),
|
||||
case_id: str = Form(..., description="Case identifier"),
|
||||
cobb_pt: float = Form(0.0, description="Proximal Thoracic Cobb angle"),
|
||||
cobb_mt: float = Form(0.0, description="Main Thoracic Cobb angle"),
|
||||
cobb_tl: float = Form(0.0, description="Thoracolumbar Cobb angle"),
|
||||
body_scan: Optional[UploadFile] = File(None, description="Optional 3D body scan STL")
|
||||
):
|
||||
"""
|
||||
Generate a GLB brace with embedded markers.
|
||||
|
||||
This endpoint generates a brace file that includes marker positions
|
||||
for later editing. Optionally fits to a body scan.
|
||||
|
||||
**Pressure Zones in Output:**
|
||||
- LM_PAD_TH: Thoracic pad (pushes INWARD on curve convex side)
|
||||
- LM_BAY_TH: Thoracic bay (creates SPACE on curve concave side)
|
||||
- LM_PAD_LUM: Lumbar pad (pushes INWARD)
|
||||
- LM_BAY_LUM: Lumbar bay (creates SPACE)
|
||||
- LM_ANCHOR_HIP_L/R: Hip anchors (stabilize brace)
|
||||
|
||||
Returns:
|
||||
GLB and STL file paths, marker positions, pressure zone info
|
||||
"""
|
||||
if rigo_type not in AVAILABLE_RIGO_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid rigo_type. Must be one of: {AVAILABLE_RIGO_TYPES}"
|
||||
)
|
||||
|
||||
if template_type not in ["regular", "vase"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="template_type must be 'regular' or 'vase'"
|
||||
)
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
output_dir = Path(tempfile.gettempdir()) / "brace_generator" / case_id
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
body_scan_path = None
|
||||
|
||||
# Save body scan if provided
|
||||
if body_scan:
|
||||
body_ext = Path(body_scan.filename).suffix if body_scan.filename else ".stl"
|
||||
body_scan_path = str(output_dir / f"body_scan{body_ext}")
|
||||
with open(body_scan_path, "wb") as f:
|
||||
content = await body_scan.read()
|
||||
f.write(content)
|
||||
|
||||
cobb_angles = {
|
||||
"PT": cobb_pt,
|
||||
"MT": cobb_mt,
|
||||
"TL": cobb_tl
|
||||
}
|
||||
|
||||
try:
|
||||
result = generate_glb_brace(
|
||||
rigo_type=rigo_type,
|
||||
template_type=template_type,
|
||||
output_dir=output_dir,
|
||||
case_id=case_id,
|
||||
cobb_angles=cobb_angles,
|
||||
body_scan_path=body_scan_path,
|
||||
clearance_mm=8.0
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"case_id": case_id,
|
||||
"rigo_type": rigo_type,
|
||||
"template_type": template_type,
|
||||
"outputs": {
|
||||
"glb": result.glb_path,
|
||||
"stl": result.stl_path,
|
||||
"json": result.json_path
|
||||
},
|
||||
"markers": result.markers,
|
||||
"basis": result.basis,
|
||||
"pressure_zones": result.pressure_zones,
|
||||
"mesh_stats": result.mesh_stats,
|
||||
"body_fitting": result.transform_applied
|
||||
}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/generate-both-braces", summary="Generate both brace types for comparison")
|
||||
async def generate_both_braces_endpoint(
|
||||
req: Request,
|
||||
rigo_type: str = Form(..., description="Rigo classification (A1-E2)"),
|
||||
case_id: str = Form(..., description="Case identifier"),
|
||||
cobb_pt: float = Form(0.0, description="Proximal Thoracic Cobb angle"),
|
||||
cobb_mt: float = Form(0.0, description="Main Thoracic Cobb angle"),
|
||||
cobb_tl: float = Form(0.0, description="Thoracolumbar Cobb angle"),
|
||||
body_scan: Optional[UploadFile] = File(None, description="Optional 3D body scan STL"),
|
||||
body_scan_path: Optional[str] = Form(None, description="Optional path to existing body scan file"),
|
||||
clearance_mm: float = Form(8.0, description="Brace clearance from body in mm")
|
||||
):
|
||||
"""
|
||||
Generate BOTH regular and vase brace types for side-by-side comparison.
|
||||
|
||||
This allows the user to compare the two brace shapes and choose
|
||||
the preferred design.
|
||||
|
||||
Returns:
|
||||
Both brace files with markers and pressure zones
|
||||
"""
|
||||
if rigo_type not in AVAILABLE_RIGO_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid rigo_type. Must be one of: {AVAILABLE_RIGO_TYPES}"
|
||||
)
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
output_dir = Path(tempfile.gettempdir()) / "brace_generator" / case_id
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
final_body_scan_path = None
|
||||
|
||||
# Save body scan if uploaded as file
|
||||
if body_scan:
|
||||
body_ext = Path(body_scan.filename).suffix if body_scan.filename else ".stl"
|
||||
final_body_scan_path = str(output_dir / f"body_scan{body_ext}")
|
||||
with open(final_body_scan_path, "wb") as f:
|
||||
content = await body_scan.read()
|
||||
f.write(content)
|
||||
# Or use provided path if it exists
|
||||
elif body_scan_path and Path(body_scan_path).exists():
|
||||
final_body_scan_path = body_scan_path
|
||||
print(f"Using existing body scan at: {body_scan_path}")
|
||||
|
||||
cobb_angles = {
|
||||
"PT": cobb_pt,
|
||||
"MT": cobb_mt,
|
||||
"TL": cobb_tl
|
||||
}
|
||||
|
||||
try:
|
||||
results = generate_both_brace_types(
|
||||
rigo_type=rigo_type,
|
||||
output_dir=output_dir,
|
||||
case_id=case_id,
|
||||
cobb_angles=cobb_angles,
|
||||
body_scan_path=final_body_scan_path,
|
||||
clearance_mm=clearance_mm
|
||||
)
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"case_id": case_id,
|
||||
"rigo_type": rigo_type,
|
||||
"cobb_angles": cobb_angles,
|
||||
"body_scan_used": final_body_scan_path is not None,
|
||||
"braces": {}
|
||||
}
|
||||
|
||||
for brace_type, result in results.items():
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
response["braces"][brace_type] = result
|
||||
else:
|
||||
response["braces"][brace_type] = {
|
||||
"outputs": {
|
||||
"glb": result.glb_path,
|
||||
"stl": result.stl_path,
|
||||
"json": result.json_path
|
||||
},
|
||||
"markers": result.markers,
|
||||
"pressure_zones": result.pressure_zones,
|
||||
"mesh_stats": result.mesh_stats
|
||||
}
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/pressure-zones/{rigo_type}", summary="Get pressure zone information")
|
||||
async def get_pressure_zones(
|
||||
rigo_type: str,
|
||||
template_type: str = "regular",
|
||||
cobb_mt: float = 25.0,
|
||||
cobb_tl: float = 15.0
|
||||
):
|
||||
"""
|
||||
Get detailed pressure zone information for a Rigo type.
|
||||
|
||||
This explains WHERE and HOW MUCH pressure is applied based on
|
||||
the Cobb angles.
|
||||
|
||||
**Pressure Zone Types:**
|
||||
|
||||
- **PAD (Push Zone)**: Pushes INWARD on the convex side of the curve
|
||||
to apply corrective force. Depth increases with Cobb angle severity.
|
||||
|
||||
- **BAY (Expansion Zone)**: Creates SPACE on the concave side for the
|
||||
body to shift into during correction. Clearance is ~1.3x pad depth.
|
||||
|
||||
- **ANCHOR (Stability Zone)**: Grips the pelvis to prevent the brace
|
||||
from riding up. Light inward pressure.
|
||||
|
||||
Returns:
|
||||
Detailed pressure zone descriptions with depths in mm
|
||||
"""
|
||||
if rigo_type not in AVAILABLE_RIGO_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid rigo_type. Must be one of: {AVAILABLE_RIGO_TYPES}"
|
||||
)
|
||||
|
||||
try:
|
||||
markers = load_template_markers(rigo_type, template_type)
|
||||
zones = calculate_pressure_zones(
|
||||
markers,
|
||||
rigo_type,
|
||||
{"PT": 0, "MT": cobb_mt, "TL": cobb_tl}
|
||||
)
|
||||
|
||||
return {
|
||||
"rigo_type": rigo_type,
|
||||
"template_type": template_type,
|
||||
"cobb_angles": {"MT": cobb_mt, "TL": cobb_tl},
|
||||
"pressure_zones": [
|
||||
{
|
||||
"name": z.name,
|
||||
"marker": z.marker_name,
|
||||
"position": list(z.position),
|
||||
"type": z.zone_type,
|
||||
"direction": z.direction,
|
||||
"function": z.function,
|
||||
"depth_mm": round(z.depth_mm, 1),
|
||||
"radius_mm": list(z.radius_mm)
|
||||
}
|
||||
for z in zones
|
||||
],
|
||||
"explanation": {
|
||||
"pad_depth": f"Based on Cobb angle severity: {cobb_mt}° MT → {round(8 + min(max((cobb_mt - 10) / 40, 0), 1) * 14, 1)}mm thoracic pad",
|
||||
"bay_clearance": "Bay clearance = 1.3 × pad depth + 4-5mm to allow body movement",
|
||||
"hip_anchors": "4mm inward pressure to grip pelvis and stabilize brace"
|
||||
}
|
||||
}
|
||||
except FileNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DEV MODE: LOCAL FILE STORAGE AND SERVING
|
||||
# =============================================================================
|
||||
|
||||
# Local storage directory for DEV mode
|
||||
DEV_STORAGE_DIR = config.TEMP_DIR / "dev_storage"
|
||||
DEV_STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@router.post("/cases", summary="Create a new case (DEV)")
|
||||
async def create_case():
|
||||
"""Create a new case with a generated ID (DEV mode)."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
case_id = f"case-{datetime.now().strftime('%Y%m%d')}-{uuid.uuid4().hex[:8]}"
|
||||
case_dir = DEV_STORAGE_DIR / case_id
|
||||
(case_dir / "uploads").mkdir(parents=True, exist_ok=True)
|
||||
(case_dir / "outputs").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save case metadata
|
||||
metadata = {
|
||||
"case_id": case_id,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"status": "created"
|
||||
}
|
||||
(case_dir / "case.json").write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
return {"caseId": case_id, "status": "created"}
|
||||
|
||||
|
||||
@router.get("/cases/{case_id}", summary="Get case details (DEV)")
|
||||
async def get_case(case_id: str):
|
||||
"""Get case details (DEV mode)."""
|
||||
case_dir = DEV_STORAGE_DIR / case_id
|
||||
|
||||
if not case_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Case not found: {case_id}")
|
||||
|
||||
metadata_file = case_dir / "case.json"
|
||||
if metadata_file.exists():
|
||||
metadata = json.loads(metadata_file.read_text())
|
||||
else:
|
||||
metadata = {"case_id": case_id, "status": "unknown"}
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
@router.post("/cases/{case_id}/upload", summary="Upload X-ray for case (DEV)")
|
||||
async def upload_xray(
|
||||
case_id: str,
|
||||
file: UploadFile = File(..., description="X-ray image file")
|
||||
):
|
||||
"""Upload X-ray image for a case (DEV mode - saves locally)."""
|
||||
case_dir = DEV_STORAGE_DIR / case_id
|
||||
uploads_dir = case_dir / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Determine extension from filename
|
||||
ext = Path(file.filename).suffix.lower() if file.filename else ".jpg"
|
||||
if ext not in [".jpg", ".jpeg", ".png", ".webp"]:
|
||||
ext = ".jpg"
|
||||
|
||||
# Save as xray.{ext}
|
||||
xray_path = uploads_dir / f"xray{ext}"
|
||||
contents = await file.read()
|
||||
xray_path.write_bytes(contents)
|
||||
|
||||
# Update case metadata
|
||||
metadata_file = case_dir / "case.json"
|
||||
if metadata_file.exists():
|
||||
metadata = json.loads(metadata_file.read_text())
|
||||
else:
|
||||
metadata = {"case_id": case_id}
|
||||
|
||||
metadata["xray_uploaded"] = True
|
||||
metadata["xray_filename"] = f"xray{ext}"
|
||||
metadata_file.write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
return {
|
||||
"filename": f"xray{ext}",
|
||||
"path": f"/files/uploads/{case_id}/xray{ext}"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/cases/{case_id}/assets", summary="Get case assets (DEV)")
|
||||
async def get_case_assets(case_id: str):
|
||||
"""List all uploaded and output files for a case (DEV mode)."""
|
||||
case_dir = DEV_STORAGE_DIR / case_id
|
||||
|
||||
if not case_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Case not found: {case_id}")
|
||||
|
||||
uploads = []
|
||||
outputs = []
|
||||
|
||||
# List uploads
|
||||
uploads_dir = case_dir / "uploads"
|
||||
if uploads_dir.exists():
|
||||
for f in uploads_dir.iterdir():
|
||||
if f.is_file():
|
||||
uploads.append({
|
||||
"filename": f.name,
|
||||
"url": f"/files/uploads/{case_id}/{f.name}"
|
||||
})
|
||||
|
||||
# List outputs
|
||||
outputs_dir = case_dir / "outputs"
|
||||
if outputs_dir.exists():
|
||||
for f in outputs_dir.iterdir():
|
||||
if f.is_file():
|
||||
outputs.append({
|
||||
"filename": f.name,
|
||||
"url": f"/files/outputs/{case_id}/{f.name}"
|
||||
})
|
||||
|
||||
return {
|
||||
"caseId": case_id,
|
||||
"assets": {
|
||||
"uploads": uploads,
|
||||
"outputs": outputs
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/files/uploads/{case_id}/{filename}", summary="Serve uploaded file (DEV)")
|
||||
async def serve_upload_file(case_id: str, filename: str):
|
||||
"""Serve an uploaded file (DEV mode)."""
|
||||
file_path = DEV_STORAGE_DIR / case_id / "uploads" / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
# Determine media type
|
||||
ext = file_path.suffix.lower()
|
||||
media_types = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
".stl": "application/octet-stream",
|
||||
".glb": "model/gltf-binary",
|
||||
".json": "application/json",
|
||||
}
|
||||
media_type = media_types.get(ext, "application/octet-stream")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=filename,
|
||||
media_type=media_type
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/outputs/{case_id}/{filename}", summary="Serve output file (DEV)")
|
||||
async def serve_output_file(case_id: str, filename: str):
|
||||
"""Serve an output file (DEV mode)."""
|
||||
file_path = DEV_STORAGE_DIR / case_id / "outputs" / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
||||
|
||||
# Determine media type
|
||||
ext = file_path.suffix.lower()
|
||||
media_types = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
".stl": "application/octet-stream",
|
||||
".ply": "application/octet-stream",
|
||||
".obj": "application/octet-stream",
|
||||
".glb": "model/gltf-binary",
|
||||
".gltf": "model/gltf+json",
|
||||
".json": "application/json",
|
||||
}
|
||||
media_type = media_types.get(ext, "application/octet-stream")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
filename=filename,
|
||||
media_type=media_type
|
||||
)
|
||||
125
brace-generator/schemas.py
Normal file
125
brace-generator/schemas.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Pydantic schemas for API request/response validation.
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ExperimentType(str, Enum):
|
||||
"""Available brace generation experiments."""
|
||||
STANDARD = "standard" # Original pipeline
|
||||
EXPERIMENT_3 = "experiment_3" # Research-based adaptive
|
||||
|
||||
|
||||
class BraceConfigRequest(BaseModel):
|
||||
"""Brace configuration parameters."""
|
||||
brace_height_mm: float = Field(default=400.0, ge=200, le=600)
|
||||
torso_width_mm: float = Field(default=280.0, ge=150, le=400)
|
||||
torso_depth_mm: float = Field(default=200.0, ge=100, le=350)
|
||||
wall_thickness_mm: float = Field(default=4.0, ge=2, le=10)
|
||||
pressure_strength_mm: float = Field(default=15.0, ge=0, le=30)
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
"""Request to analyze X-ray and generate brace."""
|
||||
s3_key: Optional[str] = Field(None, description="S3 key of uploaded X-ray image")
|
||||
case_id: Optional[str] = Field(None, description="Case ID for organizing outputs")
|
||||
experiment: ExperimentType = Field(default=ExperimentType.EXPERIMENT_3)
|
||||
config: Optional[BraceConfigRequest] = None
|
||||
|
||||
# Output options
|
||||
save_visualization: bool = Field(default=True)
|
||||
save_landmarks: bool = Field(default=True)
|
||||
output_format: str = Field(default="stl", description="stl, ply, or both")
|
||||
|
||||
|
||||
class AnalyzeFromUrlRequest(BaseModel):
|
||||
"""Request with direct image URL."""
|
||||
image_url: str = Field(..., description="URL to download X-ray image from")
|
||||
case_id: Optional[str] = Field(None)
|
||||
experiment: ExperimentType = Field(default=ExperimentType.EXPERIMENT_3)
|
||||
config: Optional[BraceConfigRequest] = None
|
||||
save_visualization: bool = True
|
||||
save_landmarks: bool = True
|
||||
output_format: str = "stl"
|
||||
|
||||
|
||||
class Vertebra(BaseModel):
|
||||
"""Single vertebra data."""
|
||||
level: str
|
||||
centroid_px: List[float]
|
||||
orientation_deg: Optional[float] = None
|
||||
confidence: Optional[float] = None
|
||||
corners_px: Optional[List[List[float]]] = None
|
||||
|
||||
|
||||
class CobbAngles(BaseModel):
|
||||
"""Cobb angle measurements."""
|
||||
PT: float = Field(..., description="Proximal Thoracic angle")
|
||||
MT: float = Field(..., description="Main Thoracic angle")
|
||||
TL: float = Field(..., description="Thoracolumbar angle")
|
||||
|
||||
|
||||
class RigoClassification(BaseModel):
|
||||
"""Rigo-Chêneau classification result."""
|
||||
type: str
|
||||
description: str
|
||||
curve_pattern: Optional[str] = None
|
||||
|
||||
|
||||
class DeformationReport(BaseModel):
|
||||
"""Patch-based deformation report (Experiment 3)."""
|
||||
patch_grid: str
|
||||
deformations: Optional[List[List[float]]] = None
|
||||
zones: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
"""Complete analysis result."""
|
||||
case_id: Optional[str] = None
|
||||
experiment: str
|
||||
|
||||
# Input
|
||||
input_image: str
|
||||
|
||||
# Detection results
|
||||
model_used: str
|
||||
vertebrae_detected: int
|
||||
vertebrae: Optional[List[Vertebra]] = None
|
||||
|
||||
# Measurements
|
||||
cobb_angles: CobbAngles
|
||||
curve_type: str
|
||||
|
||||
# Classification
|
||||
rigo_classification: RigoClassification
|
||||
|
||||
# Brace mesh info
|
||||
mesh_vertices: int
|
||||
mesh_faces: int
|
||||
|
||||
# Deformation (Experiment 3)
|
||||
deformation_report: Optional[DeformationReport] = None
|
||||
|
||||
# Output URLs/paths
|
||||
outputs: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
# Timing
|
||||
processing_time_ms: float
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
status: str
|
||||
device: str
|
||||
cuda_available: bool
|
||||
model_loaded: bool
|
||||
gpu_name: Optional[str] = None
|
||||
gpu_memory_mb: Optional[int] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
error: str
|
||||
detail: Optional[str] = None
|
||||
884
brace-generator/services.py
Normal file
884
brace-generator/services.py
Normal file
@@ -0,0 +1,884 @@
|
||||
"""
|
||||
Business logic service for brace generation.
|
||||
|
||||
This service handles ML inference and file generation.
|
||||
S3 operations are handled by the Lambda function, not here.
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import trimesh
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from io import BytesIO
|
||||
|
||||
from .config import config
|
||||
from .schemas import (
|
||||
AnalyzeRequest, AnalyzeFromUrlRequest, BraceConfigRequest,
|
||||
AnalysisResult, CobbAngles, RigoClassification, Vertebra,
|
||||
DeformationReport, ExperimentType
|
||||
)
|
||||
|
||||
|
||||
class BraceService:
|
||||
"""
|
||||
Service for X-ray analysis and brace generation.
|
||||
|
||||
Handles:
|
||||
- Model loading and inference
|
||||
- Pipeline orchestration
|
||||
- Local file management
|
||||
|
||||
Note: S3 operations are handled by Lambda, not here.
|
||||
"""
|
||||
|
||||
def __init__(self, device: str = "cuda", model: str = "scoliovis"):
|
||||
self.device = device
|
||||
self.model_name = model
|
||||
|
||||
# Initialize pipelines
|
||||
self._init_pipelines()
|
||||
|
||||
def _init_pipelines(self):
|
||||
"""Initialize brace generation pipelines."""
|
||||
from brace_generator.data_models import BraceConfig
|
||||
from brace_generator.pipeline import BracePipeline
|
||||
|
||||
# Standard pipeline
|
||||
self.standard_pipeline = BracePipeline(
|
||||
model=self.model_name,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Experiment 3 pipeline (lazy load)
|
||||
self._exp3_pipeline = None
|
||||
|
||||
def _get_exp3_pipeline(self):
|
||||
"""Return standard pipeline (EXPERIMENT_3 not deployed)."""
|
||||
return self.standard_pipeline
|
||||
|
||||
@property
|
||||
def model_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.standard_pipeline is not None
|
||||
|
||||
async def analyze_from_bytes(
|
||||
self,
|
||||
image_data: bytes,
|
||||
filename: str,
|
||||
experiment: ExperimentType = ExperimentType.EXPERIMENT_3,
|
||||
case_id: Optional[str] = None,
|
||||
brace_config: Optional[BraceConfigRequest] = None,
|
||||
landmarks_data: Optional[Dict[str, Any]] = None
|
||||
) -> AnalysisResult:
|
||||
"""
|
||||
Analyze X-ray from raw bytes.
|
||||
|
||||
If landmarks_data is provided, it will use those landmarks (with manual edits)
|
||||
instead of re-running automatic detection.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Generate case ID if not provided
|
||||
case_id = case_id or str(uuid.uuid4())[:8]
|
||||
|
||||
# Save image to temp file
|
||||
suffix = Path(filename).suffix or ".jpg"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
||||
f.write(image_data)
|
||||
input_path = f.name
|
||||
|
||||
try:
|
||||
# Prepare output directory
|
||||
output_dir = config.TEMP_DIR / case_id
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
output_base = output_dir / f"brace_{case_id}"
|
||||
|
||||
# Select pipeline based on experiment
|
||||
if experiment == ExperimentType.EXPERIMENT_3:
|
||||
result = await self._run_experiment_3(
|
||||
input_path, output_base, brace_config, landmarks_data
|
||||
)
|
||||
else:
|
||||
result = await self._run_standard(input_path, output_base, brace_config)
|
||||
|
||||
# Add timing and case ID
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result.case_id = case_id
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Cleanup temp input
|
||||
Path(input_path).unlink(missing_ok=True)
|
||||
|
||||
async def _run_standard(
|
||||
self,
|
||||
input_path: str,
|
||||
output_base: Path,
|
||||
brace_config: Optional[BraceConfigRequest]
|
||||
) -> AnalysisResult:
|
||||
"""Run standard pipeline."""
|
||||
from brace_generator.data_models import BraceConfig
|
||||
|
||||
# Configure
|
||||
if brace_config:
|
||||
self.standard_pipeline.config = BraceConfig(
|
||||
brace_height_mm=brace_config.brace_height_mm,
|
||||
torso_width_mm=brace_config.torso_width_mm,
|
||||
torso_depth_mm=brace_config.torso_depth_mm,
|
||||
wall_thickness_mm=brace_config.wall_thickness_mm,
|
||||
pressure_strength_mm=brace_config.pressure_strength_mm,
|
||||
)
|
||||
|
||||
# Run pipeline
|
||||
results = self.standard_pipeline.process(
|
||||
input_path,
|
||||
str(output_base) + ".stl",
|
||||
visualize=True,
|
||||
save_landmarks=True
|
||||
)
|
||||
|
||||
# Build response with local file paths
|
||||
outputs = {
|
||||
"stl": str(output_base) + ".stl",
|
||||
}
|
||||
|
||||
vis_path = str(output_base) + ".png"
|
||||
json_path = str(output_base) + ".json"
|
||||
if Path(vis_path).exists():
|
||||
outputs["visualization"] = vis_path
|
||||
if Path(json_path).exists():
|
||||
outputs["landmarks"] = json_path
|
||||
|
||||
return AnalysisResult(
|
||||
experiment="standard",
|
||||
input_image=input_path,
|
||||
model_used=results["model"],
|
||||
vertebrae_detected=results["vertebrae_detected"],
|
||||
cobb_angles=CobbAngles(
|
||||
PT=results["cobb_angles"]["PT"],
|
||||
MT=results["cobb_angles"]["MT"],
|
||||
TL=results["cobb_angles"]["TL"],
|
||||
),
|
||||
curve_type=results["curve_type"],
|
||||
rigo_classification=RigoClassification(
|
||||
type=results["rigo_type"],
|
||||
description=results.get("rigo_description", "")
|
||||
),
|
||||
mesh_vertices=results["mesh_vertices"],
|
||||
mesh_faces=results["mesh_faces"],
|
||||
outputs=outputs,
|
||||
processing_time_ms=0 # Will be set by caller
|
||||
)
|
||||
|
||||
async def _run_experiment_3(
|
||||
self,
|
||||
input_path: str,
|
||||
output_base: Path,
|
||||
brace_config: Optional[BraceConfigRequest],
|
||||
landmarks_data: Optional[Dict[str, Any]] = None
|
||||
) -> AnalysisResult:
|
||||
"""
|
||||
Run Experiment 3 (research-based adaptive) pipeline.
|
||||
|
||||
If landmarks_data is provided, it uses those landmarks (with manual edits)
|
||||
instead of running automatic detection.
|
||||
"""
|
||||
import sys
|
||||
from brace_generator.data_models import BraceConfig
|
||||
|
||||
pipeline = self._get_exp3_pipeline()
|
||||
|
||||
# Configure
|
||||
if brace_config:
|
||||
pipeline.config = BraceConfig(
|
||||
brace_height_mm=brace_config.brace_height_mm,
|
||||
torso_width_mm=brace_config.torso_width_mm,
|
||||
torso_depth_mm=brace_config.torso_depth_mm,
|
||||
wall_thickness_mm=brace_config.wall_thickness_mm,
|
||||
pressure_strength_mm=brace_config.pressure_strength_mm,
|
||||
)
|
||||
|
||||
# If landmarks_data is provided, use it instead of running detection
|
||||
if landmarks_data:
|
||||
results = await self._run_experiment_3_with_landmarks(
|
||||
input_path, output_base, pipeline, landmarks_data
|
||||
)
|
||||
else:
|
||||
# Run full pipeline with automatic detection
|
||||
results = pipeline.process(
|
||||
input_path,
|
||||
str(output_base),
|
||||
visualize=True,
|
||||
save_landmarks=True
|
||||
)
|
||||
|
||||
# Build deformation report
|
||||
deformation_report = None
|
||||
if results.get("deformation_report"):
|
||||
dr = results["deformation_report"]
|
||||
deformation_report = DeformationReport(
|
||||
patch_grid=dr.get("patch_grid", "6x8"),
|
||||
deformations=dr.get("deformations"),
|
||||
zones=dr.get("zones")
|
||||
)
|
||||
|
||||
# Collect output file paths
|
||||
outputs = {}
|
||||
|
||||
if results.get("output_stl"):
|
||||
outputs["stl"] = results["output_stl"]
|
||||
if results.get("output_ply"):
|
||||
outputs["ply"] = results["output_ply"]
|
||||
|
||||
# Check for visualization and landmarks files
|
||||
vis_path = str(output_base) + ".png"
|
||||
json_path = str(output_base) + ".json"
|
||||
if Path(vis_path).exists():
|
||||
outputs["visualization"] = vis_path
|
||||
if Path(json_path).exists():
|
||||
outputs["landmarks"] = json_path
|
||||
|
||||
return AnalysisResult(
|
||||
experiment="experiment_3",
|
||||
input_image=input_path,
|
||||
model_used=results.get("model", "manual_landmarks"),
|
||||
vertebrae_detected=results.get("vertebrae_detected", 0),
|
||||
cobb_angles=CobbAngles(
|
||||
PT=results["cobb_angles"]["PT"],
|
||||
MT=results["cobb_angles"]["MT"],
|
||||
TL=results["cobb_angles"]["TL"],
|
||||
),
|
||||
curve_type=results["curve_type"],
|
||||
rigo_classification=RigoClassification(
|
||||
type=results["rigo_type"],
|
||||
description=results.get("rigo_description", "")
|
||||
),
|
||||
mesh_vertices=results.get("mesh_vertices", 0),
|
||||
mesh_faces=results.get("mesh_faces", 0),
|
||||
deformation_report=deformation_report,
|
||||
outputs=outputs,
|
||||
processing_time_ms=0
|
||||
)
|
||||
|
||||
async def _run_experiment_3_with_landmarks(
|
||||
self,
|
||||
input_path: str,
|
||||
output_base: Path,
|
||||
pipeline,
|
||||
landmarks_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run experiment 3 brace generation using pre-computed landmarks.
|
||||
Uses final_values from landmarks_data (which may include manual edits).
|
||||
"""
|
||||
import sys
|
||||
import json
|
||||
|
||||
# Load analysis modules from brace_generator root
|
||||
from brace_generator.data_models import Spine2D, VertebraLandmark
|
||||
from brace_generator.spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
|
||||
from image_loader import load_xray_rgb
|
||||
|
||||
# Load the image for visualization
|
||||
image_rgb, pixel_spacing = load_xray_rgb(input_path)
|
||||
|
||||
# Build Spine2D from landmarks_data final_values
|
||||
vertebrae_structure = landmarks_data.get("vertebrae_structure", landmarks_data)
|
||||
vertebrae_list = vertebrae_structure.get("vertebrae", [])
|
||||
|
||||
spine = Spine2D()
|
||||
for vdata in vertebrae_list:
|
||||
final = vdata.get("final_values", {})
|
||||
centroid = final.get("centroid_px")
|
||||
|
||||
if centroid is None:
|
||||
continue
|
||||
|
||||
v = VertebraLandmark(
|
||||
level=vdata.get("level"),
|
||||
centroid_px=np.array(centroid, dtype=np.float32),
|
||||
confidence=float(final.get("confidence", 0.5))
|
||||
)
|
||||
|
||||
corners = final.get("corners_px")
|
||||
if corners:
|
||||
v.corners_px = np.array(corners, dtype=np.float32)
|
||||
|
||||
spine.vertebrae.append(v)
|
||||
|
||||
if len(spine.vertebrae) < 3:
|
||||
raise ValueError("Need at least 3 vertebrae for brace generation")
|
||||
|
||||
spine.pixel_spacing_mm = pixel_spacing
|
||||
spine.image_shape = image_rgb.shape[:2]
|
||||
spine.sort_vertebrae()
|
||||
|
||||
# Compute Cobb angles and classification
|
||||
compute_cobb_angles(spine)
|
||||
apex_indices = find_apex_vertebrae(spine)
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
# Generate adaptive brace using the pipeline's brace generator directly
|
||||
# This uses our manually-edited spine instead of re-detecting
|
||||
brace_mesh = pipeline.brace_generator.generate(spine)
|
||||
|
||||
mesh_vertices = 0
|
||||
mesh_faces = 0
|
||||
output_stl = None
|
||||
output_ply = None
|
||||
deformation_report = None
|
||||
|
||||
if brace_mesh is not None:
|
||||
mesh_vertices = len(brace_mesh.vertices)
|
||||
mesh_faces = len(brace_mesh.faces)
|
||||
|
||||
# Get deformation report if available
|
||||
if hasattr(pipeline.brace_generator, 'get_deformation_report'):
|
||||
deformation_report = pipeline.brace_generator.get_deformation_report()
|
||||
|
||||
# Export STL and PLY
|
||||
output_base_path = Path(output_base)
|
||||
output_stl = str(output_base_path.with_suffix('.stl'))
|
||||
output_ply = str(output_base_path.with_suffix('.ply'))
|
||||
|
||||
brace_mesh.export(output_stl)
|
||||
|
||||
# Export PLY if method available
|
||||
if hasattr(pipeline.brace_generator, 'export_ply'):
|
||||
pipeline.brace_generator.export_ply(brace_mesh, output_ply)
|
||||
else:
|
||||
brace_mesh.export(output_ply)
|
||||
print(f" Exported: {output_stl}, {output_ply}")
|
||||
|
||||
# Save visualization with the manual/combined landmarks and deformation heatmap
|
||||
vis_path = str(output_base) + ".png"
|
||||
self._save_landmarks_visualization_with_spine(
|
||||
image_rgb, spine, rigo_result, deformation_report, vis_path
|
||||
)
|
||||
|
||||
# Build result dict
|
||||
result = {
|
||||
"model": "manual_landmarks",
|
||||
"vertebrae_detected": len(spine.vertebrae),
|
||||
"cobb_angles": {
|
||||
"PT": float(spine.cobb_pt or 0),
|
||||
"MT": float(spine.cobb_mt or 0),
|
||||
"TL": float(spine.cobb_tl or 0),
|
||||
},
|
||||
"curve_type": spine.curve_type or "Unknown",
|
||||
"rigo_type": rigo_result["rigo_type"],
|
||||
"rigo_description": rigo_result.get("description", ""),
|
||||
"mesh_vertices": mesh_vertices,
|
||||
"mesh_faces": mesh_faces,
|
||||
"output_stl": output_stl,
|
||||
"output_ply": output_ply,
|
||||
"deformation_report": deformation_report,
|
||||
}
|
||||
|
||||
# Save landmarks JSON
|
||||
json_path = str(output_base) + ".json"
|
||||
with open(json_path, "w") as f:
|
||||
json.dump({
|
||||
"source": "manual_landmarks",
|
||||
"vertebrae_count": len(spine.vertebrae),
|
||||
"cobb_angles": result["cobb_angles"],
|
||||
"rigo_type": result["rigo_type"],
|
||||
"curve_type": result["curve_type"],
|
||||
"deformation_report": deformation_report,
|
||||
}, f, indent=2, default=lambda x: x.tolist() if hasattr(x, 'tolist') else x)
|
||||
|
||||
return result
|
||||
|
||||
def _save_landmarks_visualization_with_spine(self, image, spine, rigo_result, deformation_report, path):
|
||||
"""Save visualization using a pre-built Spine2D object with deformation heatmap."""
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.colors import TwoSlopeNorm
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
# 3-panel layout like the original pipeline
|
||||
fig, axes = plt.subplots(1, 3, figsize=(18, 10))
|
||||
|
||||
# Left: landmarks with X-shaped markers
|
||||
ax1 = axes[0]
|
||||
ax1.imshow(image)
|
||||
|
||||
# Draw green X-shaped vertebra markers and red centroids
|
||||
for v in spine.vertebrae:
|
||||
if v.corners_px is not None:
|
||||
corners = v.corners_px
|
||||
for i in range(4):
|
||||
j = (i + 1) % 4
|
||||
ax1.plot([corners[i, 0], corners[j, 0]],
|
||||
[corners[i, 1], corners[j, 1]],
|
||||
'g-', linewidth=1.5, zorder=4)
|
||||
|
||||
if v.centroid_px is not None:
|
||||
ax1.scatter(v.centroid_px[0], v.centroid_px[1], c='red', s=40, zorder=5)
|
||||
|
||||
# Add labels
|
||||
for v in spine.vertebrae:
|
||||
if v.centroid_px is not None:
|
||||
label = v.level or "?"
|
||||
ax1.annotate(
|
||||
label, (v.centroid_px[0] + 8, v.centroid_px[1]),
|
||||
fontsize=7, color='yellow', fontweight='bold',
|
||||
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.6)
|
||||
)
|
||||
|
||||
ax1.set_title(f"Landmarks ({len(spine.vertebrae)} vertebrae)")
|
||||
ax1.axis('off')
|
||||
|
||||
# Middle: analysis with spine curve
|
||||
ax2 = axes[1]
|
||||
ax2.imshow(image, alpha=0.5)
|
||||
|
||||
# Draw spine curve line through centroids
|
||||
centroids = [v.centroid_px for v in spine.vertebrae if v.centroid_px is not None]
|
||||
if len(centroids) > 1:
|
||||
centroids_arr = np.array(centroids)
|
||||
ax2.plot(centroids_arr[:, 0], centroids_arr[:, 1], 'b-', linewidth=2, alpha=0.8)
|
||||
|
||||
text = f"Cobb Angles:\n"
|
||||
text += f"PT: {spine.cobb_pt:.1f}°\n"
|
||||
text += f"MT: {spine.cobb_mt:.1f}°\n"
|
||||
text += f"TL: {spine.cobb_tl:.1f}°\n\n"
|
||||
text += f"Curve: {spine.curve_type}\n"
|
||||
text += f"Rigo: {rigo_result['rigo_type']}"
|
||||
|
||||
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
|
||||
verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))
|
||||
ax2.set_title("Spine Analysis")
|
||||
ax2.axis('off')
|
||||
|
||||
# Right: deformation heatmap
|
||||
ax3 = axes[2]
|
||||
|
||||
if deformation_report and deformation_report.get('deformations'):
|
||||
deform_array = np.array(deformation_report['deformations'])
|
||||
|
||||
# Create heatmap with diverging colormap
|
||||
vmax = max(abs(deform_array.min()), abs(deform_array.max()), 1)
|
||||
norm = TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
|
||||
|
||||
im = ax3.imshow(deform_array, cmap='RdBu_r', aspect='auto',
|
||||
norm=norm, origin='upper')
|
||||
|
||||
# Add colorbar
|
||||
cbar = plt.colorbar(im, ax=ax3, shrink=0.8)
|
||||
cbar.set_label('Radial deformation (mm)')
|
||||
|
||||
# Labels
|
||||
ax3.set_xlabel('Angular Position (patches)')
|
||||
ax3.set_ylabel('Height (patches)')
|
||||
ax3.set_title('Patch Deformations (mm)\nBlue=Relief, Red=Pressure')
|
||||
|
||||
# Add zone labels on y-axis
|
||||
height_labels = ['Pelvis', 'Low Lumb', 'Up Lumb', 'Low Thor', 'Up Thor', 'Shoulder']
|
||||
if deform_array.shape[0] <= len(height_labels):
|
||||
ax3.set_yticks(range(deform_array.shape[0]))
|
||||
ax3.set_yticklabels(height_labels[:deform_array.shape[0]])
|
||||
|
||||
# Angular position labels
|
||||
angle_labels = ['BR', 'R', 'FR', 'F', 'FL', 'L', 'BL', 'B']
|
||||
if deform_array.shape[1] <= len(angle_labels):
|
||||
ax3.set_xticks(range(deform_array.shape[1]))
|
||||
ax3.set_xticklabels(angle_labels[:deform_array.shape[1]])
|
||||
else:
|
||||
ax3.text(0.5, 0.5, 'No deformation data', ha='center', va='center',
|
||||
transform=ax3.transAxes, fontsize=14, color='gray')
|
||||
ax3.set_title('Patch Deformations')
|
||||
ax3.axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
# ============================================
|
||||
# NEW METHODS FOR PIPELINE DEV
|
||||
# ============================================
|
||||
|
||||
async def detect_landmarks_only(
|
||||
self,
|
||||
image_data: bytes,
|
||||
filename: str,
|
||||
case_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect landmarks only, without generating a brace.
|
||||
Returns full vertebrae_structure with manual_override support.
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
start_time = time.time()
|
||||
|
||||
case_id = case_id or str(uuid.uuid4())[:8]
|
||||
|
||||
# Save image to temp file
|
||||
suffix = Path(filename).suffix or ".jpg"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
||||
f.write(image_data)
|
||||
input_path = f.name
|
||||
|
||||
try:
|
||||
# Import from brace_generator root (modules are already in PYTHONPATH)
|
||||
# Note: In Docker, PYTHONPATH includes /app/brace_generator
|
||||
from image_loader import load_xray_rgb
|
||||
from adapters import ScolioVisAdapter
|
||||
from spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
|
||||
from data_models import Spine2D
|
||||
|
||||
# Load full image
|
||||
image_rgb_full, pixel_spacing = load_xray_rgb(input_path)
|
||||
image_h, image_w = image_rgb_full.shape[:2]
|
||||
|
||||
# Smart cropping: Only crop to middle 1/3 if image is wide enough
|
||||
# Wide images (e.g., full chest X-rays) benefit from cropping to spine area
|
||||
# Narrow images (e.g., already cropped to spine) should not be cropped further
|
||||
MIN_WIDTH_FOR_CROPPING = 500 # Only crop if wider than 500 pixels
|
||||
CROPPED_MIN_WIDTH = 200 # Ensure cropped width is at least 200 pixels
|
||||
|
||||
left_margin = 0 # Initialize offset for coordinate mapping
|
||||
|
||||
if image_w >= MIN_WIDTH_FOR_CROPPING:
|
||||
# Image is wide - crop to middle 1/3 for better spine detection
|
||||
left_margin = image_w // 3
|
||||
right_margin = 2 * image_w // 3
|
||||
cropped_width = right_margin - left_margin
|
||||
|
||||
# Ensure minimum width after cropping
|
||||
if cropped_width >= CROPPED_MIN_WIDTH:
|
||||
image_rgb_for_detection = image_rgb_full[:, left_margin:right_margin]
|
||||
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Cropped to middle 1/3: {cropped_width}x{image_h}")
|
||||
else:
|
||||
# Cropped would be too narrow, use full image
|
||||
image_rgb_for_detection = image_rgb_full
|
||||
left_margin = 0
|
||||
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Cropped would be too narrow ({cropped_width}px), using full image")
|
||||
else:
|
||||
# Image is already narrow - use full image
|
||||
image_rgb_for_detection = image_rgb_full
|
||||
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Already narrow (< {MIN_WIDTH_FOR_CROPPING}px), using full image")
|
||||
|
||||
# Detect landmarks (on cropped or full image depending on width)
|
||||
adapter = ScolioVisAdapter(device=self.device)
|
||||
spine = adapter.predict(image_rgb_for_detection)
|
||||
spine.pixel_spacing_mm = pixel_spacing
|
||||
|
||||
# Offset all detected coordinates back to full image space if cropping was applied
|
||||
if left_margin > 0:
|
||||
for v in spine.vertebrae:
|
||||
if v.centroid_px is not None:
|
||||
# Offset centroid X coordinate
|
||||
v.centroid_px[0] += left_margin
|
||||
if v.corners_px is not None:
|
||||
# Offset all corner X coordinates
|
||||
v.corners_px[:, 0] += left_margin
|
||||
|
||||
# Keep reference to full image for visualization
|
||||
image_rgb = image_rgb_full
|
||||
|
||||
# Compute analysis
|
||||
compute_cobb_angles(spine)
|
||||
apex_indices = find_apex_vertebrae(spine)
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
# Prepare output directory
|
||||
output_dir = config.TEMP_DIR / case_id
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save visualization
|
||||
vis_path = output_dir / "visualization.png"
|
||||
self._save_landmarks_visualization(image_rgb, spine, rigo_result, str(vis_path))
|
||||
|
||||
# Build full vertebrae structure (all T1-L5)
|
||||
ALL_LEVELS = ["T1", "T2", "T3", "T4", "T5", "T6", "T7", "T8", "T9", "T10", "T11", "T12", "L1", "L2", "L3", "L4", "L5"]
|
||||
|
||||
# ScolioVis doesn't assign levels - assign based on Y position (top to bottom)
|
||||
# Sort detected vertebrae by Y coordinate (centroid)
|
||||
detected_verts = sorted(
|
||||
[v for v in spine.vertebrae if v.centroid_px is not None],
|
||||
key=lambda v: v.centroid_px[1] # Sort by Y (top to bottom)
|
||||
)
|
||||
|
||||
# Assign levels based on count
|
||||
# If we detect 17 vertebrae, assign T1-L5
|
||||
# If fewer, we need to figure out which ones are missing
|
||||
num_detected = len(detected_verts)
|
||||
|
||||
if num_detected >= 17:
|
||||
# All vertebrae detected - assign directly
|
||||
for i, v in enumerate(detected_verts[:17]):
|
||||
v.level = ALL_LEVELS[i]
|
||||
elif num_detected > 0:
|
||||
# Fewer than 17 - assign from T1 onwards (assuming top vertebrae visible)
|
||||
# This is a simplification - ideally we'd use anatomical features
|
||||
for i, v in enumerate(detected_verts):
|
||||
if i < len(ALL_LEVELS):
|
||||
v.level = ALL_LEVELS[i]
|
||||
|
||||
# Build detected_map with assigned levels
|
||||
detected_map = {v.level: v for v in detected_verts if v.level}
|
||||
|
||||
vertebrae_list = []
|
||||
for level in ALL_LEVELS:
|
||||
if level in detected_map:
|
||||
v = detected_map[level]
|
||||
centroid = v.centroid_px.tolist() if v.centroid_px is not None else None
|
||||
corners = v.corners_px.tolist() if v.corners_px is not None else None
|
||||
orientation = float(v.compute_orientation()) if centroid else None
|
||||
|
||||
vertebrae_list.append({
|
||||
"level": level,
|
||||
"detected": True,
|
||||
"scoliovis_data": {
|
||||
"centroid_px": centroid,
|
||||
"corners_px": corners,
|
||||
"orientation_deg": orientation,
|
||||
"confidence": float(v.confidence),
|
||||
},
|
||||
"manual_override": {
|
||||
"enabled": False,
|
||||
"centroid_px": None,
|
||||
"corners_px": None,
|
||||
"orientation_deg": None,
|
||||
"confidence": None,
|
||||
"notes": None,
|
||||
},
|
||||
"final_values": {
|
||||
"centroid_px": centroid,
|
||||
"corners_px": corners,
|
||||
"orientation_deg": orientation,
|
||||
"confidence": float(v.confidence),
|
||||
"source": "scoliovis",
|
||||
},
|
||||
})
|
||||
else:
|
||||
vertebrae_list.append({
|
||||
"level": level,
|
||||
"detected": False,
|
||||
"scoliovis_data": {
|
||||
"centroid_px": None,
|
||||
"corners_px": None,
|
||||
"orientation_deg": None,
|
||||
"confidence": 0.0,
|
||||
},
|
||||
"manual_override": {
|
||||
"enabled": False,
|
||||
"centroid_px": None,
|
||||
"corners_px": None,
|
||||
"orientation_deg": None,
|
||||
"confidence": None,
|
||||
"notes": None,
|
||||
},
|
||||
"final_values": {
|
||||
"centroid_px": None,
|
||||
"corners_px": None,
|
||||
"orientation_deg": None,
|
||||
"confidence": 0.0,
|
||||
"source": "undetected",
|
||||
},
|
||||
})
|
||||
|
||||
# Build result
|
||||
result = {
|
||||
"case_id": case_id,
|
||||
"status": "landmarks_detected",
|
||||
"input": {
|
||||
"image_dimensions": {"width": image_w, "height": image_h},
|
||||
"pixel_spacing_mm": pixel_spacing,
|
||||
},
|
||||
"detection_quality": {
|
||||
"vertebrae_count": len(spine.vertebrae),
|
||||
"average_confidence": float(np.mean([v.confidence for v in spine.vertebrae])) if spine.vertebrae else 0.0,
|
||||
},
|
||||
"cobb_angles": {
|
||||
"PT": float(spine.cobb_pt),
|
||||
"MT": float(spine.cobb_mt),
|
||||
"TL": float(spine.cobb_tl),
|
||||
"max": float(max(spine.cobb_pt, spine.cobb_mt, spine.cobb_tl)),
|
||||
"PT_severity": get_curve_severity(spine.cobb_pt),
|
||||
"MT_severity": get_curve_severity(spine.cobb_mt),
|
||||
"TL_severity": get_curve_severity(spine.cobb_tl),
|
||||
},
|
||||
"rigo_classification": {
|
||||
"type": rigo_result["rigo_type"],
|
||||
"description": rigo_result["description"],
|
||||
},
|
||||
"curve_type": spine.curve_type,
|
||||
"vertebrae_structure": {
|
||||
"all_levels": ALL_LEVELS,
|
||||
"detected_count": len(spine.vertebrae),
|
||||
"total_count": len(ALL_LEVELS),
|
||||
"vertebrae": vertebrae_list,
|
||||
"manual_edit_instructions": {
|
||||
"to_override": "Set manual_override.enabled=true and fill manual_override fields",
|
||||
"final_values_rule": "When manual_override.enabled=true, final_values uses manual values",
|
||||
},
|
||||
},
|
||||
"visualization_path": str(vis_path),
|
||||
"processing_time_ms": (time.time() - start_time) * 1000,
|
||||
}
|
||||
|
||||
# Save JSON
|
||||
json_path = output_dir / "landmarks.json"
|
||||
import json
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
|
||||
result["json_path"] = str(json_path)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
Path(input_path).unlink(missing_ok=True)
|
||||
|
||||
def _save_landmarks_visualization(self, image, spine, rigo_result, path):
|
||||
"""Save visualization with landmarks and green quadrilateral boxes."""
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 10))
|
||||
|
||||
# Left: landmarks with green boxes
|
||||
ax1 = axes[0]
|
||||
ax1.imshow(image)
|
||||
|
||||
# Draw green X-shaped vertebra markers and red centroids
|
||||
for v in spine.vertebrae:
|
||||
# Draw green X-shape if corners exist
|
||||
# Corner order: [0]=top_left, [1]=top_right, [2]=bottom_left, [3]=bottom_right
|
||||
# Drawing 0→1→2→3→0 creates the X pattern showing endplate orientations
|
||||
if v.corners_px is not None:
|
||||
corners = v.corners_px
|
||||
for i in range(4):
|
||||
j = (i + 1) % 4
|
||||
ax1.plot([corners[i, 0], corners[j, 0]],
|
||||
[corners[i, 1], corners[j, 1]],
|
||||
'g-', linewidth=1.5, zorder=4)
|
||||
|
||||
# Draw red centroid dot
|
||||
if v.centroid_px is not None:
|
||||
ax1.scatter(v.centroid_px[0], v.centroid_px[1], c='red', s=40, zorder=5)
|
||||
|
||||
# Add labels
|
||||
for i, v in enumerate(spine.vertebrae):
|
||||
if v.centroid_px is not None:
|
||||
label = v.level or str(i)
|
||||
ax1.annotate(
|
||||
label, (v.centroid_px[0] + 8, v.centroid_px[1]),
|
||||
fontsize=7, color='yellow', fontweight='bold',
|
||||
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.6)
|
||||
)
|
||||
|
||||
ax1.set_title(f"Automatic Detection ({len(spine.vertebrae)} vertebrae)")
|
||||
ax1.axis('off')
|
||||
|
||||
# Right: analysis
|
||||
ax2 = axes[1]
|
||||
ax2.imshow(image, alpha=0.5)
|
||||
|
||||
text = f"Cobb Angles:\n"
|
||||
text += f"PT: {spine.cobb_pt:.1f}°\n"
|
||||
text += f"MT: {spine.cobb_mt:.1f}°\n"
|
||||
text += f"TL: {spine.cobb_tl:.1f}°\n\n"
|
||||
text += f"Curve: {spine.curve_type}\n"
|
||||
text += f"Rigo: {rigo_result['rigo_type']}"
|
||||
|
||||
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
|
||||
verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))
|
||||
ax2.set_title("Spine Analysis")
|
||||
ax2.axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
async def recalculate_from_landmarks(
|
||||
self,
|
||||
landmarks_data: Dict[str, Any],
|
||||
case_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Recalculate Cobb angles and Rigo classification from landmarks data.
|
||||
Uses final_values from each vertebra (which may be manual overrides).
|
||||
"""
|
||||
import sys
|
||||
start_time = time.time()
|
||||
|
||||
case_id = case_id or str(uuid.uuid4())[:8]
|
||||
|
||||
# Load analysis modules from brace_generator root
|
||||
from brace_generator.data_models import Spine2D, VertebraLandmark
|
||||
from brace_generator.spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
|
||||
|
||||
# Reconstruct spine from landmarks data
|
||||
vertebrae_structure = landmarks_data.get("vertebrae_structure", landmarks_data)
|
||||
vertebrae_list = vertebrae_structure.get("vertebrae", [])
|
||||
|
||||
# Build Spine2D from final_values
|
||||
spine = Spine2D()
|
||||
for vdata in vertebrae_list:
|
||||
final = vdata.get("final_values", {})
|
||||
centroid = final.get("centroid_px")
|
||||
|
||||
if centroid is None:
|
||||
continue # Skip undetected/empty vertebrae
|
||||
|
||||
v = VertebraLandmark(
|
||||
level=vdata.get("level"),
|
||||
centroid_px=np.array(centroid, dtype=np.float32),
|
||||
confidence=float(final.get("confidence", 0.5))
|
||||
)
|
||||
|
||||
corners = final.get("corners_px")
|
||||
if corners:
|
||||
v.corners_px = np.array(corners, dtype=np.float32)
|
||||
|
||||
spine.vertebrae.append(v)
|
||||
|
||||
if len(spine.vertebrae) < 3:
|
||||
raise ValueError("Need at least 3 vertebrae for analysis")
|
||||
|
||||
# Sort by Y position (top to bottom)
|
||||
spine.sort_vertebrae()
|
||||
|
||||
# Compute Cobb angles and Rigo
|
||||
compute_cobb_angles(spine)
|
||||
apex_indices = find_apex_vertebrae(spine)
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
result = {
|
||||
"case_id": case_id,
|
||||
"status": "analysis_recalculated",
|
||||
"cobb_angles": {
|
||||
"PT": float(spine.cobb_pt),
|
||||
"MT": float(spine.cobb_mt),
|
||||
"TL": float(spine.cobb_tl),
|
||||
"max": float(max(spine.cobb_pt, spine.cobb_mt, spine.cobb_tl)),
|
||||
"PT_severity": get_curve_severity(spine.cobb_pt),
|
||||
"MT_severity": get_curve_severity(spine.cobb_mt),
|
||||
"TL_severity": get_curve_severity(spine.cobb_tl),
|
||||
},
|
||||
"rigo_classification": {
|
||||
"type": rigo_result["rigo_type"],
|
||||
"description": rigo_result["description"],
|
||||
},
|
||||
"curve_type": spine.curve_type,
|
||||
"apex_indices": apex_indices,
|
||||
"vertebrae_used": len(spine.vertebrae),
|
||||
"processing_time_ms": (time.time() - start_time) * 1000,
|
||||
}
|
||||
|
||||
return result
|
||||
456
brace-generator/simple_server.py
Normal file
456
brace-generator/simple_server.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
Simple FastAPI server for brace generation - CPU optimized.
|
||||
Designed to work standalone with minimal dependencies.
|
||||
"""
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '' # Force CPU
|
||||
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import trimesh
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Paths
|
||||
BASE_DIR = Path(__file__).parent
|
||||
MODELS_DIR = BASE_DIR / "models"
|
||||
OUTPUTS_DIR = BASE_DIR / "outputs"
|
||||
TEMPLATES_DIR = BASE_DIR / "templates"
|
||||
|
||||
OUTPUTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Global model (loaded once)
|
||||
model = None
|
||||
model_loaded = False
|
||||
|
||||
|
||||
def get_kprcnn_model():
|
||||
"""Load Keypoint RCNN model."""
|
||||
from torchvision.models.detection.rpn import AnchorGenerator
|
||||
import torchvision
|
||||
|
||||
model_path = MODELS_DIR / "keypointsrcnn_weights.pt"
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model not found: {model_path}")
|
||||
|
||||
num_keypoints = 4
|
||||
anchor_generator = AnchorGenerator(
|
||||
sizes=(32, 64, 128, 256, 512),
|
||||
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
|
||||
)
|
||||
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
|
||||
weights=None,
|
||||
weights_backbone='IMAGENET1K_V1',
|
||||
num_keypoints=num_keypoints,
|
||||
num_classes=2,
|
||||
rpn_anchor_generator=anchor_generator
|
||||
)
|
||||
|
||||
state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def predict_keypoints(model, image_rgb):
|
||||
"""Run keypoint detection."""
|
||||
from torchvision.transforms import functional as F
|
||||
import torchvision
|
||||
|
||||
# Convert to tensor
|
||||
image_tensor = F.to_tensor(image_rgb).unsqueeze(0)
|
||||
|
||||
# Inference
|
||||
with torch.no_grad():
|
||||
outputs = model(image_tensor)
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
# Filter results
|
||||
scores = output['scores'].cpu().numpy()
|
||||
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
|
||||
|
||||
if not high_scores_idxs:
|
||||
return [], [], []
|
||||
|
||||
post_nms_idxs = torchvision.ops.nms(
|
||||
output['boxes'][high_scores_idxs],
|
||||
output['scores'][high_scores_idxs],
|
||||
0.3
|
||||
).cpu().numpy()
|
||||
|
||||
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].cpu().numpy()
|
||||
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].cpu().numpy()
|
||||
np_scores = scores[high_scores_idxs][post_nms_idxs]
|
||||
|
||||
# Sort by score, take top 18
|
||||
sorted_idxs = np.argsort(-np_scores)[:18]
|
||||
np_keypoints = np_keypoints[sorted_idxs]
|
||||
np_bboxes = np_bboxes[sorted_idxs]
|
||||
np_scores = np_scores[sorted_idxs]
|
||||
|
||||
# Sort by y position
|
||||
ymins = np.array([kps[0][1] for kps in np_keypoints])
|
||||
sorted_ymin_idxs = np.argsort(ymins)
|
||||
|
||||
np_keypoints = np_keypoints[sorted_ymin_idxs]
|
||||
np_bboxes = np_bboxes[sorted_ymin_idxs]
|
||||
np_scores = np_scores[sorted_ymin_idxs]
|
||||
|
||||
# Convert to list format
|
||||
keypoints_list = [[list(map(float, kp[:2])) for kp in kps] for kps in np_keypoints]
|
||||
bboxes_list = [list(map(int, bbox.tolist())) for bbox in np_bboxes]
|
||||
scores_list = np_scores.tolist()
|
||||
|
||||
return bboxes_list, keypoints_list, scores_list
|
||||
|
||||
|
||||
def compute_cobb_angles(keypoints):
|
||||
"""Compute Cobb angles from keypoints."""
|
||||
if len(keypoints) < 5:
|
||||
return {"pt": 0, "mt": 0, "tl": 0}
|
||||
|
||||
# Calculate midpoints and angles
|
||||
midpoints = []
|
||||
angles = []
|
||||
|
||||
for kps in keypoints:
|
||||
# kps is list of [x, y] for 4 corners
|
||||
corners = np.array(kps)
|
||||
|
||||
# Top midpoint (average of corners 0 and 1)
|
||||
top = (corners[0] + corners[1]) / 2
|
||||
# Bottom midpoint (average of corners 2 and 3)
|
||||
bottom = (corners[2] + corners[3]) / 2
|
||||
|
||||
midpoints.append((top, bottom))
|
||||
|
||||
# Vertebra angle
|
||||
dx = bottom[0] - top[0]
|
||||
dy = bottom[1] - top[1]
|
||||
angle = np.degrees(np.arctan2(dx, dy))
|
||||
angles.append(angle)
|
||||
|
||||
angles = np.array(angles)
|
||||
|
||||
# Find inflection points for curve regions
|
||||
n = len(angles)
|
||||
|
||||
# Simple approach: divide into 3 regions
|
||||
third = n // 3
|
||||
|
||||
pt_region = angles[:third] if third > 0 else angles[:2]
|
||||
mt_region = angles[third:2*third] if 2*third > third else angles[2:4]
|
||||
tl_region = angles[2*third:] if n > 2*third else angles[-2:]
|
||||
|
||||
# Cobb angle = difference between max and min tilt in region
|
||||
pt_angle = float(np.max(pt_region) - np.min(pt_region)) if len(pt_region) > 1 else 0
|
||||
mt_angle = float(np.max(mt_region) - np.min(mt_region)) if len(mt_region) > 1 else 0
|
||||
tl_angle = float(np.max(tl_region) - np.min(tl_region)) if len(tl_region) > 1 else 0
|
||||
|
||||
return {"pt": pt_angle, "mt": mt_angle, "tl": tl_angle}
|
||||
|
||||
|
||||
def classify_rigo_type(cobb_angles):
|
||||
"""Classify Rigo-Chêneau type based on Cobb angles."""
|
||||
pt = abs(cobb_angles['pt'])
|
||||
mt = abs(cobb_angles['mt'])
|
||||
tl = abs(cobb_angles['tl'])
|
||||
|
||||
max_angle = max(pt, mt, tl)
|
||||
|
||||
if max_angle < 10:
|
||||
return "Normal"
|
||||
elif mt >= tl and mt >= pt:
|
||||
if mt < 25:
|
||||
return "A1"
|
||||
elif mt < 40:
|
||||
return "A2"
|
||||
else:
|
||||
return "A3"
|
||||
elif tl >= mt:
|
||||
if tl < 30:
|
||||
return "C1"
|
||||
else:
|
||||
return "C2"
|
||||
else:
|
||||
return "B1"
|
||||
|
||||
|
||||
def generate_brace(rigo_type: str, case_id: str):
|
||||
"""Load brace template and export."""
|
||||
if rigo_type == "Normal":
|
||||
return None, None
|
||||
|
||||
template_path = TEMPLATES_DIR / f"{rigo_type}.obj"
|
||||
if not template_path.exists():
|
||||
# Try fallback templates
|
||||
fallback = {"A1": "A2", "A2": "A1", "A3": "A2", "C1": "C2", "C2": "C1", "B1": "A1", "B2": "A1"}
|
||||
if rigo_type in fallback:
|
||||
template_path = TEMPLATES_DIR / f"{fallback[rigo_type]}.obj"
|
||||
|
||||
if not template_path.exists():
|
||||
return None, None
|
||||
|
||||
mesh = trimesh.load(str(template_path))
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
meshes = [g for g in mesh.geometry.values() if isinstance(g, trimesh.Trimesh)]
|
||||
if meshes:
|
||||
mesh = trimesh.util.concatenate(meshes)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
# Export
|
||||
case_dir = OUTPUTS_DIR / case_id
|
||||
case_dir.mkdir(exist_ok=True)
|
||||
|
||||
stl_path = case_dir / f"{case_id}_brace.stl"
|
||||
mesh.export(str(stl_path))
|
||||
|
||||
return str(stl_path), mesh
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load model on startup."""
|
||||
global model, model_loaded
|
||||
print("Loading ScolioVis model...")
|
||||
start = time.time()
|
||||
try:
|
||||
model = get_kprcnn_model()
|
||||
model_loaded = True
|
||||
print(f"Model loaded in {time.time() - start:.1f}s")
|
||||
except Exception as e:
|
||||
print(f"Failed to load model: {e}")
|
||||
model_loaded = False
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Brace Generator API",
|
||||
description="CPU-based scoliosis brace generation",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
case_id: str
|
||||
experiment: str = "experiment_4"
|
||||
model_used: str = "ScolioVis"
|
||||
vertebrae_detected: int
|
||||
cobb_angles: dict
|
||||
curve_type: str = "Unknown"
|
||||
rigo_classification: dict
|
||||
mesh_vertices: int = 0
|
||||
mesh_faces: int = 0
|
||||
timing_ms: float
|
||||
outputs: dict
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"model_loaded": model_loaded,
|
||||
"device": "CPU"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/analyze/upload", response_model=AnalysisResult)
|
||||
async def analyze_upload(
|
||||
file: UploadFile = File(...),
|
||||
case_id: str = Form(None)
|
||||
):
|
||||
"""Analyze X-ray and generate brace."""
|
||||
global model
|
||||
|
||||
if not model_loaded:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Generate case ID if not provided
|
||||
if not case_id:
|
||||
case_id = f"case_{int(time.time())}"
|
||||
|
||||
# Save uploaded file
|
||||
case_dir = OUTPUTS_DIR / case_id
|
||||
case_dir.mkdir(exist_ok=True)
|
||||
|
||||
input_path = case_dir / file.filename
|
||||
with open(input_path, "wb") as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(str(input_path))
|
||||
if img is None:
|
||||
raise HTTPException(status_code=400, detail="Could not read image")
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Detect keypoints
|
||||
bboxes, keypoints, scores = predict_keypoints(model, img_rgb)
|
||||
n_vertebrae = len(keypoints)
|
||||
|
||||
if n_vertebrae < 3:
|
||||
raise HTTPException(status_code=400, detail=f"Insufficient vertebrae detected: {n_vertebrae}")
|
||||
|
||||
# Compute Cobb angles
|
||||
cobb_angles = compute_cobb_angles(keypoints)
|
||||
|
||||
# Classify Rigo type
|
||||
rigo_type = classify_rigo_type(cobb_angles)
|
||||
|
||||
# Generate brace
|
||||
stl_path, mesh = generate_brace(rigo_type, case_id)
|
||||
|
||||
outputs = {}
|
||||
mesh_vertices = 0
|
||||
mesh_faces = 0
|
||||
|
||||
if stl_path:
|
||||
outputs["stl"] = stl_path
|
||||
if mesh is not None:
|
||||
mesh_vertices = len(mesh.vertices)
|
||||
mesh_faces = len(mesh.faces)
|
||||
|
||||
# Determine curve type
|
||||
curve_type = "Normal"
|
||||
if cobb_angles['pt'] > 10 or cobb_angles['mt'] > 10 or cobb_angles['tl'] > 10:
|
||||
if (cobb_angles['mt'] > 10 and cobb_angles['tl'] > 10) or (cobb_angles['pt'] > 10 and cobb_angles['tl'] > 10):
|
||||
curve_type = "S-shaped"
|
||||
else:
|
||||
curve_type = "C-shaped"
|
||||
|
||||
# Rigo classification details
|
||||
rigo_descriptions = {
|
||||
"Normal": "No significant curve - normal spine",
|
||||
"A1": "Main thoracic curve (mild) - 3C pattern",
|
||||
"A2": "Main thoracic curve (moderate) - 3C pattern",
|
||||
"A3": "Main thoracic curve (severe) - 3C pattern",
|
||||
"B1": "Double curve (thoracic dominant) - 4C pattern",
|
||||
"B2": "Double curve (lumbar dominant) - 4C pattern",
|
||||
"C1": "Main thoracolumbar/lumbar (mild) - 4C pattern",
|
||||
"C2": "Main thoracolumbar/lumbar (moderate-severe) - 4C pattern",
|
||||
"E1": "Not structural - functional curve",
|
||||
"E2": "Not structural - compensatory curve",
|
||||
}
|
||||
|
||||
rigo_classification = {
|
||||
"type": rigo_type,
|
||||
"description": rigo_descriptions.get(rigo_type, "Unknown classification"),
|
||||
}
|
||||
|
||||
# Generate visualization
|
||||
vis_path = None
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 12))
|
||||
ax.imshow(img_rgb)
|
||||
|
||||
# Draw keypoints
|
||||
for kps in keypoints:
|
||||
corners = np.array(kps) # List of [x, y]
|
||||
centroid = corners.mean(axis=0)
|
||||
ax.scatter(centroid[0], centroid[1], c='red', s=50, zorder=5)
|
||||
|
||||
# Draw corners (vertebra outline)
|
||||
for i in range(4):
|
||||
j = (i + 1) % 4
|
||||
ax.plot([corners[i][0], corners[j][0]],
|
||||
[corners[i][1], corners[j][1]], 'g-', linewidth=1)
|
||||
|
||||
# Add text overlay
|
||||
text = f"ScolioVis Analysis\n"
|
||||
text += f"-" * 20 + "\n"
|
||||
text += f"Vertebrae: {n_vertebrae}\n"
|
||||
text += f"Cobb Angles:\n"
|
||||
text += f" PT: {cobb_angles['pt']:.1f}°\n"
|
||||
text += f" MT: {cobb_angles['mt']:.1f}°\n"
|
||||
text += f" TL: {cobb_angles['tl']:.1f}°\n"
|
||||
text += f"Curve: {curve_type}\n"
|
||||
text += f"Rigo: {rigo_type}\n"
|
||||
text += f"{rigo_classification['description']}"
|
||||
|
||||
ax.text(0.02, 0.98, text, transform=ax.transAxes, fontsize=10,
|
||||
verticalalignment='top', fontfamily='monospace',
|
||||
bbox=dict(facecolor='white', alpha=0.9))
|
||||
|
||||
ax.set_title(f"Case: {case_id}")
|
||||
ax.axis('off')
|
||||
|
||||
vis_path = case_dir / f"{case_id}_visualization.png"
|
||||
plt.savefig(str(vis_path), dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
outputs["visualization"] = str(vis_path)
|
||||
except Exception as e:
|
||||
print(f"Visualization failed: {e}")
|
||||
|
||||
# Save analysis JSON
|
||||
analysis = {
|
||||
"case_id": case_id,
|
||||
"input_file": file.filename,
|
||||
"experiment": "experiment_4",
|
||||
"model_used": "ScolioVis",
|
||||
"vertebrae_detected": n_vertebrae,
|
||||
"cobb_angles": cobb_angles,
|
||||
"curve_type": curve_type,
|
||||
"rigo_type": rigo_type,
|
||||
"rigo_classification": rigo_classification,
|
||||
"keypoints": keypoints,
|
||||
"bboxes": bboxes,
|
||||
"scores": scores,
|
||||
"mesh_vertices": mesh_vertices,
|
||||
"mesh_faces": mesh_faces,
|
||||
}
|
||||
|
||||
analysis_path = case_dir / f"{case_id}_analysis.json"
|
||||
with open(analysis_path, "w") as f:
|
||||
json.dump(analysis, f, indent=2)
|
||||
outputs["analysis"] = str(analysis_path)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return AnalysisResult(
|
||||
case_id=case_id,
|
||||
experiment="experiment_4",
|
||||
model_used="ScolioVis",
|
||||
vertebrae_detected=n_vertebrae,
|
||||
cobb_angles=cobb_angles,
|
||||
curve_type=curve_type,
|
||||
rigo_classification=rigo_classification,
|
||||
mesh_vertices=mesh_vertices,
|
||||
mesh_faces=mesh_faces,
|
||||
timing_ms=elapsed_ms,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
|
||||
@app.get("/download/{case_id}/{filename}")
|
||||
async def download_file(case_id: str, filename: str):
|
||||
"""Download generated file."""
|
||||
file_path = OUTPUTS_DIR / case_id / filename
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return FileResponse(str(file_path), filename=filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user