Initial commit - BraceIQMed platform with frontend, API, and brace generator
This commit is contained in:
263
scoliovis-api/test_balgrist.py
Normal file
263
scoliovis-api/test_balgrist.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Test ScolioVis API with Balgrist Patient Data
|
||||
==============================================
|
||||
Runs Keypoint R-CNN inference on all patient PNG images.
|
||||
|
||||
Usage:
|
||||
python test_balgrist.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # Non-interactive backend
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def load_model():
|
||||
"""Load the Keypoint R-CNN model."""
|
||||
print("Loading Keypoint R-CNN model...")
|
||||
from scoliovis.get_model import get_kprcnn_model
|
||||
import torch
|
||||
|
||||
model = get_kprcnn_model()
|
||||
model.eval()
|
||||
print("Model loaded successfully!")
|
||||
return model
|
||||
|
||||
|
||||
def predict_single(model, image_path):
|
||||
"""
|
||||
Run prediction on a single image.
|
||||
|
||||
Returns:
|
||||
dict with detections, landmarks, angles, curve_type, midpoint_lines
|
||||
"""
|
||||
import torch
|
||||
from torchvision.transforms import functional as F
|
||||
from scoliovis.kprcnn import _filter_output, kprcnn_to_scoliovis_api_format
|
||||
|
||||
# Load image
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Prepare input
|
||||
device = torch.device('cpu')
|
||||
model.to(device)
|
||||
|
||||
image_tensor = F.to_tensor(image_cv)
|
||||
images_input = [image_tensor.to(device)]
|
||||
|
||||
# Inference
|
||||
with torch.no_grad():
|
||||
outputs = model(images_input)
|
||||
|
||||
# Filter output
|
||||
bboxes, keypoints, scores = _filter_output(outputs[0])
|
||||
|
||||
# Convert to API format
|
||||
result = kprcnn_to_scoliovis_api_format(bboxes, keypoints, scores, image_cv.shape)
|
||||
|
||||
return result, image_cv, bboxes, keypoints
|
||||
|
||||
|
||||
def visualize_result(image_cv, keypoints, result, output_path):
|
||||
"""
|
||||
Create visualization with detected vertebrae and angles.
|
||||
"""
|
||||
img_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 16))
|
||||
ax.imshow(img_rgb)
|
||||
|
||||
# Draw keypoints for each vertebra
|
||||
colors = plt.cm.rainbow(np.linspace(0, 1, len(keypoints)))
|
||||
|
||||
for idx, (kps, color) in enumerate(zip(keypoints, colors)):
|
||||
# Draw 4 corners of vertebra
|
||||
xs = [p[0] for p in kps]
|
||||
ys = [p[1] for p in kps]
|
||||
|
||||
# Connect corners: top-left -> top-right -> bottom-right -> bottom-left -> top-left
|
||||
order = [0, 1, 3, 2, 0]
|
||||
for i in range(4):
|
||||
ax.plot([xs[order[i]], xs[order[i+1]]],
|
||||
[ys[order[i]], ys[order[i+1]]],
|
||||
color=color, linewidth=2)
|
||||
|
||||
# Mark center
|
||||
cx = np.mean(xs)
|
||||
cy = np.mean(ys)
|
||||
ax.plot(cx, cy, 'o', color=color, markersize=5)
|
||||
ax.text(cx + 20, cy, f'V{idx+1}', color=color, fontsize=8)
|
||||
|
||||
# Draw midpoint lines if available
|
||||
if result.get('midpoint_lines'):
|
||||
for line in result['midpoint_lines']:
|
||||
ax.plot([line[0][0], line[1][0]],
|
||||
[line[0][1], line[1][1]],
|
||||
'g-', linewidth=1, alpha=0.5)
|
||||
|
||||
# Add angle info
|
||||
if result.get('angles'):
|
||||
angles = result['angles']
|
||||
title_text = f"Curve Type: {result.get('curve_type', 'N/A')}\n"
|
||||
title_text += f"PT: {angles['pt']['angle']:.1f}° "
|
||||
title_text += f"MT: {angles['mt']['angle']:.1f}° "
|
||||
title_text += f"TL: {angles['tl']['angle']:.1f}°"
|
||||
ax.set_title(title_text, fontsize=12)
|
||||
else:
|
||||
ax.set_title("Could not calculate Cobb angles", fontsize=12)
|
||||
|
||||
ax.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f" Visualization saved: {output_path}")
|
||||
|
||||
|
||||
def get_severity(max_angle):
|
||||
"""Classify scoliosis severity based on maximum Cobb angle."""
|
||||
if max_angle < 10:
|
||||
return "Normal"
|
||||
elif max_angle < 25:
|
||||
return "Mild"
|
||||
elif max_angle < 40:
|
||||
return "Moderate"
|
||||
else:
|
||||
return "Severe"
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("ScolioVis API - Balgrist Patient Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Paths
|
||||
balgrist_dir = Path("../PCdareSoftware/Balgrist")
|
||||
output_dir = Path("balgrist_results")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Find patient folders
|
||||
patient_folders = sorted([
|
||||
d for d in balgrist_dir.iterdir()
|
||||
if d.is_dir() and d.name.isdigit()
|
||||
])
|
||||
|
||||
print(f"\nFound {len(patient_folders)} patient folders")
|
||||
|
||||
# Load model once
|
||||
model = load_model()
|
||||
|
||||
# Results summary
|
||||
all_results = []
|
||||
|
||||
# Process each patient
|
||||
for patient_folder in patient_folders:
|
||||
patient_id = patient_folder.name
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Processing Patient {patient_id}")
|
||||
print("=" * 60)
|
||||
|
||||
# Find PNG files (AP and Lateral)
|
||||
png_files = list(patient_folder.glob("*.png"))
|
||||
|
||||
patient_results = {"patient_id": patient_id, "images": []}
|
||||
|
||||
for png_file in png_files:
|
||||
print(f"\n Image: {png_file.name}")
|
||||
|
||||
try:
|
||||
# Run prediction
|
||||
result, image_cv, bboxes, keypoints = predict_single(model, png_file)
|
||||
|
||||
# Get angles
|
||||
if result.get('angles'):
|
||||
angles = result['angles']
|
||||
pt = angles['pt']['angle']
|
||||
mt = angles['mt']['angle']
|
||||
tl = angles['tl']['angle']
|
||||
max_angle = max(pt, mt, tl)
|
||||
severity = get_severity(max_angle)
|
||||
|
||||
print(f" Vertebrae detected: {len(keypoints)}")
|
||||
print(f" Curve type: {result.get('curve_type', 'N/A')}")
|
||||
print(f" PT: {pt:.1f}°")
|
||||
print(f" MT: {mt:.1f}°")
|
||||
print(f" TL: {tl:.1f}°")
|
||||
print(f" Max angle: {max_angle:.1f}° ({severity})")
|
||||
|
||||
image_result = {
|
||||
"filename": png_file.name,
|
||||
"vertebrae_detected": len(keypoints),
|
||||
"curve_type": result.get('curve_type'),
|
||||
"pt": round(pt, 2),
|
||||
"mt": round(mt, 2),
|
||||
"tl": round(tl, 2),
|
||||
"max_angle": round(max_angle, 2),
|
||||
"severity": severity
|
||||
}
|
||||
else:
|
||||
print(f" Vertebrae detected: {len(keypoints)}")
|
||||
print(f" Could not calculate Cobb angles")
|
||||
image_result = {
|
||||
"filename": png_file.name,
|
||||
"vertebrae_detected": len(keypoints),
|
||||
"error": "Could not calculate angles"
|
||||
}
|
||||
|
||||
patient_results["images"].append(image_result)
|
||||
|
||||
# Save visualization
|
||||
output_filename = f"patient{patient_id}_{png_file.stem}_result.png"
|
||||
output_path = output_dir / output_filename
|
||||
visualize_result(image_cv, keypoints, result, output_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
patient_results["images"].append({
|
||||
"filename": png_file.name,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
all_results.append(patient_results)
|
||||
|
||||
# Save JSON results
|
||||
results_file = output_dir / "balgrist_results.json"
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
print(f"\nResults saved to: {results_file}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\n{'Patient':<10} {'Image':<25} {'Verts':<8} {'Type':<6} {'PT':<8} {'MT':<8} {'TL':<8} {'Max':<8} {'Severity':<10}")
|
||||
print("-" * 100)
|
||||
|
||||
for patient in all_results:
|
||||
for img in patient["images"]:
|
||||
if "error" not in img or "vertebrae_detected" in img:
|
||||
print(f"{patient['patient_id']:<10} "
|
||||
f"{img['filename']:<25} "
|
||||
f"{img.get('vertebrae_detected', 'N/A'):<8} "
|
||||
f"{img.get('curve_type', 'N/A'):<6} "
|
||||
f"{img.get('pt', 'N/A'):<8} "
|
||||
f"{img.get('mt', 'N/A'):<8} "
|
||||
f"{img.get('tl', 'N/A'):<8} "
|
||||
f"{img.get('max_angle', 'N/A'):<8} "
|
||||
f"{img.get('severity', 'N/A'):<10}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user