""" Test ScolioVis API with Balgrist Patient Data ============================================== Runs Keypoint R-CNN inference on all patient PNG images. Usage: python test_balgrist.py """ import os import sys import json from pathlib import Path # Add parent to path for imports sys.path.insert(0, str(Path(__file__).parent)) import cv2 import numpy as np from PIL import Image import matplotlib matplotlib.use('Agg') # Non-interactive backend import matplotlib.pyplot as plt def load_model(): """Load the Keypoint R-CNN model.""" print("Loading Keypoint R-CNN model...") from scoliovis.get_model import get_kprcnn_model import torch model = get_kprcnn_model() model.eval() print("Model loaded successfully!") return model def predict_single(model, image_path): """ Run prediction on a single image. Returns: dict with detections, landmarks, angles, curve_type, midpoint_lines """ import torch from torchvision.transforms import functional as F from scoliovis.kprcnn import _filter_output, kprcnn_to_scoliovis_api_format # Load image image = Image.open(image_path).convert('RGB') image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # Prepare input device = torch.device('cpu') model.to(device) image_tensor = F.to_tensor(image_cv) images_input = [image_tensor.to(device)] # Inference with torch.no_grad(): outputs = model(images_input) # Filter output bboxes, keypoints, scores = _filter_output(outputs[0]) # Convert to API format result = kprcnn_to_scoliovis_api_format(bboxes, keypoints, scores, image_cv.shape) return result, image_cv, bboxes, keypoints def visualize_result(image_cv, keypoints, result, output_path): """ Create visualization with detected vertebrae and angles. """ img_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) fig, ax = plt.subplots(1, 1, figsize=(10, 16)) ax.imshow(img_rgb) # Draw keypoints for each vertebra colors = plt.cm.rainbow(np.linspace(0, 1, len(keypoints))) for idx, (kps, color) in enumerate(zip(keypoints, colors)): # Draw 4 corners of vertebra xs = [p[0] for p in kps] ys = [p[1] for p in kps] # Connect corners: top-left -> top-right -> bottom-right -> bottom-left -> top-left order = [0, 1, 3, 2, 0] for i in range(4): ax.plot([xs[order[i]], xs[order[i+1]]], [ys[order[i]], ys[order[i+1]]], color=color, linewidth=2) # Mark center cx = np.mean(xs) cy = np.mean(ys) ax.plot(cx, cy, 'o', color=color, markersize=5) ax.text(cx + 20, cy, f'V{idx+1}', color=color, fontsize=8) # Draw midpoint lines if available if result.get('midpoint_lines'): for line in result['midpoint_lines']: ax.plot([line[0][0], line[1][0]], [line[0][1], line[1][1]], 'g-', linewidth=1, alpha=0.5) # Add angle info if result.get('angles'): angles = result['angles'] title_text = f"Curve Type: {result.get('curve_type', 'N/A')}\n" title_text += f"PT: {angles['pt']['angle']:.1f}° " title_text += f"MT: {angles['mt']['angle']:.1f}° " title_text += f"TL: {angles['tl']['angle']:.1f}°" ax.set_title(title_text, fontsize=12) else: ax.set_title("Could not calculate Cobb angles", fontsize=12) ax.axis('off') plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() print(f" Visualization saved: {output_path}") def get_severity(max_angle): """Classify scoliosis severity based on maximum Cobb angle.""" if max_angle < 10: return "Normal" elif max_angle < 25: return "Mild" elif max_angle < 40: return "Moderate" else: return "Severe" def main(): print("=" * 60) print("ScolioVis API - Balgrist Patient Test") print("=" * 60) # Paths balgrist_dir = Path("../PCdareSoftware/Balgrist") output_dir = Path("balgrist_results") output_dir.mkdir(exist_ok=True) # Find patient folders patient_folders = sorted([ d for d in balgrist_dir.iterdir() if d.is_dir() and d.name.isdigit() ]) print(f"\nFound {len(patient_folders)} patient folders") # Load model once model = load_model() # Results summary all_results = [] # Process each patient for patient_folder in patient_folders: patient_id = patient_folder.name print(f"\n{'='*60}") print(f"Processing Patient {patient_id}") print("=" * 60) # Find PNG files (AP and Lateral) png_files = list(patient_folder.glob("*.png")) patient_results = {"patient_id": patient_id, "images": []} for png_file in png_files: print(f"\n Image: {png_file.name}") try: # Run prediction result, image_cv, bboxes, keypoints = predict_single(model, png_file) # Get angles if result.get('angles'): angles = result['angles'] pt = angles['pt']['angle'] mt = angles['mt']['angle'] tl = angles['tl']['angle'] max_angle = max(pt, mt, tl) severity = get_severity(max_angle) print(f" Vertebrae detected: {len(keypoints)}") print(f" Curve type: {result.get('curve_type', 'N/A')}") print(f" PT: {pt:.1f}°") print(f" MT: {mt:.1f}°") print(f" TL: {tl:.1f}°") print(f" Max angle: {max_angle:.1f}° ({severity})") image_result = { "filename": png_file.name, "vertebrae_detected": len(keypoints), "curve_type": result.get('curve_type'), "pt": round(pt, 2), "mt": round(mt, 2), "tl": round(tl, 2), "max_angle": round(max_angle, 2), "severity": severity } else: print(f" Vertebrae detected: {len(keypoints)}") print(f" Could not calculate Cobb angles") image_result = { "filename": png_file.name, "vertebrae_detected": len(keypoints), "error": "Could not calculate angles" } patient_results["images"].append(image_result) # Save visualization output_filename = f"patient{patient_id}_{png_file.stem}_result.png" output_path = output_dir / output_filename visualize_result(image_cv, keypoints, result, output_path) except Exception as e: print(f" ERROR: {e}") patient_results["images"].append({ "filename": png_file.name, "error": str(e) }) all_results.append(patient_results) # Save JSON results results_file = output_dir / "balgrist_results.json" with open(results_file, 'w') as f: json.dump(all_results, f, indent=2) print(f"\nResults saved to: {results_file}") # Print summary print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) print(f"\n{'Patient':<10} {'Image':<25} {'Verts':<8} {'Type':<6} {'PT':<8} {'MT':<8} {'TL':<8} {'Max':<8} {'Severity':<10}") print("-" * 100) for patient in all_results: for img in patient["images"]: if "error" not in img or "vertebrae_detected" in img: print(f"{patient['patient_id']:<10} " f"{img['filename']:<25} " f"{img.get('vertebrae_detected', 'N/A'):<8} " f"{img.get('curve_type', 'N/A'):<6} " f"{img.get('pt', 'N/A'):<8} " f"{img.get('mt', 'N/A'):<8} " f"{img.get('tl', 'N/A'):<8} " f"{img.get('max_angle', 'N/A'):<8} " f"{img.get('severity', 'N/A'):<10}") if __name__ == "__main__": main()