Files
braceiqmed/brace-generator/body_integration.py

412 lines
15 KiB
Python

"""
Body scan integration for patient-specific brace fitting.
Based on EXPERIMENT_10's approach:
1. Extract body measurements from 3D scan
2. Compute body basis (coordinate frame)
3. Select template based on Rigo classification
4. Fit shell to body using basis alignment
"""
import sys
import json
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional, Tuple
from dataclasses import dataclass, asdict
try:
import trimesh
HAS_TRIMESH = True
except ImportError:
HAS_TRIMESH = False
# Add EXPERIMENT_10 to path for imports
EXPERIMENTS_DIR = Path(__file__).parent.parent / "EXPERIMENTS"
EXP10_DIR = EXPERIMENTS_DIR / "EXPERIMENT_10"
if str(EXP10_DIR) not in sys.path:
sys.path.insert(0, str(EXP10_DIR))
# Import EXPERIMENT_10 modules
try:
from body_measurements import extract_body_measurements, measurements_to_dict, BodyMeasurements
from body_basis import compute_body_basis, body_basis_to_dict, BodyBasis
from shell_fitter_v2 import (
fit_shell_to_body_v2,
compute_brace_basis_from_geometry,
brace_basis_to_dict,
RIGO_TO_VASE,
FittingFeedback
)
HAS_EXP10 = True
except ImportError as e:
print(f"Warning: Could not import EXPERIMENT_10 modules: {e}")
HAS_EXP10 = False
# Vase templates directory
VASES_DIR = Path(__file__).parent.parent.parent / "_vase" / "_vase"
def extract_measurements_from_scan(scan_path: str) -> Dict[str, Any]:
"""
Extract body measurements from a 3D body scan.
Args:
scan_path: Path to STL/OBJ/PLY body scan file
Returns:
Dictionary with measurements suitable for API response
"""
if not HAS_TRIMESH:
raise ImportError("trimesh is required for body scan processing")
# Try EXPERIMENT_10 first
if HAS_EXP10:
try:
measurements = extract_body_measurements(scan_path)
result = measurements_to_dict(measurements)
# Flatten for API-friendly format
flat = {
"total_height_mm": result["overall_dimensions"]["total_height_mm"],
"shoulder_width_mm": result["widths_mm"]["shoulder_width"],
"chest_width_mm": result["widths_mm"]["chest_width"],
"chest_depth_mm": result["depths_mm"]["chest_depth"],
"waist_width_mm": result["widths_mm"]["waist_width"],
"waist_depth_mm": result["depths_mm"]["waist_depth"],
"hip_width_mm": result["widths_mm"]["hip_width"],
"hip_depth_mm": result["depths_mm"]["hip_depth"],
"brace_coverage_height_mm": result["brace_coverage_region"]["coverage_height_mm"],
"chest_circumference_mm": result["circumferences_mm"]["chest"],
"waist_circumference_mm": result["circumferences_mm"]["waist"],
"hip_circumference_mm": result["circumferences_mm"]["hip"],
}
# Also include full detailed result
flat["detailed"] = result
return flat
except Exception as e:
print(f"EXPERIMENT_10 measurement extraction failed: {e}, using fallback")
# Fallback: Simple trimesh-based measurements
return _extract_measurements_trimesh_fallback(scan_path)
def _extract_measurements_trimesh_fallback(scan_path: str) -> Dict[str, Any]:
"""
Simple fallback for body measurements using trimesh bounding box analysis.
Less accurate than EXPERIMENT_10 but provides basic measurements.
"""
mesh = trimesh.load(scan_path)
# Get bounding box
bounds = mesh.bounds
min_pt, max_pt = bounds[0], bounds[1]
# Assuming Y is up (typical human scan orientation)
# Try to auto-detect orientation
extents = max_pt - min_pt
height_axis = np.argmax(extents) # Longest axis is usually height
if height_axis == 1: # Y-up
total_height = extents[1]
width_axis, depth_axis = 0, 2
elif height_axis == 2: # Z-up
total_height = extents[2]
width_axis, depth_axis = 0, 1
else: # X-up (unusual)
total_height = extents[0]
width_axis, depth_axis = 1, 2
width = extents[width_axis]
depth = extents[depth_axis]
# Estimate body segments using height percentages
# These are approximate ratios for human body
chest_height_ratio = 0.75 # Chest at 75% of height from bottom
waist_height_ratio = 0.60 # Waist at 60% of height
hip_height_ratio = 0.50 # Hips at 50% of height
shoulder_height_ratio = 0.82 # Shoulders at 82%
# Get cross-sections at different heights to estimate widths
def get_width_at_height(height_ratio):
if height_axis == 1:
h = min_pt[1] + total_height * height_ratio
mask = (mesh.vertices[:, 1] > h - total_height * 0.05) & \
(mesh.vertices[:, 1] < h + total_height * 0.05)
elif height_axis == 2:
h = min_pt[2] + total_height * height_ratio
mask = (mesh.vertices[:, 2] > h - total_height * 0.05) & \
(mesh.vertices[:, 2] < h + total_height * 0.05)
else:
h = min_pt[0] + total_height * height_ratio
mask = (mesh.vertices[:, 0] > h - total_height * 0.05) & \
(mesh.vertices[:, 0] < h + total_height * 0.05)
if not np.any(mask):
return width, depth
slice_verts = mesh.vertices[mask]
slice_width = np.ptp(slice_verts[:, width_axis])
slice_depth = np.ptp(slice_verts[:, depth_axis])
return slice_width, slice_depth
shoulder_w, shoulder_d = get_width_at_height(shoulder_height_ratio)
chest_w, chest_d = get_width_at_height(chest_height_ratio)
waist_w, waist_d = get_width_at_height(waist_height_ratio)
hip_w, hip_d = get_width_at_height(hip_height_ratio)
# Estimate circumferences using ellipse approximation
def estimate_circumference(w, d):
a, b = w / 2, d / 2
# Ramanujan's approximation for ellipse circumference
h = ((a - b) ** 2) / ((a + b) ** 2)
return np.pi * (a + b) * (1 + 3 * h / (10 + np.sqrt(4 - 3 * h)))
return {
"total_height_mm": float(total_height),
"shoulder_width_mm": float(shoulder_w),
"chest_width_mm": float(chest_w),
"chest_depth_mm": float(chest_d),
"waist_width_mm": float(waist_w),
"waist_depth_mm": float(waist_d),
"hip_width_mm": float(hip_w),
"hip_depth_mm": float(hip_d),
"brace_coverage_height_mm": float(total_height * 0.55), # 55% coverage
"chest_circumference_mm": float(estimate_circumference(chest_w, chest_d)),
"waist_circumference_mm": float(estimate_circumference(waist_w, waist_d)),
"hip_circumference_mm": float(estimate_circumference(hip_w, hip_d)),
"measurement_source": "trimesh_fallback"
}
def generate_fitted_brace(
body_scan_path: str,
rigo_type: str,
output_dir: str,
case_id: str,
clearance_mm: float = 8.0,
wall_thickness_mm: float = 2.4
) -> Dict[str, Any]:
"""
Generate a patient-specific brace fitted to body scan.
Args:
body_scan_path: Path to 3D body scan (STL/OBJ/PLY)
rigo_type: Rigo classification (A1, A2, B1, etc.)
output_dir: Directory to save output files
case_id: Case identifier for naming files
clearance_mm: Clearance between body and shell (default 8mm)
wall_thickness_mm: Shell wall thickness (default 2.4mm for 3D printing)
Returns:
Dictionary with output file paths and fitting info
"""
if not HAS_TRIMESH:
raise ImportError("trimesh is required for brace fitting")
if not HAS_EXP10:
raise ImportError("EXPERIMENT_10 modules not available")
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Select template based on Rigo type
template_file = RIGO_TO_VASE.get(rigo_type, "A1_vase.OBJ")
template_path = VASES_DIR / template_file
if not template_path.exists():
# Try alternative paths
alt_paths = [
EXPERIMENTS_DIR / "EXPERIMENT_10" / "_vase" / template_file,
Path(__file__).parent.parent.parent / "_vase" / template_file,
]
for alt in alt_paths:
if alt.exists():
template_path = alt
break
else:
raise FileNotFoundError(f"Template not found: {template_file}")
# Fit shell to body
# Returns: (shell_mesh, body_mesh, combined_mesh, feedback)
fitted_mesh, body_mesh, combined_mesh, feedback = fit_shell_to_body_v2(
body_scan_path=body_scan_path,
template_path=str(template_path),
clearance_mm=clearance_mm
)
# Generate output files
outputs = {}
# Shell STL (for 3D printing)
shell_stl = output_path / f"{case_id}_shell.stl"
fitted_mesh.export(str(shell_stl))
outputs["shell_stl"] = str(shell_stl)
# Shell GLB (for web viewing)
shell_glb = output_path / f"{case_id}_shell.glb"
fitted_mesh.export(str(shell_glb))
outputs["shell_glb"] = str(shell_glb)
# Combined body + shell STL (for visualization)
# combined_mesh is already returned from fit_shell_to_body_v2
combined_stl = output_path / f"{case_id}_body_with_shell.stl"
combined_mesh.export(str(combined_stl))
outputs["combined_stl"] = str(combined_stl)
# Feedback JSON
feedback_json = output_path / f"{case_id}_feedback.json"
with open(feedback_json, "w") as f:
json.dump(asdict(feedback), f, indent=2, default=_json_serializer)
outputs["feedback_json"] = str(feedback_json)
# Create visualization
try:
viz_path = output_path / f"{case_id}_visualization.png"
create_fitting_visualization(body_mesh, fitted_mesh, feedback, str(viz_path))
outputs["visualization"] = str(viz_path)
except Exception as e:
print(f"Warning: Could not create visualization: {e}")
# Return result
return {
"template_used": template_file,
"rigo_type": rigo_type,
"clearance_mm": clearance_mm,
"fitting": {
"scale_right": feedback.scale_right,
"scale_up": feedback.scale_up,
"scale_forward": feedback.scale_forward,
"pelvis_distance_mm": feedback.pelvis_distance_mm,
"up_alignment_dot": feedback.up_alignment_dot,
"warnings": feedback.warnings,
},
"body_measurements": {
"max_width_mm": feedback.max_body_width_mm,
"max_depth_mm": feedback.max_body_depth_mm,
},
"shell_dimensions": {
"width_mm": feedback.target_shell_width_mm,
"depth_mm": feedback.target_shell_depth_mm,
"bounds_min": feedback.final_bounds_min,
"bounds_max": feedback.final_bounds_max,
},
"mesh_stats": {
"vertices": len(fitted_mesh.vertices),
"faces": len(fitted_mesh.faces),
},
"outputs": outputs,
}
def create_fitting_visualization(
body_mesh: 'trimesh.Trimesh',
shell_mesh: 'trimesh.Trimesh',
feedback: 'FittingFeedback',
output_path: str
):
"""Create a multi-panel visualization of the fitted brace."""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
except ImportError:
return
fig = plt.figure(figsize=(16, 10))
# Panel 1: Front view
ax1 = fig.add_subplot(2, 3, 1, projection='3d')
plot_mesh_silhouette(ax1, body_mesh, 'gray', alpha=0.3)
plot_mesh_silhouette(ax1, shell_mesh, 'blue', alpha=0.6)
ax1.set_title('Front View')
ax1.view_init(elev=0, azim=0)
# Panel 2: Side view
ax2 = fig.add_subplot(2, 3, 2, projection='3d')
plot_mesh_silhouette(ax2, body_mesh, 'gray', alpha=0.3)
plot_mesh_silhouette(ax2, shell_mesh, 'blue', alpha=0.6)
ax2.set_title('Side View')
ax2.view_init(elev=0, azim=90)
# Panel 3: Top view
ax3 = fig.add_subplot(2, 3, 3, projection='3d')
plot_mesh_silhouette(ax3, body_mesh, 'gray', alpha=0.3)
plot_mesh_silhouette(ax3, shell_mesh, 'blue', alpha=0.6)
ax3.set_title('Top View')
ax3.view_init(elev=90, azim=0)
# Panel 4: Fitting info
ax4 = fig.add_subplot(2, 3, 4)
ax4.axis('off')
info_text = f"""
Fitting Information
-------------------
Template: {feedback.template_name}
Clearance: {feedback.clearance_mm} mm
Scale Factors:
Right: {feedback.scale_right:.3f}
Up: {feedback.scale_up:.3f}
Forward: {feedback.scale_forward:.3f}
Alignment:
Pelvis Distance: {feedback.pelvis_distance_mm:.2f} mm
Up Alignment: {feedback.up_alignment_dot:.4f}
Shell vs Body:
Width Margin: {feedback.shell_minus_body_width_mm:.1f} mm
Depth Margin: {feedback.shell_minus_body_depth_mm:.1f} mm
"""
ax4.text(0.1, 0.9, info_text, transform=ax4.transAxes, fontsize=10,
verticalalignment='top', fontfamily='monospace')
# Panel 5: Warnings
ax5 = fig.add_subplot(2, 3, 5)
ax5.axis('off')
warnings_text = "Warnings:\n" + ("\n".join(feedback.warnings) if feedback.warnings else "None")
ax5.text(0.1, 0.9, warnings_text, transform=ax5.transAxes, fontsize=10,
verticalalignment='top', color='orange' if feedback.warnings else 'green')
# Panel 6: Isometric view
ax6 = fig.add_subplot(2, 3, 6, projection='3d')
plot_mesh_silhouette(ax6, body_mesh, 'gray', alpha=0.3)
plot_mesh_silhouette(ax6, shell_mesh, 'blue', alpha=0.6)
ax6.set_title('Isometric View')
ax6.view_init(elev=20, azim=45)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
def plot_mesh_silhouette(ax, mesh, color, alpha=0.5):
"""Plot a simplified mesh representation."""
# Sample vertices for plotting
verts = mesh.vertices
if len(verts) > 5000:
indices = np.random.choice(len(verts), 5000, replace=False)
verts = verts[indices]
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2],
c=color, alpha=alpha, s=1)
# Set equal aspect ratio
max_range = np.max(mesh.extents) / 2
mid = mesh.centroid
ax.set_xlim(mid[0] - max_range, mid[0] + max_range)
ax.set_ylim(mid[1] - max_range, mid[1] + max_range)
ax.set_zlim(mid[2] - max_range, mid[2] + max_range)
def _json_serializer(obj):
"""JSON serializer for numpy types."""
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.float32, np.float64)):
return float(obj)
if isinstance(obj, (np.int32, np.int64)):
return int(obj)
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")