264 lines
8.5 KiB
Python
264 lines
8.5 KiB
Python
"""
|
|
Test ScolioVis API with Balgrist Patient Data
|
|
==============================================
|
|
Runs Keypoint R-CNN inference on all patient PNG images.
|
|
|
|
Usage:
|
|
python test_balgrist.py
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# Add parent to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
import matplotlib
|
|
matplotlib.use('Agg') # Non-interactive backend
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def load_model():
|
|
"""Load the Keypoint R-CNN model."""
|
|
print("Loading Keypoint R-CNN model...")
|
|
from scoliovis.get_model import get_kprcnn_model
|
|
import torch
|
|
|
|
model = get_kprcnn_model()
|
|
model.eval()
|
|
print("Model loaded successfully!")
|
|
return model
|
|
|
|
|
|
def predict_single(model, image_path):
|
|
"""
|
|
Run prediction on a single image.
|
|
|
|
Returns:
|
|
dict with detections, landmarks, angles, curve_type, midpoint_lines
|
|
"""
|
|
import torch
|
|
from torchvision.transforms import functional as F
|
|
from scoliovis.kprcnn import _filter_output, kprcnn_to_scoliovis_api_format
|
|
|
|
# Load image
|
|
image = Image.open(image_path).convert('RGB')
|
|
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
|
|
|
# Prepare input
|
|
device = torch.device('cpu')
|
|
model.to(device)
|
|
|
|
image_tensor = F.to_tensor(image_cv)
|
|
images_input = [image_tensor.to(device)]
|
|
|
|
# Inference
|
|
with torch.no_grad():
|
|
outputs = model(images_input)
|
|
|
|
# Filter output
|
|
bboxes, keypoints, scores = _filter_output(outputs[0])
|
|
|
|
# Convert to API format
|
|
result = kprcnn_to_scoliovis_api_format(bboxes, keypoints, scores, image_cv.shape)
|
|
|
|
return result, image_cv, bboxes, keypoints
|
|
|
|
|
|
def visualize_result(image_cv, keypoints, result, output_path):
|
|
"""
|
|
Create visualization with detected vertebrae and angles.
|
|
"""
|
|
img_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(10, 16))
|
|
ax.imshow(img_rgb)
|
|
|
|
# Draw keypoints for each vertebra
|
|
colors = plt.cm.rainbow(np.linspace(0, 1, len(keypoints)))
|
|
|
|
for idx, (kps, color) in enumerate(zip(keypoints, colors)):
|
|
# Draw 4 corners of vertebra
|
|
xs = [p[0] for p in kps]
|
|
ys = [p[1] for p in kps]
|
|
|
|
# Connect corners: top-left -> top-right -> bottom-right -> bottom-left -> top-left
|
|
order = [0, 1, 3, 2, 0]
|
|
for i in range(4):
|
|
ax.plot([xs[order[i]], xs[order[i+1]]],
|
|
[ys[order[i]], ys[order[i+1]]],
|
|
color=color, linewidth=2)
|
|
|
|
# Mark center
|
|
cx = np.mean(xs)
|
|
cy = np.mean(ys)
|
|
ax.plot(cx, cy, 'o', color=color, markersize=5)
|
|
ax.text(cx + 20, cy, f'V{idx+1}', color=color, fontsize=8)
|
|
|
|
# Draw midpoint lines if available
|
|
if result.get('midpoint_lines'):
|
|
for line in result['midpoint_lines']:
|
|
ax.plot([line[0][0], line[1][0]],
|
|
[line[0][1], line[1][1]],
|
|
'g-', linewidth=1, alpha=0.5)
|
|
|
|
# Add angle info
|
|
if result.get('angles'):
|
|
angles = result['angles']
|
|
title_text = f"Curve Type: {result.get('curve_type', 'N/A')}\n"
|
|
title_text += f"PT: {angles['pt']['angle']:.1f}° "
|
|
title_text += f"MT: {angles['mt']['angle']:.1f}° "
|
|
title_text += f"TL: {angles['tl']['angle']:.1f}°"
|
|
ax.set_title(title_text, fontsize=12)
|
|
else:
|
|
ax.set_title("Could not calculate Cobb angles", fontsize=12)
|
|
|
|
ax.axis('off')
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
print(f" Visualization saved: {output_path}")
|
|
|
|
|
|
def get_severity(max_angle):
|
|
"""Classify scoliosis severity based on maximum Cobb angle."""
|
|
if max_angle < 10:
|
|
return "Normal"
|
|
elif max_angle < 25:
|
|
return "Mild"
|
|
elif max_angle < 40:
|
|
return "Moderate"
|
|
else:
|
|
return "Severe"
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("ScolioVis API - Balgrist Patient Test")
|
|
print("=" * 60)
|
|
|
|
# Paths
|
|
balgrist_dir = Path("../PCdareSoftware/Balgrist")
|
|
output_dir = Path("balgrist_results")
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
# Find patient folders
|
|
patient_folders = sorted([
|
|
d for d in balgrist_dir.iterdir()
|
|
if d.is_dir() and d.name.isdigit()
|
|
])
|
|
|
|
print(f"\nFound {len(patient_folders)} patient folders")
|
|
|
|
# Load model once
|
|
model = load_model()
|
|
|
|
# Results summary
|
|
all_results = []
|
|
|
|
# Process each patient
|
|
for patient_folder in patient_folders:
|
|
patient_id = patient_folder.name
|
|
print(f"\n{'='*60}")
|
|
print(f"Processing Patient {patient_id}")
|
|
print("=" * 60)
|
|
|
|
# Find PNG files (AP and Lateral)
|
|
png_files = list(patient_folder.glob("*.png"))
|
|
|
|
patient_results = {"patient_id": patient_id, "images": []}
|
|
|
|
for png_file in png_files:
|
|
print(f"\n Image: {png_file.name}")
|
|
|
|
try:
|
|
# Run prediction
|
|
result, image_cv, bboxes, keypoints = predict_single(model, png_file)
|
|
|
|
# Get angles
|
|
if result.get('angles'):
|
|
angles = result['angles']
|
|
pt = angles['pt']['angle']
|
|
mt = angles['mt']['angle']
|
|
tl = angles['tl']['angle']
|
|
max_angle = max(pt, mt, tl)
|
|
severity = get_severity(max_angle)
|
|
|
|
print(f" Vertebrae detected: {len(keypoints)}")
|
|
print(f" Curve type: {result.get('curve_type', 'N/A')}")
|
|
print(f" PT: {pt:.1f}°")
|
|
print(f" MT: {mt:.1f}°")
|
|
print(f" TL: {tl:.1f}°")
|
|
print(f" Max angle: {max_angle:.1f}° ({severity})")
|
|
|
|
image_result = {
|
|
"filename": png_file.name,
|
|
"vertebrae_detected": len(keypoints),
|
|
"curve_type": result.get('curve_type'),
|
|
"pt": round(pt, 2),
|
|
"mt": round(mt, 2),
|
|
"tl": round(tl, 2),
|
|
"max_angle": round(max_angle, 2),
|
|
"severity": severity
|
|
}
|
|
else:
|
|
print(f" Vertebrae detected: {len(keypoints)}")
|
|
print(f" Could not calculate Cobb angles")
|
|
image_result = {
|
|
"filename": png_file.name,
|
|
"vertebrae_detected": len(keypoints),
|
|
"error": "Could not calculate angles"
|
|
}
|
|
|
|
patient_results["images"].append(image_result)
|
|
|
|
# Save visualization
|
|
output_filename = f"patient{patient_id}_{png_file.stem}_result.png"
|
|
output_path = output_dir / output_filename
|
|
visualize_result(image_cv, keypoints, result, output_path)
|
|
|
|
except Exception as e:
|
|
print(f" ERROR: {e}")
|
|
patient_results["images"].append({
|
|
"filename": png_file.name,
|
|
"error": str(e)
|
|
})
|
|
|
|
all_results.append(patient_results)
|
|
|
|
# Save JSON results
|
|
results_file = output_dir / "balgrist_results.json"
|
|
with open(results_file, 'w') as f:
|
|
json.dump(all_results, f, indent=2)
|
|
print(f"\nResults saved to: {results_file}")
|
|
|
|
# Print summary
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
|
|
print(f"\n{'Patient':<10} {'Image':<25} {'Verts':<8} {'Type':<6} {'PT':<8} {'MT':<8} {'TL':<8} {'Max':<8} {'Severity':<10}")
|
|
print("-" * 100)
|
|
|
|
for patient in all_results:
|
|
for img in patient["images"]:
|
|
if "error" not in img or "vertebrae_detected" in img:
|
|
print(f"{patient['patient_id']:<10} "
|
|
f"{img['filename']:<25} "
|
|
f"{img.get('vertebrae_detected', 'N/A'):<8} "
|
|
f"{img.get('curve_type', 'N/A'):<6} "
|
|
f"{img.get('pt', 'N/A'):<8} "
|
|
f"{img.get('mt', 'N/A'):<8} "
|
|
f"{img.get('tl', 'N/A'):<8} "
|
|
f"{img.get('max_angle', 'N/A'):<8} "
|
|
f"{img.get('severity', 'N/A'):<10}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|