Files
braceiqmed/brace-generator/pipeline.py

347 lines
13 KiB
Python

"""
Complete pipeline: X-ray → Landmarks → Brace STL
"""
import numpy as np
from pathlib import Path
from typing import Optional, Dict, Any, Union
import json
from brace_generator.data_models import Spine2D, BraceConfig, VertebraLandmark
from brace_generator.image_loader import load_xray, load_xray_rgb
from brace_generator.adapters import BaseLandmarkAdapter, ScolioVisAdapter, VertLandmarkAdapter
from brace_generator.spine_analysis import (
compute_spine_curve, compute_cobb_angles, find_apex_vertebrae,
get_curve_severity, classify_rigo_type
)
from brace_generator.brace_surface import BraceGenerator
class BracePipeline:
"""
End-to-end pipeline for generating scoliosis braces from X-rays.
Usage:
# Basic usage with default model
pipeline = BracePipeline()
pipeline.process("xray.png", "brace.stl")
# With specific model
pipeline = BracePipeline(model="vertebra-landmark")
pipeline.process("xray.dcm", "brace.stl")
# With body scan
config = BraceConfig(use_body_scan=True, body_scan_path="body.obj")
pipeline = BracePipeline(config=config)
pipeline.process("xray.png", "brace.stl")
"""
AVAILABLE_MODELS = {
'scoliovis': ScolioVisAdapter,
'vertebra-landmark': VertLandmarkAdapter,
}
def __init__(
self,
model: str = 'scoliovis',
config: Optional[BraceConfig] = None,
device: str = 'cpu'
):
"""
Initialize pipeline.
Args:
model: Model to use ('scoliovis' or 'vertebra-landmark')
config: Brace configuration
device: 'cpu' or 'cuda'
"""
self.device = device
self.config = config or BraceConfig()
self.model_name = model.lower()
# Initialize model adapter
if self.model_name not in self.AVAILABLE_MODELS:
raise ValueError(f"Unknown model: {model}. Available: {list(self.AVAILABLE_MODELS.keys())}")
self.adapter: BaseLandmarkAdapter = self.AVAILABLE_MODELS[self.model_name](device=device)
self.brace_generator = BraceGenerator(self.config)
# Store last results for inspection
self.last_spine: Optional[Spine2D] = None
self.last_image: Optional[np.ndarray] = None
def process(
self,
xray_path: str,
output_stl_path: str,
visualize: bool = False,
save_landmarks: bool = False
) -> Dict[str, Any]:
"""
Process X-ray and generate brace STL.
Args:
xray_path: Path to input X-ray (JPEG, PNG, or DICOM)
output_stl_path: Path for output STL file
visualize: If True, also save visualization image
save_landmarks: If True, also save landmarks JSON
Returns:
Dictionary with analysis results
"""
print(f"=" * 60)
print(f"Brace Generation Pipeline")
print(f"Model: {self.adapter.name}")
print(f"=" * 60)
# 1) Load X-ray
print(f"\n1. Loading X-ray: {xray_path}")
image_rgb, pixel_spacing = load_xray_rgb(xray_path)
self.last_image = image_rgb
print(f" Image size: {image_rgb.shape[:2]}")
if pixel_spacing is not None:
print(f" Pixel spacing: {pixel_spacing} mm")
# 2) Detect landmarks
print(f"\n2. Detecting landmarks...")
spine = self.adapter.predict(image_rgb)
spine.pixel_spacing_mm = pixel_spacing
self.last_spine = spine
print(f" Detected {len(spine.vertebrae)} vertebrae")
if len(spine.vertebrae) < 5:
raise ValueError(f"Insufficient vertebrae detected ({len(spine.vertebrae)}). Need at least 5.")
# 3) Compute spine analysis
print(f"\n3. Analyzing spine curvature...")
compute_cobb_angles(spine)
apexes = find_apex_vertebrae(spine)
# Classify Rigo type
rigo_result = classify_rigo_type(spine)
print(f" Cobb Angles:")
print(f" PT (Proximal Thoracic): {spine.cobb_pt:.1f}° - {get_curve_severity(spine.cobb_pt)}")
print(f" MT (Main Thoracic): {spine.cobb_mt:.1f}° - {get_curve_severity(spine.cobb_mt)}")
print(f" TL (Thoracolumbar): {spine.cobb_tl:.1f}° - {get_curve_severity(spine.cobb_tl)}")
print(f" Curve type: {spine.curve_type}")
print(f" Rigo Classification: {rigo_result['rigo_type']}")
print(f" - {rigo_result['description']}")
print(f" Apex vertebrae indices: {apexes}")
# 4) Generate brace
print(f"\n4. Generating brace mesh...")
if self.config.use_body_scan:
print(f" Mode: Using body scan ({self.config.body_scan_path})")
else:
print(f" Mode: Average body shape")
brace_mesh = self.brace_generator.generate(spine)
print(f" Mesh: {len(brace_mesh.vertices)} vertices, {len(brace_mesh.faces)} faces")
# 5) Export STL
print(f"\n5. Exporting STL: {output_stl_path}")
self.brace_generator.export_stl(brace_mesh, output_stl_path)
# 6) Optional: Save visualization
if visualize:
vis_path = str(Path(output_stl_path).with_suffix('.png'))
self._save_visualization(vis_path, spine, image_rgb)
print(f" Visualization saved: {vis_path}")
# 7) Optional: Save landmarks JSON
if save_landmarks:
json_path = str(Path(output_stl_path).with_suffix('.json'))
self._save_landmarks_json(json_path, spine)
print(f" Landmarks saved: {json_path}")
# Prepare results
results = {
'input_image': xray_path,
'output_stl': output_stl_path,
'model': self.adapter.name,
'vertebrae_detected': len(spine.vertebrae),
'cobb_angles': {
'PT': spine.cobb_pt,
'MT': spine.cobb_mt,
'TL': spine.cobb_tl,
},
'curve_type': spine.curve_type,
'rigo_type': rigo_result['rigo_type'],
'rigo_description': rigo_result['description'],
'apex_indices': apexes,
'mesh_vertices': len(brace_mesh.vertices),
'mesh_faces': len(brace_mesh.faces),
}
print(f"\n{'=' * 60}")
print(f"Pipeline complete!")
print(f"{'=' * 60}")
return results
def _save_visualization(self, path: str, spine: Spine2D, image: np.ndarray):
"""Save visualization of detected landmarks and spine curve."""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
except ImportError:
print(" Warning: matplotlib not available for visualization")
return
fig, axes = plt.subplots(1, 2, figsize=(14, 10))
# Left: Original with landmarks
ax1 = axes[0]
ax1.imshow(image)
# Draw vertebra centers
centroids = spine.get_centroids()
ax1.scatter(centroids[:, 0], centroids[:, 1], c='red', s=30, zorder=5)
# Draw corners if available
for vert in spine.vertebrae:
if vert.corners_px is not None:
corners = vert.corners_px
# Draw quadrilateral
for i in range(4):
j = (i + 1) % 4
ax1.plot([corners[i, 0], corners[j, 0]],
[corners[i, 1], corners[j, 1]], 'g-', linewidth=1)
ax1.set_title(f"Detected Landmarks ({len(spine.vertebrae)} vertebrae)")
ax1.axis('off')
# Right: Spine curve analysis
ax2 = axes[1]
ax2.imshow(image, alpha=0.5)
# Draw spine curve
try:
C, T, N, curv = compute_spine_curve(spine)
ax2.plot(C[:, 0], C[:, 1], 'b-', linewidth=2, label='Spine curve')
# Highlight high curvature regions
high_curv_mask = curv > curv.mean() + curv.std()
ax2.scatter(C[high_curv_mask, 0], C[high_curv_mask, 1],
c='orange', s=20, label='High curvature')
except:
pass
# Get Rigo classification for display
rigo_result = classify_rigo_type(spine)
# Add Cobb angles and Rigo type text
text = f"Cobb Angles:\n"
text += f"PT: {spine.cobb_pt:.1f}°\n"
text += f"MT: {spine.cobb_mt:.1f}°\n"
text += f"TL: {spine.cobb_tl:.1f}°\n"
text += f"Curve: {spine.curve_type}\n"
text += f"-----------\n"
text += f"Rigo: {rigo_result['rigo_type']}"
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
ax2.set_title("Spine Analysis")
ax2.axis('off')
ax2.legend(loc='lower right')
plt.tight_layout()
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
def _save_landmarks_json(self, path: str, spine: Spine2D):
"""Save landmarks to JSON file with Rigo classification."""
def to_native(val):
"""Convert numpy types to native Python types."""
if isinstance(val, np.ndarray):
return val.tolist()
elif isinstance(val, (np.float32, np.float64)):
return float(val)
elif isinstance(val, (np.int32, np.int64)):
return int(val)
return val
# Get Rigo classification
rigo_result = classify_rigo_type(spine)
data = {
'source_model': spine.source_model,
'image_shape': list(spine.image_shape) if spine.image_shape else None,
'pixel_spacing_mm': spine.pixel_spacing_mm.tolist() if spine.pixel_spacing_mm is not None else None,
'cobb_angles': {
'PT': to_native(spine.cobb_pt),
'MT': to_native(spine.cobb_mt),
'TL': to_native(spine.cobb_tl),
},
'curve_type': spine.curve_type,
'rigo_classification': {
'type': rigo_result['rigo_type'],
'description': rigo_result['description'],
'curve_pattern': rigo_result['curve_pattern'],
'n_significant_curves': rigo_result['n_significant_curves'],
},
'vertebrae': []
}
for vert in spine.vertebrae:
vert_data = {
'level': vert.level,
'centroid_px': vert.centroid_px.tolist(),
'orientation_deg': to_native(vert.orientation_deg),
'confidence': to_native(vert.confidence),
}
if vert.corners_px is not None:
vert_data['corners_px'] = vert.corners_px.tolist()
data['vertebrae'].append(vert_data)
with open(path, 'w') as f:
json.dump(data, f, indent=2)
def main():
"""Command-line interface for brace generation."""
import argparse
parser = argparse.ArgumentParser(description='Generate scoliosis brace from X-ray')
parser.add_argument('input', help='Input X-ray image (JPEG, PNG, or DICOM)')
parser.add_argument('output', help='Output STL file path')
parser.add_argument('--model', choices=['scoliovis', 'vertebra-landmark'],
default='scoliovis', help='Landmark detection model')
parser.add_argument('--device', default='cpu', help='Device (cpu or cuda)')
parser.add_argument('--body-scan', help='Path to 3D body scan mesh (optional)')
parser.add_argument('--visualize', action='store_true', help='Save visualization')
parser.add_argument('--save-landmarks', action='store_true', help='Save landmarks JSON')
parser.add_argument('--pressure', type=float, default=15.0,
help='Pressure strength in mm (default: 15)')
parser.add_argument('--thickness', type=float, default=4.0,
help='Wall thickness in mm (default: 4)')
args = parser.parse_args()
# Build config
config = BraceConfig(
pressure_strength_mm=args.pressure,
wall_thickness_mm=args.thickness,
)
if args.body_scan:
config.use_body_scan = True
config.body_scan_path = args.body_scan
# Run pipeline
pipeline = BracePipeline(model=args.model, config=config, device=args.device)
results = pipeline.process(
args.input,
args.output,
visualize=args.visualize,
save_landmarks=args.save_landmarks
)
return results
if __name__ == '__main__':
main()