116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
"""
|
|
Image loader supporting JPEG, PNG, and DICOM formats.
|
|
"""
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Tuple, Optional
|
|
|
|
try:
|
|
import pydicom
|
|
HAS_PYDICOM = True
|
|
except ImportError:
|
|
HAS_PYDICOM = False
|
|
|
|
try:
|
|
from PIL import Image
|
|
HAS_PIL = True
|
|
except ImportError:
|
|
HAS_PIL = False
|
|
|
|
try:
|
|
import cv2
|
|
HAS_CV2 = True
|
|
except ImportError:
|
|
HAS_CV2 = False
|
|
|
|
|
|
def load_xray(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
"""
|
|
Load an X-ray image from file.
|
|
|
|
Supports: JPEG, PNG, BMP, DICOM (.dcm)
|
|
|
|
Args:
|
|
path: Path to the image file
|
|
|
|
Returns:
|
|
img_u8: Grayscale image as uint8 array (H, W)
|
|
spacing_mm: Pixel spacing [sx, sy] in mm, or None if not available
|
|
"""
|
|
path = Path(path)
|
|
suffix = path.suffix.lower()
|
|
|
|
# DICOM
|
|
if suffix in ['.dcm', '.dicom']:
|
|
if not HAS_PYDICOM:
|
|
raise ImportError("pydicom is required for DICOM files. Install with: pip install pydicom")
|
|
return _load_dicom(str(path))
|
|
|
|
# Standard image formats
|
|
if suffix in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
|
return _load_standard_image(str(path))
|
|
|
|
# Try to load as standard image anyway
|
|
try:
|
|
return _load_standard_image(str(path))
|
|
except Exception as e:
|
|
raise ValueError(f"Could not load image: {path}. Error: {e}")
|
|
|
|
|
|
def _load_dicom(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
"""Load DICOM file."""
|
|
ds = pydicom.dcmread(path)
|
|
arr = ds.pixel_array.astype(np.float32)
|
|
|
|
# Apply modality LUT if present
|
|
if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
|
|
arr = arr * ds.RescaleSlope + ds.RescaleIntercept
|
|
|
|
# Normalize to 0-255
|
|
arr = arr - arr.min()
|
|
if arr.max() > 0:
|
|
arr = arr / arr.max()
|
|
img_u8 = (arr * 255).astype(np.uint8)
|
|
|
|
# Get pixel spacing
|
|
spacing_mm = None
|
|
if hasattr(ds, 'PixelSpacing'):
|
|
# PixelSpacing is [row_spacing, col_spacing] in mm
|
|
sy, sx = [float(x) for x in ds.PixelSpacing]
|
|
spacing_mm = np.array([sx, sy], dtype=np.float32)
|
|
elif hasattr(ds, 'ImagerPixelSpacing'):
|
|
sy, sx = [float(x) for x in ds.ImagerPixelSpacing]
|
|
spacing_mm = np.array([sx, sy], dtype=np.float32)
|
|
|
|
return img_u8, spacing_mm
|
|
|
|
|
|
def _load_standard_image(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
"""Load standard image format (JPEG, PNG, etc.)."""
|
|
if HAS_CV2:
|
|
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
|
if img is None:
|
|
raise ValueError(f"Could not read image: {path}")
|
|
return img.astype(np.uint8), None
|
|
elif HAS_PIL:
|
|
img = Image.open(path).convert('L') # Convert to grayscale
|
|
return np.array(img, dtype=np.uint8), None
|
|
else:
|
|
raise ImportError("Either opencv-python or Pillow is required. Install with: pip install opencv-python")
|
|
|
|
|
|
def load_xray_rgb(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
"""
|
|
Load X-ray as RGB (for models that expect 3-channel input).
|
|
|
|
Returns:
|
|
img_rgb: RGB image as uint8 array (H, W, 3)
|
|
spacing_mm: Pixel spacing or None
|
|
"""
|
|
img_gray, spacing = load_xray(path)
|
|
|
|
# Convert grayscale to RGB by stacking
|
|
img_rgb = np.stack([img_gray, img_gray, img_gray], axis=-1)
|
|
|
|
return img_rgb, spacing
|