""" Simple FastAPI server for brace generation - CPU optimized. Designed to work standalone with minimal dependencies. """ import os os.environ['CUDA_VISIBLE_DEVICES'] = '' # Force CPU import sys import time import json import shutil import tempfile from pathlib import Path from typing import Optional from contextlib import asynccontextmanager import numpy as np import cv2 import torch import trimesh from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import JSONResponse, FileResponse from pydantic import BaseModel # Paths BASE_DIR = Path(__file__).parent MODELS_DIR = BASE_DIR / "models" OUTPUTS_DIR = BASE_DIR / "outputs" TEMPLATES_DIR = BASE_DIR / "templates" OUTPUTS_DIR.mkdir(exist_ok=True) # Global model (loaded once) model = None model_loaded = False def get_kprcnn_model(): """Load Keypoint RCNN model.""" from torchvision.models.detection.rpn import AnchorGenerator import torchvision model_path = MODELS_DIR / "keypointsrcnn_weights.pt" if not model_path.exists(): raise FileNotFoundError(f"Model not found: {model_path}") num_keypoints = 4 anchor_generator = AnchorGenerator( sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0) ) model = torchvision.models.detection.keypointrcnn_resnet50_fpn( weights=None, weights_backbone='IMAGENET1K_V1', num_keypoints=num_keypoints, num_classes=2, rpn_anchor_generator=anchor_generator ) state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True) model.load_state_dict(state_dict) model.eval() return model def predict_keypoints(model, image_rgb): """Run keypoint detection.""" from torchvision.transforms import functional as F import torchvision # Convert to tensor image_tensor = F.to_tensor(image_rgb).unsqueeze(0) # Inference with torch.no_grad(): outputs = model(image_tensor) output = outputs[0] # Filter results scores = output['scores'].cpu().numpy() high_scores_idxs = np.where(scores > 0.5)[0].tolist() if not high_scores_idxs: return [], [], [] post_nms_idxs = torchvision.ops.nms( output['boxes'][high_scores_idxs], output['scores'][high_scores_idxs], 0.3 ).cpu().numpy() np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].cpu().numpy() np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].cpu().numpy() np_scores = scores[high_scores_idxs][post_nms_idxs] # Sort by score, take top 18 sorted_idxs = np.argsort(-np_scores)[:18] np_keypoints = np_keypoints[sorted_idxs] np_bboxes = np_bboxes[sorted_idxs] np_scores = np_scores[sorted_idxs] # Sort by y position ymins = np.array([kps[0][1] for kps in np_keypoints]) sorted_ymin_idxs = np.argsort(ymins) np_keypoints = np_keypoints[sorted_ymin_idxs] np_bboxes = np_bboxes[sorted_ymin_idxs] np_scores = np_scores[sorted_ymin_idxs] # Convert to list format keypoints_list = [[list(map(float, kp[:2])) for kp in kps] for kps in np_keypoints] bboxes_list = [list(map(int, bbox.tolist())) for bbox in np_bboxes] scores_list = np_scores.tolist() return bboxes_list, keypoints_list, scores_list def compute_cobb_angles(keypoints): """Compute Cobb angles from keypoints.""" if len(keypoints) < 5: return {"pt": 0, "mt": 0, "tl": 0} # Calculate midpoints and angles midpoints = [] angles = [] for kps in keypoints: # kps is list of [x, y] for 4 corners corners = np.array(kps) # Top midpoint (average of corners 0 and 1) top = (corners[0] + corners[1]) / 2 # Bottom midpoint (average of corners 2 and 3) bottom = (corners[2] + corners[3]) / 2 midpoints.append((top, bottom)) # Vertebra angle dx = bottom[0] - top[0] dy = bottom[1] - top[1] angle = np.degrees(np.arctan2(dx, dy)) angles.append(angle) angles = np.array(angles) # Find inflection points for curve regions n = len(angles) # Simple approach: divide into 3 regions third = n // 3 pt_region = angles[:third] if third > 0 else angles[:2] mt_region = angles[third:2*third] if 2*third > third else angles[2:4] tl_region = angles[2*third:] if n > 2*third else angles[-2:] # Cobb angle = difference between max and min tilt in region pt_angle = float(np.max(pt_region) - np.min(pt_region)) if len(pt_region) > 1 else 0 mt_angle = float(np.max(mt_region) - np.min(mt_region)) if len(mt_region) > 1 else 0 tl_angle = float(np.max(tl_region) - np.min(tl_region)) if len(tl_region) > 1 else 0 return {"pt": pt_angle, "mt": mt_angle, "tl": tl_angle} def classify_rigo_type(cobb_angles): """Classify Rigo-Chêneau type based on Cobb angles.""" pt = abs(cobb_angles['pt']) mt = abs(cobb_angles['mt']) tl = abs(cobb_angles['tl']) max_angle = max(pt, mt, tl) if max_angle < 10: return "Normal" elif mt >= tl and mt >= pt: if mt < 25: return "A1" elif mt < 40: return "A2" else: return "A3" elif tl >= mt: if tl < 30: return "C1" else: return "C2" else: return "B1" def generate_brace(rigo_type: str, case_id: str): """Load brace template and export.""" if rigo_type == "Normal": return None, None template_path = TEMPLATES_DIR / f"{rigo_type}.obj" if not template_path.exists(): # Try fallback templates fallback = {"A1": "A2", "A2": "A1", "A3": "A2", "C1": "C2", "C2": "C1", "B1": "A1", "B2": "A1"} if rigo_type in fallback: template_path = TEMPLATES_DIR / f"{fallback[rigo_type]}.obj" if not template_path.exists(): return None, None mesh = trimesh.load(str(template_path)) if isinstance(mesh, trimesh.Scene): meshes = [g for g in mesh.geometry.values() if isinstance(g, trimesh.Trimesh)] if meshes: mesh = trimesh.util.concatenate(meshes) else: return None, None # Export case_dir = OUTPUTS_DIR / case_id case_dir.mkdir(exist_ok=True) stl_path = case_dir / f"{case_id}_brace.stl" mesh.export(str(stl_path)) return str(stl_path), mesh @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup.""" global model, model_loaded print("Loading ScolioVis model...") start = time.time() try: model = get_kprcnn_model() model_loaded = True print(f"Model loaded in {time.time() - start:.1f}s") except Exception as e: print(f"Failed to load model: {e}") model_loaded = False yield app = FastAPI( title="Brace Generator API", description="CPU-based scoliosis brace generation", lifespan=lifespan ) class AnalysisResult(BaseModel): case_id: str experiment: str = "experiment_4" model_used: str = "ScolioVis" vertebrae_detected: int cobb_angles: dict curve_type: str = "Unknown" rigo_classification: dict mesh_vertices: int = 0 mesh_faces: int = 0 timing_ms: float outputs: dict @app.get("/health") async def health(): return { "status": "healthy", "model_loaded": model_loaded, "device": "CPU" } @app.post("/analyze/upload", response_model=AnalysisResult) async def analyze_upload( file: UploadFile = File(...), case_id: str = Form(None) ): """Analyze X-ray and generate brace.""" global model if not model_loaded: raise HTTPException(status_code=503, detail="Model not loaded") start_time = time.time() # Generate case ID if not provided if not case_id: case_id = f"case_{int(time.time())}" # Save uploaded file case_dir = OUTPUTS_DIR / case_id case_dir.mkdir(exist_ok=True) input_path = case_dir / file.filename with open(input_path, "wb") as f: content = await file.read() f.write(content) # Load image img = cv2.imread(str(input_path)) if img is None: raise HTTPException(status_code=400, detail="Could not read image") img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Detect keypoints bboxes, keypoints, scores = predict_keypoints(model, img_rgb) n_vertebrae = len(keypoints) if n_vertebrae < 3: raise HTTPException(status_code=400, detail=f"Insufficient vertebrae detected: {n_vertebrae}") # Compute Cobb angles cobb_angles = compute_cobb_angles(keypoints) # Classify Rigo type rigo_type = classify_rigo_type(cobb_angles) # Generate brace stl_path, mesh = generate_brace(rigo_type, case_id) outputs = {} mesh_vertices = 0 mesh_faces = 0 if stl_path: outputs["stl"] = stl_path if mesh is not None: mesh_vertices = len(mesh.vertices) mesh_faces = len(mesh.faces) # Determine curve type curve_type = "Normal" if cobb_angles['pt'] > 10 or cobb_angles['mt'] > 10 or cobb_angles['tl'] > 10: if (cobb_angles['mt'] > 10 and cobb_angles['tl'] > 10) or (cobb_angles['pt'] > 10 and cobb_angles['tl'] > 10): curve_type = "S-shaped" else: curve_type = "C-shaped" # Rigo classification details rigo_descriptions = { "Normal": "No significant curve - normal spine", "A1": "Main thoracic curve (mild) - 3C pattern", "A2": "Main thoracic curve (moderate) - 3C pattern", "A3": "Main thoracic curve (severe) - 3C pattern", "B1": "Double curve (thoracic dominant) - 4C pattern", "B2": "Double curve (lumbar dominant) - 4C pattern", "C1": "Main thoracolumbar/lumbar (mild) - 4C pattern", "C2": "Main thoracolumbar/lumbar (moderate-severe) - 4C pattern", "E1": "Not structural - functional curve", "E2": "Not structural - compensatory curve", } rigo_classification = { "type": rigo_type, "description": rigo_descriptions.get(rigo_type, "Unknown classification"), } # Generate visualization vis_path = None try: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt fig, ax = plt.subplots(1, 1, figsize=(10, 12)) ax.imshow(img_rgb) # Draw keypoints for kps in keypoints: corners = np.array(kps) # List of [x, y] centroid = corners.mean(axis=0) ax.scatter(centroid[0], centroid[1], c='red', s=50, zorder=5) # Draw corners (vertebra outline) for i in range(4): j = (i + 1) % 4 ax.plot([corners[i][0], corners[j][0]], [corners[i][1], corners[j][1]], 'g-', linewidth=1) # Add text overlay text = f"ScolioVis Analysis\n" text += f"-" * 20 + "\n" text += f"Vertebrae: {n_vertebrae}\n" text += f"Cobb Angles:\n" text += f" PT: {cobb_angles['pt']:.1f}°\n" text += f" MT: {cobb_angles['mt']:.1f}°\n" text += f" TL: {cobb_angles['tl']:.1f}°\n" text += f"Curve: {curve_type}\n" text += f"Rigo: {rigo_type}\n" text += f"{rigo_classification['description']}" ax.text(0.02, 0.98, text, transform=ax.transAxes, fontsize=10, verticalalignment='top', fontfamily='monospace', bbox=dict(facecolor='white', alpha=0.9)) ax.set_title(f"Case: {case_id}") ax.axis('off') vis_path = case_dir / f"{case_id}_visualization.png" plt.savefig(str(vis_path), dpi=150, bbox_inches='tight') plt.close() outputs["visualization"] = str(vis_path) except Exception as e: print(f"Visualization failed: {e}") # Save analysis JSON analysis = { "case_id": case_id, "input_file": file.filename, "experiment": "experiment_4", "model_used": "ScolioVis", "vertebrae_detected": n_vertebrae, "cobb_angles": cobb_angles, "curve_type": curve_type, "rigo_type": rigo_type, "rigo_classification": rigo_classification, "keypoints": keypoints, "bboxes": bboxes, "scores": scores, "mesh_vertices": mesh_vertices, "mesh_faces": mesh_faces, } analysis_path = case_dir / f"{case_id}_analysis.json" with open(analysis_path, "w") as f: json.dump(analysis, f, indent=2) outputs["analysis"] = str(analysis_path) elapsed_ms = (time.time() - start_time) * 1000 return AnalysisResult( case_id=case_id, experiment="experiment_4", model_used="ScolioVis", vertebrae_detected=n_vertebrae, cobb_angles=cobb_angles, curve_type=curve_type, rigo_classification=rigo_classification, mesh_vertices=mesh_vertices, mesh_faces=mesh_faces, timing_ms=elapsed_ms, outputs=outputs ) @app.get("/download/{case_id}/{filename}") async def download_file(case_id: str, filename: str): """Download generated file.""" file_path = OUTPUTS_DIR / case_id / filename if not file_path.exists(): raise HTTPException(status_code=404, detail="File not found") return FileResponse(str(file_path), filename=filename) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)