Add patient management, deployment scripts, and Docker fixes
This commit is contained in:
346
brace-generator/pipeline.py
Normal file
346
brace-generator/pipeline.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Complete pipeline: X-ray → Landmarks → Brace STL
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Union
|
||||
import json
|
||||
|
||||
from brace_generator.data_models import Spine2D, BraceConfig, VertebraLandmark
|
||||
from brace_generator.image_loader import load_xray, load_xray_rgb
|
||||
from brace_generator.adapters import BaseLandmarkAdapter, ScolioVisAdapter, VertLandmarkAdapter
|
||||
from brace_generator.spine_analysis import (
|
||||
compute_spine_curve, compute_cobb_angles, find_apex_vertebrae,
|
||||
get_curve_severity, classify_rigo_type
|
||||
)
|
||||
from brace_generator.brace_surface import BraceGenerator
|
||||
|
||||
|
||||
class BracePipeline:
|
||||
"""
|
||||
End-to-end pipeline for generating scoliosis braces from X-rays.
|
||||
|
||||
Usage:
|
||||
# Basic usage with default model
|
||||
pipeline = BracePipeline()
|
||||
pipeline.process("xray.png", "brace.stl")
|
||||
|
||||
# With specific model
|
||||
pipeline = BracePipeline(model="vertebra-landmark")
|
||||
pipeline.process("xray.dcm", "brace.stl")
|
||||
|
||||
# With body scan
|
||||
config = BraceConfig(use_body_scan=True, body_scan_path="body.obj")
|
||||
pipeline = BracePipeline(config=config)
|
||||
pipeline.process("xray.png", "brace.stl")
|
||||
"""
|
||||
|
||||
AVAILABLE_MODELS = {
|
||||
'scoliovis': ScolioVisAdapter,
|
||||
'vertebra-landmark': VertLandmarkAdapter,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = 'scoliovis',
|
||||
config: Optional[BraceConfig] = None,
|
||||
device: str = 'cpu'
|
||||
):
|
||||
"""
|
||||
Initialize pipeline.
|
||||
|
||||
Args:
|
||||
model: Model to use ('scoliovis' or 'vertebra-landmark')
|
||||
config: Brace configuration
|
||||
device: 'cpu' or 'cuda'
|
||||
"""
|
||||
self.device = device
|
||||
self.config = config or BraceConfig()
|
||||
self.model_name = model.lower()
|
||||
|
||||
# Initialize model adapter
|
||||
if self.model_name not in self.AVAILABLE_MODELS:
|
||||
raise ValueError(f"Unknown model: {model}. Available: {list(self.AVAILABLE_MODELS.keys())}")
|
||||
|
||||
self.adapter: BaseLandmarkAdapter = self.AVAILABLE_MODELS[self.model_name](device=device)
|
||||
self.brace_generator = BraceGenerator(self.config)
|
||||
|
||||
# Store last results for inspection
|
||||
self.last_spine: Optional[Spine2D] = None
|
||||
self.last_image: Optional[np.ndarray] = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
xray_path: str,
|
||||
output_stl_path: str,
|
||||
visualize: bool = False,
|
||||
save_landmarks: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process X-ray and generate brace STL.
|
||||
|
||||
Args:
|
||||
xray_path: Path to input X-ray (JPEG, PNG, or DICOM)
|
||||
output_stl_path: Path for output STL file
|
||||
visualize: If True, also save visualization image
|
||||
save_landmarks: If True, also save landmarks JSON
|
||||
|
||||
Returns:
|
||||
Dictionary with analysis results
|
||||
"""
|
||||
print(f"=" * 60)
|
||||
print(f"Brace Generation Pipeline")
|
||||
print(f"Model: {self.adapter.name}")
|
||||
print(f"=" * 60)
|
||||
|
||||
# 1) Load X-ray
|
||||
print(f"\n1. Loading X-ray: {xray_path}")
|
||||
image_rgb, pixel_spacing = load_xray_rgb(xray_path)
|
||||
self.last_image = image_rgb
|
||||
print(f" Image size: {image_rgb.shape[:2]}")
|
||||
if pixel_spacing is not None:
|
||||
print(f" Pixel spacing: {pixel_spacing} mm")
|
||||
|
||||
# 2) Detect landmarks
|
||||
print(f"\n2. Detecting landmarks...")
|
||||
spine = self.adapter.predict(image_rgb)
|
||||
spine.pixel_spacing_mm = pixel_spacing
|
||||
self.last_spine = spine
|
||||
|
||||
print(f" Detected {len(spine.vertebrae)} vertebrae")
|
||||
|
||||
if len(spine.vertebrae) < 5:
|
||||
raise ValueError(f"Insufficient vertebrae detected ({len(spine.vertebrae)}). Need at least 5.")
|
||||
|
||||
# 3) Compute spine analysis
|
||||
print(f"\n3. Analyzing spine curvature...")
|
||||
compute_cobb_angles(spine)
|
||||
apexes = find_apex_vertebrae(spine)
|
||||
|
||||
# Classify Rigo type
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
print(f" Cobb Angles:")
|
||||
print(f" PT (Proximal Thoracic): {spine.cobb_pt:.1f}° - {get_curve_severity(spine.cobb_pt)}")
|
||||
print(f" MT (Main Thoracic): {spine.cobb_mt:.1f}° - {get_curve_severity(spine.cobb_mt)}")
|
||||
print(f" TL (Thoracolumbar): {spine.cobb_tl:.1f}° - {get_curve_severity(spine.cobb_tl)}")
|
||||
print(f" Curve type: {spine.curve_type}")
|
||||
print(f" Rigo Classification: {rigo_result['rigo_type']}")
|
||||
print(f" - {rigo_result['description']}")
|
||||
print(f" Apex vertebrae indices: {apexes}")
|
||||
|
||||
# 4) Generate brace
|
||||
print(f"\n4. Generating brace mesh...")
|
||||
if self.config.use_body_scan:
|
||||
print(f" Mode: Using body scan ({self.config.body_scan_path})")
|
||||
else:
|
||||
print(f" Mode: Average body shape")
|
||||
|
||||
brace_mesh = self.brace_generator.generate(spine)
|
||||
print(f" Mesh: {len(brace_mesh.vertices)} vertices, {len(brace_mesh.faces)} faces")
|
||||
|
||||
# 5) Export STL
|
||||
print(f"\n5. Exporting STL: {output_stl_path}")
|
||||
self.brace_generator.export_stl(brace_mesh, output_stl_path)
|
||||
|
||||
# 6) Optional: Save visualization
|
||||
if visualize:
|
||||
vis_path = str(Path(output_stl_path).with_suffix('.png'))
|
||||
self._save_visualization(vis_path, spine, image_rgb)
|
||||
print(f" Visualization saved: {vis_path}")
|
||||
|
||||
# 7) Optional: Save landmarks JSON
|
||||
if save_landmarks:
|
||||
json_path = str(Path(output_stl_path).with_suffix('.json'))
|
||||
self._save_landmarks_json(json_path, spine)
|
||||
print(f" Landmarks saved: {json_path}")
|
||||
|
||||
# Prepare results
|
||||
results = {
|
||||
'input_image': xray_path,
|
||||
'output_stl': output_stl_path,
|
||||
'model': self.adapter.name,
|
||||
'vertebrae_detected': len(spine.vertebrae),
|
||||
'cobb_angles': {
|
||||
'PT': spine.cobb_pt,
|
||||
'MT': spine.cobb_mt,
|
||||
'TL': spine.cobb_tl,
|
||||
},
|
||||
'curve_type': spine.curve_type,
|
||||
'rigo_type': rigo_result['rigo_type'],
|
||||
'rigo_description': rigo_result['description'],
|
||||
'apex_indices': apexes,
|
||||
'mesh_vertices': len(brace_mesh.vertices),
|
||||
'mesh_faces': len(brace_mesh.faces),
|
||||
}
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Pipeline complete!")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
return results
|
||||
|
||||
def _save_visualization(self, path: str, spine: Spine2D, image: np.ndarray):
|
||||
"""Save visualization of detected landmarks and spine curve."""
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
print(" Warning: matplotlib not available for visualization")
|
||||
return
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 10))
|
||||
|
||||
# Left: Original with landmarks
|
||||
ax1 = axes[0]
|
||||
ax1.imshow(image)
|
||||
|
||||
# Draw vertebra centers
|
||||
centroids = spine.get_centroids()
|
||||
ax1.scatter(centroids[:, 0], centroids[:, 1], c='red', s=30, zorder=5)
|
||||
|
||||
# Draw corners if available
|
||||
for vert in spine.vertebrae:
|
||||
if vert.corners_px is not None:
|
||||
corners = vert.corners_px
|
||||
# Draw quadrilateral
|
||||
for i in range(4):
|
||||
j = (i + 1) % 4
|
||||
ax1.plot([corners[i, 0], corners[j, 0]],
|
||||
[corners[i, 1], corners[j, 1]], 'g-', linewidth=1)
|
||||
|
||||
ax1.set_title(f"Detected Landmarks ({len(spine.vertebrae)} vertebrae)")
|
||||
ax1.axis('off')
|
||||
|
||||
# Right: Spine curve analysis
|
||||
ax2 = axes[1]
|
||||
ax2.imshow(image, alpha=0.5)
|
||||
|
||||
# Draw spine curve
|
||||
try:
|
||||
C, T, N, curv = compute_spine_curve(spine)
|
||||
ax2.plot(C[:, 0], C[:, 1], 'b-', linewidth=2, label='Spine curve')
|
||||
|
||||
# Highlight high curvature regions
|
||||
high_curv_mask = curv > curv.mean() + curv.std()
|
||||
ax2.scatter(C[high_curv_mask, 0], C[high_curv_mask, 1],
|
||||
c='orange', s=20, label='High curvature')
|
||||
except:
|
||||
pass
|
||||
|
||||
# Get Rigo classification for display
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
# Add Cobb angles and Rigo type text
|
||||
text = f"Cobb Angles:\n"
|
||||
text += f"PT: {spine.cobb_pt:.1f}°\n"
|
||||
text += f"MT: {spine.cobb_mt:.1f}°\n"
|
||||
text += f"TL: {spine.cobb_tl:.1f}°\n"
|
||||
text += f"Curve: {spine.curve_type}\n"
|
||||
text += f"-----------\n"
|
||||
text += f"Rigo: {rigo_result['rigo_type']}"
|
||||
ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10,
|
||||
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
|
||||
|
||||
ax2.set_title("Spine Analysis")
|
||||
ax2.axis('off')
|
||||
ax2.legend(loc='lower right')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def _save_landmarks_json(self, path: str, spine: Spine2D):
|
||||
"""Save landmarks to JSON file with Rigo classification."""
|
||||
def to_native(val):
|
||||
"""Convert numpy types to native Python types."""
|
||||
if isinstance(val, np.ndarray):
|
||||
return val.tolist()
|
||||
elif isinstance(val, (np.float32, np.float64)):
|
||||
return float(val)
|
||||
elif isinstance(val, (np.int32, np.int64)):
|
||||
return int(val)
|
||||
return val
|
||||
|
||||
# Get Rigo classification
|
||||
rigo_result = classify_rigo_type(spine)
|
||||
|
||||
data = {
|
||||
'source_model': spine.source_model,
|
||||
'image_shape': list(spine.image_shape) if spine.image_shape else None,
|
||||
'pixel_spacing_mm': spine.pixel_spacing_mm.tolist() if spine.pixel_spacing_mm is not None else None,
|
||||
'cobb_angles': {
|
||||
'PT': to_native(spine.cobb_pt),
|
||||
'MT': to_native(spine.cobb_mt),
|
||||
'TL': to_native(spine.cobb_tl),
|
||||
},
|
||||
'curve_type': spine.curve_type,
|
||||
'rigo_classification': {
|
||||
'type': rigo_result['rigo_type'],
|
||||
'description': rigo_result['description'],
|
||||
'curve_pattern': rigo_result['curve_pattern'],
|
||||
'n_significant_curves': rigo_result['n_significant_curves'],
|
||||
},
|
||||
'vertebrae': []
|
||||
}
|
||||
|
||||
for vert in spine.vertebrae:
|
||||
vert_data = {
|
||||
'level': vert.level,
|
||||
'centroid_px': vert.centroid_px.tolist(),
|
||||
'orientation_deg': to_native(vert.orientation_deg),
|
||||
'confidence': to_native(vert.confidence),
|
||||
}
|
||||
if vert.corners_px is not None:
|
||||
vert_data['corners_px'] = vert.corners_px.tolist()
|
||||
data['vertebrae'].append(vert_data)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
"""Command-line interface for brace generation."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate scoliosis brace from X-ray')
|
||||
parser.add_argument('input', help='Input X-ray image (JPEG, PNG, or DICOM)')
|
||||
parser.add_argument('output', help='Output STL file path')
|
||||
parser.add_argument('--model', choices=['scoliovis', 'vertebra-landmark'],
|
||||
default='scoliovis', help='Landmark detection model')
|
||||
parser.add_argument('--device', default='cpu', help='Device (cpu or cuda)')
|
||||
parser.add_argument('--body-scan', help='Path to 3D body scan mesh (optional)')
|
||||
parser.add_argument('--visualize', action='store_true', help='Save visualization')
|
||||
parser.add_argument('--save-landmarks', action='store_true', help='Save landmarks JSON')
|
||||
parser.add_argument('--pressure', type=float, default=15.0,
|
||||
help='Pressure strength in mm (default: 15)')
|
||||
parser.add_argument('--thickness', type=float, default=4.0,
|
||||
help='Wall thickness in mm (default: 4)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Build config
|
||||
config = BraceConfig(
|
||||
pressure_strength_mm=args.pressure,
|
||||
wall_thickness_mm=args.thickness,
|
||||
)
|
||||
|
||||
if args.body_scan:
|
||||
config.use_body_scan = True
|
||||
config.body_scan_path = args.body_scan
|
||||
|
||||
# Run pipeline
|
||||
pipeline = BracePipeline(model=args.model, config=config, device=args.device)
|
||||
results = pipeline.process(
|
||||
args.input,
|
||||
args.output,
|
||||
visualize=args.visualize,
|
||||
save_landmarks=args.save_landmarks
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user