""" Complete pipeline: X-ray → Landmarks → Brace STL """ import numpy as np from pathlib import Path from typing import Optional, Dict, Any, Union import json from brace_generator.data_models import Spine2D, BraceConfig, VertebraLandmark from brace_generator.image_loader import load_xray, load_xray_rgb from brace_generator.adapters import BaseLandmarkAdapter, ScolioVisAdapter, VertLandmarkAdapter from brace_generator.spine_analysis import ( compute_spine_curve, compute_cobb_angles, find_apex_vertebrae, get_curve_severity, classify_rigo_type ) from brace_generator.brace_surface import BraceGenerator class BracePipeline: """ End-to-end pipeline for generating scoliosis braces from X-rays. Usage: # Basic usage with default model pipeline = BracePipeline() pipeline.process("xray.png", "brace.stl") # With specific model pipeline = BracePipeline(model="vertebra-landmark") pipeline.process("xray.dcm", "brace.stl") # With body scan config = BraceConfig(use_body_scan=True, body_scan_path="body.obj") pipeline = BracePipeline(config=config) pipeline.process("xray.png", "brace.stl") """ AVAILABLE_MODELS = { 'scoliovis': ScolioVisAdapter, 'vertebra-landmark': VertLandmarkAdapter, } def __init__( self, model: str = 'scoliovis', config: Optional[BraceConfig] = None, device: str = 'cpu' ): """ Initialize pipeline. Args: model: Model to use ('scoliovis' or 'vertebra-landmark') config: Brace configuration device: 'cpu' or 'cuda' """ self.device = device self.config = config or BraceConfig() self.model_name = model.lower() # Initialize model adapter if self.model_name not in self.AVAILABLE_MODELS: raise ValueError(f"Unknown model: {model}. Available: {list(self.AVAILABLE_MODELS.keys())}") self.adapter: BaseLandmarkAdapter = self.AVAILABLE_MODELS[self.model_name](device=device) self.brace_generator = BraceGenerator(self.config) # Store last results for inspection self.last_spine: Optional[Spine2D] = None self.last_image: Optional[np.ndarray] = None def process( self, xray_path: str, output_stl_path: str, visualize: bool = False, save_landmarks: bool = False ) -> Dict[str, Any]: """ Process X-ray and generate brace STL. Args: xray_path: Path to input X-ray (JPEG, PNG, or DICOM) output_stl_path: Path for output STL file visualize: If True, also save visualization image save_landmarks: If True, also save landmarks JSON Returns: Dictionary with analysis results """ print(f"=" * 60) print(f"Brace Generation Pipeline") print(f"Model: {self.adapter.name}") print(f"=" * 60) # 1) Load X-ray print(f"\n1. Loading X-ray: {xray_path}") image_rgb, pixel_spacing = load_xray_rgb(xray_path) self.last_image = image_rgb print(f" Image size: {image_rgb.shape[:2]}") if pixel_spacing is not None: print(f" Pixel spacing: {pixel_spacing} mm") # 2) Detect landmarks print(f"\n2. Detecting landmarks...") spine = self.adapter.predict(image_rgb) spine.pixel_spacing_mm = pixel_spacing self.last_spine = spine print(f" Detected {len(spine.vertebrae)} vertebrae") if len(spine.vertebrae) < 5: raise ValueError(f"Insufficient vertebrae detected ({len(spine.vertebrae)}). Need at least 5.") # 3) Compute spine analysis print(f"\n3. Analyzing spine curvature...") compute_cobb_angles(spine) apexes = find_apex_vertebrae(spine) # Classify Rigo type rigo_result = classify_rigo_type(spine) print(f" Cobb Angles:") print(f" PT (Proximal Thoracic): {spine.cobb_pt:.1f}° - {get_curve_severity(spine.cobb_pt)}") print(f" MT (Main Thoracic): {spine.cobb_mt:.1f}° - {get_curve_severity(spine.cobb_mt)}") print(f" TL (Thoracolumbar): {spine.cobb_tl:.1f}° - {get_curve_severity(spine.cobb_tl)}") print(f" Curve type: {spine.curve_type}") print(f" Rigo Classification: {rigo_result['rigo_type']}") print(f" - {rigo_result['description']}") print(f" Apex vertebrae indices: {apexes}") # 4) Generate brace print(f"\n4. Generating brace mesh...") if self.config.use_body_scan: print(f" Mode: Using body scan ({self.config.body_scan_path})") else: print(f" Mode: Average body shape") brace_mesh = self.brace_generator.generate(spine) print(f" Mesh: {len(brace_mesh.vertices)} vertices, {len(brace_mesh.faces)} faces") # 5) Export STL print(f"\n5. Exporting STL: {output_stl_path}") self.brace_generator.export_stl(brace_mesh, output_stl_path) # 6) Optional: Save visualization if visualize: vis_path = str(Path(output_stl_path).with_suffix('.png')) self._save_visualization(vis_path, spine, image_rgb) print(f" Visualization saved: {vis_path}") # 7) Optional: Save landmarks JSON if save_landmarks: json_path = str(Path(output_stl_path).with_suffix('.json')) self._save_landmarks_json(json_path, spine) print(f" Landmarks saved: {json_path}") # Prepare results results = { 'input_image': xray_path, 'output_stl': output_stl_path, 'model': self.adapter.name, 'vertebrae_detected': len(spine.vertebrae), 'cobb_angles': { 'PT': spine.cobb_pt, 'MT': spine.cobb_mt, 'TL': spine.cobb_tl, }, 'curve_type': spine.curve_type, 'rigo_type': rigo_result['rigo_type'], 'rigo_description': rigo_result['description'], 'apex_indices': apexes, 'mesh_vertices': len(brace_mesh.vertices), 'mesh_faces': len(brace_mesh.faces), } print(f"\n{'=' * 60}") print(f"Pipeline complete!") print(f"{'=' * 60}") return results def _save_visualization(self, path: str, spine: Spine2D, image: np.ndarray): """Save visualization of detected landmarks and spine curve.""" try: import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt except ImportError: print(" Warning: matplotlib not available for visualization") return fig, axes = plt.subplots(1, 2, figsize=(14, 10)) # Left: Original with landmarks ax1 = axes[0] ax1.imshow(image) # Draw vertebra centers centroids = spine.get_centroids() ax1.scatter(centroids[:, 0], centroids[:, 1], c='red', s=30, zorder=5) # Draw corners if available for vert in spine.vertebrae: if vert.corners_px is not None: corners = vert.corners_px # Draw quadrilateral for i in range(4): j = (i + 1) % 4 ax1.plot([corners[i, 0], corners[j, 0]], [corners[i, 1], corners[j, 1]], 'g-', linewidth=1) ax1.set_title(f"Detected Landmarks ({len(spine.vertebrae)} vertebrae)") ax1.axis('off') # Right: Spine curve analysis ax2 = axes[1] ax2.imshow(image, alpha=0.5) # Draw spine curve try: C, T, N, curv = compute_spine_curve(spine) ax2.plot(C[:, 0], C[:, 1], 'b-', linewidth=2, label='Spine curve') # Highlight high curvature regions high_curv_mask = curv > curv.mean() + curv.std() ax2.scatter(C[high_curv_mask, 0], C[high_curv_mask, 1], c='orange', s=20, label='High curvature') except: pass # Get Rigo classification for display rigo_result = classify_rigo_type(spine) # Add Cobb angles and Rigo type text text = f"Cobb Angles:\n" text += f"PT: {spine.cobb_pt:.1f}°\n" text += f"MT: {spine.cobb_mt:.1f}°\n" text += f"TL: {spine.cobb_tl:.1f}°\n" text += f"Curve: {spine.curve_type}\n" text += f"-----------\n" text += f"Rigo: {rigo_result['rigo_type']}" ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) ax2.set_title("Spine Analysis") ax2.axis('off') ax2.legend(loc='lower right') plt.tight_layout() plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() def _save_landmarks_json(self, path: str, spine: Spine2D): """Save landmarks to JSON file with Rigo classification.""" def to_native(val): """Convert numpy types to native Python types.""" if isinstance(val, np.ndarray): return val.tolist() elif isinstance(val, (np.float32, np.float64)): return float(val) elif isinstance(val, (np.int32, np.int64)): return int(val) return val # Get Rigo classification rigo_result = classify_rigo_type(spine) data = { 'source_model': spine.source_model, 'image_shape': list(spine.image_shape) if spine.image_shape else None, 'pixel_spacing_mm': spine.pixel_spacing_mm.tolist() if spine.pixel_spacing_mm is not None else None, 'cobb_angles': { 'PT': to_native(spine.cobb_pt), 'MT': to_native(spine.cobb_mt), 'TL': to_native(spine.cobb_tl), }, 'curve_type': spine.curve_type, 'rigo_classification': { 'type': rigo_result['rigo_type'], 'description': rigo_result['description'], 'curve_pattern': rigo_result['curve_pattern'], 'n_significant_curves': rigo_result['n_significant_curves'], }, 'vertebrae': [] } for vert in spine.vertebrae: vert_data = { 'level': vert.level, 'centroid_px': vert.centroid_px.tolist(), 'orientation_deg': to_native(vert.orientation_deg), 'confidence': to_native(vert.confidence), } if vert.corners_px is not None: vert_data['corners_px'] = vert.corners_px.tolist() data['vertebrae'].append(vert_data) with open(path, 'w') as f: json.dump(data, f, indent=2) def main(): """Command-line interface for brace generation.""" import argparse parser = argparse.ArgumentParser(description='Generate scoliosis brace from X-ray') parser.add_argument('input', help='Input X-ray image (JPEG, PNG, or DICOM)') parser.add_argument('output', help='Output STL file path') parser.add_argument('--model', choices=['scoliovis', 'vertebra-landmark'], default='scoliovis', help='Landmark detection model') parser.add_argument('--device', default='cpu', help='Device (cpu or cuda)') parser.add_argument('--body-scan', help='Path to 3D body scan mesh (optional)') parser.add_argument('--visualize', action='store_true', help='Save visualization') parser.add_argument('--save-landmarks', action='store_true', help='Save landmarks JSON') parser.add_argument('--pressure', type=float, default=15.0, help='Pressure strength in mm (default: 15)') parser.add_argument('--thickness', type=float, default=4.0, help='Wall thickness in mm (default: 4)') args = parser.parse_args() # Build config config = BraceConfig( pressure_strength_mm=args.pressure, wall_thickness_mm=args.thickness, ) if args.body_scan: config.use_body_scan = True config.body_scan_path = args.body_scan # Run pipeline pipeline = BracePipeline(model=args.model, config=config, device=args.device) results = pipeline.process( args.input, args.output, visualize=args.visualize, save_landmarks=args.save_landmarks ) return results if __name__ == '__main__': main()