mstar.model.vjepa2.components.predictor#
V-JEPA 2 masked latent predictor.
Ports VJEPA2PredictorEmbeddings and VJEPA2Predictor from HuggingFace
transformers/models/vjepa2/modeling_vjepa2.py.
Takes encoder hidden states + context_mask (positions the predictor sees)
+ target_mask (positions to predict) and emits predicted embeddings at
the target positions. NOT autoregressive — a single forward predicts every
target token in parallel.
- Weight layout (prefix
predictor.): predictor.embeddings.predictor_embeddings.{weight,bias} predictor.embeddings.mask_tokens predictor.layer.{N}.* predictor.layernorm.{weight,bias} predictor.proj.{weight,bias}
Classes
|
|
|
Project encoder hidden states into predictor space and concatenate learned mask tokens at the target positions. |
- class mstar.model.vjepa2.components.predictor.VJEPA2Predictor(config)[source]#
Bases:
Module- Parameters:
config (VJepa2Config)
- forward(encoder_hidden_states, context_mask, target_mask)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- make_layer_loop_fn(static_cm, static_pos_bufs)[source]#
Return a closure over the layer loop for PiecewiseCudaGraphRunner capture.
static_pos_bufs["position_mask"]must already be filled with the static[n_seq]position IDs for this rollout config before this method is called (the fn_factory in get_piecewise_runner_config does this via.copy_()). At replay, the runner never updates this buffer (callers passpos_bufs=None), so the captured ops always see the same position IDs.The
unsqueeze(0)insidefnis a zero-copy view, CUDA-graph compatible.VJEPA2RopeAttention.get_position_idsbroadcasts[1, H, N]IDs against[B, H, N, D]Q/K tensors correctly.
- class mstar.model.vjepa2.components.predictor.VJEPA2PredictorEmbeddings(config)[source]#
Bases:
ModuleProject encoder hidden states into predictor space and concatenate learned mask tokens at the target positions.
- Parameters:
config (VJepa2Config)
- forward(hidden_states, context_mask, target_mask, mask_index=1)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.