Initial commit - BraceIQMed platform with frontend, API, and brace generator
This commit is contained in:
456
brace-generator/simple_server.py
Normal file
456
brace-generator/simple_server.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
Simple FastAPI server for brace generation - CPU optimized.
|
||||
Designed to work standalone with minimal dependencies.
|
||||
"""
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '' # Force CPU
|
||||
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import trimesh
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Paths
|
||||
BASE_DIR = Path(__file__).parent
|
||||
MODELS_DIR = BASE_DIR / "models"
|
||||
OUTPUTS_DIR = BASE_DIR / "outputs"
|
||||
TEMPLATES_DIR = BASE_DIR / "templates"
|
||||
|
||||
OUTPUTS_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Global model (loaded once)
|
||||
model = None
|
||||
model_loaded = False
|
||||
|
||||
|
||||
def get_kprcnn_model():
|
||||
"""Load Keypoint RCNN model."""
|
||||
from torchvision.models.detection.rpn import AnchorGenerator
|
||||
import torchvision
|
||||
|
||||
model_path = MODELS_DIR / "keypointsrcnn_weights.pt"
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model not found: {model_path}")
|
||||
|
||||
num_keypoints = 4
|
||||
anchor_generator = AnchorGenerator(
|
||||
sizes=(32, 64, 128, 256, 512),
|
||||
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
|
||||
)
|
||||
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
|
||||
weights=None,
|
||||
weights_backbone='IMAGENET1K_V1',
|
||||
num_keypoints=num_keypoints,
|
||||
num_classes=2,
|
||||
rpn_anchor_generator=anchor_generator
|
||||
)
|
||||
|
||||
state_dict = torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def predict_keypoints(model, image_rgb):
|
||||
"""Run keypoint detection."""
|
||||
from torchvision.transforms import functional as F
|
||||
import torchvision
|
||||
|
||||
# Convert to tensor
|
||||
image_tensor = F.to_tensor(image_rgb).unsqueeze(0)
|
||||
|
||||
# Inference
|
||||
with torch.no_grad():
|
||||
outputs = model(image_tensor)
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
# Filter results
|
||||
scores = output['scores'].cpu().numpy()
|
||||
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
|
||||
|
||||
if not high_scores_idxs:
|
||||
return [], [], []
|
||||
|
||||
post_nms_idxs = torchvision.ops.nms(
|
||||
output['boxes'][high_scores_idxs],
|
||||
output['scores'][high_scores_idxs],
|
||||
0.3
|
||||
).cpu().numpy()
|
||||
|
||||
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].cpu().numpy()
|
||||
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].cpu().numpy()
|
||||
np_scores = scores[high_scores_idxs][post_nms_idxs]
|
||||
|
||||
# Sort by score, take top 18
|
||||
sorted_idxs = np.argsort(-np_scores)[:18]
|
||||
np_keypoints = np_keypoints[sorted_idxs]
|
||||
np_bboxes = np_bboxes[sorted_idxs]
|
||||
np_scores = np_scores[sorted_idxs]
|
||||
|
||||
# Sort by y position
|
||||
ymins = np.array([kps[0][1] for kps in np_keypoints])
|
||||
sorted_ymin_idxs = np.argsort(ymins)
|
||||
|
||||
np_keypoints = np_keypoints[sorted_ymin_idxs]
|
||||
np_bboxes = np_bboxes[sorted_ymin_idxs]
|
||||
np_scores = np_scores[sorted_ymin_idxs]
|
||||
|
||||
# Convert to list format
|
||||
keypoints_list = [[list(map(float, kp[:2])) for kp in kps] for kps in np_keypoints]
|
||||
bboxes_list = [list(map(int, bbox.tolist())) for bbox in np_bboxes]
|
||||
scores_list = np_scores.tolist()
|
||||
|
||||
return bboxes_list, keypoints_list, scores_list
|
||||
|
||||
|
||||
def compute_cobb_angles(keypoints):
|
||||
"""Compute Cobb angles from keypoints."""
|
||||
if len(keypoints) < 5:
|
||||
return {"pt": 0, "mt": 0, "tl": 0}
|
||||
|
||||
# Calculate midpoints and angles
|
||||
midpoints = []
|
||||
angles = []
|
||||
|
||||
for kps in keypoints:
|
||||
# kps is list of [x, y] for 4 corners
|
||||
corners = np.array(kps)
|
||||
|
||||
# Top midpoint (average of corners 0 and 1)
|
||||
top = (corners[0] + corners[1]) / 2
|
||||
# Bottom midpoint (average of corners 2 and 3)
|
||||
bottom = (corners[2] + corners[3]) / 2
|
||||
|
||||
midpoints.append((top, bottom))
|
||||
|
||||
# Vertebra angle
|
||||
dx = bottom[0] - top[0]
|
||||
dy = bottom[1] - top[1]
|
||||
angle = np.degrees(np.arctan2(dx, dy))
|
||||
angles.append(angle)
|
||||
|
||||
angles = np.array(angles)
|
||||
|
||||
# Find inflection points for curve regions
|
||||
n = len(angles)
|
||||
|
||||
# Simple approach: divide into 3 regions
|
||||
third = n // 3
|
||||
|
||||
pt_region = angles[:third] if third > 0 else angles[:2]
|
||||
mt_region = angles[third:2*third] if 2*third > third else angles[2:4]
|
||||
tl_region = angles[2*third:] if n > 2*third else angles[-2:]
|
||||
|
||||
# Cobb angle = difference between max and min tilt in region
|
||||
pt_angle = float(np.max(pt_region) - np.min(pt_region)) if len(pt_region) > 1 else 0
|
||||
mt_angle = float(np.max(mt_region) - np.min(mt_region)) if len(mt_region) > 1 else 0
|
||||
tl_angle = float(np.max(tl_region) - np.min(tl_region)) if len(tl_region) > 1 else 0
|
||||
|
||||
return {"pt": pt_angle, "mt": mt_angle, "tl": tl_angle}
|
||||
|
||||
|
||||
def classify_rigo_type(cobb_angles):
|
||||
"""Classify Rigo-Chêneau type based on Cobb angles."""
|
||||
pt = abs(cobb_angles['pt'])
|
||||
mt = abs(cobb_angles['mt'])
|
||||
tl = abs(cobb_angles['tl'])
|
||||
|
||||
max_angle = max(pt, mt, tl)
|
||||
|
||||
if max_angle < 10:
|
||||
return "Normal"
|
||||
elif mt >= tl and mt >= pt:
|
||||
if mt < 25:
|
||||
return "A1"
|
||||
elif mt < 40:
|
||||
return "A2"
|
||||
else:
|
||||
return "A3"
|
||||
elif tl >= mt:
|
||||
if tl < 30:
|
||||
return "C1"
|
||||
else:
|
||||
return "C2"
|
||||
else:
|
||||
return "B1"
|
||||
|
||||
|
||||
def generate_brace(rigo_type: str, case_id: str):
|
||||
"""Load brace template and export."""
|
||||
if rigo_type == "Normal":
|
||||
return None, None
|
||||
|
||||
template_path = TEMPLATES_DIR / f"{rigo_type}.obj"
|
||||
if not template_path.exists():
|
||||
# Try fallback templates
|
||||
fallback = {"A1": "A2", "A2": "A1", "A3": "A2", "C1": "C2", "C2": "C1", "B1": "A1", "B2": "A1"}
|
||||
if rigo_type in fallback:
|
||||
template_path = TEMPLATES_DIR / f"{fallback[rigo_type]}.obj"
|
||||
|
||||
if not template_path.exists():
|
||||
return None, None
|
||||
|
||||
mesh = trimesh.load(str(template_path))
|
||||
if isinstance(mesh, trimesh.Scene):
|
||||
meshes = [g for g in mesh.geometry.values() if isinstance(g, trimesh.Trimesh)]
|
||||
if meshes:
|
||||
mesh = trimesh.util.concatenate(meshes)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
# Export
|
||||
case_dir = OUTPUTS_DIR / case_id
|
||||
case_dir.mkdir(exist_ok=True)
|
||||
|
||||
stl_path = case_dir / f"{case_id}_brace.stl"
|
||||
mesh.export(str(stl_path))
|
||||
|
||||
return str(stl_path), mesh
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load model on startup."""
|
||||
global model, model_loaded
|
||||
print("Loading ScolioVis model...")
|
||||
start = time.time()
|
||||
try:
|
||||
model = get_kprcnn_model()
|
||||
model_loaded = True
|
||||
print(f"Model loaded in {time.time() - start:.1f}s")
|
||||
except Exception as e:
|
||||
print(f"Failed to load model: {e}")
|
||||
model_loaded = False
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Brace Generator API",
|
||||
description="CPU-based scoliosis brace generation",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
case_id: str
|
||||
experiment: str = "experiment_4"
|
||||
model_used: str = "ScolioVis"
|
||||
vertebrae_detected: int
|
||||
cobb_angles: dict
|
||||
curve_type: str = "Unknown"
|
||||
rigo_classification: dict
|
||||
mesh_vertices: int = 0
|
||||
mesh_faces: int = 0
|
||||
timing_ms: float
|
||||
outputs: dict
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"model_loaded": model_loaded,
|
||||
"device": "CPU"
|
||||
}
|
||||
|
||||
|
||||
@app.post("/analyze/upload", response_model=AnalysisResult)
|
||||
async def analyze_upload(
|
||||
file: UploadFile = File(...),
|
||||
case_id: str = Form(None)
|
||||
):
|
||||
"""Analyze X-ray and generate brace."""
|
||||
global model
|
||||
|
||||
if not model_loaded:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Generate case ID if not provided
|
||||
if not case_id:
|
||||
case_id = f"case_{int(time.time())}"
|
||||
|
||||
# Save uploaded file
|
||||
case_dir = OUTPUTS_DIR / case_id
|
||||
case_dir.mkdir(exist_ok=True)
|
||||
|
||||
input_path = case_dir / file.filename
|
||||
with open(input_path, "wb") as f:
|
||||
content = await file.read()
|
||||
f.write(content)
|
||||
|
||||
# Load image
|
||||
img = cv2.imread(str(input_path))
|
||||
if img is None:
|
||||
raise HTTPException(status_code=400, detail="Could not read image")
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Detect keypoints
|
||||
bboxes, keypoints, scores = predict_keypoints(model, img_rgb)
|
||||
n_vertebrae = len(keypoints)
|
||||
|
||||
if n_vertebrae < 3:
|
||||
raise HTTPException(status_code=400, detail=f"Insufficient vertebrae detected: {n_vertebrae}")
|
||||
|
||||
# Compute Cobb angles
|
||||
cobb_angles = compute_cobb_angles(keypoints)
|
||||
|
||||
# Classify Rigo type
|
||||
rigo_type = classify_rigo_type(cobb_angles)
|
||||
|
||||
# Generate brace
|
||||
stl_path, mesh = generate_brace(rigo_type, case_id)
|
||||
|
||||
outputs = {}
|
||||
mesh_vertices = 0
|
||||
mesh_faces = 0
|
||||
|
||||
if stl_path:
|
||||
outputs["stl"] = stl_path
|
||||
if mesh is not None:
|
||||
mesh_vertices = len(mesh.vertices)
|
||||
mesh_faces = len(mesh.faces)
|
||||
|
||||
# Determine curve type
|
||||
curve_type = "Normal"
|
||||
if cobb_angles['pt'] > 10 or cobb_angles['mt'] > 10 or cobb_angles['tl'] > 10:
|
||||
if (cobb_angles['mt'] > 10 and cobb_angles['tl'] > 10) or (cobb_angles['pt'] > 10 and cobb_angles['tl'] > 10):
|
||||
curve_type = "S-shaped"
|
||||
else:
|
||||
curve_type = "C-shaped"
|
||||
|
||||
# Rigo classification details
|
||||
rigo_descriptions = {
|
||||
"Normal": "No significant curve - normal spine",
|
||||
"A1": "Main thoracic curve (mild) - 3C pattern",
|
||||
"A2": "Main thoracic curve (moderate) - 3C pattern",
|
||||
"A3": "Main thoracic curve (severe) - 3C pattern",
|
||||
"B1": "Double curve (thoracic dominant) - 4C pattern",
|
||||
"B2": "Double curve (lumbar dominant) - 4C pattern",
|
||||
"C1": "Main thoracolumbar/lumbar (mild) - 4C pattern",
|
||||
"C2": "Main thoracolumbar/lumbar (moderate-severe) - 4C pattern",
|
||||
"E1": "Not structural - functional curve",
|
||||
"E2": "Not structural - compensatory curve",
|
||||
}
|
||||
|
||||
rigo_classification = {
|
||||
"type": rigo_type,
|
||||
"description": rigo_descriptions.get(rigo_type, "Unknown classification"),
|
||||
}
|
||||
|
||||
# Generate visualization
|
||||
vis_path = None
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 12))
|
||||
ax.imshow(img_rgb)
|
||||
|
||||
# Draw keypoints
|
||||
for kps in keypoints:
|
||||
corners = np.array(kps) # List of [x, y]
|
||||
centroid = corners.mean(axis=0)
|
||||
ax.scatter(centroid[0], centroid[1], c='red', s=50, zorder=5)
|
||||
|
||||
# Draw corners (vertebra outline)
|
||||
for i in range(4):
|
||||
j = (i + 1) % 4
|
||||
ax.plot([corners[i][0], corners[j][0]],
|
||||
[corners[i][1], corners[j][1]], 'g-', linewidth=1)
|
||||
|
||||
# Add text overlay
|
||||
text = f"ScolioVis Analysis\n"
|
||||
text += f"-" * 20 + "\n"
|
||||
text += f"Vertebrae: {n_vertebrae}\n"
|
||||
text += f"Cobb Angles:\n"
|
||||
text += f" PT: {cobb_angles['pt']:.1f}°\n"
|
||||
text += f" MT: {cobb_angles['mt']:.1f}°\n"
|
||||
text += f" TL: {cobb_angles['tl']:.1f}°\n"
|
||||
text += f"Curve: {curve_type}\n"
|
||||
text += f"Rigo: {rigo_type}\n"
|
||||
text += f"{rigo_classification['description']}"
|
||||
|
||||
ax.text(0.02, 0.98, text, transform=ax.transAxes, fontsize=10,
|
||||
verticalalignment='top', fontfamily='monospace',
|
||||
bbox=dict(facecolor='white', alpha=0.9))
|
||||
|
||||
ax.set_title(f"Case: {case_id}")
|
||||
ax.axis('off')
|
||||
|
||||
vis_path = case_dir / f"{case_id}_visualization.png"
|
||||
plt.savefig(str(vis_path), dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
outputs["visualization"] = str(vis_path)
|
||||
except Exception as e:
|
||||
print(f"Visualization failed: {e}")
|
||||
|
||||
# Save analysis JSON
|
||||
analysis = {
|
||||
"case_id": case_id,
|
||||
"input_file": file.filename,
|
||||
"experiment": "experiment_4",
|
||||
"model_used": "ScolioVis",
|
||||
"vertebrae_detected": n_vertebrae,
|
||||
"cobb_angles": cobb_angles,
|
||||
"curve_type": curve_type,
|
||||
"rigo_type": rigo_type,
|
||||
"rigo_classification": rigo_classification,
|
||||
"keypoints": keypoints,
|
||||
"bboxes": bboxes,
|
||||
"scores": scores,
|
||||
"mesh_vertices": mesh_vertices,
|
||||
"mesh_faces": mesh_faces,
|
||||
}
|
||||
|
||||
analysis_path = case_dir / f"{case_id}_analysis.json"
|
||||
with open(analysis_path, "w") as f:
|
||||
json.dump(analysis, f, indent=2)
|
||||
outputs["analysis"] = str(analysis_path)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return AnalysisResult(
|
||||
case_id=case_id,
|
||||
experiment="experiment_4",
|
||||
model_used="ScolioVis",
|
||||
vertebrae_detected=n_vertebrae,
|
||||
cobb_angles=cobb_angles,
|
||||
curve_type=curve_type,
|
||||
rigo_classification=rigo_classification,
|
||||
mesh_vertices=mesh_vertices,
|
||||
mesh_faces=mesh_faces,
|
||||
timing_ms=elapsed_ms,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
|
||||
@app.get("/download/{case_id}/{filename}")
|
||||
async def download_file(case_id: str, filename: str):
|
||||
"""Download generated file."""
|
||||
file_path = OUTPUTS_DIR / case_id / filename
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return FileResponse(str(file_path), filename=filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user