mstar.model.vjepa2.config

Contents

mstar.model.vjepa2.config#

V-JEPA 2 configuration dataclasses.

Flat config matching HuggingFace VJEPA2Config plus an optional action-conditioned predictor block. A single VJepa2Config.from_hf_config(dict) handles every open V-JEPA 2 checkpoint (vitl/h/g at 256, vitg at 384, and the AC variant).

Classes

VJepa2ACPredictorConfig([img_size, ...])

Configuration for the action-conditioned predictor (upstream V-JEPA 2-AC).

VJepa2Config([patch_size, crop_size, ...])

Top-level V-JEPA 2 config (shared across the HF and upstream ports).

class mstar.model.vjepa2.config.VJepa2ACPredictorConfig(img_size=(256, 256), patch_size=16, num_frames=64, tubelet_size=2, embed_dim=1408, predictor_embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, layer_norm_eps=1e-06, is_frame_causal=True, use_rope=True, action_embed_dim=7, use_extrinsics=False)[source]#

Bases: object

Configuration for the action-conditioned predictor (upstream V-JEPA 2-AC).

Mirrors vjepa2/src/models/ac_predictor.py defaults.

Parameters:
action_embed_dim: int = 7#
attn_drop_rate: float = 0.0#
depth: int = 24#
drop_path_rate: float = 0.0#
drop_rate: float = 0.0#
embed_dim: int = 1408#
img_size: tuple[int, int] = (256, 256)#
is_frame_causal: bool = True#
layer_norm_eps: float = 1e-06#
mlp_ratio: float = 4.0#
num_frames: int = 64#
num_heads: int = 16#
patch_size: int = 16#
predictor_embed_dim: int = 1024#
qkv_bias: bool = True#
tubelet_size: int = 2#
use_extrinsics: bool = False#
use_rope: bool = True#
class mstar.model.vjepa2.config.VJepa2Config(patch_size=16, crop_size=256, frames_per_clip=64, tubelet_size=2, hidden_size=1024, in_chans=3, num_attention_heads=16, num_hidden_layers=24, drop_path_rate=0.0, mlp_ratio=4.0, layer_norm_eps=1e-06, qkv_bias=True, attention_probs_dropout_prob=0.0, hidden_act='gelu', pred_hidden_size=384, pred_num_attention_heads=12, pred_num_hidden_layers=12, pred_num_mask_tokens=10, pred_zero_init_mask_tokens=True, pred_mlp_ratio=4.0, predictor_kind='masked', ac_predictor=None, max_rollout_horizon=16, rollout_num_output_frames=2, rollout_frames_per_second=4, rollout_anticipation_seconds=1.0, mpc_cost_fn='l1')[source]#

Bases: object

Top-level V-JEPA 2 config (shared across the HF and upstream ports).

Field names match HF VJEPA2Config so a plain JSON round-trip works.

Parameters:
  • patch_size (int)

  • crop_size (int)

  • frames_per_clip (int)

  • tubelet_size (int)

  • hidden_size (int)

  • in_chans (int)

  • num_attention_heads (int)

  • num_hidden_layers (int)

  • drop_path_rate (float)

  • mlp_ratio (float)

  • layer_norm_eps (float)

  • qkv_bias (bool)

  • attention_probs_dropout_prob (float)

  • hidden_act (str)

  • pred_hidden_size (int)

  • pred_num_attention_heads (int)

  • pred_num_hidden_layers (int)

  • pred_num_mask_tokens (int)

  • pred_zero_init_mask_tokens (bool)

  • pred_mlp_ratio (float)

  • predictor_kind (str)

  • ac_predictor (VJepa2ACPredictorConfig | None)

  • max_rollout_horizon (int)

  • rollout_num_output_frames (int)

  • rollout_frames_per_second (int)

  • rollout_anticipation_seconds (float)

  • mpc_cost_fn (str)

ac_predictor: VJepa2ACPredictorConfig | None = None#
attention_probs_dropout_prob: float = 0.0#
crop_size: int = 256#
drop_path_rate: float = 0.0#
frames_per_clip: int = 64#
classmethod from_hf_config(config_dict)[source]#
Parameters:

config_dict (dict[str, Any])

Return type:

VJepa2Config

property grid_depth: int#
property grid_size: int#
hidden_act: str = 'gelu'#
hidden_size: int = 1024#
in_chans: int = 3#
layer_norm_eps: float = 1e-06#
max_rollout_horizon: int = 16#
mlp_ratio: float = 4.0#
mpc_cost_fn: str = 'l1'#
num_attention_heads: int = 16#
num_hidden_layers: int = 24#
property num_patches: int#
patch_size: int = 16#
pred_hidden_size: int = 384#
pred_mlp_ratio: float = 4.0#
pred_num_attention_heads: int = 12#
pred_num_hidden_layers: int = 12#
pred_num_mask_tokens: int = 10#
pred_zero_init_mask_tokens: bool = True#
predictor_kind: str = 'masked'#
qkv_bias: bool = True#
rollout_anticipation_seconds: float = 1.0#
rollout_frames_per_second: int = 4#
rollout_num_output_frames: int = 2#
tubelet_size: int = 2#