Files
braceiqmed/brace-generator/adapters.py

509 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Model adapters that convert different model outputs to unified Spine2D format.
Each adapter wraps a specific model and produces consistent output.
"""
import sys
import numpy as np
from pathlib import Path
from typing import Optional, Dict, Any
from abc import ABC, abstractmethod
from data_models import VertebraLandmark, Spine2D
class BaseLandmarkAdapter(ABC):
"""Base class for landmark detection model adapters."""
@abstractmethod
def predict(self, image: np.ndarray) -> Spine2D:
"""
Run inference on an image and return unified spine landmarks.
Args:
image: Input image as numpy array (grayscale or RGB)
Returns:
Spine2D object with detected landmarks
"""
pass
@property
@abstractmethod
def name(self) -> str:
"""Model name for identification."""
pass
class ScolioVisAdapter(BaseLandmarkAdapter):
"""
Adapter for ScolioVis-API (Keypoint R-CNN model).
Uses the original ScolioVis inference code for best accuracy.
Outputs: 4 keypoints per vertebra + Cobb angles (PT, MT, TL) + curve type (S/C)
"""
def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'):
"""
Initialize ScolioVis model.
Args:
weights_path: Path to keypointsrcnn_weights.pt (auto-detects if None)
device: 'cpu' or 'cuda'
"""
self.device = device
self.model = None
self.weights_path = weights_path
self._scoliovis_path = None
self._load_model()
def _load_model(self):
"""Load the Keypoint R-CNN model."""
import torch
import torchvision
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection.rpn import AnchorGenerator
# Find weights and scoliovis module
scoliovis_api_path = Path(__file__).parent.parent / 'scoliovis-api'
if self.weights_path is None:
possible_paths = [
scoliovis_api_path / 'models' / 'keypointsrcnn_weights.pt',
scoliovis_api_path / 'keypointsrcnn_weights.pt',
scoliovis_api_path / 'weights' / 'keypointsrcnn_weights.pt',
]
for p in possible_paths:
if p.exists():
self.weights_path = str(p)
break
if self.weights_path is None or not Path(self.weights_path).exists():
raise FileNotFoundError(
"ScolioVis weights not found. Please provide weights_path or ensure "
"scoliovis-api/models/keypointsrcnn_weights.pt exists."
)
# Store path to scoliovis module for Cobb angle calculation
self._scoliovis_path = scoliovis_api_path
# Create model with same anchor generator as original training
anchor_generator = AnchorGenerator(
sizes=(32, 64, 128, 256, 512),
aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0)
)
self.model = keypointrcnn_resnet50_fpn(
weights=None,
weights_backbone=None,
num_classes=2, # background + vertebra
num_keypoints=4, # 4 corners per vertebra
rpn_anchor_generator=anchor_generator
)
# Load weights
checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint)
self.model.to(self.device)
self.model.eval()
print(f"ScolioVis model loaded from {self.weights_path}")
@property
def name(self) -> str:
return "ScolioVis-API"
def _filter_output(self, output, max_verts: int = 17):
"""
Filter model output using NMS and score threshold.
Matches the original ScolioVis filtering logic.
"""
import torch
import torchvision
scores = output['scores'].detach().cpu().numpy()
# Get indices of scores over threshold (0.5)
high_scores_idxs = np.where(scores > 0.5)[0].tolist()
if len(high_scores_idxs) == 0:
return [], [], []
# Apply NMS with IoU threshold 0.3
post_nms_idxs = torchvision.ops.nms(
output['boxes'][high_scores_idxs],
output['scores'][high_scores_idxs],
0.3
).cpu().numpy()
# Get filtered results
np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
np_scores = output['scores'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy()
# Take top N by score (usually 17 for full spine)
sorted_scores_idxs = np.argsort(-1 * np_scores)
np_scores = np_scores[sorted_scores_idxs][:max_verts]
np_keypoints = np.array([np_keypoints[idx] for idx in sorted_scores_idxs])[:max_verts]
np_bboxes = np.array([np_bboxes[idx] for idx in sorted_scores_idxs])[:max_verts]
# Sort by ymin (top to bottom)
if len(np_keypoints) > 0:
ymins = np.array([kps[0][1] for kps in np_keypoints])
sorted_ymin_idxs = np.argsort(ymins)
np_scores = np.array([np_scores[idx] for idx in sorted_ymin_idxs])
np_keypoints = np.array([np_keypoints[idx] for idx in sorted_ymin_idxs])
np_bboxes = np.array([np_bboxes[idx] for idx in sorted_ymin_idxs])
# Convert to lists
keypoints_list = []
for kps in np_keypoints:
keypoints_list.append([list(map(float, kp[:2])) for kp in kps])
bboxes_list = []
for bbox in np_bboxes:
bboxes_list.append(list(map(int, bbox.tolist())))
scores_list = np_scores.tolist()
return bboxes_list, keypoints_list, scores_list
def predict(self, image: np.ndarray) -> Spine2D:
"""Run inference and return unified landmarks with ScolioVis Cobb angles."""
import torch
from torchvision.transforms import functional as F
# Ensure RGB
if len(image.shape) == 2:
image_rgb = np.stack([image, image, image], axis=-1)
else:
image_rgb = image
image_shape = image_rgb.shape # (H, W, C)
# Convert to tensor (ScolioVis uses torchvision's to_tensor)
img_tensor = F.to_tensor(image_rgb).to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model([img_tensor])
# Filter output using original ScolioVis logic
bboxes, keypoints, scores = self._filter_output(outputs[0])
if len(keypoints) == 0:
return Spine2D(
vertebrae=[],
image_shape=image_shape[:2],
source_model=self.name
)
# Convert to unified format
vertebrae = []
for i in range(len(bboxes)):
kps = np.array(keypoints[i], dtype=np.float32) # (4, 2)
# Corners order from ScolioVis: [top_left, top_right, bottom_left, bottom_right]
corners = kps
centroid = np.mean(corners, axis=0)
# Compute orientation from top edge (kps[0] to kps[1])
top_left, top_right = corners[0], corners[1]
dx = top_right[0] - top_left[0]
dy = top_right[1] - top_left[1]
orientation = np.degrees(np.arctan2(dy, dx))
vert = VertebraLandmark(
level=None, # ScolioVis doesn't assign levels
centroid_px=centroid,
corners_px=corners,
endplate_upper_px=corners[:2], # top-left, top-right
endplate_lower_px=corners[2:], # bottom-left, bottom-right
orientation_deg=orientation,
confidence=float(scores[i]),
meta={'box': bboxes[i]}
)
vertebrae.append(vert)
# Create Spine2D
spine = Spine2D(
vertebrae=vertebrae,
image_shape=image_shape[:2],
source_model=self.name
)
# Use original ScolioVis Cobb angle calculation if available
if len(keypoints) >= 5:
try:
# Import original ScolioVis cobb_angle_cal
if str(self._scoliovis_path) not in sys.path:
sys.path.insert(0, str(self._scoliovis_path))
from scoliovis.cobb_angle_cal import cobb_angle_cal, keypoints_to_landmark_xy
landmark_xy = keypoints_to_landmark_xy(keypoints)
cobb_angles_list, angles_with_pos, curve_type, midpoint_lines = cobb_angle_cal(
landmark_xy, image_shape
)
# Store Cobb angles in spine object
spine.cobb_angles = {
'PT': cobb_angles_list[0],
'MT': cobb_angles_list[1],
'TL': cobb_angles_list[2]
}
spine.curve_type = curve_type
spine.meta = {
'angles_with_pos': angles_with_pos,
'midpoint_lines': midpoint_lines
}
except Exception as e:
print(f"Warning: Could not use ScolioVis Cobb calculation: {e}")
# Fallback to our own calculation
from spine_analysis import compute_cobb_angles
compute_cobb_angles(spine)
return spine
class VertLandmarkAdapter(BaseLandmarkAdapter):
"""
Adapter for Vertebra-Landmark-Detection (SpineNet model).
Outputs: 68 landmarks (4 corners × 17 vertebrae)
"""
def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'):
"""
Initialize SpineNet model.
Args:
weights_path: Path to model_last.pth (auto-detects if None)
device: 'cpu' or 'cuda'
"""
self.device = device
self.model = None
self.weights_path = weights_path
self._load_model()
def _load_model(self):
"""Load the SpineNet model."""
import torch
# Find weights
if self.weights_path is None:
possible_paths = [
Path(__file__).parent.parent / 'Vertebra-Landmark-Detection' / 'weights_spinal' / 'model_last.pth',
]
for p in possible_paths:
if p.exists():
self.weights_path = str(p)
break
if self.weights_path is None or not Path(self.weights_path).exists():
raise FileNotFoundError(
"Vertebra-Landmark-Detection weights not found. "
"Download from Google Drive and place in weights_spinal/model_last.pth"
)
# Add repo to path to import model
repo_path = Path(__file__).parent.parent / 'Vertebra-Landmark-Detection'
if str(repo_path) not in sys.path:
sys.path.insert(0, str(repo_path))
from models import spinal_net
# Create model
heads = {'hm': 1, 'reg': 2, 'wh': 8}
self.model = spinal_net.SpineNet(
heads=heads,
pretrained=False,
down_ratio=4,
final_kernel=1,
head_conv=256
)
# Load weights
checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint['state_dict'], strict=False)
self.model.to(self.device)
self.model.eval()
print(f"Vertebra-Landmark-Detection model loaded from {self.weights_path}")
@property
def name(self) -> str:
return "Vertebra-Landmark-Detection"
def _nms(self, heat, kernel=3):
"""Apply NMS using max pooling."""
import torch
import torch.nn.functional as F
hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
keep = (hmax == heat).float()
return heat * keep
def _gather_feat(self, feat, ind):
"""Gather features by index."""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
return feat
def _tranpose_and_gather_feat(self, feat, ind):
"""Transpose and gather features - matches original decoder."""
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = self._gather_feat(feat, ind)
return feat
def _decode_predictions(self, output: Dict, down_ratio: int = 4, K: int = 17):
"""Decode model output using original decoder logic."""
import torch
hm = output['hm'].sigmoid()
reg = output['reg']
wh = output['wh']
batch, cat, height, width = hm.size()
# Apply NMS
hm = self._nms(hm)
# Get top K from heatmap
topk_scores, topk_inds = torch.topk(hm.view(batch, cat, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds // width).float()
topk_xs = (topk_inds % width).float()
# Get overall top K
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_inds = self._gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
scores = topk_score.view(batch, K, 1)
# Get regression offset and apply
reg = self._tranpose_and_gather_feat(reg, topk_inds)
reg = reg.view(batch, K, 2)
xs = topk_xs.view(batch, K, 1) + reg[:, :, 0:1]
ys = topk_ys.view(batch, K, 1) + reg[:, :, 1:2]
# Get corner offsets
wh = self._tranpose_and_gather_feat(wh, topk_inds)
wh = wh.view(batch, K, 8)
# Calculate corners by SUBTRACTING offsets (original decoder logic)
tl_x = xs - wh[:, :, 0:1]
tl_y = ys - wh[:, :, 1:2]
tr_x = xs - wh[:, :, 2:3]
tr_y = ys - wh[:, :, 3:4]
bl_x = xs - wh[:, :, 4:5]
bl_y = ys - wh[:, :, 5:6]
br_x = xs - wh[:, :, 6:7]
br_y = ys - wh[:, :, 7:8]
# Combine into output format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score]
pts = torch.cat([xs, ys, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, scores], dim=2)
# Scale to image coordinates
pts[:, :, :10] *= down_ratio
return pts[0].cpu().numpy() # (K, 11)
def predict(self, image: np.ndarray) -> Spine2D:
"""Run inference and return unified landmarks."""
import torch
import cv2
# Ensure RGB
if len(image.shape) == 2:
image_rgb = np.stack([image, image, image], axis=-1)
else:
image_rgb = image
orig_h, orig_w = image_rgb.shape[:2]
# Resize to model input size (1024x512)
input_h, input_w = 1024, 512
img_resized = cv2.resize(image_rgb, (input_w, input_h))
# Normalize and convert to tensor - use original preprocessing!
# Original: out_image = image / 255. - 0.5 (NOT ImageNet stats)
img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0 - 0.5
img_tensor = img_tensor.unsqueeze(0).to(self.device)
# Run inference
with torch.no_grad():
output = self.model(img_tensor)
# Decode predictions - returns (K, 11) array
# Format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score]
pts = self._decode_predictions(output, down_ratio=4, K=17)
# Scale coordinates back to original image size
scale_x = orig_w / input_w
scale_y = orig_h / input_h
# Convert to unified format
vertebrae = []
threshold = 0.3
for i in range(len(pts)):
score = pts[i, 10]
if score < threshold:
continue
# Get center and corners (already scaled by down_ratio in decoder)
cx = pts[i, 0] * scale_x
cy = pts[i, 1] * scale_y
# Corners: tl, tr, bl, br
tl = np.array([pts[i, 2] * scale_x, pts[i, 3] * scale_y])
tr = np.array([pts[i, 4] * scale_x, pts[i, 5] * scale_y])
bl = np.array([pts[i, 6] * scale_x, pts[i, 7] * scale_y])
br = np.array([pts[i, 8] * scale_x, pts[i, 9] * scale_y])
# Reorder to [tl, tr, br, bl] for consistency with ScolioVis
corners = np.array([tl, tr, br, bl], dtype=np.float32)
centroid = np.array([cx, cy], dtype=np.float32)
# Compute orientation from top edge
dx = tr[0] - tl[0]
dy = tr[1] - tl[1]
orientation = np.degrees(np.arctan2(dy, dx))
vert = VertebraLandmark(
level=None,
centroid_px=centroid,
corners_px=corners,
endplate_upper_px=np.array([tl, tr]), # top edge
endplate_lower_px=np.array([bl, br]), # bottom edge
orientation_deg=orientation,
confidence=float(score),
meta={'raw_pts': pts[i].tolist()}
)
vertebrae.append(vert)
# Sort by y-coordinate (top to bottom)
vertebrae.sort(key=lambda v: float(v.centroid_px[1]))
# Assign vertebra levels (T1-L5 = 17 vertebrae typically)
level_names = ['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4', 'L5']
for i, vert in enumerate(vertebrae):
if i < len(level_names):
vert.level = level_names[i]
# Create Spine2D
spine = Spine2D(
vertebrae=vertebrae,
image_shape=(orig_h, orig_w),
source_model=self.name
)
# Compute Cobb angles
if len(vertebrae) >= 7:
from spine_analysis import compute_cobb_angles
compute_cobb_angles(spine)
return spine