From d28d2f20c6efa58bb58eed3579d3cbe3557efa19 Mon Sep 17 00:00:00 2001 From: Mo_Saghafian Date: Fri, 30 Jan 2026 01:51:33 -0800 Subject: [PATCH] Add patient management, deployment scripts, and Docker fixes --- api/.dockerignore | 7 + api/Dockerfile | 7 +- api/db/sqlite.js | 699 ++++++- api/server.js | 690 ++++++- brace-generator/Dockerfile | 22 +- brace-generator/adapters.py | 508 +++++ brace-generator/brace_surface.py | 354 ++++ brace-generator/data_models.py | 177 ++ brace-generator/image_loader.py | 115 ++ brace-generator/pipeline.py | 346 ++++ brace-generator/spine_analysis.py | 464 +++++ frontend/Dockerfile | 3 +- frontend/nginx.conf | 26 + frontend/src/App.tsx | 47 + frontend/src/api/adminApi.ts | 168 +- frontend/src/api/braceflowApi.ts | 36 +- frontend/src/api/patientApi.ts | 248 +++ frontend/src/components/AppShell.tsx | 2 + frontend/src/components/pipeline/pipeline.css | 44 + frontend/src/context/AuthContext.tsx | 2 +- frontend/src/pages/Dashboard.tsx | 128 +- frontend/src/pages/PatientDetail.tsx | 338 ++++ frontend/src/pages/PatientForm.tsx | 336 ++++ frontend/src/pages/PatientList.tsx | 264 +++ frontend/src/pages/PipelineCaseDetail.tsx | 30 +- frontend/src/pages/admin/AdminActivity.tsx | 746 ++++++-- frontend/src/pages/admin/AdminCases.tsx | 53 +- frontend/src/styles.css | 1678 +++++++++++++++++ frontend/vite.config.ts | 14 + scripts/deploy-to-server.ps1 | 68 + scripts/deploy-to-server.sh | 62 + scripts/update-local.ps1 | 60 + scripts/update-local.sh | 38 + 33 files changed, 7496 insertions(+), 284 deletions(-) create mode 100644 api/.dockerignore create mode 100644 brace-generator/adapters.py create mode 100644 brace-generator/brace_surface.py create mode 100644 brace-generator/data_models.py create mode 100644 brace-generator/image_loader.py create mode 100644 brace-generator/pipeline.py create mode 100644 brace-generator/spine_analysis.py create mode 100644 frontend/src/api/patientApi.ts create mode 100644 frontend/src/pages/PatientDetail.tsx create mode 100644 frontend/src/pages/PatientForm.tsx create mode 100644 frontend/src/pages/PatientList.tsx create mode 100644 scripts/deploy-to-server.ps1 create mode 100644 scripts/deploy-to-server.sh create mode 100644 scripts/update-local.ps1 create mode 100644 scripts/update-local.sh diff --git a/api/.dockerignore b/api/.dockerignore new file mode 100644 index 0000000..4808d35 --- /dev/null +++ b/api/.dockerignore @@ -0,0 +1,7 @@ +node_modules +npm-debug.log +.git +.gitignore +*.md +.env +.env.* diff --git a/api/Dockerfile b/api/Dockerfile index 4b3cd03..4c88d00 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -12,10 +12,11 @@ WORKDIR /app # Copy package files COPY package*.json ./ -# Install dependencies -RUN npm ci --only=production +# Install dependencies (rebuild native modules for Linux) +RUN npm ci --only=production && \ + npm rebuild better-sqlite3 -# Copy application code +# Copy application code (excluding node_modules via .dockerignore) COPY . . # Create data directories diff --git a/api/db/sqlite.js b/api/db/sqlite.js index 9eda87c..546dad0 100644 --- a/api/db/sqlite.js +++ b/api/db/sqlite.js @@ -18,10 +18,36 @@ db.pragma('journal_mode = WAL'); // Create tables db.exec(` - -- Main cases table + -- Patients table + CREATE TABLE IF NOT EXISTS patients ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + mrn TEXT UNIQUE, + first_name TEXT NOT NULL, + last_name TEXT NOT NULL, + date_of_birth TEXT, + gender TEXT CHECK(gender IN ('male', 'female', 'other')), + email TEXT, + phone TEXT, + address TEXT, + diagnosis TEXT, + curve_type TEXT, + medical_history TEXT, + referring_physician TEXT, + insurance_info TEXT, + notes TEXT, + is_active INTEGER NOT NULL DEFAULT 1, + created_by INTEGER, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')), + FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE SET NULL + ); + + -- Main cases table (linked to patients) CREATE TABLE IF NOT EXISTS brace_cases ( case_id TEXT PRIMARY KEY, + patient_id INTEGER, case_type TEXT NOT NULL DEFAULT 'braceflow', + visit_date TEXT DEFAULT (date('now')), status TEXT NOT NULL DEFAULT 'created' CHECK(status IN ( 'created', 'running', 'completed', 'failed', 'cancelled', 'processing_brace', 'brace_generated', 'brace_failed', @@ -38,9 +64,12 @@ db.exec(` body_scan_path TEXT DEFAULT NULL, body_scan_url TEXT DEFAULT NULL, body_scan_metadata TEXT DEFAULT NULL, + is_archived INTEGER NOT NULL DEFAULT 0, + archived_at TEXT DEFAULT NULL, created_by INTEGER DEFAULT NULL, created_at TEXT NOT NULL DEFAULT (datetime('now')), - updated_at TEXT NOT NULL DEFAULT (datetime('now')) + updated_at TEXT NOT NULL DEFAULT (datetime('now')), + FOREIGN KEY (patient_id) REFERENCES patients(id) ON DELETE SET NULL ); -- Case steps table @@ -109,7 +138,35 @@ db.exec(` FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE SET NULL ); + -- API request logging table (for tracking all HTTP API calls) + CREATE TABLE IF NOT EXISTS api_requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER DEFAULT NULL, + username TEXT DEFAULT NULL, + method TEXT NOT NULL, + path TEXT NOT NULL, + route_pattern TEXT DEFAULT NULL, + query_params TEXT DEFAULT NULL, + request_params TEXT DEFAULT NULL, + file_uploads TEXT DEFAULT NULL, + status_code INTEGER DEFAULT NULL, + response_time_ms INTEGER DEFAULT NULL, + response_summary TEXT DEFAULT NULL, + ip_address TEXT DEFAULT NULL, + user_agent TEXT DEFAULT NULL, + request_body_size INTEGER DEFAULT NULL, + response_body_size INTEGER DEFAULT NULL, + error_message TEXT DEFAULT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE SET NULL + ); + -- Create indexes + CREATE INDEX IF NOT EXISTS idx_patients_name ON patients(last_name, first_name); + CREATE INDEX IF NOT EXISTS idx_patients_mrn ON patients(mrn); + CREATE INDEX IF NOT EXISTS idx_patients_created_by ON patients(created_by); + CREATE INDEX IF NOT EXISTS idx_patients_active ON patients(is_active); + CREATE INDEX IF NOT EXISTS idx_cases_patient ON brace_cases(patient_id); CREATE INDEX IF NOT EXISTS idx_cases_status ON brace_cases(status); CREATE INDEX IF NOT EXISTS idx_cases_created ON brace_cases(created_at); CREATE INDEX IF NOT EXISTS idx_steps_case_id ON brace_case_steps(case_id); @@ -120,6 +177,11 @@ db.exec(` CREATE INDEX IF NOT EXISTS idx_audit_user ON audit_log(user_id); CREATE INDEX IF NOT EXISTS idx_audit_action ON audit_log(action); CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_log(created_at); + CREATE INDEX IF NOT EXISTS idx_api_requests_user ON api_requests(user_id); + CREATE INDEX IF NOT EXISTS idx_api_requests_path ON api_requests(path); + CREATE INDEX IF NOT EXISTS idx_api_requests_method ON api_requests(method); + CREATE INDEX IF NOT EXISTS idx_api_requests_status ON api_requests(status_code); + CREATE INDEX IF NOT EXISTS idx_api_requests_created ON api_requests(created_at); `); // Migration: Add new columns to existing tables @@ -143,6 +205,42 @@ try { db.exec(`ALTER TABLE brace_cases ADD COLUMN created_by INTEGER DEFAULT NULL`); } catch (e) { /* Column already exists */ } +// Migration: Add new columns to api_requests table for enhanced logging +try { + db.exec(`ALTER TABLE api_requests ADD COLUMN route_pattern TEXT DEFAULT NULL`); +} catch (e) { /* Column already exists */ } + +try { + db.exec(`ALTER TABLE api_requests ADD COLUMN request_params TEXT DEFAULT NULL`); +} catch (e) { /* Column already exists */ } + +try { + db.exec(`ALTER TABLE api_requests ADD COLUMN file_uploads TEXT DEFAULT NULL`); +} catch (e) { /* Column already exists */ } + +try { + db.exec(`ALTER TABLE api_requests ADD COLUMN response_summary TEXT DEFAULT NULL`); +} catch (e) { /* Column already exists */ } + +// Migration: Add patient_id to brace_cases +try { + db.exec(`ALTER TABLE brace_cases ADD COLUMN patient_id INTEGER DEFAULT NULL`); +} catch (e) { /* Column already exists */ } + +// Migration: Add visit_date to brace_cases +try { + db.exec(`ALTER TABLE brace_cases ADD COLUMN visit_date TEXT DEFAULT NULL`); +} catch (e) { /* Column already exists */ } + +// Migration: Add is_archived to brace_cases +try { + db.exec(`ALTER TABLE brace_cases ADD COLUMN is_archived INTEGER NOT NULL DEFAULT 0`); +} catch (e) { /* Column already exists */ } + +try { + db.exec(`ALTER TABLE brace_cases ADD COLUMN archived_at TEXT DEFAULT NULL`); +} catch (e) { /* Column already exists */ } + // Insert default admin user if not exists (password: admin123) // Note: In production, use proper bcrypt hashing. This is a simple hash for dev. try { @@ -168,12 +266,12 @@ const STEP_NAMES = [ ]; /** - * Create a new case + * Create a new case (optionally linked to a patient) */ -export function createCase(caseId, caseType = 'braceflow', notes = null) { +export function createCase(caseId, caseType = 'braceflow', notes = null, patientId = null, visitDate = null) { const insertCase = db.prepare(` - INSERT INTO brace_cases (case_id, case_type, status, notes, created_at, updated_at) - VALUES (?, ?, 'created', ?, datetime('now'), datetime('now')) + INSERT INTO brace_cases (case_id, patient_id, case_type, visit_date, status, notes, created_at, updated_at) + VALUES (?, ?, ?, ?, 'created', ?, datetime('now'), datetime('now')) `); const insertStep = db.prepare(` @@ -182,40 +280,108 @@ export function createCase(caseId, caseType = 'braceflow', notes = null) { `); const transaction = db.transaction(() => { - insertCase.run(caseId, caseType, notes); + insertCase.run(caseId, patientId, caseType, visitDate || new Date().toISOString().split('T')[0], notes); STEP_NAMES.forEach((stepName, idx) => { insertStep.run(caseId, stepName, idx + 1); }); }); transaction(); - return { caseId, status: 'created', steps: STEP_NAMES }; + return { caseId, patientId, status: 'created', steps: STEP_NAMES }; } /** - * List all cases + * List all cases with patient info + * @param {Object} options - Query options + * @param {boolean} options.includeArchived - Include archived cases (for admin view) + * @param {boolean} options.archivedOnly - Only show archived cases */ -export function listCases() { +export function listCases(options = {}) { + const { includeArchived = false, archivedOnly = false } = options; + + let whereClause = ''; + if (archivedOnly) { + whereClause = 'WHERE c.is_archived = 1'; + } else if (!includeArchived) { + whereClause = 'WHERE c.is_archived = 0'; + } + const stmt = db.prepare(` - SELECT case_id as caseId, case_type, status, current_step, notes, - analysis_result, landmarks_data, created_at, updated_at - FROM brace_cases - ORDER BY created_at DESC + SELECT c.case_id as caseId, c.patient_id, c.case_type, c.visit_date, c.status, c.current_step, c.notes, + c.analysis_result, c.landmarks_data, c.is_archived, c.archived_at, c.created_at, c.updated_at, + p.first_name as patient_first_name, p.last_name as patient_last_name, + p.mrn as patient_mrn + FROM brace_cases c + LEFT JOIN patients p ON c.patient_id = p.id + ${whereClause} + ORDER BY c.created_at DESC `); - return stmt.all(); + + const rows = stmt.all(); + + // Transform to include patient object + return rows.map(row => ({ + caseId: row.caseId, + patient_id: row.patient_id, + patient: row.patient_id ? { + id: row.patient_id, + firstName: row.patient_first_name, + lastName: row.patient_last_name, + fullName: `${row.patient_first_name} ${row.patient_last_name}`, + mrn: row.patient_mrn + } : null, + case_type: row.case_type, + visit_date: row.visit_date, + status: row.status, + current_step: row.current_step, + notes: row.notes, + analysis_result: row.analysis_result, + landmarks_data: row.landmarks_data, + is_archived: row.is_archived === 1, + archived_at: row.archived_at, + created_at: row.created_at, + updated_at: row.updated_at + })); } /** - * Get case by ID with steps + * Archive a case (soft delete) + */ +export function archiveCase(caseId) { + const stmt = db.prepare(` + UPDATE brace_cases + SET is_archived = 1, archived_at = datetime('now'), updated_at = datetime('now') + WHERE case_id = ? + `); + return stmt.run(caseId); +} + +/** + * Unarchive a case + */ +export function unarchiveCase(caseId) { + const stmt = db.prepare(` + UPDATE brace_cases + SET is_archived = 0, archived_at = NULL, updated_at = datetime('now') + WHERE case_id = ? + `); + return stmt.run(caseId); +} + +/** + * Get case by ID with steps and patient info */ export function getCase(caseId) { const caseStmt = db.prepare(` - SELECT case_id, case_type, status, current_step, notes, - analysis_result, landmarks_data, analysis_data, markers_data, - body_scan_path, body_scan_url, body_scan_metadata, - created_at, updated_at - FROM brace_cases - WHERE case_id = ? + SELECT c.case_id, c.patient_id, c.case_type, c.visit_date, c.status, c.current_step, c.notes, + c.analysis_result, c.landmarks_data, c.analysis_data, c.markers_data, + c.body_scan_path, c.body_scan_url, c.body_scan_metadata, + c.created_at, c.updated_at, + p.first_name as patient_first_name, p.last_name as patient_last_name, + p.mrn as patient_mrn, p.date_of_birth as patient_dob, p.gender as patient_gender + FROM brace_cases c + LEFT JOIN patients p ON c.patient_id = p.id + WHERE c.case_id = ? `); const stepsStmt = db.prepare(` @@ -267,9 +433,23 @@ export function getCase(caseId) { } } catch (e) { /* ignore */ } + // Build patient object if patient_id exists + const patient = caseData.patient_id ? { + id: caseData.patient_id, + firstName: caseData.patient_first_name, + lastName: caseData.patient_last_name, + fullName: `${caseData.patient_first_name} ${caseData.patient_last_name}`, + mrn: caseData.patient_mrn, + dateOfBirth: caseData.patient_dob, + gender: caseData.patient_gender + } : null; + return { caseId: caseData.case_id, + patient_id: caseData.patient_id, + patient, case_type: caseData.case_type, + visit_date: caseData.visit_date, status: caseData.status, current_step: caseData.current_step, notes: caseData.notes, @@ -458,6 +638,244 @@ export function updateStepStatus(caseId, stepName, status, errorMessage = null) return stmt.run(status, errorMessage, status, status, caseId, stepName); } +// ============================================ +// PATIENT MANAGEMENT +// ============================================ + +/** + * Create a new patient + */ +export function createPatient(data) { + const stmt = db.prepare(` + INSERT INTO patients ( + mrn, first_name, last_name, date_of_birth, gender, + email, phone, address, diagnosis, curve_type, + medical_history, referring_physician, insurance_info, notes, + created_by, created_at, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now')) + `); + + const result = stmt.run( + data.mrn || null, + data.firstName, + data.lastName, + data.dateOfBirth || null, + data.gender || null, + data.email || null, + data.phone || null, + data.address || null, + data.diagnosis || null, + data.curveType || null, + data.medicalHistory || null, + data.referringPhysician || null, + data.insuranceInfo || null, + data.notes || null, + data.createdBy || null + ); + + return { + id: result.lastInsertRowid, + ...data + }; +} + +/** + * Get patient by ID + */ +export function getPatient(patientId) { + const stmt = db.prepare(` + SELECT p.*, u.username as created_by_username + FROM patients p + LEFT JOIN users u ON p.created_by = u.id + WHERE p.id = ? + `); + return stmt.get(patientId); +} + +/** + * List all patients with optional filters + */ +export function listPatients(options = {}) { + const { search, isActive = true, createdBy, limit = 50, offset = 0, sortBy = 'created_at', sortOrder = 'DESC' } = options; + + let where = []; + let values = []; + + if (isActive !== undefined && isActive !== null) { + where.push('p.is_active = ?'); + values.push(isActive ? 1 : 0); + } + + if (createdBy) { + where.push('p.created_by = ?'); + values.push(createdBy); + } + + if (search) { + where.push('(p.first_name LIKE ? OR p.last_name LIKE ? OR p.mrn LIKE ? OR p.email LIKE ?)'); + const searchPattern = `%${search}%`; + values.push(searchPattern, searchPattern, searchPattern, searchPattern); + } + + const whereClause = where.length > 0 ? `WHERE ${where.join(' AND ')}` : ''; + const validSortColumns = ['created_at', 'updated_at', 'last_name', 'first_name', 'date_of_birth']; + const sortColumn = validSortColumns.includes(sortBy) ? sortBy : 'created_at'; + const order = sortOrder.toUpperCase() === 'ASC' ? 'ASC' : 'DESC'; + + // Get total count + const countStmt = db.prepare(`SELECT COUNT(*) as count FROM patients p ${whereClause}`); + const totalCount = countStmt.get(...values).count; + + // Get patients with case count + const stmt = db.prepare(` + SELECT p.*, + u.username as created_by_username, + (SELECT COUNT(*) FROM brace_cases c WHERE c.patient_id = p.id) as case_count, + (SELECT MAX(c.created_at) FROM brace_cases c WHERE c.patient_id = p.id) as last_visit + FROM patients p + LEFT JOIN users u ON p.created_by = u.id + ${whereClause} + ORDER BY p.${sortColumn} ${order} + LIMIT ? OFFSET ? + `); + + const patients = stmt.all(...values, limit, offset); + + return { + patients, + total: totalCount, + limit, + offset + }; +} + +/** + * Update patient + */ +export function updatePatient(patientId, data) { + const fields = []; + const values = []; + + if (data.mrn !== undefined) { fields.push('mrn = ?'); values.push(data.mrn); } + if (data.firstName !== undefined) { fields.push('first_name = ?'); values.push(data.firstName); } + if (data.lastName !== undefined) { fields.push('last_name = ?'); values.push(data.lastName); } + if (data.dateOfBirth !== undefined) { fields.push('date_of_birth = ?'); values.push(data.dateOfBirth); } + if (data.gender !== undefined) { fields.push('gender = ?'); values.push(data.gender); } + if (data.email !== undefined) { fields.push('email = ?'); values.push(data.email); } + if (data.phone !== undefined) { fields.push('phone = ?'); values.push(data.phone); } + if (data.address !== undefined) { fields.push('address = ?'); values.push(data.address); } + if (data.diagnosis !== undefined) { fields.push('diagnosis = ?'); values.push(data.diagnosis); } + if (data.curveType !== undefined) { fields.push('curve_type = ?'); values.push(data.curveType); } + if (data.medicalHistory !== undefined) { fields.push('medical_history = ?'); values.push(data.medicalHistory); } + if (data.referringPhysician !== undefined) { fields.push('referring_physician = ?'); values.push(data.referringPhysician); } + if (data.insuranceInfo !== undefined) { fields.push('insurance_info = ?'); values.push(data.insuranceInfo); } + if (data.notes !== undefined) { fields.push('notes = ?'); values.push(data.notes); } + if (data.isActive !== undefined) { fields.push('is_active = ?'); values.push(data.isActive ? 1 : 0); } + + if (fields.length === 0) return null; + + fields.push('updated_at = datetime(\'now\')'); + values.push(patientId); + + const stmt = db.prepare(`UPDATE patients SET ${fields.join(', ')} WHERE id = ?`); + return stmt.run(...values); +} + +/** + * Archive patient (soft delete - set is_active = 0) + */ +export function archivePatient(patientId) { + const stmt = db.prepare(` + UPDATE patients + SET is_active = 0, updated_at = datetime('now') + WHERE id = ? + `); + return stmt.run(patientId); +} + +/** + * Unarchive patient (restore - set is_active = 1) + */ +export function unarchivePatient(patientId) { + const stmt = db.prepare(` + UPDATE patients + SET is_active = 1, updated_at = datetime('now') + WHERE id = ? + `); + return stmt.run(patientId); +} + +/** + * Delete patient - kept for backwards compatibility, now archives + */ +export function deletePatient(patientId, hard = false) { + if (hard) { + // Hard delete should never be used in normal operation + const stmt = db.prepare(`DELETE FROM patients WHERE id = ?`); + return stmt.run(patientId); + } else { + return archivePatient(patientId); + } +} + +/** + * Get cases for a patient + * @param {number} patientId - Patient ID + * @param {Object} options - Query options + * @param {boolean} options.includeArchived - Include archived cases + */ +export function getPatientCases(patientId, options = {}) { + const { includeArchived = false } = options; + + const archivedFilter = includeArchived ? '' : 'AND is_archived = 0'; + + const stmt = db.prepare(` + SELECT case_id, case_type, status, current_step, visit_date, notes, + analysis_result, landmarks_data, body_scan_path, body_scan_url, + is_archived, archived_at, created_at, updated_at + FROM brace_cases + WHERE patient_id = ? ${archivedFilter} + ORDER BY created_at DESC + `); + return stmt.all(patientId); +} + +/** + * Get patient statistics + */ +export function getPatientStats() { + const total = db.prepare(`SELECT COUNT(*) as count FROM patients`).get(); + const active = db.prepare(`SELECT COUNT(*) as count FROM patients WHERE is_active = 1`).get(); + const withCases = db.prepare(` + SELECT COUNT(DISTINCT patient_id) as count + FROM brace_cases + WHERE patient_id IS NOT NULL + `).get(); + + const byGender = db.prepare(` + SELECT gender, COUNT(*) as count + FROM patients + WHERE is_active = 1 + GROUP BY gender + `).all(); + + const recentPatients = db.prepare(` + SELECT COUNT(*) as count + FROM patients + WHERE created_at >= datetime('now', '-30 days') + `).get(); + + return { + total: total.count, + active: active.count, + inactive: total.count - active.count, + withCases: withCases.count, + byGender: byGender.reduce((acc, row) => { acc[row.gender || 'unspecified'] = row.count; return acc; }, {}), + recentPatients: recentPatients.count + }; +} + // ============================================ // USER MANAGEMENT // ============================================ @@ -631,6 +1049,210 @@ export function getAuditLog(options = {}) { return stmt.all(...values, limit, offset); } +// ============================================ +// API REQUEST LOGGING +// ============================================ + +/** + * Log an API request with full details + */ +export function logApiRequest(data) { + const stmt = db.prepare(` + INSERT INTO api_requests ( + user_id, username, method, path, route_pattern, query_params, + request_params, file_uploads, status_code, response_time_ms, + response_summary, ip_address, user_agent, request_body_size, + response_body_size, error_message, created_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now')) + `); + + return stmt.run( + data.userId || null, + data.username || null, + data.method, + data.path, + data.routePattern || null, + data.queryParams ? JSON.stringify(data.queryParams) : null, + data.requestParams ? JSON.stringify(data.requestParams) : null, + data.fileUploads ? JSON.stringify(data.fileUploads) : null, + data.statusCode || null, + data.responseTimeMs || null, + data.responseSummary ? JSON.stringify(data.responseSummary) : null, + data.ipAddress || null, + data.userAgent || null, + data.requestBodySize || null, + data.responseBodySize || null, + data.errorMessage || null + ); +} + +/** + * Get API request logs with filters + */ +export function getApiRequests(options = {}) { + const { + userId, + username, + method, + path, + statusCode, + minStatusCode, + maxStatusCode, + startDate, + endDate, + limit = 100, + offset = 0 + } = options; + + let where = []; + let values = []; + + if (userId) { where.push('user_id = ?'); values.push(userId); } + if (username) { where.push('username LIKE ?'); values.push(`%${username}%`); } + if (method) { where.push('method = ?'); values.push(method); } + if (path) { where.push('path LIKE ?'); values.push(`%${path}%`); } + if (statusCode) { where.push('status_code = ?'); values.push(statusCode); } + if (minStatusCode) { where.push('status_code >= ?'); values.push(minStatusCode); } + if (maxStatusCode) { where.push('status_code < ?'); values.push(maxStatusCode); } + if (startDate) { where.push('created_at >= ?'); values.push(startDate); } + if (endDate) { where.push('created_at <= ?'); values.push(endDate); } + + const whereClause = where.length > 0 ? `WHERE ${where.join(' AND ')}` : ''; + + // Get total count + const countStmt = db.prepare(`SELECT COUNT(*) as count FROM api_requests ${whereClause}`); + const totalCount = countStmt.get(...values).count; + + // Get paginated results + const stmt = db.prepare(` + SELECT * + FROM api_requests + ${whereClause} + ORDER BY created_at DESC + LIMIT ? OFFSET ? + `); + + const requests = stmt.all(...values, limit, offset); + + return { + requests, + total: totalCount, + limit, + offset + }; +} + +/** + * Get API request statistics + */ +export function getApiRequestStats(options = {}) { + const { startDate, endDate } = options; + + let where = []; + let values = []; + + if (startDate) { where.push('created_at >= ?'); values.push(startDate); } + if (endDate) { where.push('created_at <= ?'); values.push(endDate); } + + const whereClause = where.length > 0 ? `WHERE ${where.join(' AND ')}` : ''; + + // Total requests + const total = db.prepare(`SELECT COUNT(*) as count FROM api_requests ${whereClause}`).get(...values); + + // By method + const byMethod = db.prepare(` + SELECT method, COUNT(*) as count + FROM api_requests ${whereClause} + GROUP BY method ORDER BY count DESC + `).all(...values); + + // By status code category + const byStatusCategory = db.prepare(` + SELECT + CASE + WHEN status_code >= 200 AND status_code < 300 THEN '2xx Success' + WHEN status_code >= 300 AND status_code < 400 THEN '3xx Redirect' + WHEN status_code >= 400 AND status_code < 500 THEN '4xx Client Error' + WHEN status_code >= 500 THEN '5xx Server Error' + ELSE 'Unknown' + END as category, + COUNT(*) as count + FROM api_requests ${whereClause} + GROUP BY category ORDER BY count DESC + `).all(...values); + + // Top endpoints + const topEndpoints = db.prepare(` + SELECT method, path, COUNT(*) as count, + AVG(response_time_ms) as avg_response_time + FROM api_requests ${whereClause} + GROUP BY method, path + ORDER BY count DESC + LIMIT 20 + `).all(...values); + + // Top users + const topUsers = db.prepare(` + SELECT user_id, username, COUNT(*) as count + FROM api_requests + ${whereClause ? whereClause + ' AND username IS NOT NULL' : 'WHERE username IS NOT NULL'} + GROUP BY user_id, username + ORDER BY count DESC + LIMIT 10 + `).all(...values); + + // Average response time + const avgResponseTime = db.prepare(` + SELECT AVG(response_time_ms) as avg, + MIN(response_time_ms) as min, + MAX(response_time_ms) as max + FROM api_requests + ${whereClause ? whereClause + ' AND response_time_ms IS NOT NULL' : 'WHERE response_time_ms IS NOT NULL'} + `).get(...values); + + // Requests per hour (last 24 hours) + const requestsPerHour = db.prepare(` + SELECT strftime('%Y-%m-%d %H:00', created_at) as hour, COUNT(*) as count + FROM api_requests + WHERE created_at >= datetime('now', '-24 hours') + GROUP BY hour + ORDER BY hour ASC + `).all(); + + // Error rate + const errors = db.prepare(` + SELECT COUNT(*) as count FROM api_requests + ${whereClause ? whereClause + ' AND status_code >= 400' : 'WHERE status_code >= 400'} + `).get(...values); + + return { + total: total.count, + byMethod: byMethod.reduce((acc, row) => { acc[row.method] = row.count; return acc; }, {}), + byStatusCategory: byStatusCategory.reduce((acc, row) => { acc[row.category] = row.count; return acc; }, {}), + topEndpoints, + topUsers, + responseTime: { + avg: Math.round(avgResponseTime?.avg || 0), + min: avgResponseTime?.min || 0, + max: avgResponseTime?.max || 0 + }, + requestsPerHour, + errorRate: total.count > 0 ? Math.round((errors.count / total.count) * 100 * 10) / 10 : 0 + }; +} + +/** + * Cleanup old API request logs (older than N days) + */ +export function cleanupOldApiRequests(daysToKeep = 30) { + const stmt = db.prepare(` + DELETE FROM api_requests + WHERE created_at < datetime('now', '-' || ? || ' days') + `); + return stmt.run(daysToKeep); +} + // ============================================ // ANALYTICS QUERIES // ============================================ @@ -819,7 +1441,7 @@ export function getUserStats() { * List cases with filters (for admin) */ export function listCasesFiltered(options = {}) { - const { status, createdBy, search, limit = 50, offset = 0, sortBy = 'created_at', sortOrder = 'DESC' } = options; + const { status, createdBy, search, limit = 50, offset = 0, sortBy = 'created_at', sortOrder = 'DESC', includeArchived = false, archivedOnly = false } = options; let where = []; let values = []; @@ -828,6 +1450,13 @@ export function listCasesFiltered(options = {}) { if (createdBy) { where.push('c.created_by = ?'); values.push(createdBy); } if (search) { where.push('c.case_id LIKE ?'); values.push(`%${search}%`); } + // Archive filtering + if (archivedOnly) { + where.push('c.is_archived = 1'); + } else if (!includeArchived) { + where.push('c.is_archived = 0'); + } + const whereClause = where.length > 0 ? `WHERE ${where.join(' AND ')}` : ''; const validSortColumns = ['created_at', 'updated_at', 'status', 'case_id']; const sortColumn = validSortColumns.includes(sortBy) ? sortBy : 'created_at'; @@ -839,6 +1468,7 @@ export function listCasesFiltered(options = {}) { const stmt = db.prepare(` SELECT c.case_id as caseId, c.case_type, c.status, c.current_step, c.notes, c.analysis_result, c.landmarks_data, c.body_scan_path, + c.is_archived, c.archived_at, c.created_by, c.created_at, c.updated_at, u.username as created_by_username FROM brace_cases c @@ -848,7 +1478,10 @@ export function listCasesFiltered(options = {}) { LIMIT ? OFFSET ? `); - const cases = stmt.all(...values, limit, offset); + const cases = stmt.all(...values, limit, offset).map(row => ({ + ...row, + is_archived: row.is_archived === 1 + })); return { cases, @@ -859,6 +1492,7 @@ export function listCasesFiltered(options = {}) { } export default { + // Case management createCase, listCases, listCasesFiltered, @@ -874,8 +1508,20 @@ export default { saveBodyScan, clearBodyScan, deleteCase, + archiveCase, + unarchiveCase, updateStepStatus, STEP_NAMES, + // Patient management + createPatient, + getPatient, + listPatients, + updatePatient, + deletePatient, + archivePatient, + unarchivePatient, + getPatientCases, + getPatientStats, // User management getUserByUsername, getUserById, @@ -892,6 +1538,11 @@ export default { // Audit logging logAudit, getAuditLog, + // API request logging + logApiRequest, + getApiRequests, + getApiRequestStats, + cleanupOldApiRequests, // Analytics getCaseStats, getRigoDistribution, diff --git a/api/server.js b/api/server.js index f266266..a317e9a 100644 --- a/api/server.js +++ b/api/server.js @@ -92,6 +92,239 @@ app.use('/files', express.static(DATA_DIR, { } })); +// ============================================ +// API REQUEST LOGGING MIDDLEWARE +// ============================================ + +/** + * Sanitize request parameters - remove sensitive data but keep structure + */ +function sanitizeParams(body, sensitiveKeys = ['password', 'token', 'secret', 'apiKey', 'authorization']) { + if (!body || typeof body !== 'object') return null; + + const sanitized = {}; + for (const [key, value] of Object.entries(body)) { + // Skip very large values (like base64 images) + if (typeof value === 'string' && value.length > 500) { + sanitized[key] = `[String: ${value.length} chars]`; + } else if (sensitiveKeys.some(sk => key.toLowerCase().includes(sk.toLowerCase()))) { + sanitized[key] = '[REDACTED]'; + } else if (typeof value === 'object' && value !== null) { + if (Array.isArray(value)) { + sanitized[key] = `[Array: ${value.length} items]`; + } else { + sanitized[key] = sanitizeParams(value, sensitiveKeys); + } + } else { + sanitized[key] = value; + } + } + return Object.keys(sanitized).length > 0 ? sanitized : null; +} + +/** + * Extract file upload information + */ +function extractFileInfo(req) { + const files = []; + + // Single file upload (req.file from multer) + if (req.file) { + files.push({ + fieldname: req.file.fieldname, + originalname: req.file.originalname, + mimetype: req.file.mimetype, + size: req.file.size, + destination: req.file.destination?.replace(/\\/g, '/').split('/').slice(-2).join('/'), // Last 2 dirs only + filename: req.file.filename + }); + } + + // Multiple files upload (req.files from multer) + if (req.files) { + const fileList = Array.isArray(req.files) ? req.files : Object.values(req.files).flat(); + for (const file of fileList) { + files.push({ + fieldname: file.fieldname, + originalname: file.originalname, + mimetype: file.mimetype, + size: file.size, + destination: file.destination?.replace(/\\/g, '/').split('/').slice(-2).join('/'), + filename: file.filename + }); + } + } + + return files.length > 0 ? files : null; +} + +/** + * Extract response summary - key fields from response body + */ +function extractResponseSummary(body, statusCode) { + if (!body || typeof body !== 'object') return null; + + const summary = {}; + + // Common success indicators + if (body.success !== undefined) summary.success = body.success; + if (body.message) summary.message = body.message.substring(0, 200); + if (body.error) summary.error = typeof body.error === 'string' ? body.error.substring(0, 200) : 'Error object'; + + // Case-related + if (body.caseId) summary.caseId = body.caseId; + if (body.case_id) summary.caseId = body.case_id; + if (body.status) summary.status = body.status; + + // User-related + if (body.user?.id) summary.userId = body.user.id; + if (body.user?.username) summary.username = body.user.username; + if (body.token) summary.tokenGenerated = true; + + // Analysis/brace results + if (body.rigoType || body.rigo_classification) { + summary.rigoType = body.rigoType || body.rigo_classification?.type; + } + if (body.cobb_angles || body.cobbAngles) { + const angles = body.cobb_angles || body.cobbAngles; + summary.cobbAngles = { PT: angles.PT, MT: angles.MT, TL: angles.TL }; + } + if (body.vertebrae_detected) summary.vertebraeDetected = body.vertebrae_detected; + + // Brace outputs + if (body.braces) { + summary.bracesGenerated = { + regular: !!body.braces.regular, + vase: !!body.braces.vase + }; + } + if (body.brace) { + summary.braceGenerated = true; + if (body.brace.vertices) summary.braceVertices = body.brace.vertices; + } + + // File outputs + if (body.glbUrl || body.stlUrl || body.url) { + summary.filesGenerated = []; + if (body.glbUrl) summary.filesGenerated.push('GLB'); + if (body.stlUrl) summary.filesGenerated.push('STL'); + if (body.url) summary.outputUrl = body.url.split('/').slice(-2).join('/'); + } + + // Landmarks + if (body.landmarks) { + summary.landmarksCount = Array.isArray(body.landmarks) ? body.landmarks.length : 'object'; + } + + // List responses + if (body.cases && Array.isArray(body.cases)) summary.casesCount = body.cases.length; + if (body.users && Array.isArray(body.users)) summary.usersCount = body.users.length; + if (body.entries && Array.isArray(body.entries)) summary.entriesCount = body.entries.length; + if (body.requests && Array.isArray(body.requests)) summary.requestsCount = body.requests.length; + if (body.total !== undefined) summary.total = body.total; + + // Body scan + if (body.body_scan_url) summary.bodyScanUploaded = true; + if (body.measurements || body.body_measurements) { + summary.measurementsExtracted = true; + } + + // Error responses + if (statusCode >= 400) { + summary.errorCode = statusCode; + } + + return Object.keys(summary).length > 0 ? summary : null; +} + +/** + * Get route pattern from request (e.g., /api/cases/:caseId) + */ +function getRoutePattern(req) { + // Express stores the matched route in req.route + if (req.route && req.route.path) { + return req.baseUrl + req.route.path; + } + // Fallback: replace common ID patterns + return req.path + .replace(/\/case-[\w-]+/g, '/:caseId') + .replace(/\/\d+/g, '/:id'); +} + +// Logs all API requests for the activity page +app.use('/api', (req, res, next) => { + const startTime = Date.now(); + + // Capture original functions + const originalEnd = res.end; + const originalJson = res.json; + let responseBody = null; + let responseBodySize = 0; + + // Override res.json to capture response body + res.json = function(body) { + responseBody = body; + if (body) { + try { + responseBodySize = JSON.stringify(body).length; + } catch (e) { /* ignore */ } + } + return originalJson.call(this, body); + }; + + // Override res.end to log the request after it completes + res.end = function(chunk, encoding) { + const responseTime = Date.now() - startTime; + + // Calculate request body size + let requestBodySize = 0; + if (req.body && Object.keys(req.body).length > 0) { + try { + requestBodySize = JSON.stringify(req.body).length; + } catch (e) { /* ignore */ } + } + + // Get user info from req.user (set by authMiddleware) + let userId = req.user?.id || null; + let username = req.user?.username || null; + + // Skip logging for health check and static files to reduce noise + const skipPaths = ['/api/health', '/api/favicon.ico']; + const shouldLog = !skipPaths.includes(req.path) && !req.path.startsWith('/files'); + + if (shouldLog) { + // Log asynchronously to not block response + setImmediate(() => { + try { + db.logApiRequest({ + userId, + username, + method: req.method, + path: req.path, + routePattern: getRoutePattern(req), + queryParams: Object.keys(req.query).length > 0 ? req.query : null, + requestParams: sanitizeParams(req.body), + fileUploads: extractFileInfo(req), + statusCode: res.statusCode, + responseTimeMs: responseTime, + responseSummary: extractResponseSummary(responseBody, res.statusCode), + ipAddress: req.ip || req.connection?.remoteAddress, + userAgent: req.get('User-Agent'), + requestBodySize, + responseBodySize + }); + } catch (e) { + console.error('Failed to log API request:', e.message); + } + }); + } + + return originalEnd.call(this, chunk, encoding); + }; + + next(); +}); + // File upload configuration const storage = multer.diskStorage({ destination: (req, file, cb) => { @@ -155,10 +388,21 @@ app.post('/api/cases', (req, res) => { /** * List all cases * GET /api/cases + * Query params: + * - includeArchived: boolean - Include archived cases (admin only) + * - archivedOnly: boolean - Show only archived cases (admin only) */ app.get('/api/cases', (req, res) => { try { - const cases = db.listCases(); + const { includeArchived, archivedOnly } = req.query; + + // Parse boolean query params + const options = { + includeArchived: includeArchived === 'true', + archivedOnly: archivedOnly === 'true' + }; + + const cases = db.listCases(options); res.json(cases); } catch (err) { console.error('List cases error:', err); @@ -1497,28 +1741,79 @@ app.post('/api/cases/:caseId/skip-body-scan', (req, res) => { // ============================================== /** - * Delete case - * DELETE /api/cases/:caseId + * Archive case (soft delete - keeps all files) + * POST /api/cases/:caseId/archive */ -app.delete('/api/cases/:caseId', (req, res) => { +app.post('/api/cases/:caseId/archive', authMiddleware, (req, res) => { try { const { caseId } = req.params; - // Delete from database - db.deleteCase(caseId); - - // Delete files - const uploadDir = path.join(UPLOADS_DIR, caseId); - const outputDir = path.join(OUTPUTS_DIR, caseId); - - if (fs.existsSync(uploadDir)) { - fs.rmSync(uploadDir, { recursive: true }); - } - if (fs.existsSync(outputDir)) { - fs.rmSync(outputDir, { recursive: true }); + const caseData = db.getCase(caseId); + if (!caseData) { + return res.status(404).json({ message: 'Case not found' }); } - res.json({ caseId, deleted: true }); + // Archive the case (soft delete - no files are deleted) + db.archiveCase(caseId); + + // Log the archive action + db.logAudit(req.user?.id, 'case_archived', 'brace_cases', caseId, null, { archived: true }); + + res.json({ caseId, archived: true, message: 'Case archived successfully' }); + } catch (err) { + console.error('Archive case error:', err); + res.status(500).json({ message: 'Failed to archive case', error: err.message }); + } +}); + +/** + * Unarchive case (restore) + * POST /api/cases/:caseId/unarchive + */ +app.post('/api/cases/:caseId/unarchive', authMiddleware, (req, res) => { + try { + const { caseId } = req.params; + + const caseData = db.getCase(caseId); + if (!caseData) { + return res.status(404).json({ message: 'Case not found' }); + } + + // Unarchive the case + db.unarchiveCase(caseId); + + // Log the unarchive action + db.logAudit(req.user?.id, 'case_unarchived', 'brace_cases', caseId, { archived: true }, { archived: false }); + + res.json({ caseId, archived: false, message: 'Case restored successfully' }); + } catch (err) { + console.error('Unarchive case error:', err); + res.status(500).json({ message: 'Failed to unarchive case', error: err.message }); + } +}); + +/** + * Delete case - DEPRECATED: Use archive instead + * DELETE /api/cases/:caseId + * This endpoint now archives instead of deleting to preserve data + */ +app.delete('/api/cases/:caseId', authMiddleware, (req, res) => { + try { + const { caseId } = req.params; + + const caseData = db.getCase(caseId); + if (!caseData) { + return res.status(404).json({ message: 'Case not found' }); + } + + // Archive instead of delete (preserves all files) + db.archiveCase(caseId); + + // Log the archive action + db.logAudit(req.user?.id, 'case_archived', 'brace_cases', caseId, null, { archived: true }); + + // Return deleted: true for backwards compatibility + res.json({ caseId, deleted: true, archived: true, message: 'Case archived (files preserved)' }); } catch (err) { console.error('Delete case error:', err); res.status(500).json({ message: 'Failed to delete case', error: err.message }); @@ -1613,6 +1908,249 @@ app.get('/api/cases/:caseId/assets', (req, res) => { } }); +// ============================================ +// PATIENT API +// ============================================ + +/** + * Create a new patient + * POST /api/patients + */ +app.post('/api/patients', authMiddleware, (req, res) => { + try { + const { + mrn, firstName, lastName, dateOfBirth, gender, + email, phone, address, diagnosis, curveType, + medicalHistory, referringPhysician, insuranceInfo, notes + } = req.body; + + if (!firstName || !lastName) { + return res.status(400).json({ message: 'First name and last name are required' }); + } + + const patient = db.createPatient({ + mrn, + firstName, + lastName, + dateOfBirth, + gender, + email, + phone, + address, + diagnosis, + curveType, + medicalHistory, + referringPhysician, + insuranceInfo, + notes, + createdBy: req.user?.id + }); + + db.logAudit(req.user?.id, 'create_patient', 'patient', patient.id.toString(), + { firstName, lastName, mrn }, req.ip); + + res.status(201).json({ patient }); + } catch (err) { + console.error('Create patient error:', err); + res.status(500).json({ message: 'Failed to create patient', error: err.message }); + } +}); + +/** + * List patients + * GET /api/patients + */ +app.get('/api/patients', authMiddleware, (req, res) => { + try { + const { search, isActive, limit = 50, offset = 0, sortBy, sortOrder } = req.query; + + const result = db.listPatients({ + search, + isActive: isActive === 'false' ? false : (isActive === 'all' ? null : true), + limit: parseInt(limit), + offset: parseInt(offset), + sortBy, + sortOrder + }); + + res.json(result); + } catch (err) { + console.error('List patients error:', err); + res.status(500).json({ message: 'Failed to list patients', error: err.message }); + } +}); + +/** + * Get patient by ID + * GET /api/patients/:patientId + */ +app.get('/api/patients/:patientId', authMiddleware, (req, res) => { + try { + const { patientId } = req.params; + const { includeArchivedCases } = req.query; + + const patient = db.getPatient(parseInt(patientId)); + + if (!patient) { + return res.status(404).json({ message: 'Patient not found' }); + } + + // Get patient's cases (filter archived unless explicitly requested) + const cases = db.getPatientCases(parseInt(patientId), { + includeArchived: includeArchivedCases === 'true' + }); + + res.json({ patient, cases }); + } catch (err) { + console.error('Get patient error:', err); + res.status(500).json({ message: 'Failed to get patient', error: err.message }); + } +}); + +/** + * Update patient + * PUT /api/patients/:patientId + */ +app.put('/api/patients/:patientId', authMiddleware, (req, res) => { + try { + const { patientId } = req.params; + const patient = db.getPatient(parseInt(patientId)); + + if (!patient) { + return res.status(404).json({ message: 'Patient not found' }); + } + + const updateData = req.body; + db.updatePatient(parseInt(patientId), updateData); + + db.logAudit(req.user?.id, 'update_patient', 'patient', patientId, updateData, req.ip); + + const updatedPatient = db.getPatient(parseInt(patientId)); + res.json({ patient: updatedPatient }); + } catch (err) { + console.error('Update patient error:', err); + res.status(500).json({ message: 'Failed to update patient', error: err.message }); + } +}); + +/** + * Archive patient (soft delete - preserves all data) + * POST /api/patients/:patientId/archive + */ +app.post('/api/patients/:patientId/archive', authMiddleware, (req, res) => { + try { + const { patientId } = req.params; + + const patient = db.getPatient(parseInt(patientId)); + if (!patient) { + return res.status(404).json({ message: 'Patient not found' }); + } + + db.archivePatient(parseInt(patientId)); + + db.logAudit(req.user?.id, 'patient_archived', 'patient', patientId, + { firstName: patient.first_name, lastName: patient.last_name }, { archived: true }); + + res.json({ patientId: parseInt(patientId), archived: true, message: 'Patient archived successfully' }); + } catch (err) { + console.error('Archive patient error:', err); + res.status(500).json({ message: 'Failed to archive patient', error: err.message }); + } +}); + +/** + * Unarchive patient (restore) + * POST /api/patients/:patientId/unarchive + */ +app.post('/api/patients/:patientId/unarchive', authMiddleware, (req, res) => { + try { + const { patientId } = req.params; + + const patient = db.getPatient(parseInt(patientId)); + if (!patient) { + return res.status(404).json({ message: 'Patient not found' }); + } + + db.unarchivePatient(parseInt(patientId)); + + db.logAudit(req.user?.id, 'patient_unarchived', 'patient', patientId, + { archived: true }, { firstName: patient.first_name, lastName: patient.last_name, archived: false }); + + res.json({ patientId: parseInt(patientId), archived: false, message: 'Patient restored successfully' }); + } catch (err) { + console.error('Unarchive patient error:', err); + res.status(500).json({ message: 'Failed to unarchive patient', error: err.message }); + } +}); + +/** + * Delete patient - DEPRECATED: Use archive instead + * DELETE /api/patients/:patientId + * This endpoint now archives instead of deleting to preserve data + */ +app.delete('/api/patients/:patientId', authMiddleware, (req, res) => { + try { + const { patientId } = req.params; + + const patient = db.getPatient(parseInt(patientId)); + if (!patient) { + return res.status(404).json({ message: 'Patient not found' }); + } + + // Archive instead of delete (preserves all data) + db.archivePatient(parseInt(patientId)); + + db.logAudit(req.user?.id, 'patient_archived', 'patient', patientId, + { firstName: patient.first_name, lastName: patient.last_name }, { archived: true }); + + res.json({ message: 'Patient archived successfully', archived: true }); + } catch (err) { + console.error('Delete patient error:', err); + res.status(500).json({ message: 'Failed to archive patient', error: err.message }); + } +}); + +/** + * Create a case for a patient + * POST /api/patients/:patientId/cases + */ +app.post('/api/patients/:patientId/cases', authMiddleware, (req, res) => { + try { + const { patientId } = req.params; + const { notes, visitDate } = req.body; + + const patient = db.getPatient(parseInt(patientId)); + if (!patient) { + return res.status(404).json({ message: 'Patient not found' }); + } + + const caseId = `case-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`; + const result = db.createCase(caseId, 'braceflow', notes, parseInt(patientId), visitDate); + + db.logAudit(req.user?.id, 'create_case', 'case', caseId, + { patientId, patientName: `${patient.first_name} ${patient.last_name}` }, req.ip); + + res.status(201).json(result); + } catch (err) { + console.error('Create patient case error:', err); + res.status(500).json({ message: 'Failed to create case', error: err.message }); + } +}); + +/** + * Get patient statistics + * GET /api/patients/stats + */ +app.get('/api/patients-stats', authMiddleware, (req, res) => { + try { + const stats = db.getPatientStats(); + res.json({ stats }); + } catch (err) { + console.error('Get patient stats error:', err); + res.status(500).json({ message: 'Failed to get patient stats', error: err.message }); + } +}); + // ============================================ // AUTHENTICATION API // ============================================ @@ -1899,7 +2437,7 @@ app.delete('/api/admin/users/:userId', authMiddleware, adminMiddleware, (req, re */ app.get('/api/admin/cases', authMiddleware, adminMiddleware, (req, res) => { try { - const { status, createdBy, search, limit = 50, offset = 0, sortBy, sortOrder } = req.query; + const { status, createdBy, search, limit = 50, offset = 0, sortBy, sortOrder, includeArchived, archivedOnly } = req.query; const result = db.listCasesFiltered({ status, @@ -1908,7 +2446,9 @@ app.get('/api/admin/cases', authMiddleware, adminMiddleware, (req, res) => { limit: parseInt(limit), offset: parseInt(offset), sortBy, - sortOrder + sortOrder, + includeArchived: includeArchived === 'true', + archivedOnly: archivedOnly === 'true' }); res.json(result); @@ -2018,6 +2558,106 @@ app.get('/api/admin/audit-log', authMiddleware, adminMiddleware, (req, res) => { } }); +// ============================================ +// ADMIN API - API REQUEST ACTIVITY LOG +// ============================================ + +/** + * Get API request logs (admin only) + * GET /api/admin/activity + */ +app.get('/api/admin/activity', authMiddleware, adminMiddleware, (req, res) => { + try { + const { + userId, + username, + method, + path, + statusCode, + statusCategory, // '2xx', '4xx', '5xx' + startDate, + endDate, + limit = 100, + offset = 0 + } = req.query; + + const options = { + userId: userId ? parseInt(userId) : undefined, + username, + method, + path, + statusCode: statusCode ? parseInt(statusCode) : undefined, + startDate, + endDate, + limit: parseInt(limit), + offset: parseInt(offset) + }; + + // Handle status category filter + if (statusCategory === '2xx') { + options.minStatusCode = 200; + options.maxStatusCode = 300; + } else if (statusCategory === '3xx') { + options.minStatusCode = 300; + options.maxStatusCode = 400; + } else if (statusCategory === '4xx') { + options.minStatusCode = 400; + options.maxStatusCode = 500; + } else if (statusCategory === '5xx') { + options.minStatusCode = 500; + options.maxStatusCode = 600; + } + + const result = db.getApiRequests(options); + res.json(result); + } catch (err) { + console.error('Get API activity error:', err); + res.status(500).json({ message: 'Failed to get API activity', error: err.message }); + } +}); + +/** + * Get API request statistics (admin only) + * GET /api/admin/activity/stats + */ +app.get('/api/admin/activity/stats', authMiddleware, adminMiddleware, (req, res) => { + try { + const { startDate, endDate } = req.query; + + const stats = db.getApiRequestStats({ + startDate, + endDate + }); + + res.json({ stats }); + } catch (err) { + console.error('Get API activity stats error:', err); + res.status(500).json({ message: 'Failed to get API activity stats', error: err.message }); + } +}); + +/** + * Cleanup old API request logs (admin only) + * DELETE /api/admin/activity/cleanup + */ +app.delete('/api/admin/activity/cleanup', authMiddleware, adminMiddleware, (req, res) => { + try { + const { daysToKeep = 30 } = req.query; + + const result = db.cleanupOldApiRequests(parseInt(daysToKeep)); + + db.logAudit(req.user.id, 'cleanup_api_logs', 'system', null, { daysToKeep, deletedCount: result.changes }, req.ip); + + res.json({ + message: `Cleaned up API request logs older than ${daysToKeep} days`, + deletedCount: result.changes + }); + } catch (err) { + console.error('Cleanup API activity error:', err); + res.status(500).json({ message: 'Failed to cleanup API activity', error: err.message }); + } +}); + // ============================================ // Start server // ============================================ @@ -2048,6 +2688,15 @@ app.listen(PORT, () => { console.log(' DELETE /api/cases/:id Delete case'); console.log(' GET /api/cases/:id/assets Get files'); console.log(''); + console.log('Patient Endpoints:'); + console.log(' POST /api/patients Create patient'); + console.log(' GET /api/patients List patients'); + console.log(' GET /api/patients/:id Get patient'); + console.log(' PUT /api/patients/:id Update patient'); + console.log(' DELETE /api/patients/:id Delete patient'); + console.log(' POST /api/patients/:id/cases Create case for patient'); + console.log(' GET /api/patients-stats Get patient statistics'); + console.log(''); console.log('Auth Endpoints:'); console.log(' POST /api/auth/login Login'); console.log(' POST /api/auth/logout Logout'); @@ -2061,5 +2710,8 @@ app.listen(PORT, () => { console.log(' GET /api/admin/cases List cases (filtered)'); console.log(' GET /api/admin/analytics/dashboard Get dashboard stats'); console.log(' GET /api/admin/audit-log Get audit log'); + console.log(' GET /api/admin/activity Get API request logs'); + console.log(' GET /api/admin/activity/stats Get API activity stats'); + console.log(' DELETE /api/admin/activity/cleanup Cleanup old API logs'); console.log(''); }); diff --git a/brace-generator/Dockerfile b/brace-generator/Dockerfile index c361659..3034afd 100644 --- a/brace-generator/Dockerfile +++ b/brace-generator/Dockerfile @@ -27,7 +27,7 @@ WORKDIR /app RUN pip install --no-cache-dir --upgrade pip && \ pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu -# Copy and install requirements (from brace-generator folder) +# Copy and install requirements COPY brace-generator/requirements.txt /app/requirements.txt RUN pip install --no-cache-dir -r requirements.txt @@ -35,8 +35,16 @@ RUN pip install --no-cache-dir -r requirements.txt COPY scoliovis-api/requirements.txt /app/requirements-scoliovis.txt RUN pip install --no-cache-dir -r requirements-scoliovis.txt || true -# Copy brace-generator code -COPY brace-generator/ /app/brace_generator/server_DEV/ +# Create brace_generator package structure +RUN mkdir -p /app/brace_generator + +# Copy brace-generator code as a package +COPY brace-generator/*.py /app/brace_generator/ +COPY brace-generator/__init__.py /app/brace_generator/__init__.py + +# Also keep server_DEV structure for compatibility +RUN mkdir -p /app/brace_generator/server_DEV +COPY brace-generator/*.py /app/brace_generator/server_DEV/ # Copy scoliovis-api COPY scoliovis-api/ /app/scoliovis-api/ @@ -44,8 +52,8 @@ COPY scoliovis-api/ /app/scoliovis-api/ # Copy templates COPY templates/ /app/templates/ -# Set Python path -ENV PYTHONPATH=/app:/app/brace_generator/server_DEV:/app/scoliovis-api +# Set Python path - include both locations +ENV PYTHONPATH=/app:/app/scoliovis-api # Environment variables ENV HOST=0.0.0.0 @@ -61,8 +69,8 @@ RUN mkdir -p /tmp/brace_generator /app/data/uploads /app/data/outputs EXPOSE 8002 # Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \ +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ CMD curl -f http://localhost:8002/health || exit 1 -# Run the server +# Run the server from the brace_generator package CMD ["python", "-m", "uvicorn", "brace_generator.server_DEV.app:app", "--host", "0.0.0.0", "--port", "8002"] diff --git a/brace-generator/adapters.py b/brace-generator/adapters.py new file mode 100644 index 0000000..f0cf3f2 --- /dev/null +++ b/brace-generator/adapters.py @@ -0,0 +1,508 @@ +""" +Model adapters that convert different model outputs to unified Spine2D format. +Each adapter wraps a specific model and produces consistent output. +""" +import sys +import numpy as np +from pathlib import Path +from typing import Optional, Dict, Any +from abc import ABC, abstractmethod + +from data_models import VertebraLandmark, Spine2D + + +class BaseLandmarkAdapter(ABC): + """Base class for landmark detection model adapters.""" + + @abstractmethod + def predict(self, image: np.ndarray) -> Spine2D: + """ + Run inference on an image and return unified spine landmarks. + + Args: + image: Input image as numpy array (grayscale or RGB) + + Returns: + Spine2D object with detected landmarks + """ + pass + + @property + @abstractmethod + def name(self) -> str: + """Model name for identification.""" + pass + + +class ScolioVisAdapter(BaseLandmarkAdapter): + """ + Adapter for ScolioVis-API (Keypoint R-CNN model). + + Uses the original ScolioVis inference code for best accuracy. + Outputs: 4 keypoints per vertebra + Cobb angles (PT, MT, TL) + curve type (S/C) + """ + + def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'): + """ + Initialize ScolioVis model. + + Args: + weights_path: Path to keypointsrcnn_weights.pt (auto-detects if None) + device: 'cpu' or 'cuda' + """ + self.device = device + self.model = None + self.weights_path = weights_path + self._scoliovis_path = None + self._load_model() + + def _load_model(self): + """Load the Keypoint R-CNN model.""" + import torch + import torchvision + from torchvision.models.detection import keypointrcnn_resnet50_fpn + from torchvision.models.detection.rpn import AnchorGenerator + + # Find weights and scoliovis module + scoliovis_api_path = Path(__file__).parent.parent / 'scoliovis-api' + if self.weights_path is None: + possible_paths = [ + scoliovis_api_path / 'models' / 'keypointsrcnn_weights.pt', + scoliovis_api_path / 'keypointsrcnn_weights.pt', + scoliovis_api_path / 'weights' / 'keypointsrcnn_weights.pt', + ] + for p in possible_paths: + if p.exists(): + self.weights_path = str(p) + break + + if self.weights_path is None or not Path(self.weights_path).exists(): + raise FileNotFoundError( + "ScolioVis weights not found. Please provide weights_path or ensure " + "scoliovis-api/models/keypointsrcnn_weights.pt exists." + ) + + # Store path to scoliovis module for Cobb angle calculation + self._scoliovis_path = scoliovis_api_path + + # Create model with same anchor generator as original training + anchor_generator = AnchorGenerator( + sizes=(32, 64, 128, 256, 512), + aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0) + ) + + self.model = keypointrcnn_resnet50_fpn( + weights=None, + weights_backbone=None, + num_classes=2, # background + vertebra + num_keypoints=4, # 4 corners per vertebra + rpn_anchor_generator=anchor_generator + ) + + # Load weights + checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False) + self.model.load_state_dict(checkpoint) + self.model.to(self.device) + self.model.eval() + + print(f"ScolioVis model loaded from {self.weights_path}") + + @property + def name(self) -> str: + return "ScolioVis-API" + + def _filter_output(self, output, max_verts: int = 17): + """ + Filter model output using NMS and score threshold. + Matches the original ScolioVis filtering logic. + """ + import torch + import torchvision + + scores = output['scores'].detach().cpu().numpy() + + # Get indices of scores over threshold (0.5) + high_scores_idxs = np.where(scores > 0.5)[0].tolist() + if len(high_scores_idxs) == 0: + return [], [], [] + + # Apply NMS with IoU threshold 0.3 + post_nms_idxs = torchvision.ops.nms( + output['boxes'][high_scores_idxs], + output['scores'][high_scores_idxs], + 0.3 + ).cpu().numpy() + + # Get filtered results + np_keypoints = output['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy() + np_bboxes = output['boxes'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy() + np_scores = output['scores'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy() + + # Take top N by score (usually 17 for full spine) + sorted_scores_idxs = np.argsort(-1 * np_scores) + np_scores = np_scores[sorted_scores_idxs][:max_verts] + np_keypoints = np.array([np_keypoints[idx] for idx in sorted_scores_idxs])[:max_verts] + np_bboxes = np.array([np_bboxes[idx] for idx in sorted_scores_idxs])[:max_verts] + + # Sort by ymin (top to bottom) + if len(np_keypoints) > 0: + ymins = np.array([kps[0][1] for kps in np_keypoints]) + sorted_ymin_idxs = np.argsort(ymins) + np_scores = np.array([np_scores[idx] for idx in sorted_ymin_idxs]) + np_keypoints = np.array([np_keypoints[idx] for idx in sorted_ymin_idxs]) + np_bboxes = np.array([np_bboxes[idx] for idx in sorted_ymin_idxs]) + + # Convert to lists + keypoints_list = [] + for kps in np_keypoints: + keypoints_list.append([list(map(float, kp[:2])) for kp in kps]) + + bboxes_list = [] + for bbox in np_bboxes: + bboxes_list.append(list(map(int, bbox.tolist()))) + + scores_list = np_scores.tolist() + + return bboxes_list, keypoints_list, scores_list + + def predict(self, image: np.ndarray) -> Spine2D: + """Run inference and return unified landmarks with ScolioVis Cobb angles.""" + import torch + from torchvision.transforms import functional as F + + # Ensure RGB + if len(image.shape) == 2: + image_rgb = np.stack([image, image, image], axis=-1) + else: + image_rgb = image + + image_shape = image_rgb.shape # (H, W, C) + + # Convert to tensor (ScolioVis uses torchvision's to_tensor) + img_tensor = F.to_tensor(image_rgb).to(self.device) + + # Run inference + with torch.no_grad(): + outputs = self.model([img_tensor]) + + # Filter output using original ScolioVis logic + bboxes, keypoints, scores = self._filter_output(outputs[0]) + + if len(keypoints) == 0: + return Spine2D( + vertebrae=[], + image_shape=image_shape[:2], + source_model=self.name + ) + + # Convert to unified format + vertebrae = [] + for i in range(len(bboxes)): + kps = np.array(keypoints[i], dtype=np.float32) # (4, 2) + + # Corners order from ScolioVis: [top_left, top_right, bottom_left, bottom_right] + corners = kps + centroid = np.mean(corners, axis=0) + + # Compute orientation from top edge (kps[0] to kps[1]) + top_left, top_right = corners[0], corners[1] + dx = top_right[0] - top_left[0] + dy = top_right[1] - top_left[1] + orientation = np.degrees(np.arctan2(dy, dx)) + + vert = VertebraLandmark( + level=None, # ScolioVis doesn't assign levels + centroid_px=centroid, + corners_px=corners, + endplate_upper_px=corners[:2], # top-left, top-right + endplate_lower_px=corners[2:], # bottom-left, bottom-right + orientation_deg=orientation, + confidence=float(scores[i]), + meta={'box': bboxes[i]} + ) + vertebrae.append(vert) + + # Create Spine2D + spine = Spine2D( + vertebrae=vertebrae, + image_shape=image_shape[:2], + source_model=self.name + ) + + # Use original ScolioVis Cobb angle calculation if available + if len(keypoints) >= 5: + try: + # Import original ScolioVis cobb_angle_cal + if str(self._scoliovis_path) not in sys.path: + sys.path.insert(0, str(self._scoliovis_path)) + + from scoliovis.cobb_angle_cal import cobb_angle_cal, keypoints_to_landmark_xy + + landmark_xy = keypoints_to_landmark_xy(keypoints) + cobb_angles_list, angles_with_pos, curve_type, midpoint_lines = cobb_angle_cal( + landmark_xy, image_shape + ) + + # Store Cobb angles in spine object + spine.cobb_angles = { + 'PT': cobb_angles_list[0], + 'MT': cobb_angles_list[1], + 'TL': cobb_angles_list[2] + } + spine.curve_type = curve_type + spine.meta = { + 'angles_with_pos': angles_with_pos, + 'midpoint_lines': midpoint_lines + } + + except Exception as e: + print(f"Warning: Could not use ScolioVis Cobb calculation: {e}") + # Fallback to our own calculation + from spine_analysis import compute_cobb_angles + compute_cobb_angles(spine) + + return spine + + +class VertLandmarkAdapter(BaseLandmarkAdapter): + """ + Adapter for Vertebra-Landmark-Detection (SpineNet model). + + Outputs: 68 landmarks (4 corners × 17 vertebrae) + """ + + def __init__(self, weights_path: Optional[str] = None, device: str = 'cpu'): + """ + Initialize SpineNet model. + + Args: + weights_path: Path to model_last.pth (auto-detects if None) + device: 'cpu' or 'cuda' + """ + self.device = device + self.model = None + self.weights_path = weights_path + self._load_model() + + def _load_model(self): + """Load the SpineNet model.""" + import torch + + # Find weights + if self.weights_path is None: + possible_paths = [ + Path(__file__).parent.parent / 'Vertebra-Landmark-Detection' / 'weights_spinal' / 'model_last.pth', + ] + for p in possible_paths: + if p.exists(): + self.weights_path = str(p) + break + + if self.weights_path is None or not Path(self.weights_path).exists(): + raise FileNotFoundError( + "Vertebra-Landmark-Detection weights not found. " + "Download from Google Drive and place in weights_spinal/model_last.pth" + ) + + # Add repo to path to import model + repo_path = Path(__file__).parent.parent / 'Vertebra-Landmark-Detection' + if str(repo_path) not in sys.path: + sys.path.insert(0, str(repo_path)) + + from models import spinal_net + + # Create model + heads = {'hm': 1, 'reg': 2, 'wh': 8} + self.model = spinal_net.SpineNet( + heads=heads, + pretrained=False, + down_ratio=4, + final_kernel=1, + head_conv=256 + ) + + # Load weights + checkpoint = torch.load(self.weights_path, map_location=self.device, weights_only=False) + self.model.load_state_dict(checkpoint['state_dict'], strict=False) + self.model.to(self.device) + self.model.eval() + + print(f"Vertebra-Landmark-Detection model loaded from {self.weights_path}") + + @property + def name(self) -> str: + return "Vertebra-Landmark-Detection" + + def _nms(self, heat, kernel=3): + """Apply NMS using max pooling.""" + import torch + import torch.nn.functional as F + hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2) + keep = (hmax == heat).float() + return heat * keep + + def _gather_feat(self, feat, ind): + """Gather features by index.""" + dim = feat.size(2) + ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) + feat = feat.gather(1, ind) + return feat + + def _tranpose_and_gather_feat(self, feat, ind): + """Transpose and gather features - matches original decoder.""" + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + feat = self._gather_feat(feat, ind) + return feat + + def _decode_predictions(self, output: Dict, down_ratio: int = 4, K: int = 17): + """Decode model output using original decoder logic.""" + import torch + + hm = output['hm'].sigmoid() + reg = output['reg'] + wh = output['wh'] + + batch, cat, height, width = hm.size() + + # Apply NMS + hm = self._nms(hm) + + # Get top K from heatmap + topk_scores, topk_inds = torch.topk(hm.view(batch, cat, -1), K) + topk_inds = topk_inds % (height * width) + topk_ys = (topk_inds // width).float() + topk_xs = (topk_inds % width).float() + + # Get overall top K + topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) + topk_inds = self._gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K) + topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K) + topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K) + + scores = topk_score.view(batch, K, 1) + + # Get regression offset and apply + reg = self._tranpose_and_gather_feat(reg, topk_inds) + reg = reg.view(batch, K, 2) + xs = topk_xs.view(batch, K, 1) + reg[:, :, 0:1] + ys = topk_ys.view(batch, K, 1) + reg[:, :, 1:2] + + # Get corner offsets + wh = self._tranpose_and_gather_feat(wh, topk_inds) + wh = wh.view(batch, K, 8) + + # Calculate corners by SUBTRACTING offsets (original decoder logic) + tl_x = xs - wh[:, :, 0:1] + tl_y = ys - wh[:, :, 1:2] + tr_x = xs - wh[:, :, 2:3] + tr_y = ys - wh[:, :, 3:4] + bl_x = xs - wh[:, :, 4:5] + bl_y = ys - wh[:, :, 5:6] + br_x = xs - wh[:, :, 6:7] + br_y = ys - wh[:, :, 7:8] + + # Combine into output format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score] + pts = torch.cat([xs, ys, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, scores], dim=2) + + # Scale to image coordinates + pts[:, :, :10] *= down_ratio + + return pts[0].cpu().numpy() # (K, 11) + + def predict(self, image: np.ndarray) -> Spine2D: + """Run inference and return unified landmarks.""" + import torch + import cv2 + + # Ensure RGB + if len(image.shape) == 2: + image_rgb = np.stack([image, image, image], axis=-1) + else: + image_rgb = image + + orig_h, orig_w = image_rgb.shape[:2] + + # Resize to model input size (1024x512) + input_h, input_w = 1024, 512 + img_resized = cv2.resize(image_rgb, (input_w, input_h)) + + # Normalize and convert to tensor - use original preprocessing! + # Original: out_image = image / 255. - 0.5 (NOT ImageNet stats) + img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0 - 0.5 + img_tensor = img_tensor.unsqueeze(0).to(self.device) + + # Run inference + with torch.no_grad(): + output = self.model(img_tensor) + + # Decode predictions - returns (K, 11) array + # Format: [cx, cy, tl_x, tl_y, tr_x, tr_y, bl_x, bl_y, br_x, br_y, score] + pts = self._decode_predictions(output, down_ratio=4, K=17) + + # Scale coordinates back to original image size + scale_x = orig_w / input_w + scale_y = orig_h / input_h + + # Convert to unified format + vertebrae = [] + threshold = 0.3 + + for i in range(len(pts)): + score = pts[i, 10] + if score < threshold: + continue + + # Get center and corners (already scaled by down_ratio in decoder) + cx = pts[i, 0] * scale_x + cy = pts[i, 1] * scale_y + + # Corners: tl, tr, bl, br + tl = np.array([pts[i, 2] * scale_x, pts[i, 3] * scale_y]) + tr = np.array([pts[i, 4] * scale_x, pts[i, 5] * scale_y]) + bl = np.array([pts[i, 6] * scale_x, pts[i, 7] * scale_y]) + br = np.array([pts[i, 8] * scale_x, pts[i, 9] * scale_y]) + + # Reorder to [tl, tr, br, bl] for consistency with ScolioVis + corners = np.array([tl, tr, br, bl], dtype=np.float32) + centroid = np.array([cx, cy], dtype=np.float32) + + # Compute orientation from top edge + dx = tr[0] - tl[0] + dy = tr[1] - tl[1] + orientation = np.degrees(np.arctan2(dy, dx)) + + vert = VertebraLandmark( + level=None, + centroid_px=centroid, + corners_px=corners, + endplate_upper_px=np.array([tl, tr]), # top edge + endplate_lower_px=np.array([bl, br]), # bottom edge + orientation_deg=orientation, + confidence=float(score), + meta={'raw_pts': pts[i].tolist()} + ) + vertebrae.append(vert) + + # Sort by y-coordinate (top to bottom) + vertebrae.sort(key=lambda v: float(v.centroid_px[1])) + + # Assign vertebra levels (T1-L5 = 17 vertebrae typically) + level_names = ['T1', 'T2', 'T3', 'T4', 'T5', 'T6', 'T7', 'T8', 'T9', 'T10', 'T11', 'T12', 'L1', 'L2', 'L3', 'L4', 'L5'] + for i, vert in enumerate(vertebrae): + if i < len(level_names): + vert.level = level_names[i] + + # Create Spine2D + spine = Spine2D( + vertebrae=vertebrae, + image_shape=(orig_h, orig_w), + source_model=self.name + ) + + # Compute Cobb angles + if len(vertebrae) >= 7: + from spine_analysis import compute_cobb_angles + compute_cobb_angles(spine) + + return spine diff --git a/brace-generator/brace_surface.py b/brace-generator/brace_surface.py new file mode 100644 index 0000000..9928cd3 --- /dev/null +++ b/brace-generator/brace_surface.py @@ -0,0 +1,354 @@ +""" +Brace surface generation from spine landmarks. + +Two modes: +- Version A: Generic/average body shape (parametric torso) +- Version B: Uses actual 3D body scan mesh +""" +import numpy as np +from typing import Tuple, Optional, List +from pathlib import Path + +from .data_models import Spine2D, BraceConfig +from .spine_analysis import compute_spine_curve, find_apex_vertebrae + +try: + import trimesh + HAS_TRIMESH = True +except ImportError: + HAS_TRIMESH = False + + +class BraceGenerator: + """ + Generates 3D brace shell from spine landmarks. + """ + + def __init__(self, config: Optional[BraceConfig] = None): + """ + Initialize brace generator. + + Args: + config: Brace configuration parameters + """ + if not HAS_TRIMESH: + raise ImportError("trimesh is required for brace generation. Install with: pip install trimesh") + + self.config = config or BraceConfig() + + def generate(self, spine: Spine2D) -> 'trimesh.Trimesh': + """ + Generate brace mesh from spine landmarks. + + Args: + spine: Spine2D object with detected vertebrae + + Returns: + trimesh.Trimesh object representing the brace shell + """ + if self.config.use_body_scan and self.config.body_scan_path: + return self._generate_from_body_scan(spine) + else: + return self._generate_from_average_body(spine) + + def _torso_profile(self, z01: float) -> Tuple[float, float]: + """ + Get torso cross-section radii at a given height. + + Args: + z01: Normalized height (0=top, 1=bottom) + + Returns: + (a_mm, b_mm): Radii in left-right and front-back directions + """ + # Torso shape varies with height + # Wider at chest (z~0.3) and hips (z~0.8), narrower at waist (z~0.5) + + # Base radii from config + base_a = self.config.torso_width_mm / 2 + base_b = self.config.torso_depth_mm / 2 + + # Shape modulation + # Chest region (z ~ 0.2-0.4): wider + # Waist region (z ~ 0.5): narrower + # Hip region (z ~ 0.8-1.0): wider + + if z01 < 0.3: + # Upper chest - moderate width + mod = 1.0 + elif z01 < 0.5: + # Transition to waist + t = (z01 - 0.3) / 0.2 + mod = 1.0 - 0.15 * t # Decrease by 15% + elif z01 < 0.7: + # Waist region - narrowest + mod = 0.85 + else: + # Hips - widen again + t = (z01 - 0.7) / 0.3 + mod = 0.85 + 0.2 * t # Increase by 20% + + return base_a * mod, base_b * mod + + def _generate_from_average_body(self, spine: Spine2D) -> 'trimesh.Trimesh': + """ + Generate brace using parametric average body shape. + + The brace follows the spine curve laterally and applies + pressure zones at curve apexes. + """ + cfg = self.config + + # 1) Compute spine curve + try: + C_px, T_px, N_px, curvature = compute_spine_curve(spine, smooth=5.0, n_samples=cfg.n_vertical_slices) + except ValueError as e: + raise ValueError(f"Cannot generate brace: {e}") + + # 2) Convert to mm + if spine.pixel_spacing_mm is not None: + sx, sy = spine.pixel_spacing_mm + elif cfg.pixel_spacing_mm is not None: + sx, sy = cfg.pixel_spacing_mm + else: + sx = sy = 0.25 # Default assumption + + C_mm = np.zeros_like(C_px) + C_mm[:, 0] = C_px[:, 0] * sx + C_mm[:, 1] = C_px[:, 1] * sy + + # 3) Determine brace vertical extent + y_mm = C_mm[:, 1] + y_min, y_max = y_mm.min(), y_mm.max() + spine_height = y_max - y_min + + # Brace height (might extend beyond detected vertebrae) + brace_height = min(cfg.brace_height_mm, spine_height * 1.1) + + # 4) Normalize curvature for pressure zones + curv_norm = (curvature - curvature.min()) / (curvature.max() - curvature.min() + 1e-8) + + # 5) Build vertices + n_z = cfg.n_vertical_slices + n_theta = cfg.n_circumference_points + + # Opening angle (front of brace might be open) + opening_half = np.radians(cfg.front_opening_deg / 2) + + vertices = [] + + for i in range(n_z): + z01 = i / (n_z - 1) # 0 to 1 + + # Z coordinate (vertical position in 3D) + z_mm = y_min + z01 * spine_height + + # Get torso profile at this height + a_mm, b_mm = self._torso_profile(z01) + + # Lateral offset from spine curve + x_offset = C_mm[i, 0] - (C_mm[0, 0] + C_mm[-1, 0]) / 2 # Deviation from midline + + # Pressure modulation based on curvature + pressure = cfg.pressure_strength_mm * curv_norm[i] + + for j in range(n_theta): + theta = 2 * np.pi * (j / n_theta) + + # Skip vertices in the opening region (front = theta around 0) + # Actually, we'll still create them but can mark them for later removal + + # Base ellipse point + x = a_mm * np.cos(theta) + y = b_mm * np.sin(theta) + + # Apply lateral offset (brace follows spine curve) + x += x_offset + + # Apply pressure zones + # Pressure on sides (theta near π/2 or 3π/2 = sides) + # The side that's convex gets pushed in + side_factor = abs(np.cos(theta)) # Max at sides (theta=0 or π) + + # Determine which side based on spine deviation + if x_offset > 0: + # Spine deviated right, push on right side + if np.cos(theta) > 0: # Right side + x -= pressure * side_factor + else: + # Spine deviated left, push on left side + if np.cos(theta) < 0: # Left side + x -= pressure * side_factor * np.sign(np.cos(theta)) + + # Vertex position: x=left/right, y=front/back, z=vertical + vertices.append([x, y, z_mm]) + + vertices = np.array(vertices, dtype=np.float32) + + # 6) Build faces (quad strips between adjacent rings) + faces = [] + + def vid(i, j): + return i * n_theta + (j % n_theta) + + for i in range(n_z - 1): + for j in range(n_theta): + j2 = (j + 1) % n_theta + + # Two triangles per quad + a = vid(i, j) + b = vid(i, j2) + c = vid(i + 1, j2) + d = vid(i + 1, j) + + faces.append([a, b, c]) + faces.append([a, c, d]) + + faces = np.array(faces, dtype=np.int32) + + # 7) Create outer shell mesh + outer_shell = trimesh.Trimesh(vertices=vertices, faces=faces, process=True) + + # 8) Create inner shell (offset inward by wall thickness) + outer_shell.fix_normals() + vn = outer_shell.vertex_normals + inner_vertices = vertices - cfg.wall_thickness_mm * vn + + # Inner faces need reversed winding + inner_faces = faces[:, ::-1] + + # 9) Combine into solid shell + all_vertices = np.vstack([vertices, inner_vertices]) + inner_faces_offset = inner_faces + len(vertices) + all_faces = np.vstack([faces, inner_faces_offset]) + + # 10) Add end caps (top and bottom rings) + # Top cap (connect outer to inner at z=0) + top_faces = [] + for j in range(n_theta): + j2 = (j + 1) % n_theta + outer_j = vid(0, j) + outer_j2 = vid(0, j2) + inner_j = outer_j + len(vertices) + inner_j2 = outer_j2 + len(vertices) + top_faces.append([outer_j, inner_j, inner_j2]) + top_faces.append([outer_j, inner_j2, outer_j2]) + + # Bottom cap + bottom_faces = [] + for j in range(n_theta): + j2 = (j + 1) % n_theta + outer_j = vid(n_z - 1, j) + outer_j2 = vid(n_z - 1, j2) + inner_j = outer_j + len(vertices) + inner_j2 = outer_j2 + len(vertices) + bottom_faces.append([outer_j, outer_j2, inner_j2]) + bottom_faces.append([outer_j, inner_j2, inner_j]) + + all_faces = np.vstack([all_faces, top_faces, bottom_faces]) + + # Create final mesh + brace = trimesh.Trimesh(vertices=all_vertices, faces=all_faces, process=True) + brace.merge_vertices() + # Remove degenerate faces + valid_faces = brace.nondegenerate_faces() + brace.update_faces(valid_faces) + brace.fix_normals() + + return brace + + def _generate_from_body_scan(self, spine: Spine2D) -> 'trimesh.Trimesh': + """ + Generate brace by offsetting from a 3D body scan mesh. + + The body scan provides the actual torso shape, and we: + 1. Offset outward for clearance + 2. Apply pressure zones based on spine curvature + 3. Thicken for wall thickness + """ + cfg = self.config + + if not cfg.body_scan_path or not Path(cfg.body_scan_path).exists(): + raise FileNotFoundError(f"Body scan not found: {cfg.body_scan_path}") + + # Load body scan + body = trimesh.load(cfg.body_scan_path, force='mesh') + body.remove_unreferenced_vertices() + body.fix_normals() + + # Compute spine curve for pressure mapping + try: + C_px, T_px, N_px, curvature = compute_spine_curve(spine, smooth=5.0, n_samples=200) + except ValueError: + curvature = np.zeros(200) + + # Convert spine coordinates to mm + if spine.pixel_spacing_mm is not None: + sx, sy = spine.pixel_spacing_mm + else: + sx = sy = 0.25 + + y_mm = C_px[:, 1] * sy + y_min, y_max = y_mm.min(), y_mm.max() + H = y_max - y_min + 1e-6 + + # Normalize curvature + curv_norm = (curvature - curvature.min()) / (curvature.max() - curvature.min() + 1e-8) + + # 1) Offset body surface outward for clearance (inner brace surface) + clearance_mm = 6.0 # Gap between body and brace + vn = body.vertex_normals + inner_surface = trimesh.Trimesh( + vertices=body.vertices + clearance_mm * vn, + faces=body.faces.copy(), + process=True + ) + + # 2) Apply pressure deformation + # Map each vertex's Z coordinate to spine curvature + z_coords = inner_surface.vertices[:, 2] # Assuming Z is vertical + z_min, z_max = z_coords.min(), z_coords.max() + z01 = (z_coords - z_min) / (z_max - z_min + 1e-6) + + # Sample curvature at each vertex height + curv_idx = np.clip((z01 * (len(curv_norm) - 1)).astype(int), 0, len(curv_norm) - 1) + pressure_per_vertex = cfg.pressure_strength_mm * curv_norm[curv_idx] + + # Apply pressure on sides (based on X coordinate) + x_coords = inner_surface.vertices[:, 0] + x_range = np.abs(x_coords).max() + 1e-6 + side_factor = np.abs(x_coords) / x_range # 0 at center, 1 at sides + + deformation = (pressure_per_vertex * side_factor)[:, np.newaxis] * inner_surface.vertex_normals + inner_surface.vertices = inner_surface.vertices - deformation + + # 3) Create outer surface (offset by wall thickness) + inner_surface.fix_normals() + outer_surface = trimesh.Trimesh( + vertices=inner_surface.vertices + cfg.wall_thickness_mm * inner_surface.vertex_normals, + faces=inner_surface.faces.copy(), + process=True + ) + + # 4) Combine surfaces + # For a true solid, we'd need to stitch edges - simplified here + brace = trimesh.util.concatenate([inner_surface, outer_surface]) + brace.merge_vertices() + valid_faces = brace.nondegenerate_faces() + brace.update_faces(valid_faces) + brace.fix_normals() + + return brace + + def export_stl(self, mesh: 'trimesh.Trimesh', output_path: str): + """ + Export mesh to STL file. + + Args: + mesh: trimesh.Trimesh object + output_path: Path for output STL file + """ + mesh.export(output_path) + print(f"Exported brace to {output_path}") + print(f" Vertices: {len(mesh.vertices)}") + print(f" Faces: {len(mesh.faces)}") diff --git a/brace-generator/data_models.py b/brace-generator/data_models.py new file mode 100644 index 0000000..07c256c --- /dev/null +++ b/brace-generator/data_models.py @@ -0,0 +1,177 @@ +""" +Data models for unified spine landmark representation. +This is the "glue" that connects different model outputs to the brace generator. +""" +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any +import numpy as np + + +@dataclass +class VertebraLandmark: + """ + Unified representation of a single vertebra's landmarks. + All coordinates are in pixels (can be converted to mm with pixel_spacing). + """ + # Vertebra level identifier (e.g., "T1", "T4", "L1", etc.) - None if unknown + level: Optional[str] = None + + # Center point of vertebra [x, y] in pixels + centroid_px: np.ndarray = field(default_factory=lambda: np.zeros(2)) + + # Four corner points [top_left, top_right, bottom_right, bottom_left] shape (4, 2) + corners_px: Optional[np.ndarray] = None + + # Upper endplate points [left, right] shape (2, 2) + endplate_upper_px: Optional[np.ndarray] = None + + # Lower endplate points [left, right] shape (2, 2) + endplate_lower_px: Optional[np.ndarray] = None + + # Orientation angle of vertebra in degrees (tilt in coronal plane) + orientation_deg: Optional[float] = None + + # Detection confidence (0-1) + confidence: float = 1.0 + + # Additional metadata from source model + meta: Optional[Dict[str, Any]] = None + + def compute_orientation(self) -> float: + """Compute vertebra orientation from corners or endplates.""" + if self.orientation_deg is not None: + return self.orientation_deg + + # Try to compute from upper endplate + if self.endplate_upper_px is not None: + left, right = self.endplate_upper_px[0], self.endplate_upper_px[1] + dx = right[0] - left[0] + dy = right[1] - left[1] + angle = np.degrees(np.arctan2(dy, dx)) + self.orientation_deg = angle + return angle + + # Try to compute from corners (top-left to top-right) + if self.corners_px is not None: + top_left, top_right = self.corners_px[0], self.corners_px[1] + dx = top_right[0] - top_left[0] + dy = top_right[1] - top_left[1] + angle = np.degrees(np.arctan2(dy, dx)) + self.orientation_deg = angle + return angle + + return 0.0 + + def compute_centroid(self) -> np.ndarray: + """Compute centroid from corners if not set.""" + if self.corners_px is not None and np.all(self.centroid_px == 0): + self.centroid_px = np.mean(self.corners_px, axis=0) + return self.centroid_px + + +@dataclass +class Spine2D: + """ + Complete 2D spine representation from an X-ray. + Contains all detected vertebrae and computed angles. + """ + # List of vertebrae, ordered from top (C7/T1) to bottom (L5/S1) + vertebrae: List[VertebraLandmark] = field(default_factory=list) + + # Pixel spacing in mm [sx, sy] - from DICOM if available + pixel_spacing_mm: Optional[np.ndarray] = None + + # Original image shape (height, width) + image_shape: Optional[tuple] = None + + # Computed Cobb angles in degrees (individual fields) + cobb_pt: Optional[float] = None # Proximal Thoracic + cobb_mt: Optional[float] = None # Main Thoracic + cobb_tl: Optional[float] = None # Thoracolumbar/Lumbar + + # Cobb angles as dictionary (alternative format) + cobb_angles: Optional[Dict[str, float]] = None # {'PT': angle, 'MT': angle, 'TL': angle} + + # Curve type: "S" (double curve) or "C" (single curve) or "Normal" + curve_type: Optional[str] = None + + # Rigo-Chêneau classification + rigo_type: Optional[str] = None # A1, A2, A3, B1, B2, C1, C2, E1, E2, Normal + rigo_description: Optional[str] = None # Detailed description + + # Source model that generated this data + source_model: Optional[str] = None + + # Additional metadata + meta: Optional[Dict[str, Any]] = None + + def get_cobb_angles(self) -> Dict[str, float]: + """Get Cobb angles as dictionary, preferring computed individual fields.""" + # Prefer individual fields (set by compute_cobb_angles) over dictionary + # This ensures consistency between displayed values and classification + if self.cobb_pt is not None or self.cobb_mt is not None or self.cobb_tl is not None: + return { + 'PT': self.cobb_pt or 0.0, + 'MT': self.cobb_mt or 0.0, + 'TL': self.cobb_tl or 0.0 + } + if self.cobb_angles is not None: + return self.cobb_angles + return {'PT': 0.0, 'MT': 0.0, 'TL': 0.0} + + def get_centroids(self) -> np.ndarray: + """Get array of all vertebra centroids, shape (N, 2).""" + centroids = [] + for v in self.vertebrae: + v.compute_centroid() + centroids.append(v.centroid_px) + return np.array(centroids, dtype=np.float32) + + def get_orientations(self) -> np.ndarray: + """Get array of all vertebra orientations in degrees, shape (N,).""" + return np.array([v.compute_orientation() for v in self.vertebrae], dtype=np.float32) + + def to_mm(self, coords_px: np.ndarray) -> np.ndarray: + """Convert pixel coordinates to millimeters.""" + if self.pixel_spacing_mm is None: + # Default assumption: 0.25 mm/pixel (typical for spine X-rays) + spacing = np.array([0.25, 0.25]) + else: + spacing = self.pixel_spacing_mm + return coords_px * spacing + + def sort_vertebrae(self): + """Sort vertebrae by vertical position (top to bottom).""" + self.vertebrae.sort(key=lambda v: float(v.centroid_px[1])) + + +@dataclass +class BraceConfig: + """ + Configuration parameters for brace generation. + """ + # Brace dimensions + brace_height_mm: float = 400.0 # Total height of brace + wall_thickness_mm: float = 4.0 # Shell thickness + + # Torso shape parameters (for average body mode) + torso_width_mm: float = 280.0 # Left-right diameter at widest + torso_depth_mm: float = 200.0 # Front-back diameter at widest + + # Correction parameters + pressure_strength_mm: float = 15.0 # Max indentation at apex + pressure_spread_deg: float = 45.0 # Angular spread of pressure zone + + # Mesh resolution + n_vertical_slices: int = 100 # Number of cross-sections + n_circumference_points: int = 72 # Points per cross-section (every 5°) + + # Opening (for brace accessibility) + front_opening_deg: float = 60.0 # Angular width of front opening (0 = closed) + + # Mode + use_body_scan: bool = False # True = use 3D body scan, False = average body + body_scan_path: Optional[str] = None # Path to body scan mesh + + # Scale + pixel_spacing_mm: Optional[np.ndarray] = None # Override pixel spacing diff --git a/brace-generator/image_loader.py b/brace-generator/image_loader.py new file mode 100644 index 0000000..6f5f059 --- /dev/null +++ b/brace-generator/image_loader.py @@ -0,0 +1,115 @@ +""" +Image loader supporting JPEG, PNG, and DICOM formats. +""" +import numpy as np +from pathlib import Path +from typing import Tuple, Optional + +try: + import pydicom + HAS_PYDICOM = True +except ImportError: + HAS_PYDICOM = False + +try: + from PIL import Image + HAS_PIL = True +except ImportError: + HAS_PIL = False + +try: + import cv2 + HAS_CV2 = True +except ImportError: + HAS_CV2 = False + + +def load_xray(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Load an X-ray image from file. + + Supports: JPEG, PNG, BMP, DICOM (.dcm) + + Args: + path: Path to the image file + + Returns: + img_u8: Grayscale image as uint8 array (H, W) + spacing_mm: Pixel spacing [sx, sy] in mm, or None if not available + """ + path = Path(path) + suffix = path.suffix.lower() + + # DICOM + if suffix in ['.dcm', '.dicom']: + if not HAS_PYDICOM: + raise ImportError("pydicom is required for DICOM files. Install with: pip install pydicom") + return _load_dicom(str(path)) + + # Standard image formats + if suffix in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']: + return _load_standard_image(str(path)) + + # Try to load as standard image anyway + try: + return _load_standard_image(str(path)) + except Exception as e: + raise ValueError(f"Could not load image: {path}. Error: {e}") + + +def _load_dicom(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Load DICOM file.""" + ds = pydicom.dcmread(path) + arr = ds.pixel_array.astype(np.float32) + + # Apply modality LUT if present + if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'): + arr = arr * ds.RescaleSlope + ds.RescaleIntercept + + # Normalize to 0-255 + arr = arr - arr.min() + if arr.max() > 0: + arr = arr / arr.max() + img_u8 = (arr * 255).astype(np.uint8) + + # Get pixel spacing + spacing_mm = None + if hasattr(ds, 'PixelSpacing'): + # PixelSpacing is [row_spacing, col_spacing] in mm + sy, sx = [float(x) for x in ds.PixelSpacing] + spacing_mm = np.array([sx, sy], dtype=np.float32) + elif hasattr(ds, 'ImagerPixelSpacing'): + sy, sx = [float(x) for x in ds.ImagerPixelSpacing] + spacing_mm = np.array([sx, sy], dtype=np.float32) + + return img_u8, spacing_mm + + +def _load_standard_image(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Load standard image format (JPEG, PNG, etc.).""" + if HAS_CV2: + img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + if img is None: + raise ValueError(f"Could not read image: {path}") + return img.astype(np.uint8), None + elif HAS_PIL: + img = Image.open(path).convert('L') # Convert to grayscale + return np.array(img, dtype=np.uint8), None + else: + raise ImportError("Either opencv-python or Pillow is required. Install with: pip install opencv-python") + + +def load_xray_rgb(path: str) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Load X-ray as RGB (for models that expect 3-channel input). + + Returns: + img_rgb: RGB image as uint8 array (H, W, 3) + spacing_mm: Pixel spacing or None + """ + img_gray, spacing = load_xray(path) + + # Convert grayscale to RGB by stacking + img_rgb = np.stack([img_gray, img_gray, img_gray], axis=-1) + + return img_rgb, spacing diff --git a/brace-generator/pipeline.py b/brace-generator/pipeline.py new file mode 100644 index 0000000..61e907c --- /dev/null +++ b/brace-generator/pipeline.py @@ -0,0 +1,346 @@ +""" +Complete pipeline: X-ray → Landmarks → Brace STL +""" +import numpy as np +from pathlib import Path +from typing import Optional, Dict, Any, Union +import json + +from brace_generator.data_models import Spine2D, BraceConfig, VertebraLandmark +from brace_generator.image_loader import load_xray, load_xray_rgb +from brace_generator.adapters import BaseLandmarkAdapter, ScolioVisAdapter, VertLandmarkAdapter +from brace_generator.spine_analysis import ( + compute_spine_curve, compute_cobb_angles, find_apex_vertebrae, + get_curve_severity, classify_rigo_type +) +from brace_generator.brace_surface import BraceGenerator + + +class BracePipeline: + """ + End-to-end pipeline for generating scoliosis braces from X-rays. + + Usage: + # Basic usage with default model + pipeline = BracePipeline() + pipeline.process("xray.png", "brace.stl") + + # With specific model + pipeline = BracePipeline(model="vertebra-landmark") + pipeline.process("xray.dcm", "brace.stl") + + # With body scan + config = BraceConfig(use_body_scan=True, body_scan_path="body.obj") + pipeline = BracePipeline(config=config) + pipeline.process("xray.png", "brace.stl") + """ + + AVAILABLE_MODELS = { + 'scoliovis': ScolioVisAdapter, + 'vertebra-landmark': VertLandmarkAdapter, + } + + def __init__( + self, + model: str = 'scoliovis', + config: Optional[BraceConfig] = None, + device: str = 'cpu' + ): + """ + Initialize pipeline. + + Args: + model: Model to use ('scoliovis' or 'vertebra-landmark') + config: Brace configuration + device: 'cpu' or 'cuda' + """ + self.device = device + self.config = config or BraceConfig() + self.model_name = model.lower() + + # Initialize model adapter + if self.model_name not in self.AVAILABLE_MODELS: + raise ValueError(f"Unknown model: {model}. Available: {list(self.AVAILABLE_MODELS.keys())}") + + self.adapter: BaseLandmarkAdapter = self.AVAILABLE_MODELS[self.model_name](device=device) + self.brace_generator = BraceGenerator(self.config) + + # Store last results for inspection + self.last_spine: Optional[Spine2D] = None + self.last_image: Optional[np.ndarray] = None + + def process( + self, + xray_path: str, + output_stl_path: str, + visualize: bool = False, + save_landmarks: bool = False + ) -> Dict[str, Any]: + """ + Process X-ray and generate brace STL. + + Args: + xray_path: Path to input X-ray (JPEG, PNG, or DICOM) + output_stl_path: Path for output STL file + visualize: If True, also save visualization image + save_landmarks: If True, also save landmarks JSON + + Returns: + Dictionary with analysis results + """ + print(f"=" * 60) + print(f"Brace Generation Pipeline") + print(f"Model: {self.adapter.name}") + print(f"=" * 60) + + # 1) Load X-ray + print(f"\n1. Loading X-ray: {xray_path}") + image_rgb, pixel_spacing = load_xray_rgb(xray_path) + self.last_image = image_rgb + print(f" Image size: {image_rgb.shape[:2]}") + if pixel_spacing is not None: + print(f" Pixel spacing: {pixel_spacing} mm") + + # 2) Detect landmarks + print(f"\n2. Detecting landmarks...") + spine = self.adapter.predict(image_rgb) + spine.pixel_spacing_mm = pixel_spacing + self.last_spine = spine + + print(f" Detected {len(spine.vertebrae)} vertebrae") + + if len(spine.vertebrae) < 5: + raise ValueError(f"Insufficient vertebrae detected ({len(spine.vertebrae)}). Need at least 5.") + + # 3) Compute spine analysis + print(f"\n3. Analyzing spine curvature...") + compute_cobb_angles(spine) + apexes = find_apex_vertebrae(spine) + + # Classify Rigo type + rigo_result = classify_rigo_type(spine) + + print(f" Cobb Angles:") + print(f" PT (Proximal Thoracic): {spine.cobb_pt:.1f}° - {get_curve_severity(spine.cobb_pt)}") + print(f" MT (Main Thoracic): {spine.cobb_mt:.1f}° - {get_curve_severity(spine.cobb_mt)}") + print(f" TL (Thoracolumbar): {spine.cobb_tl:.1f}° - {get_curve_severity(spine.cobb_tl)}") + print(f" Curve type: {spine.curve_type}") + print(f" Rigo Classification: {rigo_result['rigo_type']}") + print(f" - {rigo_result['description']}") + print(f" Apex vertebrae indices: {apexes}") + + # 4) Generate brace + print(f"\n4. Generating brace mesh...") + if self.config.use_body_scan: + print(f" Mode: Using body scan ({self.config.body_scan_path})") + else: + print(f" Mode: Average body shape") + + brace_mesh = self.brace_generator.generate(spine) + print(f" Mesh: {len(brace_mesh.vertices)} vertices, {len(brace_mesh.faces)} faces") + + # 5) Export STL + print(f"\n5. Exporting STL: {output_stl_path}") + self.brace_generator.export_stl(brace_mesh, output_stl_path) + + # 6) Optional: Save visualization + if visualize: + vis_path = str(Path(output_stl_path).with_suffix('.png')) + self._save_visualization(vis_path, spine, image_rgb) + print(f" Visualization saved: {vis_path}") + + # 7) Optional: Save landmarks JSON + if save_landmarks: + json_path = str(Path(output_stl_path).with_suffix('.json')) + self._save_landmarks_json(json_path, spine) + print(f" Landmarks saved: {json_path}") + + # Prepare results + results = { + 'input_image': xray_path, + 'output_stl': output_stl_path, + 'model': self.adapter.name, + 'vertebrae_detected': len(spine.vertebrae), + 'cobb_angles': { + 'PT': spine.cobb_pt, + 'MT': spine.cobb_mt, + 'TL': spine.cobb_tl, + }, + 'curve_type': spine.curve_type, + 'rigo_type': rigo_result['rigo_type'], + 'rigo_description': rigo_result['description'], + 'apex_indices': apexes, + 'mesh_vertices': len(brace_mesh.vertices), + 'mesh_faces': len(brace_mesh.faces), + } + + print(f"\n{'=' * 60}") + print(f"Pipeline complete!") + print(f"{'=' * 60}") + + return results + + def _save_visualization(self, path: str, spine: Spine2D, image: np.ndarray): + """Save visualization of detected landmarks and spine curve.""" + try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + except ImportError: + print(" Warning: matplotlib not available for visualization") + return + + fig, axes = plt.subplots(1, 2, figsize=(14, 10)) + + # Left: Original with landmarks + ax1 = axes[0] + ax1.imshow(image) + + # Draw vertebra centers + centroids = spine.get_centroids() + ax1.scatter(centroids[:, 0], centroids[:, 1], c='red', s=30, zorder=5) + + # Draw corners if available + for vert in spine.vertebrae: + if vert.corners_px is not None: + corners = vert.corners_px + # Draw quadrilateral + for i in range(4): + j = (i + 1) % 4 + ax1.plot([corners[i, 0], corners[j, 0]], + [corners[i, 1], corners[j, 1]], 'g-', linewidth=1) + + ax1.set_title(f"Detected Landmarks ({len(spine.vertebrae)} vertebrae)") + ax1.axis('off') + + # Right: Spine curve analysis + ax2 = axes[1] + ax2.imshow(image, alpha=0.5) + + # Draw spine curve + try: + C, T, N, curv = compute_spine_curve(spine) + ax2.plot(C[:, 0], C[:, 1], 'b-', linewidth=2, label='Spine curve') + + # Highlight high curvature regions + high_curv_mask = curv > curv.mean() + curv.std() + ax2.scatter(C[high_curv_mask, 0], C[high_curv_mask, 1], + c='orange', s=20, label='High curvature') + except: + pass + + # Get Rigo classification for display + rigo_result = classify_rigo_type(spine) + + # Add Cobb angles and Rigo type text + text = f"Cobb Angles:\n" + text += f"PT: {spine.cobb_pt:.1f}°\n" + text += f"MT: {spine.cobb_mt:.1f}°\n" + text += f"TL: {spine.cobb_tl:.1f}°\n" + text += f"Curve: {spine.curve_type}\n" + text += f"-----------\n" + text += f"Rigo: {rigo_result['rigo_type']}" + ax2.text(0.02, 0.98, text, transform=ax2.transAxes, fontsize=10, + verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + + ax2.set_title("Spine Analysis") + ax2.axis('off') + ax2.legend(loc='lower right') + + plt.tight_layout() + plt.savefig(path, dpi=150, bbox_inches='tight') + plt.close() + + def _save_landmarks_json(self, path: str, spine: Spine2D): + """Save landmarks to JSON file with Rigo classification.""" + def to_native(val): + """Convert numpy types to native Python types.""" + if isinstance(val, np.ndarray): + return val.tolist() + elif isinstance(val, (np.float32, np.float64)): + return float(val) + elif isinstance(val, (np.int32, np.int64)): + return int(val) + return val + + # Get Rigo classification + rigo_result = classify_rigo_type(spine) + + data = { + 'source_model': spine.source_model, + 'image_shape': list(spine.image_shape) if spine.image_shape else None, + 'pixel_spacing_mm': spine.pixel_spacing_mm.tolist() if spine.pixel_spacing_mm is not None else None, + 'cobb_angles': { + 'PT': to_native(spine.cobb_pt), + 'MT': to_native(spine.cobb_mt), + 'TL': to_native(spine.cobb_tl), + }, + 'curve_type': spine.curve_type, + 'rigo_classification': { + 'type': rigo_result['rigo_type'], + 'description': rigo_result['description'], + 'curve_pattern': rigo_result['curve_pattern'], + 'n_significant_curves': rigo_result['n_significant_curves'], + }, + 'vertebrae': [] + } + + for vert in spine.vertebrae: + vert_data = { + 'level': vert.level, + 'centroid_px': vert.centroid_px.tolist(), + 'orientation_deg': to_native(vert.orientation_deg), + 'confidence': to_native(vert.confidence), + } + if vert.corners_px is not None: + vert_data['corners_px'] = vert.corners_px.tolist() + data['vertebrae'].append(vert_data) + + with open(path, 'w') as f: + json.dump(data, f, indent=2) + + +def main(): + """Command-line interface for brace generation.""" + import argparse + + parser = argparse.ArgumentParser(description='Generate scoliosis brace from X-ray') + parser.add_argument('input', help='Input X-ray image (JPEG, PNG, or DICOM)') + parser.add_argument('output', help='Output STL file path') + parser.add_argument('--model', choices=['scoliovis', 'vertebra-landmark'], + default='scoliovis', help='Landmark detection model') + parser.add_argument('--device', default='cpu', help='Device (cpu or cuda)') + parser.add_argument('--body-scan', help='Path to 3D body scan mesh (optional)') + parser.add_argument('--visualize', action='store_true', help='Save visualization') + parser.add_argument('--save-landmarks', action='store_true', help='Save landmarks JSON') + parser.add_argument('--pressure', type=float, default=15.0, + help='Pressure strength in mm (default: 15)') + parser.add_argument('--thickness', type=float, default=4.0, + help='Wall thickness in mm (default: 4)') + + args = parser.parse_args() + + # Build config + config = BraceConfig( + pressure_strength_mm=args.pressure, + wall_thickness_mm=args.thickness, + ) + + if args.body_scan: + config.use_body_scan = True + config.body_scan_path = args.body_scan + + # Run pipeline + pipeline = BracePipeline(model=args.model, config=config, device=args.device) + results = pipeline.process( + args.input, + args.output, + visualize=args.visualize, + save_landmarks=args.save_landmarks + ) + + return results + + +if __name__ == '__main__': + main() diff --git a/brace-generator/spine_analysis.py b/brace-generator/spine_analysis.py new file mode 100644 index 0000000..7c8d3e1 --- /dev/null +++ b/brace-generator/spine_analysis.py @@ -0,0 +1,464 @@ +""" +Spine analysis functions for computing curves, Cobb angles, and identifying apex vertebrae. +""" +import numpy as np +from scipy.interpolate import splprep, splev +from typing import Tuple, List, Optional + +from data_models import Spine2D, VertebraLandmark + + +def compute_spine_curve( + spine: Spine2D, + smooth: float = 1.0, + n_samples: int = 200 +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Compute smooth spine centerline from vertebra centroids. + + Args: + spine: Spine2D object with detected vertebrae + smooth: Smoothing factor for spline (higher = smoother) + n_samples: Number of points to sample along the curve + + Returns: + C: Curve points, shape (n_samples, 2) + T: Tangent vectors, shape (n_samples, 2) + N: Normal vectors, shape (n_samples, 2) + curvature: Curvature at each point, shape (n_samples,) + """ + pts = spine.get_centroids() + + if len(pts) < 4: + raise ValueError(f"Need at least 4 vertebrae for spline, got {len(pts)}") + + # Fit parametric spline through centroids + x = pts[:, 0] + y = pts[:, 1] + + try: + tck, u = splprep([x, y], s=smooth, k=min(3, len(pts)-1)) + except Exception as e: + # Fallback: simple linear interpolation + t = np.linspace(0, 1, n_samples) + xs = np.interp(t, np.linspace(0, 1, len(x)), x) + ys = np.interp(t, np.linspace(0, 1, len(y)), y) + C = np.stack([xs, ys], axis=1).astype(np.float32) + T = np.gradient(C, axis=0) + T = T / (np.linalg.norm(T, axis=1, keepdims=True) + 1e-8) + N = np.stack([-T[:, 1], T[:, 0]], axis=1) + curvature = np.zeros(n_samples, dtype=np.float32) + return C, T, N, curvature + + # Sample the spline + u_new = np.linspace(0, 1, n_samples) + xs, ys = splev(u_new, tck) + + # First and second derivatives + dx, dy = splev(u_new, tck, der=1) + ddx, ddy = splev(u_new, tck, der=2) + + # Curve points + C = np.stack([xs, ys], axis=1).astype(np.float32) + + # Tangent vectors (normalized) + T = np.stack([dx, dy], axis=1) + T_norm = np.linalg.norm(T, axis=1, keepdims=True) + 1e-8 + T = (T / T_norm).astype(np.float32) + + # Normal vectors (perpendicular to tangent) + N = np.stack([-T[:, 1], T[:, 0]], axis=1).astype(np.float32) + + # Curvature: |x'y'' - y'x''| / (x'^2 + y'^2)^(3/2) + curvature = np.abs(dx * ddy - dy * ddx) / (np.power(dx**2 + dy**2, 1.5) + 1e-8) + curvature = curvature.astype(np.float32) + + return C, T, N, curvature + + +def compute_cobb_angles(spine: Spine2D) -> Tuple[float, float, float]: + """ + Compute Cobb angles from vertebra orientations. + + The Cobb angle is measured as the angle between: + - The superior endplate of the most tilted vertebra at the top of a curve + - The inferior endplate of the most tilted vertebra at the bottom + + This function estimates 3 Cobb angles: + - PT: Proximal Thoracic (T1-T6 region) + - MT: Main Thoracic (T6-T12 region) + - TL: Thoracolumbar/Lumbar (T12-L5 region) + + Args: + spine: Spine2D object with detected vertebrae + + Returns: + (pt_angle, mt_angle, tl_angle) in degrees + """ + orientations = spine.get_orientations() + n = len(orientations) + + if n < 7: + spine.cobb_pt = 0.0 + spine.cobb_mt = 0.0 + spine.cobb_tl = 0.0 + return 0.0, 0.0, 0.0 + + # Divide into regions (approximately) + # PT: top 1/3, MT: middle 1/3, TL: bottom 1/3 + pt_end = n // 3 + mt_end = 2 * n // 3 + + # Find max tilt difference in each region + def region_cobb(start_idx: int, end_idx: int) -> float: + if end_idx <= start_idx: + return 0.0 + region_angles = orientations[start_idx:end_idx] + if len(region_angles) < 2: + return 0.0 + # Cobb angle = max angle - min angle in region + return abs(float(np.max(region_angles) - np.min(region_angles))) + + pt_angle = region_cobb(0, pt_end) + mt_angle = region_cobb(pt_end, mt_end) + tl_angle = region_cobb(mt_end, n) + + # Store in spine object + spine.cobb_pt = pt_angle + spine.cobb_mt = mt_angle + spine.cobb_tl = tl_angle + + # Determine curve type + if mt_angle > 10 and tl_angle > 10: + spine.curve_type = "S" # Double curve + elif mt_angle > 10 or tl_angle > 10: + spine.curve_type = "C" # Single curve + else: + spine.curve_type = "Normal" + + return pt_angle, mt_angle, tl_angle + + +def find_apex_vertebrae(spine: Spine2D) -> List[int]: + """ + Find indices of apex vertebrae (most deviated from midline). + + Args: + spine: Spine2D with computed curve + + Returns: + List of vertebra indices that are curve apexes + """ + centroids = spine.get_centroids() + + if len(centroids) < 5: + return [] + + # Find midline (linear fit through endpoints) + start = centroids[0] + end = centroids[-1] + + # Distance from midline for each vertebra + midline_vec = end - start + midline_len = np.linalg.norm(midline_vec) + + if midline_len < 1e-6: + return [] + + midline_unit = midline_vec / midline_len + + # Calculate perpendicular distance to midline + deviations = [] + for i, pt in enumerate(centroids): + v = pt - start + # Project onto midline + proj_len = np.dot(v, midline_unit) + proj = proj_len * midline_unit + # Perpendicular distance + perp = v - proj + dist = np.linalg.norm(perp) + # Sign: positive if to the right of midline + sign = np.sign(np.cross(midline_unit, v / (np.linalg.norm(v) + 1e-8))) + deviations.append(dist * sign) + + deviations = np.array(deviations) + + # Find local extrema (peaks and valleys) + apexes = [] + for i in range(1, len(deviations) - 1): + # Local maximum + if deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]: + if abs(deviations[i]) > 5: # Minimum deviation threshold (pixels) + apexes.append(i) + # Local minimum + elif deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]: + if abs(deviations[i]) > 5: + apexes.append(i) + + return apexes + + +def get_curve_severity(cobb_angle: float) -> str: + """ + Get clinical severity classification from Cobb angle. + + Args: + cobb_angle: Cobb angle in degrees + + Returns: + Severity string: "Normal", "Mild", "Moderate", or "Severe" + """ + if cobb_angle < 10: + return "Normal" + elif cobb_angle < 25: + return "Mild" + elif cobb_angle < 40: + return "Moderate" + else: + return "Severe" + + +def classify_rigo_type(spine: Spine2D) -> dict: + """ + Classify scoliosis according to Rigo-Chêneau brace classification. + + Rigo Classification Types: + - A1, A2, A3: 3-curve patterns (thoracic major) + - B1, B2: 4-curve patterns (double major) + - C1, C2: Single thoracolumbar/lumbar + - E1, E2: Single thoracic + + Args: + spine: Spine2D object with detected vertebrae and Cobb angles + + Returns: + dict with 'rigo_type', 'description', 'apex_region', 'curve_pattern' + """ + # Get Cobb angles + cobb_angles = spine.get_cobb_angles() + pt = cobb_angles.get('PT', 0) + mt = cobb_angles.get('MT', 0) + tl = cobb_angles.get('TL', 0) + + n_verts = len(spine.vertebrae) + + # Calculate lateral deviations to determine curve direction + centroids = spine.get_centroids() + deviations = _calculate_lateral_deviations(centroids) + + # Find apex positions and directions + apex_info = _find_apex_info(centroids, deviations, n_verts) + + # Determine curve pattern based on number of significant curves + significant_curves = [] + if pt >= 10: + significant_curves.append(('PT', pt)) + if mt >= 10: + significant_curves.append(('MT', mt)) + if tl >= 10: + significant_curves.append(('TL', tl)) + + n_curves = len(significant_curves) + + # Classification logic + rigo_type = "N/A" + description = "" + curve_pattern = "" + + # No significant scoliosis + if n_curves == 0 or max(pt, mt, tl) < 10: + rigo_type = "Normal" + description = "No significant scoliosis (all Cobb angles < 10°)" + curve_pattern = "None" + + # Single curve patterns + elif n_curves == 1: + max_curve = significant_curves[0][0] + max_angle = significant_curves[0][1] + + if max_curve == 'MT' or max_curve == 'PT': + # Thoracic single curve + if apex_info['thoracic_apex_idx'] is not None: + # Check if there's a compensatory lumbar + if tl > 5: + rigo_type = "E2" + description = f"Single thoracic curve ({max_angle:.1f}°) with lumbar compensatory ({tl:.1f}°)" + curve_pattern = "Thoracic with compensation" + else: + rigo_type = "E1" + description = f"True single thoracic curve ({max_angle:.1f}°)" + curve_pattern = "Single thoracic" + else: + rigo_type = "E1" + description = f"Single thoracic curve ({max_angle:.1f}°)" + curve_pattern = "Single thoracic" + + elif max_curve == 'TL': + # Thoracolumbar/Lumbar single curve + if mt > 5 or pt > 5: + rigo_type = "C2" + description = f"Thoracolumbar curve ({tl:.1f}°) with upper compensatory" + curve_pattern = "TL/L with compensation" + else: + rigo_type = "C1" + description = f"Single thoracolumbar/lumbar curve ({tl:.1f}°)" + curve_pattern = "Single TL/L" + + # Double curve patterns + elif n_curves >= 2: + # Determine which curves are primary + thoracic_total = pt + mt + lumbar_total = tl + + # Check curve directions for S vs C pattern + is_s_curve = apex_info['is_s_pattern'] + + if is_s_curve: + # S-curve: typically 3 or 4 curve patterns + if thoracic_total > lumbar_total * 1.5: + # Thoracic dominant - Type A (3-curve) + if apex_info['lumbar_apex_low']: + rigo_type = "A1" + description = f"3-curve: Thoracic major ({mt:.1f}°), lumbar apex low" + elif apex_info['apex_at_tl_junction']: + rigo_type = "A2" + description = f"3-curve: Thoracolumbar transition ({mt:.1f}°/{tl:.1f}°)" + else: + rigo_type = "A3" + description = f"3-curve: Thoracic major ({mt:.1f}°) with structural lumbar ({tl:.1f}°)" + curve_pattern = "3-curve (thoracic major)" + + elif lumbar_total > thoracic_total * 1.5: + # Lumbar dominant + rigo_type = "C2" + description = f"Lumbar major ({tl:.1f}°) with thoracic compensatory ({mt:.1f}°)" + curve_pattern = "Lumbar major" + + else: + # Double major - Type B (4-curve) + if tl >= mt: + rigo_type = "B1" + description = f"4-curve: Double major, lumbar prominent ({tl:.1f}°/{mt:.1f}°)" + else: + rigo_type = "B2" + description = f"4-curve: Double major, thoracic prominent ({mt:.1f}°/{tl:.1f}°)" + curve_pattern = "4-curve (double major)" + else: + # C-curve pattern (curves in same direction) + if mt >= tl: + if tl > 5: + rigo_type = "A3" + description = f"Long thoracic curve ({mt:.1f}°) extending to lumbar ({tl:.1f}°)" + else: + rigo_type = "E2" + description = f"Thoracic curve ({mt:.1f}°) with minor lumbar ({tl:.1f}°)" + curve_pattern = "Extended thoracic" + else: + rigo_type = "C2" + description = f"TL/Lumbar curve ({tl:.1f}°) with thoracic involvement ({mt:.1f}°)" + curve_pattern = "Extended lumbar" + + # Store in spine object + spine.rigo_type = rigo_type + spine.rigo_description = description + + return { + 'rigo_type': rigo_type, + 'description': description, + 'curve_pattern': curve_pattern, + 'apex_info': apex_info, + 'cobb_angles': cobb_angles, + 'n_significant_curves': n_curves + } + + +def _calculate_lateral_deviations(centroids: np.ndarray) -> np.ndarray: + """Calculate lateral deviation from midline for each vertebra.""" + if len(centroids) < 2: + return np.zeros(len(centroids)) + + # Midline from first to last vertebra + start = centroids[0] + end = centroids[-1] + midline_vec = end - start + midline_len = np.linalg.norm(midline_vec) + + if midline_len < 1e-6: + return np.zeros(len(centroids)) + + midline_unit = midline_vec / midline_len + + deviations = [] + for pt in centroids: + v = pt - start + # Project onto midline + proj_len = np.dot(v, midline_unit) + proj = proj_len * midline_unit + # Perpendicular vector + perp = v - proj + dist = np.linalg.norm(perp) + # Sign: positive = right, negative = left + sign = np.sign(np.cross(midline_unit, perp / (dist + 1e-8))) + deviations.append(dist * sign) + + return np.array(deviations) + + +def _find_apex_info(centroids: np.ndarray, deviations: np.ndarray, n_verts: int) -> dict: + """Find apex positions and determine curve pattern.""" + info = { + 'thoracic_apex_idx': None, + 'lumbar_apex_idx': None, + 'lumbar_apex_low': False, + 'apex_at_tl_junction': False, + 'is_s_pattern': False, + 'apex_directions': [] + } + + if len(deviations) < 3: + return info + + # Find local extrema (apexes) + apexes = [] + apex_values = [] + for i in range(1, len(deviations) - 1): + if (deviations[i] > deviations[i-1] and deviations[i] > deviations[i+1]) or \ + (deviations[i] < deviations[i-1] and deviations[i] < deviations[i+1]): + if abs(deviations[i]) > 3: # Minimum threshold + apexes.append(i) + apex_values.append(deviations[i]) + + # Determine S-pattern (alternating signs at apexes) + if len(apex_values) >= 2: + signs = [np.sign(v) for v in apex_values] + # S-pattern if adjacent apexes have opposite signs + for i in range(len(signs) - 1): + if signs[i] != signs[i+1]: + info['is_s_pattern'] = True + break + + # Classify apex regions + # Assume: top 40% = thoracic, middle 20% = TL junction, bottom 40% = lumbar + thoracic_end = int(0.4 * n_verts) + tl_junction_end = int(0.6 * n_verts) + + for apex_idx in apexes: + if apex_idx < thoracic_end: + if info['thoracic_apex_idx'] is None or \ + abs(deviations[apex_idx]) > abs(deviations[info['thoracic_apex_idx']]): + info['thoracic_apex_idx'] = apex_idx + elif apex_idx < tl_junction_end: + info['apex_at_tl_junction'] = True + else: + if info['lumbar_apex_idx'] is None or \ + abs(deviations[apex_idx]) > abs(deviations[info['lumbar_apex_idx']]): + info['lumbar_apex_idx'] = apex_idx + + # Check if lumbar apex is in lower region (bottom 30%) + if info['lumbar_apex_idx'] is not None: + if info['lumbar_apex_idx'] > int(0.7 * n_verts): + info['lumbar_apex_low'] = True + + info['apex_directions'] = apex_values + + return info diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 160619f..1d837b5 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -16,8 +16,9 @@ RUN npm ci # Copy source code COPY . . -# Build the app (uses relative API URLs) +# Build the app (uses relative API URLs for all API calls) ENV VITE_API_URL="" +ENV VITE_API_BASE="" RUN npm run build # Stage 2: Serve with nginx diff --git a/frontend/nginx.conf b/frontend/nginx.conf index 9c8005e..fdde77a 100644 --- a/frontend/nginx.conf +++ b/frontend/nginx.conf @@ -16,8 +16,24 @@ server { # Increase max body size for file uploads client_max_body_size 100M; + # CORS headers - allow all origins + add_header 'Access-Control-Allow-Origin' '*' always; + add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS, PATCH' always; + add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always; + # Proxy API requests to the API container location /api/ { + # Handle preflight OPTIONS requests + if ($request_method = 'OPTIONS') { + add_header 'Access-Control-Allow-Origin' '*'; + add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS, PATCH'; + add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization'; + add_header 'Access-Control-Max-Age' 1728000; + add_header 'Content-Type' 'text/plain; charset=utf-8'; + add_header 'Content-Length' 0; + return 204; + } + proxy_pass http://api:3002/api/; proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; @@ -29,6 +45,11 @@ server { proxy_cache_bypass $http_upgrade; proxy_read_timeout 300s; proxy_connect_timeout 75s; + + # Add CORS headers to response + add_header 'Access-Control-Allow-Origin' '*' always; + add_header 'Access-Control-Allow-Methods' 'GET, POST, PUT, DELETE, OPTIONS, PATCH' always; + add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always; } # Proxy file requests to the API container @@ -39,6 +60,11 @@ server { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_read_timeout 300s; + + # Add CORS headers + add_header 'Access-Control-Allow-Origin' '*' always; + add_header 'Access-Control-Allow-Methods' 'GET, POST, OPTIONS' always; + add_header 'Access-Control-Allow-Headers' 'DNT,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Range,Authorization' always; } # Serve static assets with caching diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 4aface6..db77bd9 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -8,6 +8,11 @@ import CaseDetailPage from "./pages/CaseDetail"; import PipelineCaseDetail from "./pages/PipelineCaseDetail"; import ShellEditorPage from "./pages/ShellEditorPage"; +// Patient pages +import PatientList from "./pages/PatientList"; +import PatientDetail from "./pages/PatientDetail"; +import PatientForm from "./pages/PatientForm"; + // Admin pages import AdminDashboard from "./pages/admin/AdminDashboard"; import AdminUsers from "./pages/admin/AdminUsers"; @@ -110,6 +115,48 @@ function AppRoutes() { } /> + {/* Patient routes */} + + + + + + } + /> + + + + + + } + /> + + + + + + } + /> + + + + + + } + /> + {/* Admin routes */} (path: string, init?: RequestInit): Promise { const url = `${API_BASE}${path}`; @@ -101,6 +101,8 @@ export type AdminCase = { analysis_result: any; landmarks_data: any; body_scan_path: string | null; + is_archived?: boolean; + archived_at?: string | null; created_by: number | null; created_by_username: string | null; created_at: string; @@ -122,6 +124,8 @@ export async function listCasesAdmin(params?: { offset?: number; sortBy?: string; sortOrder?: "ASC" | "DESC"; + includeArchived?: boolean; + archivedOnly?: boolean; }): Promise { const searchParams = new URLSearchParams(); if (params?.status) searchParams.set("status", params.status); @@ -131,11 +135,31 @@ export async function listCasesAdmin(params?: { if (params?.offset) searchParams.set("offset", params.offset.toString()); if (params?.sortBy) searchParams.set("sortBy", params.sortBy); if (params?.sortOrder) searchParams.set("sortOrder", params.sortOrder); + if (params?.includeArchived) searchParams.set("includeArchived", "true"); + if (params?.archivedOnly) searchParams.set("archivedOnly", "true"); const query = searchParams.toString(); return adminFetch(`/admin/cases${query ? `?${query}` : ""}`); } +/** + * Restore (unarchive) a case - admin only + */ +export async function restoreCase(caseId: string): Promise<{ caseId: string; archived: boolean; message: string }> { + return adminFetch(`/cases/${encodeURIComponent(caseId)}/unarchive`, { + method: "POST", + }); +} + +/** + * Restore (unarchive) a patient - admin only + */ +export async function restorePatient(patientId: number): Promise<{ patientId: number; archived: boolean; message: string }> { + return adminFetch(`/patients/${patientId}/unarchive`, { + method: "POST", + }); +} + // ============================================ // ANALYTICS // ============================================ @@ -234,3 +258,145 @@ export async function getAuditLog(params?: { const query = searchParams.toString(); return adminFetch(`/admin/audit-log${query ? `?${query}` : ""}`); } + +// ============================================ +// API REQUEST ACTIVITY LOG +// ============================================ + +export type FileUploadInfo = { + fieldname: string; + originalname: string; + mimetype: string; + size: number; + destination?: string; + filename?: string; +}; + +export type ResponseSummary = { + success?: boolean; + message?: string; + error?: string; + caseId?: string; + status?: string; + userId?: number; + username?: string; + tokenGenerated?: boolean; + rigoType?: string; + cobbAngles?: { PT?: number; MT?: number; TL?: number }; + vertebraeDetected?: number; + bracesGenerated?: { regular: boolean; vase: boolean }; + braceGenerated?: boolean; + braceVertices?: number; + filesGenerated?: string[]; + outputUrl?: string; + landmarksCount?: number | string; + casesCount?: number; + usersCount?: number; + entriesCount?: number; + requestsCount?: number; + total?: number; + bodyScanUploaded?: boolean; + measurementsExtracted?: boolean; + errorCode?: number; +}; + +export type ApiRequestEntry = { + id: number; + user_id: number | null; + username: string | null; + method: string; + path: string; + route_pattern: string | null; + query_params: string | null; + request_params: string | null; + file_uploads: string | null; + status_code: number | null; + response_time_ms: number | null; + response_summary: string | null; + ip_address: string | null; + user_agent: string | null; + request_body_size: number | null; + response_body_size: number | null; + error_message: string | null; + created_at: string; +}; + +export type ApiRequestsResponse = { + requests: ApiRequestEntry[]; + total: number; + limit: number; + offset: number; +}; + +export type ApiActivityStats = { + total: number; + byMethod: Record; + byStatusCategory: Record; + topEndpoints: { + method: string; + path: string; + count: number; + avg_response_time: number; + }[]; + topUsers: { + user_id: number | null; + username: string | null; + count: number; + }[]; + responseTime: { + avg: number; + min: number; + max: number; + }; + requestsPerHour: { + hour: string; + count: number; + }[]; + errorRate: number; +}; + +export async function getApiActivity(params?: { + userId?: number; + username?: string; + method?: string; + path?: string; + statusCode?: number; + statusCategory?: "2xx" | "3xx" | "4xx" | "5xx"; + startDate?: string; + endDate?: string; + limit?: number; + offset?: number; +}): Promise { + const searchParams = new URLSearchParams(); + if (params?.userId) searchParams.set("userId", params.userId.toString()); + if (params?.username) searchParams.set("username", params.username); + if (params?.method) searchParams.set("method", params.method); + if (params?.path) searchParams.set("path", params.path); + if (params?.statusCode) searchParams.set("statusCode", params.statusCode.toString()); + if (params?.statusCategory) searchParams.set("statusCategory", params.statusCategory); + if (params?.startDate) searchParams.set("startDate", params.startDate); + if (params?.endDate) searchParams.set("endDate", params.endDate); + if (params?.limit) searchParams.set("limit", params.limit.toString()); + if (params?.offset) searchParams.set("offset", params.offset.toString()); + + const query = searchParams.toString(); + return adminFetch(`/admin/activity${query ? `?${query}` : ""}`); +} + +export async function getApiActivityStats(params?: { + startDate?: string; + endDate?: string; +}): Promise<{ stats: ApiActivityStats }> { + const searchParams = new URLSearchParams(); + if (params?.startDate) searchParams.set("startDate", params.startDate); + if (params?.endDate) searchParams.set("endDate", params.endDate); + + const query = searchParams.toString(); + return adminFetch(`/admin/activity/stats${query ? `?${query}` : ""}`); +} + +export async function cleanupApiActivityLogs(daysToKeep: number = 30): Promise<{ message: string; deletedCount: number }> { + return adminFetch(`/admin/activity/cleanup?daysToKeep=${daysToKeep}`, { + method: "DELETE", + }); +} diff --git a/frontend/src/api/braceflowApi.ts b/frontend/src/api/braceflowApi.ts index 8dbf144..fd70fba 100644 --- a/frontend/src/api/braceflowApi.ts +++ b/frontend/src/api/braceflowApi.ts @@ -1,5 +1,18 @@ +export type CasePatient = { + id: number; + firstName: string; + lastName: string; + fullName: string; + mrn?: string | null; + dateOfBirth?: string | null; + gender?: string | null; +}; + export type CaseRecord = { caseId: string; + patient_id?: number | null; + patient?: CasePatient | null; + visit_date?: string | null; status: string; current_step: string | null; created_at: string; @@ -356,7 +369,28 @@ export async function analyzeXray( } /** - * Delete a case and all associated files + * Archive a case (soft delete - preserves all files) + */ +export async function archiveCase(caseId: string): Promise<{ caseId: string; archived: boolean; message: string }> { + return await safeFetch<{ caseId: string; archived: boolean; message: string }>( + `/cases/${encodeURIComponent(caseId)}/archive`, + { method: "POST" } + ); +} + +/** + * Unarchive a case (restore) + */ +export async function unarchiveCase(caseId: string): Promise<{ caseId: string; archived: boolean; message: string }> { + return await safeFetch<{ caseId: string; archived: boolean; message: string }>( + `/cases/${encodeURIComponent(caseId)}/unarchive`, + { method: "POST" } + ); +} + +/** + * Delete a case - DEPRECATED: Use archiveCase instead + * This now archives instead of deleting */ export async function deleteCase(caseId: string): Promise<{ message: string }> { return await safeFetch<{ message: string }>( diff --git a/frontend/src/api/patientApi.ts b/frontend/src/api/patientApi.ts new file mode 100644 index 0000000..2542477 --- /dev/null +++ b/frontend/src/api/patientApi.ts @@ -0,0 +1,248 @@ +/** + * Patient API Client + * API functions for patient management + */ + +import { getAuthHeaders } from "../context/AuthContext"; + +const API_BASE = import.meta.env.VITE_API_BASE || "/api"; + +async function patientFetch(path: string, init?: RequestInit): Promise { + const url = `${API_BASE}${path}`; + + const response = await fetch(url, { + ...init, + headers: { + "Content-Type": "application/json", + ...getAuthHeaders(), + ...init?.headers, + }, + }); + + const text = await response.text(); + + if (!response.ok) { + const error = text ? JSON.parse(text) : { message: "Request failed" }; + throw new Error(error.message || `Request failed: ${response.status}`); + } + + return text ? JSON.parse(text) : ({} as T); +} + +// ============================================ +// TYPES +// ============================================ + +export type Patient = { + id: number; + mrn: string | null; + first_name: string; + last_name: string; + date_of_birth: string | null; + gender: "male" | "female" | "other" | null; + email: string | null; + phone: string | null; + address: string | null; + diagnosis: string | null; + curve_type: string | null; + medical_history: string | null; + referring_physician: string | null; + insurance_info: string | null; + notes: string | null; + is_active: number; + created_by: number | null; + created_by_username: string | null; + created_at: string; + updated_at: string; + case_count?: number; + last_visit?: string | null; +}; + +export type PatientCase = { + case_id: string; + case_type: string; + status: string; + current_step: string | null; + visit_date: string | null; + notes: string | null; + analysis_result: any; + landmarks_data: any; + body_scan_path: string | null; + body_scan_url: string | null; + created_at: string; + updated_at: string; +}; + +export type PatientInput = { + mrn?: string; + firstName: string; + lastName: string; + dateOfBirth?: string; + gender?: "male" | "female" | "other"; + email?: string; + phone?: string; + address?: string; + diagnosis?: string; + curveType?: string; + medicalHistory?: string; + referringPhysician?: string; + insuranceInfo?: string; + notes?: string; +}; + +export type PatientListResponse = { + patients: Patient[]; + total: number; + limit: number; + offset: number; +}; + +export type PatientStats = { + total: number; + active: number; + inactive: number; + withCases: number; + byGender: Record; + recentPatients: number; +}; + +// ============================================ +// API FUNCTIONS +// ============================================ + +/** + * Create a new patient + */ +export async function createPatient(data: PatientInput): Promise<{ patient: Patient }> { + return patientFetch("/patients", { + method: "POST", + body: JSON.stringify(data), + }); +} + +/** + * List patients with optional filters + */ +export async function listPatients(params?: { + search?: string; + isActive?: boolean | "all"; + limit?: number; + offset?: number; + sortBy?: string; + sortOrder?: "ASC" | "DESC"; +}): Promise { + const searchParams = new URLSearchParams(); + if (params?.search) searchParams.set("search", params.search); + if (params?.isActive !== undefined) { + searchParams.set("isActive", params.isActive === "all" ? "all" : String(params.isActive)); + } + if (params?.limit) searchParams.set("limit", params.limit.toString()); + if (params?.offset) searchParams.set("offset", params.offset.toString()); + if (params?.sortBy) searchParams.set("sortBy", params.sortBy); + if (params?.sortOrder) searchParams.set("sortOrder", params.sortOrder); + + const query = searchParams.toString(); + return patientFetch(`/patients${query ? `?${query}` : ""}`); +} + +/** + * Get patient by ID with their cases + */ +export async function getPatient(patientId: number): Promise<{ patient: Patient; cases: PatientCase[] }> { + return patientFetch(`/patients/${patientId}`); +} + +/** + * Update patient + */ +export async function updatePatient( + patientId: number, + data: Partial & { isActive?: boolean } +): Promise<{ patient: Patient }> { + return patientFetch(`/patients/${patientId}`, { + method: "PUT", + body: JSON.stringify(data), + }); +} + +/** + * Archive patient (soft delete - preserves all data) + */ +export async function archivePatient(patientId: number): Promise<{ patientId: number; archived: boolean; message: string }> { + return patientFetch(`/patients/${patientId}/archive`, { + method: "POST", + }); +} + +/** + * Unarchive patient (restore) + */ +export async function unarchivePatient(patientId: number): Promise<{ patientId: number; archived: boolean; message: string }> { + return patientFetch(`/patients/${patientId}/unarchive`, { + method: "POST", + }); +} + +/** + * Delete patient - DEPRECATED: Use archivePatient instead + * This now archives instead of deleting + */ +export async function deletePatient(patientId: number, hard = false): Promise<{ message: string }> { + return patientFetch(`/patients/${patientId}${hard ? "?hard=true" : ""}`, { + method: "DELETE", + }); +} + +/** + * Create a case for a patient + */ +export async function createPatientCase( + patientId: number, + data?: { notes?: string; visitDate?: string } +): Promise<{ caseId: string; patientId: number; status: string }> { + return patientFetch(`/patients/${patientId}/cases`, { + method: "POST", + body: JSON.stringify(data || {}), + }); +} + +/** + * Get patient statistics + */ +export async function getPatientStats(): Promise<{ stats: PatientStats }> { + return patientFetch("/patients-stats"); +} + +// ============================================ +// UTILITY FUNCTIONS +// ============================================ + +/** + * Format patient name + */ +export function formatPatientName(patient: Patient | { first_name: string; last_name: string }): string { + return `${patient.first_name} ${patient.last_name}`; +} + +/** + * Calculate patient age from date of birth + */ +export function calculateAge(dateOfBirth: string | null): number | null { + if (!dateOfBirth) return null; + const today = new Date(); + const birth = new Date(dateOfBirth); + let age = today.getFullYear() - birth.getFullYear(); + const monthDiff = today.getMonth() - birth.getMonth(); + if (monthDiff < 0 || (monthDiff === 0 && today.getDate() < birth.getDate())) { + age--; + } + return age; +} + +/** + * Format date for display + */ +export function formatDate(date: string | null): string { + if (!date) return "-"; + return new Date(date).toLocaleDateString(); +} diff --git a/frontend/src/components/AppShell.tsx b/frontend/src/components/AppShell.tsx index 632663a..f7e712e 100644 --- a/frontend/src/components/AppShell.tsx +++ b/frontend/src/components/AppShell.tsx @@ -34,6 +34,7 @@ export function AppShell({ children }: { children: React.ReactNode }) { const [shouldFadeIn, setShouldFadeIn] = useState(false); const prevPathRef = useRef(location.pathname); + const isPatients = location.pathname === "/patients" || location.pathname.startsWith("/patients/"); const isCases = location.pathname === "/cases" || location.pathname.startsWith("/cases/"); const isEditShell = location.pathname.startsWith("/editor"); const isAdmin = location.pathname.startsWith("/admin"); @@ -72,6 +73,7 @@ export function AppShell({ children }: { children: React.ReactNode }) {