Files
braceiqmed/scoliovis-api/test_subset5.py

142 lines
4.4 KiB
Python

"""
Test ScolioVis API with Spinal-AI2024 subset5 images
"""
import sys
import json
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import cv2
import numpy as np
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def load_model():
print("Loading Keypoint R-CNN model...")
from scoliovis.get_model import get_kprcnn_model
model = get_kprcnn_model()
model.eval()
print("Model loaded!")
return model
def predict_single(model, image_path):
import torch
from torchvision.transforms import functional as F
from scoliovis.kprcnn import _filter_output, kprcnn_to_scoliovis_api_format
image = Image.open(image_path).convert('RGB')
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
device = torch.device('cpu')
model.to(device)
image_tensor = F.to_tensor(image_cv)
images_input = [image_tensor.to(device)]
with torch.no_grad():
outputs = model(images_input)
bboxes, keypoints, scores = _filter_output(outputs[0])
result = kprcnn_to_scoliovis_api_format(bboxes, keypoints, scores, image_cv.shape)
return result, image_cv, keypoints
def visualize_result(image_cv, keypoints, result, output_path):
img_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
fig, ax = plt.subplots(1, 1, figsize=(8, 12))
ax.imshow(img_rgb)
colors = plt.cm.rainbow(np.linspace(0, 1, len(keypoints)))
for idx, (kps, color) in enumerate(zip(keypoints, colors)):
xs = [p[0] for p in kps]
ys = [p[1] for p in kps]
order = [0, 1, 3, 2, 0]
for i in range(4):
ax.plot([xs[order[i]], xs[order[i+1]]],
[ys[order[i]], ys[order[i+1]]],
color=color, linewidth=2)
cx, cy = np.mean(xs), np.mean(ys)
ax.plot(cx, cy, 'o', color=color, markersize=4)
if result.get('angles'):
angles = result['angles']
title = f"Type: {result.get('curve_type', 'N/A')}\n"
title += f"PT: {angles['pt']['angle']:.1f}° MT: {angles['mt']['angle']:.1f}° TL: {angles['tl']['angle']:.1f}°"
ax.set_title(title, fontsize=10)
else:
ax.set_title("Could not calculate angles", fontsize=10)
ax.axis('off')
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
def main():
test_images = [
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016001.jpg",
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016002.jpg",
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016003.jpg",
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016004.jpg",
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016005.jpg",
]
output_dir = Path("OUTPUT_TEST_1")
output_dir.mkdir(exist_ok=True)
model = load_model()
results = []
for img_path in test_images:
img_name = Path(img_path).stem
print(f"\nProcessing {img_name}...")
result, image_cv, keypoints = predict_single(model, img_path)
# Save visualization
output_path = output_dir / f"{img_name}_result.png"
visualize_result(image_cv, keypoints, result, output_path)
print(f" Saved: {output_path}")
# Collect results
if result.get('angles'):
angles = result['angles']
results.append({
"image": img_name + ".jpg",
"vertebrae_detected": len(keypoints),
"curve_type": result.get('curve_type'),
"pt": round(angles['pt']['angle'], 2),
"mt": round(angles['mt']['angle'], 2),
"tl": round(angles['tl']['angle'], 2)
})
print(f" Vertebrae: {len(keypoints)}, PT: {angles['pt']['angle']:.1f}°, MT: {angles['mt']['angle']:.1f}°, TL: {angles['tl']['angle']:.1f}°")
else:
results.append({
"image": img_name + ".jpg",
"vertebrae_detected": len(keypoints),
"error": "Could not calculate angles"
})
print(f" Vertebrae: {len(keypoints)}, Could not calculate angles")
# Save JSON results
with open(output_dir / "results.json", 'w') as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {output_dir}/results.json")
if __name__ == "__main__":
main()