509 lines
18 KiB
Python
509 lines
18 KiB
Python
"""
|
||
Model adapters that convert different model outputs to unified Spine2D format.
|
||
Each adapter wraps a specific model and produces consistent output.
|
||
"""
|
||
import sys
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from typing import Optional, Dict, Any
|
||
from abc import ABC, abstractmethod
|
||
|
||
from data_models import VertebraLandmark, Spine2D
|
||
|
||
|
||
class BaseLandmarkAdapter(ABC):
|
||
"""Base class for landmark detection model adapters."""
|
||
|
||
@abstractmethod
|
||
def predict(self, image: np.ndarray) -> Spine2D:
|
||
"""
|
||
Run inference on an image and return unified spine landmarks.
|
||
|
||
Args:
|
||
image: Input image as numpy array (grayscale or RGB)
|
||
|
||
Returns:
|
||
Spine2D object with detected landmarks
|
||
"""
|
||
pass
|
||
|
||
@property
|
||
@abstractmethod
|
||
def name(self) -> str:
|
||
"""Model name for identification."""
|
||
pass
|
||
|
||
|
||
class ScolioVisAdapter(BaseLandmarkAdapter):
|
||
"""
|
||
Adapter for ScolioVis-API (Keypoint R-CNN model).
|
||
|
||
Uses the original ScolioVis inference code for best accuracy.
|
||
Outputs: 4 keypoints per vertebra + Cobb angles (PT, MT, TL) + curve type (S/C)
|
||
"""
|
||
|
||
def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'):
|
||
"""
|
||
Initialize ScolioVis model.
|
||
|
||
Args:
|
||
weights_path: Path to keypointsrcnn_weights.pt (auto-detects if None)
|
||
device: 'cpu' or 'cuda'
|
||
"""
|
||
self.device = device
|
||
self.model = None
|
||
self.weights_path = weights_path
|
||
self._scoliovis_path = None
|
||
self._load_model()
|
||
|
||
def _load_model(self):
|
||
"""Load the Keypoint R-CNN model."""
|
||
import torch
|
||
import torchvision
|
||
from torchvision.models.detection import keypointrcnn_resnet50_fpn
|
||
from torchvision.models.detection.rpn import AnchorGenerator
|
||
|
||
# Find weights and scoliovis module
|
||
scoliovis_api_path = Path(__file__).parent.parent / 'scoliovis-api'
|
||
if self.weights_path is None:
|
||
possible_paths = [
|
||
scoliovis_api_path / 'models' / 'keypointsrcnn_weights.pt',
|
||
scoliovis_api_path / 'keypointsrcnn_weights.pt',
|
||
scoliovis_api_path / 'weights' / 'keypointsrcnn_weights.pt',
|
||
]
|
||
for p in possible_paths:
|
||
if p.exists():
|
||
self.weights_path = str(p)
|
||
break
|
||
|
||
if self.weights_path is None or not Path(self.weights_path).exists():
|
||
raise FileNotFoundError(
|
||
"ScolioVis weights not found. Please provide weights_path or ensure "
|
||
"scoliovis-api/models/keypointsrcnn_weights.pt exists."
|
||
)
|
||
|
||
# Store path to scoliovis module for Cobb angle calculation
|
||
self._scoliovis_path = scoliovis_api_path
|
||
|
||
# Create model with same anchor generator as original training
|
||
anchor_generator = AnchorGenerator(
|
||
sizes=(32, 64, 128, 256, 512),
|
||
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
|
||
)
|
||
|
||
self.model = keypointrcnn_resnet50_fpn(
|
||
weights=None,
|
||
weights_backbone=None,
|
||
num_classes=2, # background + vertebra
|
||
num_keypoints=4, # 4 corners per vertebra
|
||
rpn_anchor_generator=anchor_generator
|
||
)
|
||
|
||
# Load weights
|
||
checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False)
|
||
self.model.load_state_dict(checkpoint)
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
print(f"ScolioVis model loaded from {self.weights_path}")
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "ScolioVis-API"
|
||
|
||
def _filter_output(self, output, max_verts: int = 17):
|
||
"""
|
||
Filter model output using NMS and score threshold.
|
||
Matches the original ScolioVis filtering logic.
|
||
"""
|
||
import torch
|
||
import torchvision
|
||
|
||
scores = output['scores'].detach().cpu().numpy()
|
||
|
||
# Get indices of scores over threshold (0.5)
|
||
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
|
||
if len(high_scores_idxs) == 0:
|
||
return [], [], []
|
||
|
||
# Apply NMS with IoU threshold 0.3
|
||
post_nms_idxs = torchvision.ops.nms(
|
||
output['boxes'][high_scores_idxs],
|
||
output['scores'][high_scores_idxs],
|
||
0.3
|
||
).cpu().numpy()
|
||
|
||
# Get filtered results
|
||
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
|
||
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
|
||
np_scores = output['scores'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
|
||
|
||
# Take top N by score (usually 17 for full spine)
|
||
sorted_scores_idxs = np.argsort(-1 * np_scores)
|
||
np_scores = np_scores[sorted_scores_idxs][:max_verts]
|
||
np_keypoints = np.array([np_keypoints[idx] for idx in sorted_scores_idxs])[:max_verts]
|
||
np_bboxes = np.array([np_bboxes[idx] for idx in sorted_scores_idxs])[:max_verts]
|
||
|
||
# Sort by ymin (top to bottom)
|
||
if len(np_keypoints) > 0:
|
||
ymins = np.array([kps[0][1] for kps in np_keypoints])
|
||
sorted_ymin_idxs = np.argsort(ymins)
|
||
np_scores = np.array([np_scores[idx] for idx in sorted_ymin_idxs])
|
||
np_keypoints = np.array([np_keypoints[idx] for idx in sorted_ymin_idxs])
|
||
np_bboxes = np.array([np_bboxes[idx] for idx in sorted_ymin_idxs])
|
||
|
||
# Convert to lists
|
||
keypoints_list = []
|
||
for kps in np_keypoints:
|
||
keypoints_list.append([list(map(float, kp[:2])) for kp in kps])
|
||
|
||
bboxes_list = []
|
||
for bbox in np_bboxes:
|
||
bboxes_list.append(list(map(int, bbox.tolist())))
|
||
|
||
scores_list = np_scores.tolist()
|
||
|
||
return bboxes_list, keypoints_list, scores_list
|
||
|
||
def predict(self, image: np.ndarray) -> Spine2D:
|
||
"""Run inference and return unified landmarks with ScolioVis Cobb angles."""
|
||
import torch
|
||
from torchvision.transforms import functional as F
|
||
|
||
# Ensure RGB
|
||
if len(image.shape) == 2:
|
||
image_rgb = np.stack([image, image, image], axis=-1)
|
||
else:
|
||
image_rgb = image
|
||
|
||
image_shape = image_rgb.shape # (H, W, C)
|
||
|
||
# Convert to tensor (ScolioVis uses torchvision's to_tensor)
|
||
img_tensor = F.to_tensor(image_rgb).to(self.device)
|
||
|
||
# Run inference
|
||
with torch.no_grad():
|
||
outputs = self.model([img_tensor])
|
||
|
||
# Filter output using original ScolioVis logic
|
||
bboxes, keypoints, scores = self._filter_output(outputs[0])
|
||
|
||
if len(keypoints) == 0:
|
||
return Spine2D(
|
||
vertebrae=[],
|
||
image_shape=image_shape[:2],
|
||
source_model=self.name
|
||
)
|
||
|
||
# Convert to unified format
|
||
vertebrae = []
|
||
for i in range(len(bboxes)):
|
||
kps = np.array(keypoints[i], dtype=np.float32) # (4, 2)
|
||
|
||
# Corners order from ScolioVis: [top_left, top_right, bottom_left, bottom_right]
|
||
corners = kps
|
||
centroid = np.mean(corners, axis=0)
|
||
|
||
# Compute orientation from top edge (kps[0] to kps[1])
|
||
top_left, top_right = corners[0], corners[1]
|
||
dx = top_right[0] - top_left[0]
|
||
dy = top_right[1] - top_left[1]
|
||
orientation = np.degrees(np.arctan2(dy, dx))
|
||
|
||
vert = VertebraLandmark(
|
||
level=None, # ScolioVis doesn't assign levels
|
||
centroid_px=centroid,
|
||
corners_px=corners,
|
||
endplate_upper_px=corners[:2], # top-left, top-right
|
||
endplate_lower_px=corners[2:], # bottom-left, bottom-right
|
||
orientation_deg=orientation,
|
||
confidence=float(scores[i]),
|
||
meta={'box': bboxes[i]}
|
||
)
|
||
vertebrae.append(vert)
|
||
|
||
# Create Spine2D
|
||
spine = Spine2D(
|
||
vertebrae=vertebrae,
|
||
image_shape=image_shape[:2],
|
||
source_model=self.name
|
||
)
|
||
|
||
# Use original ScolioVis Cobb angle calculation if available
|
||
if len(keypoints) >= 5:
|
||
try:
|
||
# Import original ScolioVis cobb_angle_cal
|
||
if str(self._scoliovis_path) not in sys.path:
|
||
sys.path.insert(0, str(self._scoliovis_path))
|
||
|
||
from scoliovis.cobb_angle_cal import cobb_angle_cal, keypoints_to_landmark_xy
|
||
|
||
landmark_xy = keypoints_to_landmark_xy(keypoints)
|
||
cobb_angles_list, angles_with_pos, curve_type, midpoint_lines = cobb_angle_cal(
|
||
landmark_xy, image_shape
|
||
)
|
||
|
||
# Store Cobb angles in spine object
|
||
spine.cobb_angles = {
|
||
'PT': cobb_angles_list[0],
|
||
'MT': cobb_angles_list[1],
|
||
'TL': cobb_angles_list[2]
|
||
}
|
||
spine.curve_type = curve_type
|
||
spine.meta = {
|
||
'angles_with_pos': angles_with_pos,
|
||
'midpoint_lines': midpoint_lines
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Warning: Could not use ScolioVis Cobb calculation: {e}")
|
||
# Fallback to our own calculation
|
||
from spine_analysis import compute_cobb_angles
|
||
compute_cobb_angles(spine)
|
||
|
||
return spine
|
||
|
||
|
||
class VertLandmarkAdapter(BaseLandmarkAdapter):
|
||
"""
|
||
Adapter for Vertebra-Landmark-Detection (SpineNet model).
|
||
|
||
Outputs: 68 landmarks (4 corners × 17 vertebrae)
|
||
"""
|
||
|
||
def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'):
|
||
"""
|
||
Initialize SpineNet model.
|
||
|
||
Args:
|
||
weights_path: Path to model_last.pth (auto-detects if None)
|
||
device: 'cpu' or 'cuda'
|
||
"""
|
||
self.device = device
|
||
self.model = None
|
||
self.weights_path = weights_path
|
||
self._load_model()
|
||
|
||
def _load_model(self):
|
||
"""Load the SpineNet model."""
|
||
import torch
|
||
|
||
# Find weights
|
||
if self.weights_path is None:
|
||
possible_paths = [
|
||
Path(__file__).parent.parent / 'Vertebra-Landmark-Detection' / 'weights_spinal' / 'model_last.pth',
|
||
]
|
||
for p in possible_paths:
|
||
if p.exists():
|
||
self.weights_path = str(p)
|
||
break
|
||
|
||
if self.weights_path is None or not Path(self.weights_path).exists():
|
||
raise FileNotFoundError(
|
||
"Vertebra-Landmark-Detection weights not found. "
|
||
"Download from Google Drive and place in weights_spinal/model_last.pth"
|
||
)
|
||
|
||
# Add repo to path to import model
|
||
repo_path = Path(__file__).parent.parent / 'Vertebra-Landmark-Detection'
|
||
if str(repo_path) not in sys.path:
|
||
sys.path.insert(0, str(repo_path))
|
||
|
||
from models import spinal_net
|
||
|
||
# Create model
|
||
heads = {'hm': 1, 'reg': 2, 'wh': 8}
|
||
self.model = spinal_net.SpineNet(
|
||
heads=heads,
|
||
pretrained=False,
|
||
down_ratio=4,
|
||
final_kernel=1,
|
||
head_conv=256
|
||
)
|
||
|
||
# Load weights
|
||
checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False)
|
||
self.model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
print(f"Vertebra-Landmark-Detection model loaded from {self.weights_path}")
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "Vertebra-Landmark-Detection"
|
||
|
||
def _nms(self, heat, kernel=3):
|
||
"""Apply NMS using max pooling."""
|
||
import torch
|
||
import torch.nn.functional as F
|
||
hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
|
||
keep = (hmax == heat).float()
|
||
return heat * keep
|
||
|
||
def _gather_feat(self, feat, ind):
|
||
"""Gather features by index."""
|
||
dim = feat.size(2)
|
||
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
|
||
feat = feat.gather(1, ind)
|
||
return feat
|
||
|
||
def _tranpose_and_gather_feat(self, feat, ind):
|
||
"""Transpose and gather features - matches original decoder."""
|
||
feat = feat.permute(0, 2, 3, 1).contiguous()
|
||
feat = feat.view(feat.size(0), -1, feat.size(3))
|
||
feat = self._gather_feat(feat, ind)
|
||
return feat
|
||
|
||
def _decode_predictions(self, output: Dict, down_ratio: int = 4, K: int = 17):
|
||
"""Decode model output using original decoder logic."""
|
||
import torch
|
||
|
||
hm = output['hm'].sigmoid()
|
||
reg = output['reg']
|
||
wh = output['wh']
|
||
|
||
batch, cat, height, width = hm.size()
|
||
|
||
# Apply NMS
|
||
hm = self._nms(hm)
|
||
|
||
# Get top K from heatmap
|
||
topk_scores, topk_inds = torch.topk(hm.view(batch, cat, -1), K)
|
||
topk_inds = topk_inds % (height * width)
|
||
topk_ys = (topk_inds // width).float()
|
||
topk_xs = (topk_inds % width).float()
|
||
|
||
# Get overall top K
|
||
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
|
||
topk_inds = self._gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
|
||
topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
|
||
topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
|
||
|
||
scores = topk_score.view(batch, K, 1)
|
||
|
||
# Get regression offset and apply
|
||
reg = self._tranpose_and_gather_feat(reg, topk_inds)
|
||
reg = reg.view(batch, K, 2)
|
||
xs = topk_xs.view(batch, K, 1) + reg[:, :, 0:1]
|
||
ys = topk_ys.view(batch, K, 1) + reg[:, :, 1:2]
|
||
|
||
# Get corner offsets
|
||
wh = self._tranpose_and_gather_feat(wh, topk_inds)
|
||
wh = wh.view(batch, K, 8)
|
||
|
||
# Calculate corners by SUBTRACTING offsets (original decoder logic)
|
||
tl_x = xs - wh[:, :, 0:1]
|
||
tl_y = ys - wh[:, :, 1:2]
|
||
tr_x = xs - wh[:, :, 2:3]
|
||
tr_y = ys - wh[:, :, 3:4]
|
||
bl_x = xs - wh[:, :, 4:5]
|
||
bl_y = ys - wh[:, :, 5:6]
|
||
br_x = xs - wh[:, :, 6:7]
|
||
br_y = ys - wh[:, :, 7:8]
|
||
|
||
# Combine into output format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score]
|
||
pts = torch.cat([xs, ys, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, scores], dim=2)
|
||
|
||
# Scale to image coordinates
|
||
pts[:, :, :10] *= down_ratio
|
||
|
||
return pts[0].cpu().numpy() # (K, 11)
|
||
|
||
def predict(self, image: np.ndarray) -> Spine2D:
|
||
"""Run inference and return unified landmarks."""
|
||
import torch
|
||
import cv2
|
||
|
||
# Ensure RGB
|
||
if len(image.shape) == 2:
|
||
image_rgb = np.stack([image, image, image], axis=-1)
|
||
else:
|
||
image_rgb = image
|
||
|
||
orig_h, orig_w = image_rgb.shape[:2]
|
||
|
||
# Resize to model input size (1024x512)
|
||
input_h, input_w = 1024, 512
|
||
img_resized = cv2.resize(image_rgb, (input_w, input_h))
|
||
|
||
# Normalize and convert to tensor - use original preprocessing!
|
||
# Original: out_image = image / 255. - 0.5 (NOT ImageNet stats)
|
||
img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0 - 0.5
|
||
img_tensor = img_tensor.unsqueeze(0).to(self.device)
|
||
|
||
# Run inference
|
||
with torch.no_grad():
|
||
output = self.model(img_tensor)
|
||
|
||
# Decode predictions - returns (K, 11) array
|
||
# Format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score]
|
||
pts = self._decode_predictions(output, down_ratio=4, K=17)
|
||
|
||
# Scale coordinates back to original image size
|
||
scale_x = orig_w / input_w
|
||
scale_y = orig_h / input_h
|
||
|
||
# Convert to unified format
|
||
vertebrae = []
|
||
threshold = 0.3
|
||
|
||
for i in range(len(pts)):
|
||
score = pts[i, 10]
|
||
if score < threshold:
|
||
continue
|
||
|
||
# Get center and corners (already scaled by down_ratio in decoder)
|
||
cx = pts[i, 0] * scale_x
|
||
cy = pts[i, 1] * scale_y
|
||
|
||
# Corners: tl, tr, bl, br
|
||
tl = np.array([pts[i, 2] * scale_x, pts[i, 3] * scale_y])
|
||
tr = np.array([pts[i, 4] * scale_x, pts[i, 5] * scale_y])
|
||
bl = np.array([pts[i, 6] * scale_x, pts[i, 7] * scale_y])
|
||
br = np.array([pts[i, 8] * scale_x, pts[i, 9] * scale_y])
|
||
|
||
# Reorder to [tl, tr, br, bl] for consistency with ScolioVis
|
||
corners = np.array([tl, tr, br, bl], dtype=np.float32)
|
||
centroid = np.array([cx, cy], dtype=np.float32)
|
||
|
||
# Compute orientation from top edge
|
||
dx = tr[0] - tl[0]
|
||
dy = tr[1] - tl[1]
|
||
orientation = np.degrees(np.arctan2(dy, dx))
|
||
|
||
vert = VertebraLandmark(
|
||
level=None,
|
||
centroid_px=centroid,
|
||
corners_px=corners,
|
||
endplate_upper_px=np.array([tl, tr]), # top edge
|
||
endplate_lower_px=np.array([bl, br]), # bottom edge
|
||
orientation_deg=orientation,
|
||
confidence=float(score),
|
||
meta={'raw_pts': pts[i].tolist()}
|
||
)
|
||
vertebrae.append(vert)
|
||
|
||
# Sort by y-coordinate (top to bottom)
|
||
vertebrae.sort(key=lambda v: float(v.centroid_px[1]))
|
||
|
||
# Assign vertebra levels (T1-L5 = 17 vertebrae typically)
|
||
level_names = ['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4', 'L5']
|
||
for i, vert in enumerate(vertebrae):
|
||
if i < len(level_names):
|
||
vert.level = level_names[i]
|
||
|
||
# Create Spine2D
|
||
spine = Spine2D(
|
||
vertebrae=vertebrae,
|
||
image_shape=(orig_h, orig_w),
|
||
source_model=self.name
|
||
)
|
||
|
||
# Compute Cobb angles
|
||
if len(vertebrae) >= 7:
|
||
from spine_analysis import compute_cobb_angles
|
||
compute_cobb_angles(spine)
|
||
|
||
return spine
|