Mechanistic Interpretability for Clinical JEPAs
Research (ongoing)Introducing MI techniques to Joint Embedding Predictive Architectures trained on clinical encounter sequences. Targeting ML4H 2026.

Independent research with input from an advisor. Trains a shared-encoder stop-gradient JEPA on ~55k MIMIC-IV patients, then applies a four-layer interpretability pipeline to the F-code psychiatric subset. Architecture, data extraction, and analysis scaffolding are built; broad-cohort training runs and full results are still ahead.
Why JEPAs are worth opening up
Most mech-interp work on transformers operates on autoregressive models: residual-stream activations decomposed against a token vocabulary, with prediction sitting in logit space. A JEPA predicts in continuous latent space. The prediction and target live in the same embedding, so the error is a geometrically meaningful vector, not a scalar loss. Instead of arguing about which logits got pushed up, the question becomes: which dimensions does the predictor systematically miss, do those align with what the encoder represents elsewhere, is the error structured or isotropic.
Why stop-gradient, not EMA
With an EMA target encoder, comes from a slightly different function than . Any analysis of then conflates predictor failure with encoder drift - a confound that varies across training and can't be factored out post-hoc. For anything that treats the error vector as a first-class object (SAE features on , probe magnitude correlation with outcomes), this is fatal.
Shared-encoder stop-gradient kills it: both paths use identical weights, and the target path wraps its forward pass in torch.no_grad() and .detach(). VICReg with weights (invariance / variance / covariance) handles collapse explicitly. Both architectures live in the repo; stop-grad is the one downstream interpretability depends on.
Analysis framework
Four analytical objects (, , , ), five layers of analysis applied to the F-code psychiatric subset:
- characterization. Sparse autoencoders, linear probes, UMAP/HDBSCAN clustering. What does the encoder learn about individual encounters?
- Trajectory geometry. Velocity, curvature, drift toward concept centroids in space. The bet is that trajectory shape carries predictive signal beyond the last context vector alone.
- vs . CKA, PCA subspace alignment, matched SAE features. Does the predictor use the same vocabulary as the encoder?
- decomposition. PCA against a Marchenko-Pastur null, SAE on errors, cross-reference against SAE features. Separates systematic gaps from residual noise.
- Partial labeling bridge. Tier A (LASSO R² on PC axes), Tier B (cluster enrichment), Tier C (SAE feature inspection). The unexplained residual at each tier is reported as a finding, not noise.
Status
- Built. MIMIC-IV extraction (~55k patients, no outcome-conditioned filtering), stop-grad and EMA JEPAs with temporal encoding and per-encounter
[CLS]tokens, post-hoc label generation, baseline scaffolding for a supervised transformer, XGBoost, and logistic regression. - Lesson. A 0.985 AUC on the diagnostic cohort turned out to be a data leak: the cohort had been pre-filtered on outcome. Later, under causal prefix masking, patient-level labels were reading future encounters and got replaced with per-encounter versions (
label_30d_per_enc[k],label_escalation_per_enc[k]). Both fixes landed before any reported result. - Next. Three-seed training on the broad cohort, baseline comparisons, the full four-layer analysis on the F-code subset.