885 lines
35 KiB
Python
885 lines
35 KiB
Python
"""
|
|
Business logic service for brace generation.
|
|
|
|
This service handles ML inference and file generation.
|
|
S3 operations are handled by the Lambda function, not here.
|
|
"""
|
|
import time
|
|
import uuid
|
|
import tempfile
|
|
import numpy as np
|
|
import trimesh
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any, Tuple
|
|
from io import BytesIO
|
|
|
|
from .config import config
|
|
from .schemas import (
|
|
AnalyzeRequest, AnalyzeFromUrlRequest, BraceConfigRequest,
|
|
AnalysisResult, CobbAngles, RigoClassification, Vertebra,
|
|
DeformationReport, ExperimentType
|
|
)
|
|
|
|
|
|
class BraceService:
|
|
"""
|
|
Service for X-ray analysis and brace generation.
|
|
|
|
Handles:
|
|
- Model loading and inference
|
|
- Pipeline orchestration
|
|
- Local file management
|
|
|
|
Note: S3 operations are handled by Lambda, not here.
|
|
"""
|
|
|
|
def __init__(self, device: str = "cuda", model: str = "scoliovis"):
|
|
self.device = device
|
|
self.model_name = model
|
|
|
|
# Initialize pipelines
|
|
self._init_pipelines()
|
|
|
|
def _init_pipelines(self):
|
|
"""Initialize brace generation pipelines."""
|
|
from brace_generator.data_models import BraceConfig
|
|
from brace_generator.pipeline import BracePipeline
|
|
|
|
# Standard pipeline
|
|
self.standard_pipeline = BracePipeline(
|
|
model=self.model_name,
|
|
device=self.device
|
|
)
|
|
|
|
# Experiment 3 pipeline (lazy load)
|
|
self._exp3_pipeline = None
|
|
|
|
def _get_exp3_pipeline(self):
|
|
"""Return standard pipeline (EXPERIMENT_3 not deployed)."""
|
|
return self.standard_pipeline
|
|
|
|
@property
|
|
def model_loaded(self) -> bool:
|
|
"""Check if model is loaded."""
|
|
return self.standard_pipeline is not None
|
|
|
|
async def analyze_from_bytes(
|
|
self,
|
|
image_data: bytes,
|
|
filename: str,
|
|
experiment: ExperimentType = ExperimentType.EXPERIMENT_3,
|
|
case_id: Optional[str] = None,
|
|
brace_config: Optional[BraceConfigRequest] = None,
|
|
landmarks_data: Optional[Dict[str, Any]] = None
|
|
) -> AnalysisResult:
|
|
"""
|
|
Analyze X-ray from raw bytes.
|
|
|
|
If landmarks_data is provided, it will use those landmarks (with manual edits)
|
|
instead of re-running automatic detection.
|
|
"""
|
|
start_time = time.time()
|
|
|
|
# Generate case ID if not provided
|
|
case_id = case_id or str(uuid.uuid4())[:8]
|
|
|
|
# Save image to temp file
|
|
suffix = Path(filename).suffix or ".jpg"
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
|
f.write(image_data)
|
|
input_path = f.name
|
|
|
|
try:
|
|
# Prepare output directory
|
|
output_dir = config.TEMP_DIR / case_id
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
output_base = output_dir / f"brace_{case_id}"
|
|
|
|
# Select pipeline based on experiment
|
|
if experiment == ExperimentType.EXPERIMENT_3:
|
|
result = await self._run_experiment_3(
|
|
input_path, output_base, brace_config, landmarks_data
|
|
)
|
|
else:
|
|
result = await self._run_standard(input_path, output_base, brace_config)
|
|
|
|
# Add timing and case ID
|
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
|
result.case_id = case_id
|
|
|
|
return result
|
|
|
|
finally:
|
|
# Cleanup temp input
|
|
Path(input_path).unlink(missing_ok=True)
|
|
|
|
async def _run_standard(
|
|
self,
|
|
input_path: str,
|
|
output_base: Path,
|
|
brace_config: Optional[BraceConfigRequest]
|
|
) -> AnalysisResult:
|
|
"""Run standard pipeline."""
|
|
from brace_generator.data_models import BraceConfig
|
|
|
|
# Configure
|
|
if brace_config:
|
|
self.standard_pipeline.config = BraceConfig(
|
|
brace_height_mm=brace_config.brace_height_mm,
|
|
torso_width_mm=brace_config.torso_width_mm,
|
|
torso_depth_mm=brace_config.torso_depth_mm,
|
|
wall_thickness_mm=brace_config.wall_thickness_mm,
|
|
pressure_strength_mm=brace_config.pressure_strength_mm,
|
|
)
|
|
|
|
# Run pipeline
|
|
results = self.standard_pipeline.process(
|
|
input_path,
|
|
str(output_base) + ".stl",
|
|
visualize=True,
|
|
save_landmarks=True
|
|
)
|
|
|
|
# Build response with local file paths
|
|
outputs = {
|
|
"stl": str(output_base) + ".stl",
|
|
}
|
|
|
|
vis_path = str(output_base) + ".png"
|
|
json_path = str(output_base) + ".json"
|
|
if Path(vis_path).exists():
|
|
outputs["visualization"] = vis_path
|
|
if Path(json_path).exists():
|
|
outputs["landmarks"] = json_path
|
|
|
|
return AnalysisResult(
|
|
experiment="standard",
|
|
input_image=input_path,
|
|
model_used=results["model"],
|
|
vertebrae_detected=results["vertebrae_detected"],
|
|
cobb_angles=CobbAngles(
|
|
PT=results["cobb_angles"]["PT"],
|
|
MT=results["cobb_angles"]["MT"],
|
|
TL=results["cobb_angles"]["TL"],
|
|
),
|
|
curve_type=results["curve_type"],
|
|
rigo_classification=RigoClassification(
|
|
type=results["rigo_type"],
|
|
description=results.get("rigo_description", "")
|
|
),
|
|
mesh_vertices=results["mesh_vertices"],
|
|
mesh_faces=results["mesh_faces"],
|
|
outputs=outputs,
|
|
processing_time_ms=0 # Will be set by caller
|
|
)
|
|
|
|
async def _run_experiment_3(
|
|
self,
|
|
input_path: str,
|
|
output_base: Path,
|
|
brace_config: Optional[BraceConfigRequest],
|
|
landmarks_data: Optional[Dict[str, Any]] = None
|
|
) -> AnalysisResult:
|
|
"""
|
|
Run Experiment 3 (research-based adaptive) pipeline.
|
|
|
|
If landmarks_data is provided, it uses those landmarks (with manual edits)
|
|
instead of running automatic detection.
|
|
"""
|
|
import sys
|
|
from brace_generator.data_models import BraceConfig
|
|
|
|
pipeline = self._get_exp3_pipeline()
|
|
|
|
# Configure
|
|
if brace_config:
|
|
pipeline.config = BraceConfig(
|
|
brace_height_mm=brace_config.brace_height_mm,
|
|
torso_width_mm=brace_config.torso_width_mm,
|
|
torso_depth_mm=brace_config.torso_depth_mm,
|
|
wall_thickness_mm=brace_config.wall_thickness_mm,
|
|
pressure_strength_mm=brace_config.pressure_strength_mm,
|
|
)
|
|
|
|
# If landmarks_data is provided, use it instead of running detection
|
|
if landmarks_data:
|
|
results = await self._run_experiment_3_with_landmarks(
|
|
input_path, output_base, pipeline, landmarks_data
|
|
)
|
|
else:
|
|
# Run full pipeline with automatic detection
|
|
results = pipeline.process(
|
|
input_path,
|
|
str(output_base),
|
|
visualize=True,
|
|
save_landmarks=True
|
|
)
|
|
|
|
# Build deformation report
|
|
deformation_report = None
|
|
if results.get("deformation_report"):
|
|
dr = results["deformation_report"]
|
|
deformation_report = DeformationReport(
|
|
patch_grid=dr.get("patch_grid", "6x8"),
|
|
deformations=dr.get("deformations"),
|
|
zones=dr.get("zones")
|
|
)
|
|
|
|
# Collect output file paths
|
|
outputs = {}
|
|
|
|
if results.get("output_stl"):
|
|
outputs["stl"] = results["output_stl"]
|
|
if results.get("output_ply"):
|
|
outputs["ply"] = results["output_ply"]
|
|
|
|
# Check for visualization and landmarks files
|
|
vis_path = str(output_base) + ".png"
|
|
json_path = str(output_base) + ".json"
|
|
if Path(vis_path).exists():
|
|
outputs["visualization"] = vis_path
|
|
if Path(json_path).exists():
|
|
outputs["landmarks"] = json_path
|
|
|
|
return AnalysisResult(
|
|
experiment="experiment_3",
|
|
input_image=input_path,
|
|
model_used=results.get("model", "manual_landmarks"),
|
|
vertebrae_detected=results.get("vertebrae_detected", 0),
|
|
cobb_angles=CobbAngles(
|
|
PT=results["cobb_angles"]["PT"],
|
|
MT=results["cobb_angles"]["MT"],
|
|
TL=results["cobb_angles"]["TL"],
|
|
),
|
|
curve_type=results["curve_type"],
|
|
rigo_classification=RigoClassification(
|
|
type=results["rigo_type"],
|
|
description=results.get("rigo_description", "")
|
|
),
|
|
mesh_vertices=results.get("mesh_vertices", 0),
|
|
mesh_faces=results.get("mesh_faces", 0),
|
|
deformation_report=deformation_report,
|
|
outputs=outputs,
|
|
processing_time_ms=0
|
|
)
|
|
|
|
async def _run_experiment_3_with_landmarks(
|
|
self,
|
|
input_path: str,
|
|
output_base: Path,
|
|
pipeline,
|
|
landmarks_data: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Run experiment 3 brace generation using pre-computed landmarks.
|
|
Uses final_values from landmarks_data (which may include manual edits).
|
|
"""
|
|
import sys
|
|
import json
|
|
|
|
# Load analysis modules from brace_generator root
|
|
from brace_generator.data_models import Spine2D, VertebraLandmark
|
|
from brace_generator.spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
|
|
from image_loader import load_xray_rgb
|
|
|
|
# Load the image for visualization
|
|
image_rgb, pixel_spacing = load_xray_rgb(input_path)
|
|
|
|
# Build Spine2D from landmarks_data final_values
|
|
vertebrae_structure = landmarks_data.get("vertebrae_structure", landmarks_data)
|
|
vertebrae_list = vertebrae_structure.get("vertebrae", [])
|
|
|
|
spine = Spine2D()
|
|
for vdata in vertebrae_list:
|
|
final = vdata.get("final_values", {})
|
|
centroid = final.get("centroid_px")
|
|
|
|
if centroid is None:
|
|
continue
|
|
|
|
v = VertebraLandmark(
|
|
level=vdata.get("level"),
|
|
centroid_px=np.array(centroid, dtype=np.float32),
|
|
confidence=float(final.get("confidence", 0.5))
|
|
)
|
|
|
|
corners = final.get("corners_px")
|
|
if corners:
|
|
v.corners_px = np.array(corners, dtype=np.float32)
|
|
|
|
spine.vertebrae.append(v)
|
|
|
|
if len(spine.vertebrae) < 3:
|
|
raise ValueError("Need at least 3 vertebrae for brace generation")
|
|
|
|
spine.pixel_spacing_mm = pixel_spacing
|
|
spine.image_shape = image_rgb.shape[:2]
|
|
spine.sort_vertebrae()
|
|
|
|
# Compute Cobb angles and classification
|
|
compute_cobb_angles(spine)
|
|
apex_indices = find_apex_vertebrae(spine)
|
|
rigo_result = classify_rigo_type(spine)
|
|
|
|
# Generate adaptive brace using the pipeline's brace generator directly
|
|
# This uses our manually-edited spine instead of re-detecting
|
|
brace_mesh = pipeline.brace_generator.generate(spine)
|
|
|
|
mesh_vertices = 0
|
|
mesh_faces = 0
|
|
output_stl = None
|
|
output_ply = None
|
|
deformation_report = None
|
|
|
|
if brace_mesh is not None:
|
|
mesh_vertices = len(brace_mesh.vertices)
|
|
mesh_faces = len(brace_mesh.faces)
|
|
|
|
# Get deformation report if available
|
|
if hasattr(pipeline.brace_generator, 'get_deformation_report'):
|
|
deformation_report = pipeline.brace_generator.get_deformation_report()
|
|
|
|
# Export STL and PLY
|
|
output_base_path = Path(output_base)
|
|
output_stl = str(output_base_path.with_suffix('.stl'))
|
|
output_ply = str(output_base_path.with_suffix('.ply'))
|
|
|
|
brace_mesh.export(output_stl)
|
|
|
|
# Export PLY if method available
|
|
if hasattr(pipeline.brace_generator, 'export_ply'):
|
|
pipeline.brace_generator.export_ply(brace_mesh, output_ply)
|
|
else:
|
|
brace_mesh.export(output_ply)
|
|
print(f" Exported: {output_stl}, {output_ply}")
|
|
|
|
# Save visualization with the manual/combined landmarks and deformation heatmap
|
|
vis_path = str(output_base) + ".png"
|
|
self._save_landmarks_visualization_with_spine(
|
|
image_rgb, spine, rigo_result, deformation_report, vis_path
|
|
)
|
|
|
|
# Build result dict
|
|
result = {
|
|
"model": "manual_landmarks",
|
|
"vertebrae_detected": len(spine.vertebrae),
|
|
"cobb_angles": {
|
|
"PT": float(spine.cobb_pt or 0),
|
|
"MT": float(spine.cobb_mt or 0),
|
|
"TL": float(spine.cobb_tl or 0),
|
|
},
|
|
"curve_type": spine.curve_type or "Unknown",
|
|
"rigo_type": rigo_result["rigo_type"],
|
|
"rigo_description": rigo_result.get("description", ""),
|
|
"mesh_vertices": mesh_vertices,
|
|
"mesh_faces": mesh_faces,
|
|
"output_stl": output_stl,
|
|
"output_ply": output_ply,
|
|
"deformation_report": deformation_report,
|
|
}
|
|
|
|
# Save landmarks JSON
|
|
json_path = str(output_base) + ".json"
|
|
with open(json_path, "w") as f:
|
|
json.dump({
|
|
"source": "manual_landmarks",
|
|
"vertebrae_count": len(spine.vertebrae),
|
|
"cobb_angles": result["cobb_angles"],
|
|
"rigo_type": result["rigo_type"],
|
|
"curve_type": result["curve_type"],
|
|
"deformation_report": deformation_report,
|
|
}, f, indent=2, default=lambda x: x.tolist() if hasattr(x, 'tolist') else x)
|
|
|
|
return result
|
|
|
|
def _save_landmarks_visualization_with_spine(self, image, spine, rigo_result, deformation_report, path):
|
|
"""Save visualization using a pre-built Spine2D object with deformation heatmap."""
|
|
try:
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.colors import TwoSlopeNorm
|
|
except ImportError:
|
|
return
|
|
|
|
# 3-panel layout like the original pipeline
|
|
fig, axes = plt.subplots(1, 3, figsize=(18, 10))
|
|
|
|
# Left: landmarks with X-shaped markers
|
|
ax1 = axes[0]
|
|
ax1.imshow(image)
|
|
|
|
# Draw green X-shaped vertebra markers and red centroids
|
|
for v in spine.vertebrae:
|
|
if v.corners_px is not None:
|
|
corners = v.corners_px
|
|
for i in range(4):
|
|
j = (i + 1) % 4
|
|
ax1.plot([corners[i, 0], corners[j, 0]],
|
|
[corners[i, 1], corners[j, 1]],
|
|
'g-', linewidth=1.5, zorder=4)
|
|
|
|
if v.centroid_px is not None:
|
|
ax1.scatter(v.centroid_px[0], v.centroid_px[1], c='red', s=40, zorder=5)
|
|
|
|
# Add labels
|
|
for v in spine.vertebrae:
|
|
if v.centroid_px is not None:
|
|
label = v.level or "?"
|
|
ax1.annotate(
|
|
label, (v.centroid_px[0] + 8, v.centroid_px[1]),
|
|
fontsize=7, color='yellow', fontweight='bold',
|
|
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.6)
|
|
)
|
|
|
|
ax1.set_title(f"Landmarks ({len(spine.vertebrae)} vertebrae)")
|
|
ax1.axis('off')
|
|
|
|
# Middle: analysis with spine curve
|
|
ax2 = axes[1]
|
|
ax2.imshow(image, alpha=0.5)
|
|
|
|
# Draw spine curve line through centroids
|
|
centroids = [v.centroid_px for v in spine.vertebrae if v.centroid_px is not None]
|
|
if len(centroids) > 1:
|
|
centroids_arr = np.array(centroids)
|
|
ax2.plot(centroids_arr[:, 0], centroids_arr[:, 1], 'b-', linewidth=2, alpha=0.8)
|
|
|
|
text = f"Cobb Angles:\n"
|
|
text += f"PT: {spine.cobb_pt:.1f}°\n"
|
|
text += f"MT: {spine.cobb_mt:.1f}°\n"
|
|
text += f"TL: {spine.cobb_tl:.1f}°\n\n"
|
|
text += f"Curve: {spine.curve_type}\n"
|
|
text += f"Rigo: {rigo_result['rigo_type']}"
|
|
|
|
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
|
|
verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))
|
|
ax2.set_title("Spine Analysis")
|
|
ax2.axis('off')
|
|
|
|
# Right: deformation heatmap
|
|
ax3 = axes[2]
|
|
|
|
if deformation_report and deformation_report.get('deformations'):
|
|
deform_array = np.array(deformation_report['deformations'])
|
|
|
|
# Create heatmap with diverging colormap
|
|
vmax = max(abs(deform_array.min()), abs(deform_array.max()), 1)
|
|
norm = TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
|
|
|
|
im = ax3.imshow(deform_array, cmap='RdBu_r', aspect='auto',
|
|
norm=norm, origin='upper')
|
|
|
|
# Add colorbar
|
|
cbar = plt.colorbar(im, ax=ax3, shrink=0.8)
|
|
cbar.set_label('Radial deformation (mm)')
|
|
|
|
# Labels
|
|
ax3.set_xlabel('Angular Position (patches)')
|
|
ax3.set_ylabel('Height (patches)')
|
|
ax3.set_title('Patch Deformations (mm)\nBlue=Relief, Red=Pressure')
|
|
|
|
# Add zone labels on y-axis
|
|
height_labels = ['Pelvis', 'Low Lumb', 'Up Lumb', 'Low Thor', 'Up Thor', 'Shoulder']
|
|
if deform_array.shape[0] <= len(height_labels):
|
|
ax3.set_yticks(range(deform_array.shape[0]))
|
|
ax3.set_yticklabels(height_labels[:deform_array.shape[0]])
|
|
|
|
# Angular position labels
|
|
angle_labels = ['BR', 'R', 'FR', 'F', 'FL', 'L', 'BL', 'B']
|
|
if deform_array.shape[1] <= len(angle_labels):
|
|
ax3.set_xticks(range(deform_array.shape[1]))
|
|
ax3.set_xticklabels(angle_labels[:deform_array.shape[1]])
|
|
else:
|
|
ax3.text(0.5, 0.5, 'No deformation data', ha='center', va='center',
|
|
transform=ax3.transAxes, fontsize=14, color='gray')
|
|
ax3.set_title('Patch Deformations')
|
|
ax3.axis('off')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(path, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
# ============================================
|
|
# NEW METHODS FOR PIPELINE DEV
|
|
# ============================================
|
|
|
|
async def detect_landmarks_only(
|
|
self,
|
|
image_data: bytes,
|
|
filename: str,
|
|
case_id: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Detect landmarks only, without generating a brace.
|
|
Returns full vertebrae_structure with manual_override support.
|
|
"""
|
|
import sys
|
|
from pathlib import Path
|
|
start_time = time.time()
|
|
|
|
case_id = case_id or str(uuid.uuid4())[:8]
|
|
|
|
# Save image to temp file
|
|
suffix = Path(filename).suffix or ".jpg"
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
|
f.write(image_data)
|
|
input_path = f.name
|
|
|
|
try:
|
|
# Import from brace_generator root (modules are already in PYTHONPATH)
|
|
# Note: In Docker, PYTHONPATH includes /app/brace_generator
|
|
from image_loader import load_xray_rgb
|
|
from adapters import ScolioVisAdapter
|
|
from spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
|
|
from data_models import Spine2D
|
|
|
|
# Load full image
|
|
image_rgb_full, pixel_spacing = load_xray_rgb(input_path)
|
|
image_h, image_w = image_rgb_full.shape[:2]
|
|
|
|
# Smart cropping: Only crop to middle 1/3 if image is wide enough
|
|
# Wide images (e.g., full chest X-rays) benefit from cropping to spine area
|
|
# Narrow images (e.g., already cropped to spine) should not be cropped further
|
|
MIN_WIDTH_FOR_CROPPING = 500 # Only crop if wider than 500 pixels
|
|
CROPPED_MIN_WIDTH = 200 # Ensure cropped width is at least 200 pixels
|
|
|
|
left_margin = 0 # Initialize offset for coordinate mapping
|
|
|
|
if image_w >= MIN_WIDTH_FOR_CROPPING:
|
|
# Image is wide - crop to middle 1/3 for better spine detection
|
|
left_margin = image_w // 3
|
|
right_margin = 2 * image_w // 3
|
|
cropped_width = right_margin - left_margin
|
|
|
|
# Ensure minimum width after cropping
|
|
if cropped_width >= CROPPED_MIN_WIDTH:
|
|
image_rgb_for_detection = image_rgb_full[:, left_margin:right_margin]
|
|
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Cropped to middle 1/3: {cropped_width}x{image_h}")
|
|
else:
|
|
# Cropped would be too narrow, use full image
|
|
image_rgb_for_detection = image_rgb_full
|
|
left_margin = 0
|
|
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Cropped would be too narrow ({cropped_width}px), using full image")
|
|
else:
|
|
# Image is already narrow - use full image
|
|
image_rgb_for_detection = image_rgb_full
|
|
print(f"[SpineCrop] Full image: {image_w}x{image_h}, Already narrow (< {MIN_WIDTH_FOR_CROPPING}px), using full image")
|
|
|
|
# Detect landmarks (on cropped or full image depending on width)
|
|
adapter = ScolioVisAdapter(device=self.device)
|
|
spine = adapter.predict(image_rgb_for_detection)
|
|
spine.pixel_spacing_mm = pixel_spacing
|
|
|
|
# Offset all detected coordinates back to full image space if cropping was applied
|
|
if left_margin > 0:
|
|
for v in spine.vertebrae:
|
|
if v.centroid_px is not None:
|
|
# Offset centroid X coordinate
|
|
v.centroid_px[0] += left_margin
|
|
if v.corners_px is not None:
|
|
# Offset all corner X coordinates
|
|
v.corners_px[:, 0] += left_margin
|
|
|
|
# Keep reference to full image for visualization
|
|
image_rgb = image_rgb_full
|
|
|
|
# Compute analysis
|
|
compute_cobb_angles(spine)
|
|
apex_indices = find_apex_vertebrae(spine)
|
|
rigo_result = classify_rigo_type(spine)
|
|
|
|
# Prepare output directory
|
|
output_dir = config.TEMP_DIR / case_id
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save visualization
|
|
vis_path = output_dir / "visualization.png"
|
|
self._save_landmarks_visualization(image_rgb, spine, rigo_result, str(vis_path))
|
|
|
|
# Build full vertebrae structure (all T1-L5)
|
|
ALL_LEVELS = ["T1", "T2", "T3", "T4", "T5", "T6", "T7", "T8", "T9", "T10", "T11", "T12", "L1", "L2", "L3", "L4", "L5"]
|
|
|
|
# ScolioVis doesn't assign levels - assign based on Y position (top to bottom)
|
|
# Sort detected vertebrae by Y coordinate (centroid)
|
|
detected_verts = sorted(
|
|
[v for v in spine.vertebrae if v.centroid_px is not None],
|
|
key=lambda v: v.centroid_px[1] # Sort by Y (top to bottom)
|
|
)
|
|
|
|
# Assign levels based on count
|
|
# If we detect 17 vertebrae, assign T1-L5
|
|
# If fewer, we need to figure out which ones are missing
|
|
num_detected = len(detected_verts)
|
|
|
|
if num_detected >= 17:
|
|
# All vertebrae detected - assign directly
|
|
for i, v in enumerate(detected_verts[:17]):
|
|
v.level = ALL_LEVELS[i]
|
|
elif num_detected > 0:
|
|
# Fewer than 17 - assign from T1 onwards (assuming top vertebrae visible)
|
|
# This is a simplification - ideally we'd use anatomical features
|
|
for i, v in enumerate(detected_verts):
|
|
if i < len(ALL_LEVELS):
|
|
v.level = ALL_LEVELS[i]
|
|
|
|
# Build detected_map with assigned levels
|
|
detected_map = {v.level: v for v in detected_verts if v.level}
|
|
|
|
vertebrae_list = []
|
|
for level in ALL_LEVELS:
|
|
if level in detected_map:
|
|
v = detected_map[level]
|
|
centroid = v.centroid_px.tolist() if v.centroid_px is not None else None
|
|
corners = v.corners_px.tolist() if v.corners_px is not None else None
|
|
orientation = float(v.compute_orientation()) if centroid else None
|
|
|
|
vertebrae_list.append({
|
|
"level": level,
|
|
"detected": True,
|
|
"scoliovis_data": {
|
|
"centroid_px": centroid,
|
|
"corners_px": corners,
|
|
"orientation_deg": orientation,
|
|
"confidence": float(v.confidence),
|
|
},
|
|
"manual_override": {
|
|
"enabled": False,
|
|
"centroid_px": None,
|
|
"corners_px": None,
|
|
"orientation_deg": None,
|
|
"confidence": None,
|
|
"notes": None,
|
|
},
|
|
"final_values": {
|
|
"centroid_px": centroid,
|
|
"corners_px": corners,
|
|
"orientation_deg": orientation,
|
|
"confidence": float(v.confidence),
|
|
"source": "scoliovis",
|
|
},
|
|
})
|
|
else:
|
|
vertebrae_list.append({
|
|
"level": level,
|
|
"detected": False,
|
|
"scoliovis_data": {
|
|
"centroid_px": None,
|
|
"corners_px": None,
|
|
"orientation_deg": None,
|
|
"confidence": 0.0,
|
|
},
|
|
"manual_override": {
|
|
"enabled": False,
|
|
"centroid_px": None,
|
|
"corners_px": None,
|
|
"orientation_deg": None,
|
|
"confidence": None,
|
|
"notes": None,
|
|
},
|
|
"final_values": {
|
|
"centroid_px": None,
|
|
"corners_px": None,
|
|
"orientation_deg": None,
|
|
"confidence": 0.0,
|
|
"source": "undetected",
|
|
},
|
|
})
|
|
|
|
# Build result
|
|
result = {
|
|
"case_id": case_id,
|
|
"status": "landmarks_detected",
|
|
"input": {
|
|
"image_dimensions": {"width": image_w, "height": image_h},
|
|
"pixel_spacing_mm": pixel_spacing,
|
|
},
|
|
"detection_quality": {
|
|
"vertebrae_count": len(spine.vertebrae),
|
|
"average_confidence": float(np.mean([v.confidence for v in spine.vertebrae])) if spine.vertebrae else 0.0,
|
|
},
|
|
"cobb_angles": {
|
|
"PT": float(spine.cobb_pt),
|
|
"MT": float(spine.cobb_mt),
|
|
"TL": float(spine.cobb_tl),
|
|
"max": float(max(spine.cobb_pt, spine.cobb_mt, spine.cobb_tl)),
|
|
"PT_severity": get_curve_severity(spine.cobb_pt),
|
|
"MT_severity": get_curve_severity(spine.cobb_mt),
|
|
"TL_severity": get_curve_severity(spine.cobb_tl),
|
|
},
|
|
"rigo_classification": {
|
|
"type": rigo_result["rigo_type"],
|
|
"description": rigo_result["description"],
|
|
},
|
|
"curve_type": spine.curve_type,
|
|
"vertebrae_structure": {
|
|
"all_levels": ALL_LEVELS,
|
|
"detected_count": len(spine.vertebrae),
|
|
"total_count": len(ALL_LEVELS),
|
|
"vertebrae": vertebrae_list,
|
|
"manual_edit_instructions": {
|
|
"to_override": "Set manual_override.enabled=true and fill manual_override fields",
|
|
"final_values_rule": "When manual_override.enabled=true, final_values uses manual values",
|
|
},
|
|
},
|
|
"visualization_path": str(vis_path),
|
|
"processing_time_ms": (time.time() - start_time) * 1000,
|
|
}
|
|
|
|
# Save JSON
|
|
json_path = output_dir / "landmarks.json"
|
|
import json
|
|
with open(json_path, "w") as f:
|
|
json.dump(result, f, indent=2)
|
|
|
|
result["json_path"] = str(json_path)
|
|
|
|
return result
|
|
|
|
finally:
|
|
Path(input_path).unlink(missing_ok=True)
|
|
|
|
def _save_landmarks_visualization(self, image, spine, rigo_result, path):
|
|
"""Save visualization with landmarks and green quadrilateral boxes."""
|
|
try:
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
except ImportError:
|
|
return
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 10))
|
|
|
|
# Left: landmarks with green boxes
|
|
ax1 = axes[0]
|
|
ax1.imshow(image)
|
|
|
|
# Draw green X-shaped vertebra markers and red centroids
|
|
for v in spine.vertebrae:
|
|
# Draw green X-shape if corners exist
|
|
# Corner order: [0]=top_left, [1]=top_right, [2]=bottom_left, [3]=bottom_right
|
|
# Drawing 0→1→2→3→0 creates the X pattern showing endplate orientations
|
|
if v.corners_px is not None:
|
|
corners = v.corners_px
|
|
for i in range(4):
|
|
j = (i + 1) % 4
|
|
ax1.plot([corners[i, 0], corners[j, 0]],
|
|
[corners[i, 1], corners[j, 1]],
|
|
'g-', linewidth=1.5, zorder=4)
|
|
|
|
# Draw red centroid dot
|
|
if v.centroid_px is not None:
|
|
ax1.scatter(v.centroid_px[0], v.centroid_px[1], c='red', s=40, zorder=5)
|
|
|
|
# Add labels
|
|
for i, v in enumerate(spine.vertebrae):
|
|
if v.centroid_px is not None:
|
|
label = v.level or str(i)
|
|
ax1.annotate(
|
|
label, (v.centroid_px[0] + 8, v.centroid_px[1]),
|
|
fontsize=7, color='yellow', fontweight='bold',
|
|
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.6)
|
|
)
|
|
|
|
ax1.set_title(f"Automatic Detection ({len(spine.vertebrae)} vertebrae)")
|
|
ax1.axis('off')
|
|
|
|
# Right: analysis
|
|
ax2 = axes[1]
|
|
ax2.imshow(image, alpha=0.5)
|
|
|
|
text = f"Cobb Angles:\n"
|
|
text += f"PT: {spine.cobb_pt:.1f}°\n"
|
|
text += f"MT: {spine.cobb_mt:.1f}°\n"
|
|
text += f"TL: {spine.cobb_tl:.1f}°\n\n"
|
|
text += f"Curve: {spine.curve_type}\n"
|
|
text += f"Rigo: {rigo_result['rigo_type']}"
|
|
|
|
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
|
|
verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))
|
|
ax2.set_title("Spine Analysis")
|
|
ax2.axis('off')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(path, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
async def recalculate_from_landmarks(
|
|
self,
|
|
landmarks_data: Dict[str, Any],
|
|
case_id: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Recalculate Cobb angles and Rigo classification from landmarks data.
|
|
Uses final_values from each vertebra (which may be manual overrides).
|
|
"""
|
|
import sys
|
|
start_time = time.time()
|
|
|
|
case_id = case_id or str(uuid.uuid4())[:8]
|
|
|
|
# Load analysis modules from brace_generator root
|
|
from brace_generator.data_models import Spine2D, VertebraLandmark
|
|
from brace_generator.spine_analysis import compute_cobb_angles, find_apex_vertebrae, classify_rigo_type, get_curve_severity
|
|
|
|
# Reconstruct spine from landmarks data
|
|
vertebrae_structure = landmarks_data.get("vertebrae_structure", landmarks_data)
|
|
vertebrae_list = vertebrae_structure.get("vertebrae", [])
|
|
|
|
# Build Spine2D from final_values
|
|
spine = Spine2D()
|
|
for vdata in vertebrae_list:
|
|
final = vdata.get("final_values", {})
|
|
centroid = final.get("centroid_px")
|
|
|
|
if centroid is None:
|
|
continue # Skip undetected/empty vertebrae
|
|
|
|
v = VertebraLandmark(
|
|
level=vdata.get("level"),
|
|
centroid_px=np.array(centroid, dtype=np.float32),
|
|
confidence=float(final.get("confidence", 0.5))
|
|
)
|
|
|
|
corners = final.get("corners_px")
|
|
if corners:
|
|
v.corners_px = np.array(corners, dtype=np.float32)
|
|
|
|
spine.vertebrae.append(v)
|
|
|
|
if len(spine.vertebrae) < 3:
|
|
raise ValueError("Need at least 3 vertebrae for analysis")
|
|
|
|
# Sort by Y position (top to bottom)
|
|
spine.sort_vertebrae()
|
|
|
|
# Compute Cobb angles and Rigo
|
|
compute_cobb_angles(spine)
|
|
apex_indices = find_apex_vertebrae(spine)
|
|
rigo_result = classify_rigo_type(spine)
|
|
|
|
result = {
|
|
"case_id": case_id,
|
|
"status": "analysis_recalculated",
|
|
"cobb_angles": {
|
|
"PT": float(spine.cobb_pt),
|
|
"MT": float(spine.cobb_mt),
|
|
"TL": float(spine.cobb_tl),
|
|
"max": float(max(spine.cobb_pt, spine.cobb_mt, spine.cobb_tl)),
|
|
"PT_severity": get_curve_severity(spine.cobb_pt),
|
|
"MT_severity": get_curve_severity(spine.cobb_mt),
|
|
"TL_severity": get_curve_severity(spine.cobb_tl),
|
|
},
|
|
"rigo_classification": {
|
|
"type": rigo_result["rigo_type"],
|
|
"description": rigo_result["description"],
|
|
},
|
|
"curve_type": spine.curve_type,
|
|
"apex_indices": apex_indices,
|
|
"vertebrae_used": len(spine.vertebrae),
|
|
"processing_time_ms": (time.time() - start_time) * 1000,
|
|
}
|
|
|
|
return result
|