347 lines
13 KiB
Python
347 lines
13 KiB
Python
"""
|
|
Complete pipeline: X-ray → Landmarks → Brace STL
|
|
"""
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any, Union
|
|
import json
|
|
|
|
from brace_generator.data_models import Spine2D, BraceConfig, VertebraLandmark
|
|
from brace_generator.image_loader import load_xray, load_xray_rgb
|
|
from brace_generator.adapters import BaseLandmarkAdapter, ScolioVisAdapter, VertLandmarkAdapter
|
|
from brace_generator.spine_analysis import (
|
|
compute_spine_curve, compute_cobb_angles, find_apex_vertebrae,
|
|
get_curve_severity, classify_rigo_type
|
|
)
|
|
from brace_generator.brace_surface import BraceGenerator
|
|
|
|
|
|
class BracePipeline:
|
|
"""
|
|
End-to-end pipeline for generating scoliosis braces from X-rays.
|
|
|
|
Usage:
|
|
# Basic usage with default model
|
|
pipeline = BracePipeline()
|
|
pipeline.process("xray.png", "brace.stl")
|
|
|
|
# With specific model
|
|
pipeline = BracePipeline(model="vertebra-landmark")
|
|
pipeline.process("xray.dcm", "brace.stl")
|
|
|
|
# With body scan
|
|
config = BraceConfig(use_body_scan=True, body_scan_path="body.obj")
|
|
pipeline = BracePipeline(config=config)
|
|
pipeline.process("xray.png", "brace.stl")
|
|
"""
|
|
|
|
AVAILABLE_MODELS = {
|
|
'scoliovis': ScolioVisAdapter,
|
|
'vertebra-landmark': VertLandmarkAdapter,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = 'scoliovis',
|
|
config: Optional[BraceConfig] = None,
|
|
device: str = 'cpu'
|
|
):
|
|
"""
|
|
Initialize pipeline.
|
|
|
|
Args:
|
|
model: Model to use ('scoliovis' or 'vertebra-landmark')
|
|
config: Brace configuration
|
|
device: 'cpu' or 'cuda'
|
|
"""
|
|
self.device = device
|
|
self.config = config or BraceConfig()
|
|
self.model_name = model.lower()
|
|
|
|
# Initialize model adapter
|
|
if self.model_name not in self.AVAILABLE_MODELS:
|
|
raise ValueError(f"Unknown model: {model}. Available: {list(self.AVAILABLE_MODELS.keys())}")
|
|
|
|
self.adapter: BaseLandmarkAdapter = self.AVAILABLE_MODELS[self.model_name](device=device)
|
|
self.brace_generator = BraceGenerator(self.config)
|
|
|
|
# Store last results for inspection
|
|
self.last_spine: Optional[Spine2D] = None
|
|
self.last_image: Optional[np.ndarray] = None
|
|
|
|
def process(
|
|
self,
|
|
xray_path: str,
|
|
output_stl_path: str,
|
|
visualize: bool = False,
|
|
save_landmarks: bool = False
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Process X-ray and generate brace STL.
|
|
|
|
Args:
|
|
xray_path: Path to input X-ray (JPEG, PNG, or DICOM)
|
|
output_stl_path: Path for output STL file
|
|
visualize: If True, also save visualization image
|
|
save_landmarks: If True, also save landmarks JSON
|
|
|
|
Returns:
|
|
Dictionary with analysis results
|
|
"""
|
|
print(f"=" * 60)
|
|
print(f"Brace Generation Pipeline")
|
|
print(f"Model: {self.adapter.name}")
|
|
print(f"=" * 60)
|
|
|
|
# 1) Load X-ray
|
|
print(f"\n1. Loading X-ray: {xray_path}")
|
|
image_rgb, pixel_spacing = load_xray_rgb(xray_path)
|
|
self.last_image = image_rgb
|
|
print(f" Image size: {image_rgb.shape[:2]}")
|
|
if pixel_spacing is not None:
|
|
print(f" Pixel spacing: {pixel_spacing} mm")
|
|
|
|
# 2) Detect landmarks
|
|
print(f"\n2. Detecting landmarks...")
|
|
spine = self.adapter.predict(image_rgb)
|
|
spine.pixel_spacing_mm = pixel_spacing
|
|
self.last_spine = spine
|
|
|
|
print(f" Detected {len(spine.vertebrae)} vertebrae")
|
|
|
|
if len(spine.vertebrae) < 5:
|
|
raise ValueError(f"Insufficient vertebrae detected ({len(spine.vertebrae)}). Need at least 5.")
|
|
|
|
# 3) Compute spine analysis
|
|
print(f"\n3. Analyzing spine curvature...")
|
|
compute_cobb_angles(spine)
|
|
apexes = find_apex_vertebrae(spine)
|
|
|
|
# Classify Rigo type
|
|
rigo_result = classify_rigo_type(spine)
|
|
|
|
print(f" Cobb Angles:")
|
|
print(f" PT (Proximal Thoracic): {spine.cobb_pt:.1f}° - {get_curve_severity(spine.cobb_pt)}")
|
|
print(f" MT (Main Thoracic): {spine.cobb_mt:.1f}° - {get_curve_severity(spine.cobb_mt)}")
|
|
print(f" TL (Thoracolumbar): {spine.cobb_tl:.1f}° - {get_curve_severity(spine.cobb_tl)}")
|
|
print(f" Curve type: {spine.curve_type}")
|
|
print(f" Rigo Classification: {rigo_result['rigo_type']}")
|
|
print(f" - {rigo_result['description']}")
|
|
print(f" Apex vertebrae indices: {apexes}")
|
|
|
|
# 4) Generate brace
|
|
print(f"\n4. Generating brace mesh...")
|
|
if self.config.use_body_scan:
|
|
print(f" Mode: Using body scan ({self.config.body_scan_path})")
|
|
else:
|
|
print(f" Mode: Average body shape")
|
|
|
|
brace_mesh = self.brace_generator.generate(spine)
|
|
print(f" Mesh: {len(brace_mesh.vertices)} vertices, {len(brace_mesh.faces)} faces")
|
|
|
|
# 5) Export STL
|
|
print(f"\n5. Exporting STL: {output_stl_path}")
|
|
self.brace_generator.export_stl(brace_mesh, output_stl_path)
|
|
|
|
# 6) Optional: Save visualization
|
|
if visualize:
|
|
vis_path = str(Path(output_stl_path).with_suffix('.png'))
|
|
self._save_visualization(vis_path, spine, image_rgb)
|
|
print(f" Visualization saved: {vis_path}")
|
|
|
|
# 7) Optional: Save landmarks JSON
|
|
if save_landmarks:
|
|
json_path = str(Path(output_stl_path).with_suffix('.json'))
|
|
self._save_landmarks_json(json_path, spine)
|
|
print(f" Landmarks saved: {json_path}")
|
|
|
|
# Prepare results
|
|
results = {
|
|
'input_image': xray_path,
|
|
'output_stl': output_stl_path,
|
|
'model': self.adapter.name,
|
|
'vertebrae_detected': len(spine.vertebrae),
|
|
'cobb_angles': {
|
|
'PT': spine.cobb_pt,
|
|
'MT': spine.cobb_mt,
|
|
'TL': spine.cobb_tl,
|
|
},
|
|
'curve_type': spine.curve_type,
|
|
'rigo_type': rigo_result['rigo_type'],
|
|
'rigo_description': rigo_result['description'],
|
|
'apex_indices': apexes,
|
|
'mesh_vertices': len(brace_mesh.vertices),
|
|
'mesh_faces': len(brace_mesh.faces),
|
|
}
|
|
|
|
print(f"\n{'=' * 60}")
|
|
print(f"Pipeline complete!")
|
|
print(f"{'=' * 60}")
|
|
|
|
return results
|
|
|
|
def _save_visualization(self, path: str, spine: Spine2D, image: np.ndarray):
|
|
"""Save visualization of detected landmarks and spine curve."""
|
|
try:
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
except ImportError:
|
|
print(" Warning: matplotlib not available for visualization")
|
|
return
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 10))
|
|
|
|
# Left: Original with landmarks
|
|
ax1 = axes[0]
|
|
ax1.imshow(image)
|
|
|
|
# Draw vertebra centers
|
|
centroids = spine.get_centroids()
|
|
ax1.scatter(centroids[:, 0], centroids[:, 1], c='red', s=30, zorder=5)
|
|
|
|
# Draw corners if available
|
|
for vert in spine.vertebrae:
|
|
if vert.corners_px is not None:
|
|
corners = vert.corners_px
|
|
# Draw quadrilateral
|
|
for i in range(4):
|
|
j = (i + 1) % 4
|
|
ax1.plot([corners[i, 0], corners[j, 0]],
|
|
[corners[i, 1], corners[j, 1]], 'g-', linewidth=1)
|
|
|
|
ax1.set_title(f"Detected Landmarks ({len(spine.vertebrae)} vertebrae)")
|
|
ax1.axis('off')
|
|
|
|
# Right: Spine curve analysis
|
|
ax2 = axes[1]
|
|
ax2.imshow(image, alpha=0.5)
|
|
|
|
# Draw spine curve
|
|
try:
|
|
C, T, N, curv = compute_spine_curve(spine)
|
|
ax2.plot(C[:, 0], C[:, 1], 'b-', linewidth=2, label='Spine curve')
|
|
|
|
# Highlight high curvature regions
|
|
high_curv_mask = curv > curv.mean() + curv.std()
|
|
ax2.scatter(C[high_curv_mask, 0], C[high_curv_mask, 1],
|
|
c='orange', s=20, label='High curvature')
|
|
except:
|
|
pass
|
|
|
|
# Get Rigo classification for display
|
|
rigo_result = classify_rigo_type(spine)
|
|
|
|
# Add Cobb angles and Rigo type text
|
|
text = f"Cobb Angles:\n"
|
|
text += f"PT: {spine.cobb_pt:.1f}°\n"
|
|
text += f"MT: {spine.cobb_mt:.1f}°\n"
|
|
text += f"TL: {spine.cobb_tl:.1f}°\n"
|
|
text += f"Curve: {spine.curve_type}\n"
|
|
text += f"-----------\n"
|
|
text += f"Rigo: {rigo_result['rigo_type']}"
|
|
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
|
|
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
|
|
|
|
ax2.set_title("Spine Analysis")
|
|
ax2.axis('off')
|
|
ax2.legend(loc='lower right')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(path, dpi=150, bbox_inches='tight')
|
|
plt.close()
|
|
|
|
def _save_landmarks_json(self, path: str, spine: Spine2D):
|
|
"""Save landmarks to JSON file with Rigo classification."""
|
|
def to_native(val):
|
|
"""Convert numpy types to native Python types."""
|
|
if isinstance(val, np.ndarray):
|
|
return val.tolist()
|
|
elif isinstance(val, (np.float32, np.float64)):
|
|
return float(val)
|
|
elif isinstance(val, (np.int32, np.int64)):
|
|
return int(val)
|
|
return val
|
|
|
|
# Get Rigo classification
|
|
rigo_result = classify_rigo_type(spine)
|
|
|
|
data = {
|
|
'source_model': spine.source_model,
|
|
'image_shape': list(spine.image_shape) if spine.image_shape else None,
|
|
'pixel_spacing_mm': spine.pixel_spacing_mm.tolist() if spine.pixel_spacing_mm is not None else None,
|
|
'cobb_angles': {
|
|
'PT': to_native(spine.cobb_pt),
|
|
'MT': to_native(spine.cobb_mt),
|
|
'TL': to_native(spine.cobb_tl),
|
|
},
|
|
'curve_type': spine.curve_type,
|
|
'rigo_classification': {
|
|
'type': rigo_result['rigo_type'],
|
|
'description': rigo_result['description'],
|
|
'curve_pattern': rigo_result['curve_pattern'],
|
|
'n_significant_curves': rigo_result['n_significant_curves'],
|
|
},
|
|
'vertebrae': []
|
|
}
|
|
|
|
for vert in spine.vertebrae:
|
|
vert_data = {
|
|
'level': vert.level,
|
|
'centroid_px': vert.centroid_px.tolist(),
|
|
'orientation_deg': to_native(vert.orientation_deg),
|
|
'confidence': to_native(vert.confidence),
|
|
}
|
|
if vert.corners_px is not None:
|
|
vert_data['corners_px'] = vert.corners_px.tolist()
|
|
data['vertebrae'].append(vert_data)
|
|
|
|
with open(path, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
|
|
|
|
def main():
|
|
"""Command-line interface for brace generation."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Generate scoliosis brace from X-ray')
|
|
parser.add_argument('input', help='Input X-ray image (JPEG, PNG, or DICOM)')
|
|
parser.add_argument('output', help='Output STL file path')
|
|
parser.add_argument('--model', choices=['scoliovis', 'vertebra-landmark'],
|
|
default='scoliovis', help='Landmark detection model')
|
|
parser.add_argument('--device', default='cpu', help='Device (cpu or cuda)')
|
|
parser.add_argument('--body-scan', help='Path to 3D body scan mesh (optional)')
|
|
parser.add_argument('--visualize', action='store_true', help='Save visualization')
|
|
parser.add_argument('--save-landmarks', action='store_true', help='Save landmarks JSON')
|
|
parser.add_argument('--pressure', type=float, default=15.0,
|
|
help='Pressure strength in mm (default: 15)')
|
|
parser.add_argument('--thickness', type=float, default=4.0,
|
|
help='Wall thickness in mm (default: 4)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Build config
|
|
config = BraceConfig(
|
|
pressure_strength_mm=args.pressure,
|
|
wall_thickness_mm=args.thickness,
|
|
)
|
|
|
|
if args.body_scan:
|
|
config.use_body_scan = True
|
|
config.body_scan_path = args.body_scan
|
|
|
|
# Run pipeline
|
|
pipeline = BracePipeline(model=args.model, config=config, device=args.device)
|
|
results = pipeline.process(
|
|
args.input,
|
|
args.output,
|
|
visualize=args.visualize,
|
|
save_landmarks=args.save_landmarks
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|