mstar.model.vjepa2.components.predictor#

V-JEPA 2 masked latent predictor.

Ports VJEPA2PredictorEmbeddings and VJEPA2Predictor from HuggingFace transformers/models/vjepa2/modeling_vjepa2.py.

Takes encoder hidden states + context_mask (positions the predictor sees) + target_mask (positions to predict) and emits predicted embeddings at the target positions. NOT autoregressive — a single forward predicts every target token in parallel.

Weight layout (prefix predictor.):

predictor.embeddings.predictor_embeddings.{weight,bias} predictor.embeddings.mask_tokens predictor.layer.{N}.* predictor.layernorm.{weight,bias} predictor.proj.{weight,bias}

Classes

VJEPA2Predictor(config)

VJEPA2PredictorEmbeddings(config)

Project encoder hidden states into predictor space and concatenate learned mask tokens at the target positions.

class mstar.model.vjepa2.components.predictor.VJEPA2Predictor(config)[source]#

Bases: Module

Parameters:

config (VJepa2Config)

forward(encoder_hidden_states, context_mask, target_mask)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

make_layer_loop_fn(static_cm, static_pos_bufs)[source]#

Return a closure over the layer loop for PiecewiseCudaGraphRunner capture.

static_pos_bufs["position_mask"] must already be filled with the static [n_seq] position IDs for this rollout config before this method is called (the fn_factory in get_piecewise_runner_config does this via .copy_()). At replay, the runner never updates this buffer (callers pass pos_bufs=None), so the captured ops always see the same position IDs.

The unsqueeze(0) inside fn is a zero-copy view, CUDA-graph compatible. VJEPA2RopeAttention.get_position_ids broadcasts [1, H, N] IDs against [B, H, N, D] Q/K tensors correctly.

Parameters:

static_pos_bufs (dict[str, Tensor])

Return type:

Callable[[Tensor], Tensor]

class mstar.model.vjepa2.components.predictor.VJEPA2PredictorEmbeddings(config)[source]#

Bases: Module

Project encoder hidden states into predictor space and concatenate learned mask tokens at the target positions.

Parameters:

config (VJepa2Config)

forward(hidden_states, context_mask, target_mask, mask_index=1)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

tuple[Tensor, Tensor]