"""Build a constant `tpose_seed.npy` for the JS host.

The denoiser expects a `seed` input of shape (B, pre_frames=4, 384) — at
inference it's `latent[:, :pre_frames]` where `latent` is the concatenation
of the three RVQVAE encoded latents (spine + arms + legs), unscaled.

Per the original training data flow (scripts/infer_vrma.py:285–290), the seed
is normally computed from the conditioning clip's first 4 latent frames. For a
deterministic JS-side replay we bake a single representative seed: the latent
of `data/test/10_kieks_0_103_103.npz` (used as the realistic-latent reference
throughout P1).

Output: lipsync-wasm/v3/models/tpose_seed.npy  shape=(1, 4, 384) float32
"""
import os
import sys
from pathlib import Path

import numpy as np
import torch

REPO = Path(__file__).resolve().parents[3]
GESTURE = REPO / "motion" / "GestureVRM"
sys.path.insert(0, str(GESTURE))

from dataloaders.build_vocab import Vocab  # noqa: E402
sys.modules['__main__'].Vocab = Vocab

from models.vq.model import RVQVAE  # noqa: E402
from dataloaders.vrm_dataset import (  # noqa: E402
    quaternion_to_rotation_6d, SPINE_HEAD_INDICES, ARMS_INDICES, LEGS_INDICES,
)


class _VqArgs:
    num_quantizers = 6
    shared_codebook = False
    quantize_dropout_prob = 0.2
    quantize_dropout_cutoff_index = 0
    mu = 0.99
    beta = 1.0


def build_vqvae(dim: int) -> RVQVAE:
    return RVQVAE(
        _VqArgs(),
        input_width=dim, nb_code=1024, code_dim=128,
        output_emb_width=128, down_t=2, stride_t=2, width=512,
        depth=3, dilation_growth_rate=3, activation='relu', norm=None,
    )


def load(rvq: RVQVAE, ckpt_path: Path):
    state = torch.load(ckpt_path, map_location='cpu', weights_only=False)['net']
    rvq.load_state_dict({k.replace('module.', ''): v for k, v in state.items()}, strict=True)


def main():
    device = torch.device('cpu')
    npz = GESTURE / "data" / "test" / "10_kieks_0_103_103.npz"
    mean_std = GESTURE / "mean_std"
    ckpt_dir = GESTURE / "ckpt"

    # Pre-encoded clip → first 4 latent frames as seed
    PRE_FRAMES = 4
    T_pose = 128  # one full window; we only keep [:PRE_FRAMES]

    data = np.load(npz, allow_pickle=True)
    rot_6d = quaternion_to_rotation_6d(data['vrm_rotations'][:T_pose])  # (T,20,6)
    spine = rot_6d[:, SPINE_HEAD_INDICES].reshape(T_pose, -1)            # 36
    arms = rot_6d[:, ARMS_INDICES].reshape(T_pose, -1)                   # 48
    legs = rot_6d[:, LEGS_INDICES].reshape(T_pose, -1)                   # 36
    trans = data['trans'][:T_pose]
    trans_v = np.zeros_like(trans); trans_v[1:] = trans[1:] - trans[:-1]
    legs_t = np.concatenate([legs, trans_v], axis=-1)                    # 39

    out_per_part = []
    for part, feat, dim in [('spine', spine, 36), ('arms', arms, 48), ('legs', legs_t, 39)]:
        mean = np.load(mean_std / f'vrm_{part}_mean.npy')
        std = np.load(mean_std / f'vrm_{part}_std.npy')
        feat_norm = (feat - mean) / (std + 1e-8)
        rvq = build_vqvae(dim).eval()
        load(rvq, ckpt_dir / f'best_{part}.pth')
        with torch.no_grad():
            pose = torch.from_numpy(feat_norm).float().unsqueeze(0)        # (1,T,D)
            latent = rvq.map2latent(pose)                                  # (1, T_latent=32, 128)
        out_per_part.append(latent.numpy())

    # Concat along feature dim then divide by vqvae_latent_scale=5 to match
    # infer_vrma.py:290. (T_latent=32 here; trim to pre_frames=4.)
    full = np.concatenate(out_per_part, axis=2) / 5.0                      # (1, 32, 384)
    seed = full[:, :PRE_FRAMES, :].astype(np.float32)                     # (1, 4, 384)

    out_path = REPO / "lipsync-wasm" / "v3" / "models" / "tpose_seed.npy"
    np.save(out_path, seed)
    print(f"saved {out_path}  shape={seed.shape}  dtype={seed.dtype}  "
          f"mean={seed.mean():.4f}  std={seed.std():.4f}")


if __name__ == '__main__':
    main()
