Skip to main content
Made with by me

Mechanistic Interpretability for Clinical JEPAs

Research (ongoing)

Introducing MI techniques to Joint Embedding Predictive Architectures trained on clinical encounter sequences. Targeting ML4H 2026.

Mechanistic InterpretabilityJEPAPyTorch
Diagram of the shared-encoder stop-gradient JEPA architecture, with VICReg loss components annotated.

Independent research with input from an advisor. Trains a shared-encoder stop-gradient JEPA on ~55k MIMIC-IV patients, then applies a four-layer interpretability pipeline to the F-code psychiatric subset. Architecture, data extraction, and analysis scaffolding are built; broad-cohort training runs and full results are still ahead.

Why JEPAs are worth opening up

Most mech-interp work on transformers operates on autoregressive models: residual-stream activations decomposed against a token vocabulary, with prediction sitting in logit space. A JEPA predicts in continuous latent space. The prediction and target live in the same embedding, so the error is a geometrically meaningful vector, not a scalar loss. Instead of arguing about which logits got pushed up, the question becomes: which dimensions does the predictor systematically miss, do those align with what the encoder represents elsewhere, is the error structured or isotropic.

Why stop-gradient, not EMA

With an EMA target encoder, comes from a slightly different function than . Any analysis of then conflates predictor failure with encoder drift - a confound that varies across training and can't be factored out post-hoc. For anything that treats the error vector as a first-class object (SAE features on , probe magnitude correlation with outcomes), this is fatal.

Shared-encoder stop-gradient kills it: both paths use identical weights, and the target path wraps its forward pass in torch.no_grad() and .detach(). VICReg with weights (invariance / variance / covariance) handles collapse explicitly. Both architectures live in the repo; stop-grad is the one downstream interpretability depends on.

Analysis framework

Four analytical objects (, , , ), five layers of analysis applied to the F-code psychiatric subset:

  • characterization. Sparse autoencoders, linear probes, UMAP/HDBSCAN clustering. What does the encoder learn about individual encounters?
  • Trajectory geometry. Velocity, curvature, drift toward concept centroids in space. The bet is that trajectory shape carries predictive signal beyond the last context vector alone.
  • vs . CKA, PCA subspace alignment, matched SAE features. Does the predictor use the same vocabulary as the encoder?
  • decomposition. PCA against a Marchenko-Pastur null, SAE on errors, cross-reference against SAE features. Separates systematic gaps from residual noise.
  • Partial labeling bridge. Tier A (LASSO R² on PC axes), Tier B (cluster enrichment), Tier C (SAE feature inspection). The unexplained residual at each tier is reported as a finding, not noise.

Status

  • Built. MIMIC-IV extraction (~55k patients, no outcome-conditioned filtering), stop-grad and EMA JEPAs with temporal encoding and per-encounter [CLS] tokens, post-hoc label generation, baseline scaffolding for a supervised transformer, XGBoost, and logistic regression.
  • Lesson. A 0.985 AUC on the diagnostic cohort turned out to be a data leak: the cohort had been pre-filtered on outcome. Later, under causal prefix masking, patient-level labels were reading future encounters and got replaced with per-encounter versions (label_30d_per_enc[k], label_escalation_per_enc[k]). Both fixes landed before any reported result.
  • Next. Three-seed training on the broad cohort, baseline comparisons, the full four-layer analysis on the F-code subset.
Mechanistic Interpretability for Clinical JEPAs · Tanner O'Rourke