142 lines
4.4 KiB
Python
142 lines
4.4 KiB
Python
"""
|
|
Test ScolioVis API with Spinal-AI2024 subset5 images
|
|
"""
|
|
|
|
import sys
|
|
import json
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def load_model():
|
|
print("Loading Keypoint R-CNN model...")
|
|
from scoliovis.get_model import get_kprcnn_model
|
|
model = get_kprcnn_model()
|
|
model.eval()
|
|
print("Model loaded!")
|
|
return model
|
|
|
|
|
|
def predict_single(model, image_path):
|
|
import torch
|
|
from torchvision.transforms import functional as F
|
|
from scoliovis.kprcnn import _filter_output, kprcnn_to_scoliovis_api_format
|
|
|
|
image = Image.open(image_path).convert('RGB')
|
|
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
|
|
|
device = torch.device('cpu')
|
|
model.to(device)
|
|
|
|
image_tensor = F.to_tensor(image_cv)
|
|
images_input = [image_tensor.to(device)]
|
|
|
|
with torch.no_grad():
|
|
outputs = model(images_input)
|
|
|
|
bboxes, keypoints, scores = _filter_output(outputs[0])
|
|
result = kprcnn_to_scoliovis_api_format(bboxes, keypoints, scores, image_cv.shape)
|
|
|
|
return result, image_cv, keypoints
|
|
|
|
|
|
def visualize_result(image_cv, keypoints, result, output_path):
|
|
img_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
|
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(8, 12))
|
|
ax.imshow(img_rgb)
|
|
|
|
colors = plt.cm.rainbow(np.linspace(0, 1, len(keypoints)))
|
|
|
|
for idx, (kps, color) in enumerate(zip(keypoints, colors)):
|
|
xs = [p[0] for p in kps]
|
|
ys = [p[1] for p in kps]
|
|
|
|
order = [0, 1, 3, 2, 0]
|
|
for i in range(4):
|
|
ax.plot([xs[order[i]], xs[order[i+1]]],
|
|
[ys[order[i]], ys[order[i+1]]],
|
|
color=color, linewidth=2)
|
|
|
|
cx, cy = np.mean(xs), np.mean(ys)
|
|
ax.plot(cx, cy, 'o', color=color, markersize=4)
|
|
|
|
if result.get('angles'):
|
|
angles = result['angles']
|
|
title = f"Type: {result.get('curve_type', 'N/A')}\n"
|
|
title += f"PT: {angles['pt']['angle']:.1f}° MT: {angles['mt']['angle']:.1f}° TL: {angles['tl']['angle']:.1f}°"
|
|
ax.set_title(title, fontsize=10)
|
|
else:
|
|
ax.set_title("Could not calculate angles", fontsize=10)
|
|
|
|
ax.axis('off')
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
|
|
def main():
|
|
test_images = [
|
|
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016001.jpg",
|
|
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016002.jpg",
|
|
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016003.jpg",
|
|
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016004.jpg",
|
|
"../data/Spinal-AI2024/Spinal-AI2024-subset5/016005.jpg",
|
|
]
|
|
|
|
output_dir = Path("OUTPUT_TEST_1")
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
model = load_model()
|
|
|
|
results = []
|
|
|
|
for img_path in test_images:
|
|
img_name = Path(img_path).stem
|
|
print(f"\nProcessing {img_name}...")
|
|
|
|
result, image_cv, keypoints = predict_single(model, img_path)
|
|
|
|
# Save visualization
|
|
output_path = output_dir / f"{img_name}_result.png"
|
|
visualize_result(image_cv, keypoints, result, output_path)
|
|
print(f" Saved: {output_path}")
|
|
|
|
# Collect results
|
|
if result.get('angles'):
|
|
angles = result['angles']
|
|
results.append({
|
|
"image": img_name + ".jpg",
|
|
"vertebrae_detected": len(keypoints),
|
|
"curve_type": result.get('curve_type'),
|
|
"pt": round(angles['pt']['angle'], 2),
|
|
"mt": round(angles['mt']['angle'], 2),
|
|
"tl": round(angles['tl']['angle'], 2)
|
|
})
|
|
print(f" Vertebrae: {len(keypoints)}, PT: {angles['pt']['angle']:.1f}°, MT: {angles['mt']['angle']:.1f}°, TL: {angles['tl']['angle']:.1f}°")
|
|
else:
|
|
results.append({
|
|
"image": img_name + ".jpg",
|
|
"vertebrae_detected": len(keypoints),
|
|
"error": "Could not calculate angles"
|
|
})
|
|
print(f" Vertebrae: {len(keypoints)}, Could not calculate angles")
|
|
|
|
# Save JSON results
|
|
with open(output_dir / "results.json", 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
print(f"\nResults saved to {output_dir}/results.json")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|