Files
braceiqmed/brace-generator/simple_server.py

457 lines
14 KiB
Python

"""
Simple FastAPI server for brace generation - CPU optimized.
Designed to work standalone with minimal dependencies.
"""
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # Force CPU
import sys
import time
import json
import shutil
import tempfile
from pathlib import Path
from typing import Optional
from contextlib import asynccontextmanager
import numpy as np
import cv2
import torch
import trimesh
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
# Paths
BASE_DIR = Path(__file__).parent
MODELS_DIR = BASE_DIR / "models"
OUTPUTS_DIR = BASE_DIR / "outputs"
TEMPLATES_DIR = BASE_DIR / "templates"
OUTPUTS_DIR.mkdir(exist_ok=True)
# Global model (loaded once)
model = None
model_loaded = False
def get_kprcnn_model():
"""Load Keypoint RCNN model."""
from torchvision.models.detection.rpn import AnchorGenerator
import torchvision
model_path = MODELS_DIR / "keypointsrcnn_weights.pt"
if not model_path.exists():
raise FileNotFoundError(f"Model not found: {model_path}")
num_keypoints = 4
anchor_generator = AnchorGenerator(
sizes=(32, 64, 128, 256, 512),
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
)
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
weights=None,
weights_backbone='IMAGENET1K_V1',
num_keypoints=num_keypoints,
num_classes=2,
rpn_anchor_generator=anchor_generator
)
state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(state_dict)
model.eval()
return model
def predict_keypoints(model, image_rgb):
"""Run keypoint detection."""
from torchvision.transforms import functional as F
import torchvision
# Convert to tensor
image_tensor = F.to_tensor(image_rgb).unsqueeze(0)
# Inference
with torch.no_grad():
outputs = model(image_tensor)
output = outputs[0]
# Filter results
scores = output['scores'].cpu().numpy()
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
if not high_scores_idxs:
return [], [], []
post_nms_idxs = torchvision.ops.nms(
output['boxes'][high_scores_idxs],
output['scores'][high_scores_idxs],
0.3
).cpu().numpy()
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].cpu().numpy()
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].cpu().numpy()
np_scores = scores[high_scores_idxs][post_nms_idxs]
# Sort by score, take top 18
sorted_idxs = np.argsort(-np_scores)[:18]
np_keypoints = np_keypoints[sorted_idxs]
np_bboxes = np_bboxes[sorted_idxs]
np_scores = np_scores[sorted_idxs]
# Sort by y position
ymins = np.array([kps[0][1] for kps in np_keypoints])
sorted_ymin_idxs = np.argsort(ymins)
np_keypoints = np_keypoints[sorted_ymin_idxs]
np_bboxes = np_bboxes[sorted_ymin_idxs]
np_scores = np_scores[sorted_ymin_idxs]
# Convert to list format
keypoints_list = [[list(map(float, kp[:2])) for kp in kps] for kps in np_keypoints]
bboxes_list = [list(map(int, bbox.tolist())) for bbox in np_bboxes]
scores_list = np_scores.tolist()
return bboxes_list, keypoints_list, scores_list
def compute_cobb_angles(keypoints):
"""Compute Cobb angles from keypoints."""
if len(keypoints) < 5:
return {"pt": 0, "mt": 0, "tl": 0}
# Calculate midpoints and angles
midpoints = []
angles = []
for kps in keypoints:
# kps is list of [x, y] for 4 corners
corners = np.array(kps)
# Top midpoint (average of corners 0 and 1)
top = (corners[0] + corners[1]) / 2
# Bottom midpoint (average of corners 2 and 3)
bottom = (corners[2] + corners[3]) / 2
midpoints.append((top, bottom))
# Vertebra angle
dx = bottom[0] - top[0]
dy = bottom[1] - top[1]
angle = np.degrees(np.arctan2(dx, dy))
angles.append(angle)
angles = np.array(angles)
# Find inflection points for curve regions
n = len(angles)
# Simple approach: divide into 3 regions
third = n // 3
pt_region = angles[:third] if third > 0 else angles[:2]
mt_region = angles[third:2*third] if 2*third > third else angles[2:4]
tl_region = angles[2*third:] if n > 2*third else angles[-2:]
# Cobb angle = difference between max and min tilt in region
pt_angle = float(np.max(pt_region) - np.min(pt_region)) if len(pt_region) > 1 else 0
mt_angle = float(np.max(mt_region) - np.min(mt_region)) if len(mt_region) > 1 else 0
tl_angle = float(np.max(tl_region) - np.min(tl_region)) if len(tl_region) > 1 else 0
return {"pt": pt_angle, "mt": mt_angle, "tl": tl_angle}
def classify_rigo_type(cobb_angles):
"""Classify Rigo-Chêneau type based on Cobb angles."""
pt = abs(cobb_angles['pt'])
mt = abs(cobb_angles['mt'])
tl = abs(cobb_angles['tl'])
max_angle = max(pt, mt, tl)
if max_angle < 10:
return "Normal"
elif mt >= tl and mt >= pt:
if mt < 25:
return "A1"
elif mt < 40:
return "A2"
else:
return "A3"
elif tl >= mt:
if tl < 30:
return "C1"
else:
return "C2"
else:
return "B1"
def generate_brace(rigo_type: str, case_id: str):
"""Load brace template and export."""
if rigo_type == "Normal":
return None, None
template_path = TEMPLATES_DIR / f"{rigo_type}.obj"
if not template_path.exists():
# Try fallback templates
fallback = {"A1": "A2", "A2": "A1", "A3": "A2", "C1": "C2", "C2": "C1", "B1": "A1", "B2": "A1"}
if rigo_type in fallback:
template_path = TEMPLATES_DIR / f"{fallback[rigo_type]}.obj"
if not template_path.exists():
return None, None
mesh = trimesh.load(str(template_path))
if isinstance(mesh, trimesh.Scene):
meshes = [g for g in mesh.geometry.values() if isinstance(g, trimesh.Trimesh)]
if meshes:
mesh = trimesh.util.concatenate(meshes)
else:
return None, None
# Export
case_dir = OUTPUTS_DIR / case_id
case_dir.mkdir(exist_ok=True)
stl_path = case_dir / f"{case_id}_brace.stl"
mesh.export(str(stl_path))
return str(stl_path), mesh
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup."""
global model, model_loaded
print("Loading ScolioVis model...")
start = time.time()
try:
model = get_kprcnn_model()
model_loaded = True
print(f"Model loaded in {time.time() - start:.1f}s")
except Exception as e:
print(f"Failed to load model: {e}")
model_loaded = False
yield
app = FastAPI(
title="Brace Generator API",
description="CPU-based scoliosis brace generation",
lifespan=lifespan
)
class AnalysisResult(BaseModel):
case_id: str
experiment: str = "experiment_4"
model_used: str = "ScolioVis"
vertebrae_detected: int
cobb_angles: dict
curve_type: str = "Unknown"
rigo_classification: dict
mesh_vertices: int = 0
mesh_faces: int = 0
timing_ms: float
outputs: dict
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": model_loaded,
"device": "CPU"
}
@app.post("/analyze/upload", response_model=AnalysisResult)
async def analyze_upload(
file: UploadFile = File(...),
case_id: str = Form(None)
):
"""Analyze X-ray and generate brace."""
global model
if not model_loaded:
raise HTTPException(status_code=503, detail="Model not loaded")
start_time = time.time()
# Generate case ID if not provided
if not case_id:
case_id = f"case_{int(time.time())}"
# Save uploaded file
case_dir = OUTPUTS_DIR / case_id
case_dir.mkdir(exist_ok=True)
input_path = case_dir / file.filename
with open(input_path, "wb") as f:
content = await file.read()
f.write(content)
# Load image
img = cv2.imread(str(input_path))
if img is None:
raise HTTPException(status_code=400, detail="Could not read image")
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Detect keypoints
bboxes, keypoints, scores = predict_keypoints(model, img_rgb)
n_vertebrae = len(keypoints)
if n_vertebrae < 3:
raise HTTPException(status_code=400, detail=f"Insufficient vertebrae detected: {n_vertebrae}")
# Compute Cobb angles
cobb_angles = compute_cobb_angles(keypoints)
# Classify Rigo type
rigo_type = classify_rigo_type(cobb_angles)
# Generate brace
stl_path, mesh = generate_brace(rigo_type, case_id)
outputs = {}
mesh_vertices = 0
mesh_faces = 0
if stl_path:
outputs["stl"] = stl_path
if mesh is not None:
mesh_vertices = len(mesh.vertices)
mesh_faces = len(mesh.faces)
# Determine curve type
curve_type = "Normal"
if cobb_angles['pt'] > 10 or cobb_angles['mt'] > 10 or cobb_angles['tl'] > 10:
if (cobb_angles['mt'] > 10 and cobb_angles['tl'] > 10) or (cobb_angles['pt'] > 10 and cobb_angles['tl'] > 10):
curve_type = "S-shaped"
else:
curve_type = "C-shaped"
# Rigo classification details
rigo_descriptions = {
"Normal": "No significant curve - normal spine",
"A1": "Main thoracic curve (mild) - 3C pattern",
"A2": "Main thoracic curve (moderate) - 3C pattern",
"A3": "Main thoracic curve (severe) - 3C pattern",
"B1": "Double curve (thoracic dominant) - 4C pattern",
"B2": "Double curve (lumbar dominant) - 4C pattern",
"C1": "Main thoracolumbar/lumbar (mild) - 4C pattern",
"C2": "Main thoracolumbar/lumbar (moderate-severe) - 4C pattern",
"E1": "Not structural - functional curve",
"E2": "Not structural - compensatory curve",
}
rigo_classification = {
"type": rigo_type,
"description": rigo_descriptions.get(rigo_type, "Unknown classification"),
}
# Generate visualization
vis_path = None
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(10, 12))
ax.imshow(img_rgb)
# Draw keypoints
for kps in keypoints:
corners = np.array(kps) # List of [x, y]
centroid = corners.mean(axis=0)
ax.scatter(centroid[0], centroid[1], c='red', s=50, zorder=5)
# Draw corners (vertebra outline)
for i in range(4):
j = (i + 1) % 4
ax.plot([corners[i][0], corners[j][0]],
[corners[i][1], corners[j][1]], 'g-', linewidth=1)
# Add text overlay
text = f"ScolioVis Analysis\n"
text += f"-" * 20 + "\n"
text += f"Vertebrae: {n_vertebrae}\n"
text += f"Cobb Angles:\n"
text += f" PT: {cobb_angles['pt']:.1f}°\n"
text += f" MT: {cobb_angles['mt']:.1f}°\n"
text += f" TL: {cobb_angles['tl']:.1f}°\n"
text += f"Curve: {curve_type}\n"
text += f"Rigo: {rigo_type}\n"
text += f"{rigo_classification['description']}"
ax.text(0.02, 0.98, text, transform=ax.transAxes, fontsize=10,
verticalalignment='top', fontfamily='monospace',
bbox=dict(facecolor='white', alpha=0.9))
ax.set_title(f"Case: {case_id}")
ax.axis('off')
vis_path = case_dir / f"{case_id}_visualization.png"
plt.savefig(str(vis_path), dpi=150, bbox_inches='tight')
plt.close()
outputs["visualization"] = str(vis_path)
except Exception as e:
print(f"Visualization failed: {e}")
# Save analysis JSON
analysis = {
"case_id": case_id,
"input_file": file.filename,
"experiment": "experiment_4",
"model_used": "ScolioVis",
"vertebrae_detected": n_vertebrae,
"cobb_angles": cobb_angles,
"curve_type": curve_type,
"rigo_type": rigo_type,
"rigo_classification": rigo_classification,
"keypoints": keypoints,
"bboxes": bboxes,
"scores": scores,
"mesh_vertices": mesh_vertices,
"mesh_faces": mesh_faces,
}
analysis_path = case_dir / f"{case_id}_analysis.json"
with open(analysis_path, "w") as f:
json.dump(analysis, f, indent=2)
outputs["analysis"] = str(analysis_path)
elapsed_ms = (time.time() - start_time) * 1000
return AnalysisResult(
case_id=case_id,
experiment="experiment_4",
model_used="ScolioVis",
vertebrae_detected=n_vertebrae,
cobb_angles=cobb_angles,
curve_type=curve_type,
rigo_classification=rigo_classification,
mesh_vertices=mesh_vertices,
mesh_faces=mesh_faces,
timing_ms=elapsed_ms,
outputs=outputs
)
@app.get("/download/{case_id}/{filename}")
async def download_file(case_id: str, filename: str):
"""Download generated file."""
file_path = OUTPUTS_DIR / case_id / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(str(file_path), filename=filename)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)