Files
braceiqmed/brace-generator/image_loader.py

116 lines
3.3 KiB
Python

"""
Image loader supporting JPEG, PNG, and DICOM formats.
"""
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
try:
import pydicom
HAS_PYDICOM = True
except ImportError:
HAS_PYDICOM = False
try:
from PIL import Image
HAS_PIL = True
except ImportError:
HAS_PIL = False
try:
import cv2
HAS_CV2 = True
except ImportError:
HAS_CV2 = False
def load_xray(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Load an X-ray image from file.
Supports: JPEG, PNG, BMP, DICOM (.dcm)
Args:
path: Path to the image file
Returns:
img_u8: Grayscale image as uint8 array (H, W)
spacing_mm: Pixel spacing [sx, sy] in mm, or None if not available
"""
path = Path(path)
suffix = path.suffix.lower()
# DICOM
if suffix in ['.dcm', '.dicom']:
if not HAS_PYDICOM:
raise ImportError("pydicom is required for DICOM files. Install with: pip install pydicom")
return _load_dicom(str(path))
# Standard image formats
if suffix in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
return _load_standard_image(str(path))
# Try to load as standard image anyway
try:
return _load_standard_image(str(path))
except Exception as e:
raise ValueError(f"Could not load image: {path}. Error: {e}")
def _load_dicom(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""Load DICOM file."""
ds = pydicom.dcmread(path)
arr = ds.pixel_array.astype(np.float32)
# Apply modality LUT if present
if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
arr = arr * ds.RescaleSlope + ds.RescaleIntercept
# Normalize to 0-255
arr = arr - arr.min()
if arr.max() > 0:
arr = arr / arr.max()
img_u8 = (arr * 255).astype(np.uint8)
# Get pixel spacing
spacing_mm = None
if hasattr(ds, 'PixelSpacing'):
# PixelSpacing is [row_spacing, col_spacing] in mm
sy, sx = [float(x) for x in ds.PixelSpacing]
spacing_mm = np.array([sx, sy], dtype=np.float32)
elif hasattr(ds, 'ImagerPixelSpacing'):
sy, sx = [float(x) for x in ds.ImagerPixelSpacing]
spacing_mm = np.array([sx, sy], dtype=np.float32)
return img_u8, spacing_mm
def _load_standard_image(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""Load standard image format (JPEG, PNG, etc.)."""
if HAS_CV2:
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if img is None:
raise ValueError(f"Could not read image: {path}")
return img.astype(np.uint8), None
elif HAS_PIL:
img = Image.open(path).convert('L') # Convert to grayscale
return np.array(img, dtype=np.uint8), None
else:
raise ImportError("Either opencv-python or Pillow is required. Install with: pip install opencv-python")
def load_xray_rgb(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Load X-ray as RGB (for models that expect 3-channel input).
Returns:
img_rgb: RGB image as uint8 array (H, W, 3)
spacing_mm: Pixel spacing or None
"""
img_gray, spacing = load_xray(path)
# Convert grayscale to RGB by stacking
img_rgb = np.stack([img_gray, img_gray, img_gray], axis=-1)
return img_rgb, spacing