457 lines
14 KiB
Python
457 lines
14 KiB
Python
"""
|
|
Simple FastAPI server for brace generation - CPU optimized.
|
|
Designed to work standalone with minimal dependencies.
|
|
"""
|
|
import os
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '' # Force CPU
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from contextlib import asynccontextmanager
|
|
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
import trimesh
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
|
from fastapi.responses import JSONResponse, FileResponse
|
|
from pydantic import BaseModel
|
|
|
|
# Paths
|
|
BASE_DIR = Path(__file__).parent
|
|
MODELS_DIR = BASE_DIR / "models"
|
|
OUTPUTS_DIR = BASE_DIR / "outputs"
|
|
TEMPLATES_DIR = BASE_DIR / "templates"
|
|
|
|
OUTPUTS_DIR.mkdir(exist_ok=True)
|
|
|
|
# Global model (loaded once)
|
|
model = None
|
|
model_loaded = False
|
|
|
|
|
|
def get_kprcnn_model():
|
|
"""Load Keypoint RCNN model."""
|
|
from torchvision.models.detection.rpn import AnchorGenerator
|
|
import torchvision
|
|
|
|
model_path = MODELS_DIR / "keypointsrcnn_weights.pt"
|
|
if not model_path.exists():
|
|
raise FileNotFoundError(f"Model not found: {model_path}")
|
|
|
|
num_keypoints = 4
|
|
anchor_generator = AnchorGenerator(
|
|
sizes=(32, 64, 128, 256, 512),
|
|
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
|
|
)
|
|
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
|
|
weights=None,
|
|
weights_backbone='IMAGENET1K_V1',
|
|
num_keypoints=num_keypoints,
|
|
num_classes=2,
|
|
rpn_anchor_generator=anchor_generator
|
|
)
|
|
|
|
state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
def predict_keypoints(model, image_rgb):
|
|
"""Run keypoint detection."""
|
|
from torchvision.transforms import functional as F
|
|
import torchvision
|
|
|
|
# Convert to tensor
|
|
image_tensor = F.to_tensor(image_rgb).unsqueeze(0)
|
|
|
|
# Inference
|
|
with torch.no_grad():
|
|
outputs = model(image_tensor)
|
|
|
|
output = outputs[0]
|
|
|
|
# Filter results
|
|
scores = output['scores'].cpu().numpy()
|
|
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
|
|
|
|
if not high_scores_idxs:
|
|
return [], [], []
|
|
|
|
post_nms_idxs = torchvision.ops.nms(
|
|
output['boxes'][high_scores_idxs],
|
|
output['scores'][high_scores_idxs],
|
|
0.3
|
|
).cpu().numpy()
|
|
|
|
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].cpu().numpy()
|
|
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].cpu().numpy()
|
|
np_scores = scores[high_scores_idxs][post_nms_idxs]
|
|
|
|
# Sort by score, take top 18
|
|
sorted_idxs = np.argsort(-np_scores)[:18]
|
|
np_keypoints = np_keypoints[sorted_idxs]
|
|
np_bboxes = np_bboxes[sorted_idxs]
|
|
np_scores = np_scores[sorted_idxs]
|
|
|
|
# Sort by y position
|
|
ymins = np.array([kps[0][1] for kps in np_keypoints])
|
|
sorted_ymin_idxs = np.argsort(ymins)
|
|
|
|
np_keypoints = np_keypoints[sorted_ymin_idxs]
|
|
np_bboxes = np_bboxes[sorted_ymin_idxs]
|
|
np_scores = np_scores[sorted_ymin_idxs]
|
|
|
|
# Convert to list format
|
|
keypoints_list = [[list(map(float, kp[:2])) for kp in kps] for kps in np_keypoints]
|
|
bboxes_list = [list(map(int, bbox.tolist())) for bbox in np_bboxes]
|
|
scores_list = np_scores.tolist()
|
|
|
|
return bboxes_list, keypoints_list, scores_list
|
|
|
|
|
|
def compute_cobb_angles(keypoints):
|
|
"""Compute Cobb angles from keypoints."""
|
|
if len(keypoints) < 5:
|
|
return {"pt": 0, "mt": 0, "tl": 0}
|
|
|
|
# Calculate midpoints and angles
|
|
midpoints = []
|
|
angles = []
|
|
|
|
for kps in keypoints:
|
|
# kps is list of [x, y] for 4 corners
|
|
corners = np.array(kps)
|
|
|
|
# Top midpoint (average of corners 0 and 1)
|
|
top = (corners[0] + corners[1]) / 2
|
|
# Bottom midpoint (average of corners 2 and 3)
|
|
bottom = (corners[2] + corners[3]) / 2
|
|
|
|
midpoints.append((top, bottom))
|
|
|
|
# Vertebra angle
|
|
dx = bottom[0] - top[0]
|
|
dy = bottom[1] - top[1]
|
|
angle = np.degrees(np.arctan2(dx, dy))
|
|
angles.append(angle)
|
|
|
|
angles = np.array(angles)
|
|
|
|
# Find inflection points for curve regions
|
|
n = len(angles)
|
|
|
|
# Simple approach: divide into 3 regions
|
|
third = n // 3
|
|
|
|
pt_region = angles[:third] if third > 0 else angles[:2]
|
|
mt_region = angles[third:2*third] if 2*third > third else angles[2:4]
|
|
tl_region = angles[2*third:] if n > 2*third else angles[-2:]
|
|
|
|
# Cobb angle = difference between max and min tilt in region
|
|
pt_angle = float(np.max(pt_region) - np.min(pt_region)) if len(pt_region) > 1 else 0
|
|
mt_angle = float(np.max(mt_region) - np.min(mt_region)) if len(mt_region) > 1 else 0
|
|
tl_angle = float(np.max(tl_region) - np.min(tl_region)) if len(tl_region) > 1 else 0
|
|
|
|
return {"pt": pt_angle, "mt": mt_angle, "tl": tl_angle}
|
|
|
|
|
|
def classify_rigo_type(cobb_angles):
|
|
"""Classify Rigo-Chêneau type based on Cobb angles."""
|
|
pt = abs(cobb_angles['pt'])
|
|
mt = abs(cobb_angles['mt'])
|
|
tl = abs(cobb_angles['tl'])
|
|
|
|
max_angle = max(pt, mt, tl)
|
|
|
|
if max_angle < 10:
|
|
return "Normal"
|
|
elif mt >= tl and mt >= pt:
|
|
if mt < 25:
|
|
return "A1"
|
|
elif mt < 40:
|
|
return "A2"
|
|
else:
|
|
return "A3"
|
|
elif tl >= mt:
|
|
if tl < 30:
|
|
return "C1"
|
|
else:
|
|
return "C2"
|
|
else:
|
|
return "B1"
|
|
|
|
|
|
def generate_brace(rigo_type: str, case_id: str):
|
|
"""Load brace template and export."""
|
|
if rigo_type == "Normal":
|
|
return None, None
|
|
|
|
template_path = TEMPLATES_DIR / f"{rigo_type}.obj"
|
|
if not template_path.exists():
|
|
# Try fallback templates
|
|
fallback = {"A1": "A2", "A2": "A1", "A3": "A2", "C1": "C2", "C2": "C1", "B1": "A1", "B2": "A1"}
|
|
if rigo_type in fallback:
|
|
template_path = TEMPLATES_DIR / f"{fallback[rigo_type]}.obj"
|
|
|
|
if not template_path.exists():
|
|
return None, None
|
|
|
|
mesh = trimesh.load(str(template_path))
|
|
if isinstance(mesh, trimesh.Scene):
|
|
meshes = [g for g in mesh.geometry.values() if isinstance(g, trimesh.Trimesh)]
|
|
if meshes:
|
|
mesh = trimesh.util.concatenate(meshes)
|
|
else:
|
|
return None, None
|
|
|
|
# Export
|
|
case_dir = OUTPUTS_DIR / case_id
|
|
case_dir.mkdir(exist_ok=True)
|
|
|
|
stl_path = case_dir / f"{case_id}_brace.stl"
|
|
mesh.export(str(stl_path))
|
|
|
|
return str(stl_path), mesh
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Load model on startup."""
|
|
global model, model_loaded
|
|
print("Loading ScolioVis model...")
|
|
start = time.time()
|
|
try:
|
|
model = get_kprcnn_model()
|
|
model_loaded = True
|
|
print(f"Model loaded in {time.time() - start:.1f}s")
|
|
except Exception as e:
|
|
print(f"Failed to load model: {e}")
|
|
model_loaded = False
|
|
yield
|
|
|
|
|
|
app = FastAPI(
|
|
title="Brace Generator API",
|
|
description="CPU-based scoliosis brace generation",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
class AnalysisResult(BaseModel):
|
|
case_id: str
|
|
experiment: str = "experiment_4"
|
|
model_used: str = "ScolioVis"
|
|
vertebrae_detected: int
|
|
cobb_angles: dict
|
|
curve_type: str = "Unknown"
|
|
rigo_classification: dict
|
|
mesh_vertices: int = 0
|
|
mesh_faces: int = 0
|
|
timing_ms: float
|
|
outputs: dict
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": model_loaded,
|
|
"device": "CPU"
|
|
}
|
|
|
|
|
|
@app.post("/analyze/upload", response_model=AnalysisResult)
|
|
async def analyze_upload(
|
|
file: UploadFile = File(...),
|
|
case_id: str = Form(None)
|
|
):
|
|
"""Analyze X-ray and generate brace."""
|
|
global model
|
|
|
|
if not model_loaded:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
start_time = time.time()
|
|
|
|
# Generate case ID if not provided
|
|
if not case_id:
|
|
case_id = f"case_{int(time.time())}"
|
|
|
|
# Save uploaded file
|
|
case_dir = OUTPUTS_DIR / case_id
|
|
case_dir.mkdir(exist_ok=True)
|
|
|
|
input_path = case_dir / file.filename
|
|
with open(input_path, "wb") as f:
|
|
content = await file.read()
|
|
f.write(content)
|
|
|
|
# Load image
|
|
img = cv2.imread(str(input_path))
|
|
if img is None:
|
|
raise HTTPException(status_code=400, detail="Could not read image")
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
# Detect keypoints
|
|
bboxes, keypoints, scores = predict_keypoints(model, img_rgb)
|
|
n_vertebrae = len(keypoints)
|
|
|
|
if n_vertebrae < 3:
|
|
raise HTTPException(status_code=400, detail=f"Insufficient vertebrae detected: {n_vertebrae}")
|
|
|
|
# Compute Cobb angles
|
|
cobb_angles = compute_cobb_angles(keypoints)
|
|
|
|
# Classify Rigo type
|
|
rigo_type = classify_rigo_type(cobb_angles)
|
|
|
|
# Generate brace
|
|
stl_path, mesh = generate_brace(rigo_type, case_id)
|
|
|
|
outputs = {}
|
|
mesh_vertices = 0
|
|
mesh_faces = 0
|
|
|
|
if stl_path:
|
|
outputs["stl"] = stl_path
|
|
if mesh is not None:
|
|
mesh_vertices = len(mesh.vertices)
|
|
mesh_faces = len(mesh.faces)
|
|
|
|
# Determine curve type
|
|
curve_type = "Normal"
|
|
if cobb_angles['pt'] > 10 or cobb_angles['mt'] > 10 or cobb_angles['tl'] > 10:
|
|
if (cobb_angles['mt'] > 10 and cobb_angles['tl'] > 10) or (cobb_angles['pt'] > 10 and cobb_angles['tl'] > 10):
|
|
curve_type = "S-shaped"
|
|
else:
|
|
curve_type = "C-shaped"
|
|
|
|
# Rigo classification details
|
|
rigo_descriptions = {
|
|
"Normal": "No significant curve - normal spine",
|
|
"A1": "Main thoracic curve (mild) - 3C pattern",
|
|
"A2": "Main thoracic curve (moderate) - 3C pattern",
|
|
"A3": "Main thoracic curve (severe) - 3C pattern",
|
|
"B1": "Double curve (thoracic dominant) - 4C pattern",
|
|
"B2": "Double curve (lumbar dominant) - 4C pattern",
|
|
"C1": "Main thoracolumbar/lumbar (mild) - 4C pattern",
|
|
"C2": "Main thoracolumbar/lumbar (moderate-severe) - 4C pattern",
|
|
"E1": "Not structural - functional curve",
|
|
"E2": "Not structural - compensatory curve",
|
|
}
|
|
|
|
rigo_classification = {
|
|
"type": rigo_type,
|
|
"description": rigo_descriptions.get(rigo_type, "Unknown classification"),
|
|
}
|
|
|
|
# Generate visualization
|
|
vis_path = None
|
|
try:
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(10, 12))
|
|
ax.imshow(img_rgb)
|
|
|
|
# Draw keypoints
|
|
for kps in keypoints:
|
|
corners = np.array(kps) # List of [x, y]
|
|
centroid = corners.mean(axis=0)
|
|
ax.scatter(centroid[0], centroid[1], c='red', s=50, zorder=5)
|
|
|
|
# Draw corners (vertebra outline)
|
|
for i in range(4):
|
|
j = (i + 1) % 4
|
|
ax.plot([corners[i][0], corners[j][0]],
|
|
[corners[i][1], corners[j][1]], 'g-', linewidth=1)
|
|
|
|
# Add text overlay
|
|
text = f"ScolioVis Analysis\n"
|
|
text += f"-" * 20 + "\n"
|
|
text += f"Vertebrae: {n_vertebrae}\n"
|
|
text += f"Cobb Angles:\n"
|
|
text += f" PT: {cobb_angles['pt']:.1f}°\n"
|
|
text += f" MT: {cobb_angles['mt']:.1f}°\n"
|
|
text += f" TL: {cobb_angles['tl']:.1f}°\n"
|
|
text += f"Curve: {curve_type}\n"
|
|
text += f"Rigo: {rigo_type}\n"
|
|
text += f"{rigo_classification['description']}"
|
|
|
|
ax.text(0.02, 0.98, text, transform=ax.transAxes, fontsize=10,
|
|
verticalalignment='top', fontfamily='monospace',
|
|
bbox=dict(facecolor='white', alpha=0.9))
|
|
|
|
ax.set_title(f"Case: {case_id}")
|
|
ax.axis('off')
|
|
|
|
vis_path = case_dir / f"{case_id}_visualization.png"
|
|
plt.savefig(str(vis_path), dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
outputs["visualization"] = str(vis_path)
|
|
except Exception as e:
|
|
print(f"Visualization failed: {e}")
|
|
|
|
# Save analysis JSON
|
|
analysis = {
|
|
"case_id": case_id,
|
|
"input_file": file.filename,
|
|
"experiment": "experiment_4",
|
|
"model_used": "ScolioVis",
|
|
"vertebrae_detected": n_vertebrae,
|
|
"cobb_angles": cobb_angles,
|
|
"curve_type": curve_type,
|
|
"rigo_type": rigo_type,
|
|
"rigo_classification": rigo_classification,
|
|
"keypoints": keypoints,
|
|
"bboxes": bboxes,
|
|
"scores": scores,
|
|
"mesh_vertices": mesh_vertices,
|
|
"mesh_faces": mesh_faces,
|
|
}
|
|
|
|
analysis_path = case_dir / f"{case_id}_analysis.json"
|
|
with open(analysis_path, "w") as f:
|
|
json.dump(analysis, f, indent=2)
|
|
outputs["analysis"] = str(analysis_path)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
return AnalysisResult(
|
|
case_id=case_id,
|
|
experiment="experiment_4",
|
|
model_used="ScolioVis",
|
|
vertebrae_detected=n_vertebrae,
|
|
cobb_angles=cobb_angles,
|
|
curve_type=curve_type,
|
|
rigo_classification=rigo_classification,
|
|
mesh_vertices=mesh_vertices,
|
|
mesh_faces=mesh_faces,
|
|
timing_ms=elapsed_ms,
|
|
outputs=outputs
|
|
)
|
|
|
|
|
|
@app.get("/download/{case_id}/{filename}")
|
|
async def download_file(case_id: str, filename: str):
|
|
"""Download generated file."""
|
|
file_path = OUTPUTS_DIR / case_id / filename
|
|
if not file_path.exists():
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
return FileResponse(str(file_path), filename=filename)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|