From 3d14964f018b0c3d5e10c2fc62a0054950322ca9 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 23 Jun 2026 16:45:00 +0000 Subject: [PATCH 1/2] Init commit for transfer capability to Cosmos3 pipeline --- docs/source/en/api/pipelines/cosmos3.md | 105 +++- examples/cosmos3/inference_cosmos3.py | 116 +++- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 515 ++++++++++++++++-- 3 files changed, 698 insertions(+), 38 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 1ac8f36457a4..221922fcd4f0 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -32,7 +32,7 @@ From one model you can: - Generate physically plausible video worlds from text, images, or action inputs (image-to-video, text-to-video, action-conditioned video generation). - Reason about physical properties like motion, causality, and spatial relationships. - Predict future video and action sequences from the current state. -- Transfer scenes across viewpoints and conditions with structural control *(coming soon)*. +- Transfer scenes across viewpoints and conditions with structural control (edge, blur, depth, segmentation, world-scenario maps). Under the hood, a single `Cosmos3OmniTransformer` runs a Qwen-style language model in parallel with a diffusion generation pathway: text tokens flow through a causal "understanding" stream while video and sound latents flow through a bi-directionally-attended "generation" stream, joined by a 3D multimodal RoPE. See the [Cosmos World Foundation Model Platform paper](https://huggingface.co/papers/2501.03575) for the architectural background. @@ -371,6 +371,109 @@ export_to_video(result.video, "cosmos3_v2v.mp4", fps=24, macro_block_size=1) +## Transfer (structural control) + +Transfer generates a target clip that follows a **precomputed control video** (a spatial control signal): edge (Canny), blur, depth, segmentation, or a world-scenario map (WSM). Pass it through `control_videos=` as a mapping from hint name to a loaded video. The control map is resized, temporally padded, normalized, and VAE-encoded into a clean conditioning item placed before the noisy target; the model then generates the target to match it. Transfer is video-only (no `image`, `video`, `action`, or `enable_sound`), and the prompt is a pre-upsampled JSON caption (see [Prompt upsampling](#prompt-upsampling)). + +Diffusers does not ship the control assets. Ready-made ones (a control video + matching `prompt.json` per hint, plus a shared `negative_prompt.json`) live in the [Cosmos cookbook](https://github.com/NVIDIA/cosmos/tree/main/cookbooks/cosmos3/generator/transfer/assets). For the edge example below, download them into a local `assets/` folder: + +```bash +base=https://github.com/NVIDIA/cosmos/raw/refs/heads/main/cookbooks/cosmos3/generator/transfer/assets +mkdir -p assets/edge +curl -sL "$base/edge/control_edge.mp4" -o assets/edge/control_edge.mp4 +curl -sL "$base/edge/prompt.json" -o assets/edge/prompt.json +curl -sL "$base/negative_prompt.json" -o assets/negative_prompt.json +``` + +Guidance uses a nested control/text classifier-free-guidance blend. `guidance_scale` is the usual text CFG; `control_guidance` (`!= 1.0`) additionally amplifies the control signal. Recommended starting values per hint (matching the Cosmos Framework defaults): + +| Hint | `guidance_scale` | `control_guidance` | `flow_shift` | Geometry | +| --- | --- | --- | --- | --- | +| Edge / Blur / Depth | 3.0 | 1.5 | 10.0 | 121 frames @ 30 FPS | +| Segmentation | 3.0 | 2.0 | 10.0 | 121 frames @ 30 FPS | +| World scenario (WSM) | 1.0 | 3.0 | 10.0 | 101 frames @ 10 FPS | + +Depth, segmentation, and WSM control maps must be precomputed by external models; edge/blur maps can be produced offline with any Canny/blur tool. The shipped cookbook configs use a single hint each; passing several entries in `control_videos` to combine hints is supported by the pipeline but is not a tuned/validated cookbook path (set `guidance_scale` / `control_guidance` explicitly, since the per-hint defaults above assume a single hint). Long clips are generated autoregressively in chunks of `num_video_frames_per_chunk` and stitched automatically. + + + + +```python +import json +import torch +from diffusers import Cosmos3OmniPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_video + +# Downloaded into assets/ from the Cosmos cookbook (see the curl snippet above). +json_prompt = json.load(open("assets/edge/prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) +control_edge = load_video("assets/edge/control_edge.mp4") + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" +) +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False +) + +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + control_videos={"edge": control_edge}, + num_frames=121, + height=720, + width=1280, + fps=30.0, + num_inference_steps=35, + guidance_scale=3.0, + control_guidance=1.5, +) +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "cosmos3_transfer_edge.mp4", fps=30, macro_block_size=1) +``` + + + + +```python +import json +import torch +from diffusers import Cosmos3OmniPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_video + +# Downloaded into assets/ from the Cosmos cookbook (see the curl snippet above). +json_prompt = json.load(open("assets/edge/prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) +control_edge = load_video("assets/edge/control_edge.mp4") + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" +) +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False +) + +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + control_videos={"edge": control_edge}, + num_frames=121, + height=720, + width=1280, + fps=30.0, + num_inference_steps=35, + guidance_scale=3.0, + control_guidance=1.5, +) +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "cosmos3_transfer_edge.mp4", fps=30, macro_block_size=1) +``` + + + + ## Video-to-video with sound When the checkpoint carries a `sound_tokenizer`, add `enable_sound=True` to the video-to-video call to jointly generate a synchronized audio track. The waveform is returned alongside the video and can be muxed into the MP4 with [`~utils.encode_video`]. diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index 62388c8d1288..16014dabaaec 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -21,6 +21,13 @@ Video-to-video: python inference_cosmos3.py --prompt "..." --video-path /path/to/video.mp4 +Transfer (ready-made control_*.mp4 + prompt.json are hosted in the Cosmos cookbook; --control-path / --prompt +accept URLs or local paths: https://github.com/NVIDIA/cosmos/tree/main/cookbooks/cosmos3/generator/transfer/assets): + base=https://github.com/NVIDIA/cosmos/raw/refs/heads/main/cookbooks/cosmos3/generator/transfer/assets + python inference_cosmos3.py --prompt "$(curl -sL $base/edge/prompt.json)" \ + --transfer-hint edge --control-path $base/edge/control_edge.mp4 \ + --guidance-scale 3.0 --control-guidance 1.5 --flow-shift 10.0 --num-frames 121 --fps 30 + Text-to-video-with-sound (requires a sound-capable checkpoint): python inference_cosmos3.py --prompt "..." --enable-sound """ @@ -62,6 +69,11 @@ def _load_action(path: str | None): def main(): parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("--prompt", required=True, help="Text prompt.") + parser.add_argument( + "--negative-prompt", + default=None, + help="Optional negative prompt text.", + ) parser.add_argument( "--model", choices=sorted(HF_REPOS), @@ -89,6 +101,60 @@ def main(): default="first", help="Take the video-to-video conditioning frames from the first or last of the source clip (default: first).", ) + parser.add_argument( + "--transfer-hint", + action="append", + choices=["edge", "blur", "depth", "seg", "wsm"], + default=None, + help="Enable transfer with a control hint. Repeat (paired with --control-path) to combine multiple hints.", + ) + parser.add_argument( + "--control-path", + action="append", + default=None, + help="URL or local path to a precomputed control video, paired in order with each --transfer-hint.", + ) + parser.add_argument( + "--control-guidance", + type=float, + default=1.0, + help="Transfer control-CFG scale (recommended 1.5 for edge/blur/depth, 2.0 for seg, 3.0 for wsm).", + ) + parser.add_argument( + "--control-guidance-interval", + default=None, + help="Comma-separated [lo,hi] timestep window for control guidance (default: applied at every step).", + ) + parser.add_argument( + "--guidance-interval", + default=None, + help="Comma-separated [lo,hi] timestep window for text guidance in transfer (default: every step).", + ) + parser.add_argument( + "--num-conditional-frames", + type=int, + default=1, + help="Frames carried over from the previous chunk as conditioning (transfer multi-chunk).", + ) + parser.add_argument( + "--num-first-chunk-conditional-frames", + type=int, + default=0, + help="Leading frames of --video-path used to condition the first transfer chunk (requires --video-path).", + ) + parser.add_argument( + "--num-video-frames-per-chunk", + type=int, + default=None, + help="Max frames generated per autoregressive transfer chunk (default: whole clip in one chunk).", + ) + parser.add_argument( + "--no-share-vision-temporal-positions", + dest="share_vision_temporal_positions", + action="store_false", + default=True, + help="Give control maps and the target distinct temporal mRoPE positions instead of sharing them (transfer).", + ) parser.add_argument("--output", default=".", help="Directory to save generated video/image/audio files.") parser.add_argument( "--height", @@ -198,7 +264,52 @@ def main(): output_dir.mkdir(parents=True, exist_ok=True) generator = torch.Generator().manual_seed(args.seed) if args.seed is not None else None - if args.action_mode is not None: + def _parse_interval(value): + if value is None: + return None + parts = [float(v) for v in value.split(",") if v.strip()] + if len(parts) != 2: + raise ValueError(f"Expected a comma-separated [lo,hi] interval, got {value!r}.") + return (parts[0], parts[1]) + + if args.transfer_hint is not None: + control_paths = args.control_path or [] + if len(control_paths) != len(args.transfer_hint): + raise ValueError("Pass one --control-path per --transfer-hint, in matching order.") + control_videos = {hint: load_video(path) for hint, path in zip(args.transfer_hint, control_paths)} + # `--video-path` is an OPTIONAL RGB prefix that only seeds the first chunk, and is consulted solely when + # --num-first-chunk-conditional-frames > 0. It is unrelated to the control hints (which always drive transfer). + conditioning_video = None + if args.num_first_chunk_conditional_frames > 0: + if args.video_path is None: + raise ValueError( + "--num-first-chunk-conditional-frames > 0 requires --video-path (an RGB prefix clip)." + ) + conditioning_video = load_video(args.video_path) + elif args.video_path is not None: + print("Ignoring --video-path: it only applies when --num-first-chunk-conditional-frames > 0.") + result = pipeline( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + control_videos=control_videos, + video=conditioning_video, + num_frames=args.num_frames if args.num_frames != 189 else None, + height=args.height, + width=args.width, + fps=args.fps, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + control_guidance=args.control_guidance, + control_guidance_interval=_parse_interval(args.control_guidance_interval), + guidance_interval=_parse_interval(args.guidance_interval), + num_conditional_frames=args.num_conditional_frames, + num_first_chunk_conditional_frames=args.num_first_chunk_conditional_frames, + num_video_frames_per_chunk=args.num_video_frames_per_chunk, + share_vision_temporal_positions=args.share_vision_temporal_positions, + generator=generator, + enable_safety_check=not args.no_safety_check, + ) + elif args.action_mode is not None: if args.vision_path is None: raise ValueError("--vision-path must point to a conditioning video for action modes.") if args.action_chunk_size is None: @@ -207,6 +318,7 @@ def main(): raw_actions = _load_action(args.action_path) if args.action_mode == "forward_dynamics" else None result = pipeline( prompt=args.prompt, + negative_prompt=args.negative_prompt, action=CosmosActionCondition( mode=args.action_mode, chunk_size=args.action_chunk_size, @@ -234,6 +346,7 @@ def main(): ) result = pipeline( prompt=args.prompt, + negative_prompt=args.negative_prompt, video=video, condition_frame_indexes_vision=condition_frame_indexes_vision, condition_video_keep=args.condition_video_keep, @@ -253,6 +366,7 @@ def main(): image = load_image(args.vision_path) if args.vision_path is not None else None result = pipeline( prompt=args.prompt, + negative_prompt=args.negative_prompt, image=image, num_frames=args.num_frames, height=args.height, diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 538b553d478d..bc6a0456eddb 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -135,6 +135,10 @@ def get_3d_mrope_ids_vae_tokens( _SYSTEM_PROMPT_IMAGE = "You are a helpful assistant who will generate images from a give prompt." _SYSTEM_PROMPT_VIDEO = "You are a helpful assistant who will generate videos from a give prompt." +_SYSTEM_PROMPT_TRANSFER = ( + "You are a helpful assistant that generates images or videos following the user's instructions" + " and control signals (edge maps, blur, depth, or segmentation)." +) _ACTION_RESOLUTION_BINS = { "256": { @@ -502,62 +506,109 @@ def _prepare_text_segment( def _prepare_vision_segment( self, - input_vision_tokens: torch.Tensor, + input_vision_tokens: torch.Tensor | list[torch.Tensor], has_image_condition: bool, mrope_offset: int | float, vision_fps: float | None, curr: int, device: torch.device | str, - condition_frame_indexes: list[int] | None = None, + condition_frame_indexes: list[int] | list[list[int] | None] | None = None, + clean_item_flags: list[bool] | None = None, + share_vision_temporal_positions: bool = False, ) -> dict[str, Any]: """Build the static portion of the vision segment of the joint sequence. Step-varying fields (``vision_tokens`` and ``vision_timesteps``) are NOT included here — the caller splices them in inside the denoising loop. The method is called once per (cond/uncond) prompt before the loop, since everything else only depends on the prompt length and the vision shape. + + For transfer, multiple vision items are packed in order ``[ctrl_1, ..., ctrl_N, target]``: control items are + marked clean via ``clean_item_flags`` (all frames conditioned, no noisy positions, no MSE-loss positions), so + the transformer treats them as fixed context and only predicts the (noisy) target frames. When + ``share_vision_temporal_positions`` is ``True`` every item reuses the same temporal mRoPE offset (the control + maps and the target are temporally aligned) instead of advancing the offset per item. """ config = self.transformer.config latent_patch_size = config.latent_patch_size - _, _, latent_t, latent_h, latent_w = input_vision_tokens.shape - patch_h = math.ceil(latent_h / latent_patch_size) - patch_w = math.ceil(latent_w / latent_patch_size) - num_vision_tokens = latent_t * patch_h * patch_w - - if condition_frame_indexes is None: - condition_frame_indexes = [0] if has_image_condition else [] - cond_frames = {idx for idx in condition_frame_indexes if 0 <= idx < latent_t} - noisy_frame_indexes = torch.tensor( - [idx for idx in range(latent_t) if idx not in cond_frames], device=device, dtype=torch.long - ) - frame_token_stride = patch_h * patch_w - mse_loss_indexes: list[int] = [] - for frame_idx in noisy_frame_indexes.tolist(): - frame_start = curr + frame_idx * frame_token_stride - mse_loss_indexes.extend(range(frame_start, frame_start + frame_token_stride)) + # Normalize to per-item lists so the single-item (non-transfer) path and the multi-item transfer path share + # one implementation. A single tensor with a flat condition_frame_indexes list reproduces the old behavior. + if isinstance(input_vision_tokens, torch.Tensor): + items = [input_vision_tokens] + per_item_condition: list[list[int] | None] = [condition_frame_indexes] # type: ignore[list-item] + else: + items = list(input_vision_tokens) + if condition_frame_indexes is None: + per_item_condition = [None] * len(items) + else: + per_item_condition = list(condition_frame_indexes) # type: ignore[arg-type] + if clean_item_flags is None: + clean_item_flags = [False] * len(items) effective_fps = vision_fps if config.enable_fps_modulation else None - vision_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( - grid_t=latent_t, - grid_h=patch_h, - grid_w=patch_w, - temporal_offset=mrope_offset, - reset_spatial_indices=config.unified_3d_mrope_reset_spatial_ids, - fps=effective_fps, - base_fps=float(config.base_fps), - temporal_compression_factor=self.vae.config.scale_factor_temporal, - ) + token_shapes: list[tuple[int, int, int]] = [] + sequence_index_parts: list[torch.Tensor] = [] + mse_loss_indexes: list[int] = [] + noisy_frame_indexes_per_item: list[torch.Tensor] = [] + mrope_id_parts: list[torch.Tensor] = [] + num_vision_tokens = 0 + num_noisy_vision_tokens = 0 + item_curr = curr + item_mrope_offset: int | float = mrope_offset + + for item, item_condition, is_clean in zip(items, per_item_condition, clean_item_flags): + _, _, latent_t, latent_h, latent_w = item.shape + patch_h = math.ceil(latent_h / latent_patch_size) + patch_w = math.ceil(latent_w / latent_patch_size) + item_num_tokens = latent_t * patch_h * patch_w + frame_token_stride = patch_h * patch_w + + if is_clean: + cond_frames = set(range(latent_t)) + else: + item_condition = item_condition if item_condition is not None else ([0] if has_image_condition else []) + cond_frames = {idx for idx in item_condition if 0 <= idx < latent_t} + noisy_frame_indexes = torch.tensor( + [idx for idx in range(latent_t) if idx not in cond_frames], device=device, dtype=torch.long + ) + + for frame_idx in noisy_frame_indexes.tolist(): + frame_start = item_curr + frame_idx * frame_token_stride + mse_loss_indexes.extend(range(frame_start, frame_start + frame_token_stride)) + + item_mrope_ids, next_mrope_offset = get_3d_mrope_ids_vae_tokens( + grid_t=latent_t, + grid_h=patch_h, + grid_w=patch_w, + temporal_offset=item_mrope_offset, + reset_spatial_indices=config.unified_3d_mrope_reset_spatial_ids, + fps=effective_fps, + base_fps=float(config.base_fps), + temporal_compression_factor=self.vae.config.scale_factor_temporal, + ) + + token_shapes.append((latent_t, patch_h, patch_w)) + sequence_index_parts.append( + torch.arange(item_curr, item_curr + item_num_tokens, dtype=torch.long, device=device) + ) + noisy_frame_indexes_per_item.append(noisy_frame_indexes) + mrope_id_parts.append(item_mrope_ids.to(device)) + num_vision_tokens += item_num_tokens + num_noisy_vision_tokens += len(noisy_frame_indexes) * frame_token_stride + item_curr += item_num_tokens + if not share_vision_temporal_positions: + item_mrope_offset = next_mrope_offset return { # Transformer-facing fields (vision_tokens and vision_timesteps spliced per step). - "vision_token_shapes": [(latent_t, patch_h, patch_w)], - "vision_sequence_indexes": torch.arange(curr, curr + num_vision_tokens, dtype=torch.long, device=device), + "vision_token_shapes": token_shapes, + "vision_sequence_indexes": torch.cat(sequence_index_parts, dim=0), "vision_mse_loss_indexes": torch.tensor(mse_loss_indexes, dtype=torch.long, device=device), - "vision_noisy_frame_indexes": [noisy_frame_indexes], + "vision_noisy_frame_indexes": noisy_frame_indexes_per_item, # Assembly helpers (consumed inline before the transformer call). - "vision_mrope_ids": vision_mrope_ids.to(device), + "vision_mrope_ids": torch.cat(mrope_id_parts, dim=1), "num_vision_tokens": num_vision_tokens, - "num_noisy_vision_tokens": len(noisy_frame_indexes) * frame_token_stride, + "num_noisy_vision_tokens": num_noisy_vision_tokens, } def _prepare_sound_segment( @@ -959,11 +1010,42 @@ def check_inputs( action: "CosmosActionCondition | None" = None, video: list[Image.Image] | torch.Tensor | np.ndarray | None = None, condition_frame_indexes_vision: Iterable[int] = (0, 1), + control_videos: dict[str, Any] | None = None, + num_first_chunk_conditional_frames: int = 0, ) -> None: if not isinstance(prompt, (str, list)) or ( isinstance(prompt, list) and not all(isinstance(p, str) for p in prompt) ): raise ValueError(f"`prompt` must be a str or list of str, got {type(prompt).__name__}.") + + if control_videos is not None: + # Transfer mode: validate the hint mapping and reject combinations the model does not support. + # The supported hints (edge, blur, depth, seg, wsm) are listed in canonical packing order. + supported_hints = ["edge", "blur", "depth", "seg", "wsm"] + if not isinstance(control_videos, dict) or not control_videos: + raise ValueError("`control_videos` must be a non-empty dict mapping hint name -> control video.") + unknown = [k for k in control_videos if k not in supported_hints] + if unknown: + raise ValueError( + f"`control_videos` has unknown hint(s) {unknown}; expected keys from {supported_hints}." + ) + if any(v is None for v in control_videos.values()): + raise ValueError("`control_videos` entries must be loaded videos, not None.") + if action is not None: + raise ValueError("Transfer (`control_videos`) cannot be combined with `action`.") + if image is not None: + raise ValueError("Transfer (`control_videos`) cannot be combined with `image`.") + if enable_sound: + raise ValueError( + "Transfer (`control_videos`) is video-only and cannot be combined with `enable_sound`." + ) + if num_first_chunk_conditional_frames > 0 and video is None: + raise ValueError( + "`num_first_chunk_conditional_frames` > 0 requires a `video` for first-chunk conditioning." + ) + if num_frames is not None and num_frames < 1: + raise ValueError(f"`num_frames` must be >= 1, got {num_frames}.") + return if negative_prompt is not None and not isinstance(negative_prompt, (str, list)): raise ValueError( f"`negative_prompt` must be a str, list of str, or None, got {type(negative_prompt).__name__}." @@ -1085,6 +1167,7 @@ def tokenize_prompt( add_duration_template: bool = True, action_mode: str | None = None, action_view_point: str | None = None, + transfer_mode: bool = False, ) -> tuple[list[int], list[int]]: """Apply prompt-augmentation templates and tokenize cond/uncond prompts via the Qwen2 chat template. @@ -1099,6 +1182,9 @@ def tokenize_prompt( was trained on (see :meth:`_build_action_json_prompt`), using ``action_view_point`` for the framing field; the flat metadata templates are skipped because the JSON already carries duration/fps/resolution/aspect_ratio. + When ``transfer_mode`` is set, the transfer system prompt is used and the prompt / negative prompt are passed + through verbatim (they are pre-upsampled JSON captions), again skipping the flat metadata templates. + Returns: ``(cond_input_ids, uncond_input_ids)`` — token-id lists for this sample. """ @@ -1128,7 +1214,10 @@ def _apply_templates(text: str, is_negative: bool = False) -> str: def _tokenize(text: str) -> BatchEncoding: conversations = [] if use_system_prompt: - system_prompt = _SYSTEM_PROMPT_IMAGE if is_image else _SYSTEM_PROMPT_VIDEO + if transfer_mode: + system_prompt = _SYSTEM_PROMPT_TRANSFER + else: + system_prompt = _SYSTEM_PROMPT_IMAGE if is_image else _SYSTEM_PROMPT_VIDEO conversations.append({"role": "system", "content": system_prompt}) conversations.append({"role": "user", "content": text}) return self.text_tokenizer.apply_chat_template( @@ -1150,6 +1239,11 @@ def _add_special_tokens(input_ids: list[int]) -> list[int]: prompt, view_point=action_view_point, num_frames=num_frames, fps=fps, height=height, width=width ) uncond_text = negative_prompt + elif transfer_mode: + # Transfer prompts are pre-upsampled JSON captions that already carry duration/fps/resolution; pass them + # through verbatim (the metadata templates would corrupt the JSON), mirroring the action-mode branch. + cond_text = prompt + uncond_text = negative_prompt else: cond_text = _apply_templates(prompt) uncond_text = _apply_templates(negative_prompt, is_negative=True) @@ -1234,6 +1328,286 @@ def _apply_video_safety_check(self, video: Any, output_type: str, device: torch. # output_type == "pt" return torch.from_numpy(checked.astype(np.float32) / 255.0).permute(0, 3, 1, 2) + @torch.no_grad() + def _generate_transfer( + self, + *, + prompt: str, + negative_prompt: str | None, + control_videos: dict[str, Any], + video: Any, + num_frames: int | None, + height: int, + width: int, + fps: float, + num_inference_steps: int, + guidance_scale: float, + control_guidance: float, + control_guidance_interval: tuple[float, float] | None, + guidance_interval: tuple[float, float] | None, + num_conditional_frames: int, + num_first_chunk_conditional_frames: int, + num_video_frames_per_chunk: int | None, + share_vision_temporal_positions: bool, + generator: torch.Generator | None, + output_type: str, + return_dict: bool, + use_system_prompt: bool, + enable_safety_check: bool, + device: torch.device, + dtype: torch.dtype, + ) -> "Cosmos3OmniPipelineOutput | tuple": + """Run video transfer: generate a target clip that follows one or more precomputed control hints. + + Control maps are packed as clean (fully conditioned) vision items before the noisy target — sequence layout + ``[ctrl_1, ..., ctrl_N, target]`` — and guidance uses a nested control/text classifier-free-guidance blend + (see the per-step branch selection below). Long clips are produced autoregressively chunk-by-chunk, with each + chunk conditioned on the tail of the previous one, then stitched back together. + """ + if output_type == "latent": + raise ValueError( + "Transfer decodes and stitches chunks in pixel space; `output_type='latent'` is unsupported." + ) + + tcf = int(self.vae.config.scale_factor_temporal) + sf = int(self.vae.config.scale_factor_spatial) + if height % sf != 0 or width % sf != 0: + raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") + + def _pad_temporal(frames: torch.Tensor, target_t: int) -> torch.Tensor: + # frames: [1, 3, T, H, W]. Reflect-pad along time up to target_t, falling back to repeating the last + # frame, mirroring the native Cosmos Framework `pad_temporal_frames`. No truncation (callers slice first). + if frames.shape[2] >= target_t: + return frames + while frames.shape[2] < target_t: + pad_len = min(frames.shape[2] - 1, target_t - frames.shape[2]) + if pad_len <= 0: + pad_frame = frames[:, :, -1:].repeat(1, 1, target_t - frames.shape[2], 1, 1) + frames = torch.cat([frames, pad_frame], dim=2) + break + frames = torch.cat([frames, frames.flip(dims=[2])[:, :, :pad_len]], dim=2) + return frames + + def _decode_to_pixel(latent: torch.Tensor) -> torch.Tensor: + vae_dtype = self.vae.dtype + mean = self._vae_latents_mean.to(device=latent.device, dtype=vae_dtype) + inv_std = self._vae_latents_inv_std.to(device=latent.device, dtype=vae_dtype) + z_raw = latent.to(vae_dtype) / inv_std.view(1, -1, 1, 1, 1) + mean.view(1, -1, 1, 1, 1) + return self.vae.decode(z_raw).sample.to(torch.float32).clamp(-1, 1) + + def _active_at(t: torch.Tensor, interval: tuple[float, float] | None) -> bool: + if interval is None: + return True + lo, hi = float(interval[0]), float(interval[1]) + return lo <= float(t.item()) <= hi + + # Canonical hint order, then preprocess every control map to [1, 3, T, H, W] in [-1, 1] at the target geometry. + hint_keys = [k for k in ["edge", "blur", "depth", "seg", "wsm"] if k in control_videos] + control_frames = { + key: self.video_processor.preprocess_video(control_videos[key], height=height, width=width).to( + device=device, dtype=dtype + ) + for key in hint_keys + } + input_frames = None + if video is not None: + input_frames = self.video_processor.preprocess_video(video, height=height, width=width).to( + device=device, dtype=dtype + ) + + # Output frame count / chunking come from the (first) control video, optionally capped by num_frames. + total_frames = next(iter(control_frames.values())).shape[2] + if num_frames is not None: + total_frames = min(total_frames, num_frames) + total_frames = max(1, total_frames) + + per_chunk = num_video_frames_per_chunk if num_video_frames_per_chunk is not None else total_frames + chunk_frames = 1 if total_frames == 1 else per_chunk + chunk_frames = math.ceil((chunk_frames - 1) / tcf) * tcf + 1 + + if total_frames <= chunk_frames: + num_chunks, stride = 1, chunk_frames + else: + stride = chunk_frames - num_conditional_frames + if stride <= 0: + raise ValueError("`num_conditional_frames` must be smaller than `num_video_frames_per_chunk`.") + remaining = total_frames - chunk_frames + num_chunks = 1 + (remaining // stride + (1 if remaining % stride else 0)) + + padded = max(total_frames, chunk_frames) + control_frames = {key: _pad_temporal(frames, padded) for key, frames in control_frames.items()} + if input_frames is not None: + input_frames = _pad_temporal(input_frames, padded) + + # Text packing is invariant across chunks and denoising steps; build it once. Transfer prompts are passed + # through verbatim (pre-upsampled JSON) under the transfer system prompt. + cond_input_ids, uncond_input_ids = self.tokenize_prompt( + prompt, + negative_prompt, + num_frames=chunk_frames, + height=height, + width=width, + fps=fps, + use_system_prompt=use_system_prompt, + transfer_mode=True, + ) + cond_text_segment = self._prepare_text_segment(cond_input_ids, device=device) + uncond_text_segment = self._prepare_text_segment(uncond_input_ids, device=device) + num_hints = len(hint_keys) + + output_chunks: list[torch.Tensor] = [] + previous_output: torch.Tensor | None = None + + for chunk_id in range(num_chunks): + start_frame = chunk_id * stride + end_frame = min(start_frame + chunk_frames, total_frames) + chunk_controls = [ + _pad_temporal(control_frames[key][:, :, start_frame:end_frame], chunk_frames) for key in hint_keys + ] + + # Seed the target with conditioning frames (first chunk from the input video, later chunks from the + # previous chunk's tail), repeat-padding the remaining frames so the whole clip is well-defined. + target = torch.zeros(1, 3, chunk_frames, height, width, device=device, dtype=dtype) + current_conditional_frames = 0 + if chunk_id == 0 and num_first_chunk_conditional_frames > 0 and input_frames is not None: + current_conditional_frames = min( + num_first_chunk_conditional_frames, input_frames.shape[2], chunk_frames + ) + if current_conditional_frames > 0: + target[:, :, :current_conditional_frames] = input_frames[:, :, :current_conditional_frames] + elif chunk_id > 0 and previous_output is not None: + current_conditional_frames = min(num_conditional_frames, previous_output.shape[2], chunk_frames) + if current_conditional_frames > 0: + target[:, :, :current_conditional_frames] = previous_output[:, :, -current_conditional_frames:].to( + device=device, dtype=dtype + ) + if 0 < current_conditional_frames < chunk_frames: + fill = target[:, :, current_conditional_frames - 1 : current_conditional_frames] + target[:, :, current_conditional_frames:] = fill.expand( + -1, -1, chunk_frames - current_conditional_frames, -1, -1 + ) + + # Encode controls as clean latents and build the noisy target latents + conditioning mask. + control_latents = [self._encode_video(ctrl).contiguous().float() for ctrl in chunk_controls] + target_x0 = self._encode_video(target).contiguous().float() + latent_t = target_x0.shape[2] + condition_mask = torch.zeros((latent_t, 1, 1), device=device, dtype=dtype) + latent_condition_frames = 0 + if current_conditional_frames > 0: + latent_condition_frames = (current_conditional_frames - 1) // tcf + 1 + condition_mask[:latent_condition_frames] = 1.0 + noise = randn_tensor(tuple(target_x0.shape), generator=generator, device=device, dtype=dtype) + latents = condition_mask * target_x0 + (1.0 - condition_mask) * noise + velocity_mask = 1.0 - condition_mask + condition_latents = condition_mask * target_x0 + + target_condition_indexes = list(range(latent_condition_frames)) + + # Pre-pack the three CFG sequence variants. cond_full / uncond_full carry every control item; the + # no-control branch drops them (only [text, target]) so the control axis can be amplified. + def _vision_pack(text_segment: dict[str, Any], include_controls: bool) -> dict[str, Any]: + if include_controls: + vision_items = [*control_latents, latents] + condition_indexes = [None] * num_hints + [target_condition_indexes] + clean_flags = [True] * num_hints + [False] + else: + vision_items = [latents] + condition_indexes = [target_condition_indexes] + clean_flags = [False] + vision_segment = self._prepare_vision_segment( + input_vision_tokens=vision_items, + has_image_condition=False, + mrope_offset=text_segment["vision_start_temporal_offset"], + vision_fps=fps, + curr=text_segment["und_len"], + device=device, + condition_frame_indexes=condition_indexes, + clean_item_flags=clean_flags, + share_vision_temporal_positions=share_vision_temporal_positions, + ) + return { + **text_segment, + **vision_segment, + "position_ids": torch.cat( + [text_segment["text_mrope_ids"], vision_segment["vision_mrope_ids"]], dim=1 + ), + "sequence_length": text_segment["und_len"] + vision_segment["num_vision_tokens"], + } + + cond_full_static = _vision_pack(cond_text_segment, include_controls=True) + cond_no_control_static = _vision_pack(cond_text_segment, include_controls=False) + uncond_full_static = _vision_pack(uncond_text_segment, include_controls=True) + num_noisy_vision_tokens = cond_full_static["num_noisy_vision_tokens"] + + def _run(static: dict[str, Any], vision_tokens: list[torch.Tensor], vision_timesteps: torch.Tensor): + preds_vision, _, _ = self.transformer( + input_ids=static["input_ids"], + text_indexes=static["text_indexes"], + position_ids=static["position_ids"], + und_len=static["und_len"], + sequence_length=static["sequence_length"], + vision_tokens=vision_tokens, + vision_token_shapes=static["vision_token_shapes"], + vision_sequence_indexes=static["vision_sequence_indexes"], + vision_mse_loss_indexes=static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=static["vision_noisy_frame_indexes"], + ) + # The target is the last vision item; control items return zeros (no MSE positions). + return preds_vision[-1] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps): + self._current_timestep = t + timestep = t.item() + vision_tokens_full = [c.to(device=device, dtype=dtype) for c in control_latents] + [ + latents.to(device=device, dtype=dtype) + ] + vision_tokens_target = [latents.to(device=device, dtype=dtype)] + vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) + + step_guidance = guidance_scale if _active_at(t, guidance_interval) else 1.0 + step_control = control_guidance if _active_at(t, control_guidance_interval) else 1.0 + needs_text_cfg = step_guidance > 1.0 + needs_control_cfg = step_control != 1.0 + + cond_full = _run(cond_full_static, vision_tokens_full, vision_timesteps) + if needs_control_cfg and needs_text_cfg: + cond_no_control = _run(cond_no_control_static, vision_tokens_target, vision_timesteps) + uncond_full = _run(uncond_full_static, vision_tokens_full, vision_timesteps) + control_cond = cond_no_control + step_control * (cond_full - cond_no_control) + velocity = uncond_full + step_guidance * (control_cond - uncond_full) + elif needs_control_cfg: + cond_no_control = _run(cond_no_control_static, vision_tokens_target, vision_timesteps) + velocity = cond_no_control + step_control * (cond_full - cond_no_control) + elif needs_text_cfg: + uncond_full = _run(uncond_full_static, vision_tokens_full, vision_timesteps) + velocity = uncond_full + step_guidance * (cond_full - uncond_full) + else: + velocity = cond_full + + velocity = velocity * velocity_mask + latents = self.scheduler.step(velocity.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False)[ + 0 + ].squeeze(0) + latents = velocity_mask * latents + (1.0 - velocity_mask) * condition_latents + + output_video = _decode_to_pixel(latents) + previous_output = output_video + # Chunks after the first overlap the previous chunk by the conditioning frames; drop them when stitching. + output_chunks.append(output_video if chunk_id == 0 else output_video[:, :, current_conditional_frames:]) + + self._current_timestep = None + decoded = torch.cat(output_chunks, dim=2)[:, :, :total_frames] + video_out = self.video_processor.postprocess_video(decoded, output_type=output_type)[0] + if enable_safety_check and isinstance(self.safety_checker, CosmosSafetyChecker): + video_out = self._apply_video_safety_check(video_out, output_type=output_type, device=device) + + self.maybe_free_model_hooks() + if not return_dict: + return (video_out, None) + return Cosmos3OmniPipelineOutput(video=video_out, sound=None, action=None) + @property def current_timestep(self): return self._current_timestep @@ -1267,6 +1641,14 @@ def __call__( sound_latents: torch.Tensor | None = None, action_latents: torch.Tensor | None = None, action: CosmosActionCondition | None = None, + control_videos: dict[str, list[Image.Image] | torch.Tensor | np.ndarray] | None = None, + control_guidance: float = 1.0, + control_guidance_interval: tuple[float, float] | None = None, + guidance_interval: tuple[float, float] | None = None, + num_conditional_frames: int = 1, + num_first_chunk_conditional_frames: int = 0, + num_video_frames_per_chunk: int | None = None, + share_vision_temporal_positions: bool = True, output_type: str = "pil", return_dict: bool = True, use_system_prompt: bool = True, @@ -1347,6 +1729,33 @@ def __call__( `action_gen=True`. When set, passing the top-level `image` argument raises; `height` / `width` / `num_frames` must be `None`, since resolution comes from `action.resolution_tier` and frame count from `action.chunk_size`. See [`CosmosActionCondition`]. + control_videos (`dict[str, video]`, *optional*): + Enables video transfer. A mapping from control-hint name (`"edge"`, `"blur"`, `"depth"`, `"seg"`, + `"wsm"`) to a precomputed control video (anything accepted by `video=`). Each control map is resized, + temporally padded, normalized and VAE-encoded into a clean conditioning item; the target clip is then + generated to follow them. Multiple hints can be combined. Transfer is video-only and cannot be combined + with `image`, `video`, `action`, or `enable_sound`. The prompt should be a pre-upsampled JSON caption. + control_guidance (`float`, *optional*, defaults to `1.0`): + Control classifier-free guidance scale for transfer. Values `!= 1.0` amplify the control signal by + blending a "with-control" prediction against a "without-control" prediction (nested with the text + `guidance_scale`). `1.0` disables the control axis (control maps still condition both text branches). + control_guidance_interval (`tuple[float, float]`, *optional*): + Optional `[lo, hi]` timestep window (in scheduler timestep units) outside which control guidance is + skipped. When `None`, control guidance is applied at every step. + guidance_interval (`tuple[float, float]`, *optional*): + Optional `[lo, hi]` timestep window outside which text guidance is skipped (transfer only). + num_conditional_frames (`int`, *optional*, defaults to `1`): + Number of frames carried over from the previous chunk as conditioning at the start of each subsequent + autoregressive chunk (transfer only). + num_first_chunk_conditional_frames (`int`, *optional*, defaults to `0`): + Number of leading frames of `video` used to condition the first transfer chunk. Requires `video` to be + passed alongside `control_videos`; `0` means the first chunk is fully generated. + num_video_frames_per_chunk (`int`, *optional*): + Maximum number of frames generated per autoregressive chunk (transfer only). When `None`, the whole + clip is generated in a single chunk. Longer clips are produced chunk-by-chunk and stitched. + share_vision_temporal_positions (`bool`, *optional*, defaults to `True`): + When `True`, the control maps and the target share the same temporal mRoPE positions (they are + temporally aligned). When `False`, each vision item advances the temporal offset (transfer only). output_type (`str`, *optional*, defaults to `"pil"`): Output format for the video. One of `"pil"` (list of `PIL.Image.Image`), `"np"` (`np.ndarray`, `[T, H, W, C]`), `"pt"` (`torch.Tensor`, `[T, C, H, W]`), or `"latent"` (raw vision latents). @@ -1383,12 +1792,14 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs if action is None: - if num_frames is None: - num_frames = 189 if height is None: height = 720 if width is None: width = 1280 + # For transfer, num_frames defaults to the control video length (resolved in _generate_transfer); for the + # other modes it falls back to the standard ~7.9s clip. + if num_frames is None and control_videos is None: + num_frames = 189 # 1. Check inputs self.check_inputs( @@ -1404,6 +1815,8 @@ def __call__( action, video=video, condition_frame_indexes_vision=condition_frame_indexes_vision, + control_videos=control_videos, + num_first_chunk_conditional_frames=num_first_chunk_conditional_frames, ) # `action_mode` is the only action field consumed directly in __call__ (prompt template + output slicing); @@ -1446,6 +1859,36 @@ def __call__( finally: self.safety_checker.to("cpu") + # Transfer is a distinct mode (autoregressive multi-chunk + nested control/text CFG over multiple vision + # items), so it runs through its own self-contained routine rather than the shared single-clip path below. + if control_videos is not None: + return self._generate_transfer( + prompt=prompt, + negative_prompt=negative_prompt, + control_videos=control_videos, + video=video, + num_frames=num_frames, + height=height, + width=width, + fps=fps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + control_guidance=control_guidance, + control_guidance_interval=control_guidance_interval, + guidance_interval=guidance_interval, + num_conditional_frames=num_conditional_frames, + num_first_chunk_conditional_frames=num_first_chunk_conditional_frames, + num_video_frames_per_chunk=num_video_frames_per_chunk, + share_vision_temporal_positions=share_vision_temporal_positions, + generator=generator, + output_type=output_type, + return_dict=return_dict, + use_system_prompt=use_system_prompt, + enable_safety_check=enable_safety_check, + device=device, + dtype=dtype, + ) + # 2. Tokenize prompt (applies metadata templates and selects mode-specific default negative prompt) cond_input_ids, uncond_input_ids = self.tokenize_prompt( prompt, From b7e10ffaa3f92a268aadb1e954ffb726f3ef29e8 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 26 Jun 2026 14:31:17 +0000 Subject: [PATCH 2/2] Refactor transfer to native pipleine steps --- docs/source/en/api/pipelines/cosmos3.md | 2 +- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 1225 +++++++++-------- 2 files changed, 626 insertions(+), 601 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 221922fcd4f0..af8a8b981cc6 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -385,7 +385,7 @@ curl -sL "$base/edge/prompt.json" -o assets/edge/prompt.json curl -sL "$base/negative_prompt.json" -o assets/negative_prompt.json ``` -Guidance uses a nested control/text classifier-free-guidance blend. `guidance_scale` is the usual text CFG; `control_guidance` (`!= 1.0`) additionally amplifies the control signal. Recommended starting values per hint (matching the Cosmos Framework defaults): +Guidance uses a nested control/text classifier-free-guidance blend. `guidance_scale` is the usual text CFG; `control_guidance` (`!= 1.0`) additionally amplifies the control signal. Recommended starting values per hint: | Hint | `guidance_scale` | `control_guidance` | `flow_shift` | Geometry | | --- | --- | --- | --- | --- | diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index bc6a0456eddb..aaeccbfd9b64 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -1329,284 +1329,77 @@ def _apply_video_safety_check(self, video: Any, output_type: str, device: torch. return torch.from_numpy(checked.astype(np.float32) / 255.0).permute(0, 3, 1, 2) @torch.no_grad() - def _generate_transfer( + def _prepare_transfer_latents( self, *, - prompt: str, - negative_prompt: str | None, - control_videos: dict[str, Any], - video: Any, - num_frames: int | None, + chunk_id: int, + chunk_frames: int, height: int, width: int, - fps: float, - num_inference_steps: int, - guidance_scale: float, - control_guidance: float, - control_guidance_interval: tuple[float, float] | None, - guidance_interval: tuple[float, float] | None, - num_conditional_frames: int, + chunk_controls: list[torch.Tensor], + video: list[Image.Image] | torch.Tensor | np.ndarray | None, + previous_output: torch.Tensor | None, num_first_chunk_conditional_frames: int, - num_video_frames_per_chunk: int | None, - share_vision_temporal_positions: bool, + num_conditional_frames: int, + tcf: int, generator: torch.Generator | None, - output_type: str, - return_dict: bool, - use_system_prompt: bool, - enable_safety_check: bool, - device: torch.device, + device: torch.device | str, dtype: torch.dtype, - ) -> "Cosmos3OmniPipelineOutput | tuple": - """Run video transfer: generate a target clip that follows one or more precomputed control hints. - - Control maps are packed as clean (fully conditioned) vision items before the noisy target — sequence layout - ``[ctrl_1, ..., ctrl_N, target]`` — and guidance uses a nested control/text classifier-free-guidance blend - (see the per-step branch selection below). Long clips are produced autoregressively chunk-by-chunk, with each - chunk conditioned on the tail of the previous one, then stitched back together. - """ - if output_type == "latent": - raise ValueError( - "Transfer decodes and stitches chunks in pixel space; `output_type='latent'` is unsupported." - ) + ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor, torch.Tensor, list[int], int]: + """Build the per-chunk transfer latents: encode the control maps as clean items and the noisy target. - tcf = int(self.vae.config.scale_factor_temporal) - sf = int(self.vae.config.scale_factor_spatial) - if height % sf != 0 or width % sf != 0: - raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") + Seeds the target's conditioning frames (first chunk from the RGB ``video``, later chunks from the previous + chunk's tail), repeat-pads the remaining frames, then encodes the (already sliced + padded) control maps and + the target. Returns the initial noisy ``latents``, the clean ``control_latents``, the ``velocity_mask`` and + ``condition_latents`` used to pin conditioned frames during denoising, the ``target_condition_indexes`` for + sequence packing, and the resolved ``current_conditional_frames``. - def _pad_temporal(frames: torch.Tensor, target_t: int) -> torch.Tensor: - # frames: [1, 3, T, H, W]. Reflect-pad along time up to target_t, falling back to repeating the last - # frame, mirroring the native Cosmos Framework `pad_temporal_frames`. No truncation (callers slice first). - if frames.shape[2] >= target_t: - return frames - while frames.shape[2] < target_t: - pad_len = min(frames.shape[2] - 1, target_t - frames.shape[2]) - if pad_len <= 0: - pad_frame = frames[:, :, -1:].repeat(1, 1, target_t - frames.shape[2], 1, 1) - frames = torch.cat([frames, pad_frame], dim=2) - break - frames = torch.cat([frames, frames.flip(dims=[2])[:, :, :pad_len]], dim=2) - return frames - - def _decode_to_pixel(latent: torch.Tensor) -> torch.Tensor: - vae_dtype = self.vae.dtype - mean = self._vae_latents_mean.to(device=latent.device, dtype=vae_dtype) - inv_std = self._vae_latents_inv_std.to(device=latent.device, dtype=vae_dtype) - z_raw = latent.to(vae_dtype) / inv_std.view(1, -1, 1, 1, 1) + mean.view(1, -1, 1, 1, 1) - return self.vae.decode(z_raw).sample.to(torch.float32).clamp(-1, 1) - - def _active_at(t: torch.Tensor, interval: tuple[float, float] | None) -> bool: - if interval is None: - return True - lo, hi = float(interval[0]), float(interval[1]) - return lo <= float(t.item()) <= hi - - # Canonical hint order, then preprocess every control map to [1, 3, T, H, W] in [-1, 1] at the target geometry. - hint_keys = [k for k in ["edge", "blur", "depth", "seg", "wsm"] if k in control_videos] - control_frames = { - key: self.video_processor.preprocess_video(control_videos[key], height=height, width=width).to( - device=device, dtype=dtype - ) - for key in hint_keys - } - input_frames = None - if video is not None: + """ + # Seed the target with conditioning frames (first chunk from the input video, later chunks from the + # previous chunk's tail), repeat-padding the remaining frames so the whole clip is well-defined. + target = torch.zeros(1, 3, chunk_frames, height, width, device=device, dtype=dtype) + current_conditional_frames = 0 + if chunk_id == 0 and num_first_chunk_conditional_frames > 0 and video is not None: input_frames = self.video_processor.preprocess_video(video, height=height, width=width).to( device=device, dtype=dtype ) - - # Output frame count / chunking come from the (first) control video, optionally capped by num_frames. - total_frames = next(iter(control_frames.values())).shape[2] - if num_frames is not None: - total_frames = min(total_frames, num_frames) - total_frames = max(1, total_frames) - - per_chunk = num_video_frames_per_chunk if num_video_frames_per_chunk is not None else total_frames - chunk_frames = 1 if total_frames == 1 else per_chunk - chunk_frames = math.ceil((chunk_frames - 1) / tcf) * tcf + 1 - - if total_frames <= chunk_frames: - num_chunks, stride = 1, chunk_frames - else: - stride = chunk_frames - num_conditional_frames - if stride <= 0: - raise ValueError("`num_conditional_frames` must be smaller than `num_video_frames_per_chunk`.") - remaining = total_frames - chunk_frames - num_chunks = 1 + (remaining // stride + (1 if remaining % stride else 0)) - - padded = max(total_frames, chunk_frames) - control_frames = {key: _pad_temporal(frames, padded) for key, frames in control_frames.items()} - if input_frames is not None: - input_frames = _pad_temporal(input_frames, padded) - - # Text packing is invariant across chunks and denoising steps; build it once. Transfer prompts are passed - # through verbatim (pre-upsampled JSON) under the transfer system prompt. - cond_input_ids, uncond_input_ids = self.tokenize_prompt( - prompt, - negative_prompt, - num_frames=chunk_frames, - height=height, - width=width, - fps=fps, - use_system_prompt=use_system_prompt, - transfer_mode=True, - ) - cond_text_segment = self._prepare_text_segment(cond_input_ids, device=device) - uncond_text_segment = self._prepare_text_segment(uncond_input_ids, device=device) - num_hints = len(hint_keys) - - output_chunks: list[torch.Tensor] = [] - previous_output: torch.Tensor | None = None - - for chunk_id in range(num_chunks): - start_frame = chunk_id * stride - end_frame = min(start_frame + chunk_frames, total_frames) - chunk_controls = [ - _pad_temporal(control_frames[key][:, :, start_frame:end_frame], chunk_frames) for key in hint_keys - ] - - # Seed the target with conditioning frames (first chunk from the input video, later chunks from the - # previous chunk's tail), repeat-padding the remaining frames so the whole clip is well-defined. - target = torch.zeros(1, 3, chunk_frames, height, width, device=device, dtype=dtype) - current_conditional_frames = 0 - if chunk_id == 0 and num_first_chunk_conditional_frames > 0 and input_frames is not None: - current_conditional_frames = min( - num_first_chunk_conditional_frames, input_frames.shape[2], chunk_frames - ) - if current_conditional_frames > 0: - target[:, :, :current_conditional_frames] = input_frames[:, :, :current_conditional_frames] - elif chunk_id > 0 and previous_output is not None: - current_conditional_frames = min(num_conditional_frames, previous_output.shape[2], chunk_frames) - if current_conditional_frames > 0: - target[:, :, :current_conditional_frames] = previous_output[:, :, -current_conditional_frames:].to( - device=device, dtype=dtype - ) - if 0 < current_conditional_frames < chunk_frames: - fill = target[:, :, current_conditional_frames - 1 : current_conditional_frames] - target[:, :, current_conditional_frames:] = fill.expand( - -1, -1, chunk_frames - current_conditional_frames, -1, -1 - ) - - # Encode controls as clean latents and build the noisy target latents + conditioning mask. - control_latents = [self._encode_video(ctrl).contiguous().float() for ctrl in chunk_controls] - target_x0 = self._encode_video(target).contiguous().float() - latent_t = target_x0.shape[2] - condition_mask = torch.zeros((latent_t, 1, 1), device=device, dtype=dtype) - latent_condition_frames = 0 + current_conditional_frames = min(num_first_chunk_conditional_frames, input_frames.shape[2], chunk_frames) if current_conditional_frames > 0: - latent_condition_frames = (current_conditional_frames - 1) // tcf + 1 - condition_mask[:latent_condition_frames] = 1.0 - noise = randn_tensor(tuple(target_x0.shape), generator=generator, device=device, dtype=dtype) - latents = condition_mask * target_x0 + (1.0 - condition_mask) * noise - velocity_mask = 1.0 - condition_mask - condition_latents = condition_mask * target_x0 - - target_condition_indexes = list(range(latent_condition_frames)) - - # Pre-pack the three CFG sequence variants. cond_full / uncond_full carry every control item; the - # no-control branch drops them (only [text, target]) so the control axis can be amplified. - def _vision_pack(text_segment: dict[str, Any], include_controls: bool) -> dict[str, Any]: - if include_controls: - vision_items = [*control_latents, latents] - condition_indexes = [None] * num_hints + [target_condition_indexes] - clean_flags = [True] * num_hints + [False] - else: - vision_items = [latents] - condition_indexes = [target_condition_indexes] - clean_flags = [False] - vision_segment = self._prepare_vision_segment( - input_vision_tokens=vision_items, - has_image_condition=False, - mrope_offset=text_segment["vision_start_temporal_offset"], - vision_fps=fps, - curr=text_segment["und_len"], - device=device, - condition_frame_indexes=condition_indexes, - clean_item_flags=clean_flags, - share_vision_temporal_positions=share_vision_temporal_positions, - ) - return { - **text_segment, - **vision_segment, - "position_ids": torch.cat( - [text_segment["text_mrope_ids"], vision_segment["vision_mrope_ids"]], dim=1 - ), - "sequence_length": text_segment["und_len"] + vision_segment["num_vision_tokens"], - } - - cond_full_static = _vision_pack(cond_text_segment, include_controls=True) - cond_no_control_static = _vision_pack(cond_text_segment, include_controls=False) - uncond_full_static = _vision_pack(uncond_text_segment, include_controls=True) - num_noisy_vision_tokens = cond_full_static["num_noisy_vision_tokens"] - - def _run(static: dict[str, Any], vision_tokens: list[torch.Tensor], vision_timesteps: torch.Tensor): - preds_vision, _, _ = self.transformer( - input_ids=static["input_ids"], - text_indexes=static["text_indexes"], - position_ids=static["position_ids"], - und_len=static["und_len"], - sequence_length=static["sequence_length"], - vision_tokens=vision_tokens, - vision_token_shapes=static["vision_token_shapes"], - vision_sequence_indexes=static["vision_sequence_indexes"], - vision_mse_loss_indexes=static["vision_mse_loss_indexes"], - vision_timesteps=vision_timesteps, - vision_noisy_frame_indexes=static["vision_noisy_frame_indexes"], + target[:, :, :current_conditional_frames] = input_frames[:, :, :current_conditional_frames] + elif chunk_id > 0 and previous_output is not None: + current_conditional_frames = min(num_conditional_frames, previous_output.shape[2], chunk_frames) + if current_conditional_frames > 0: + target[:, :, :current_conditional_frames] = previous_output[:, :, -current_conditional_frames:].to( + device=device, dtype=dtype ) - # The target is the last vision item; control items return zeros (no MSE positions). - return preds_vision[-1] - - self.scheduler.set_timesteps(num_inference_steps, device=device) - for t in self.progress_bar(self.scheduler.timesteps): - self._current_timestep = t - timestep = t.item() - vision_tokens_full = [c.to(device=device, dtype=dtype) for c in control_latents] + [ - latents.to(device=device, dtype=dtype) - ] - vision_tokens_target = [latents.to(device=device, dtype=dtype)] - vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) - - step_guidance = guidance_scale if _active_at(t, guidance_interval) else 1.0 - step_control = control_guidance if _active_at(t, control_guidance_interval) else 1.0 - needs_text_cfg = step_guidance > 1.0 - needs_control_cfg = step_control != 1.0 - - cond_full = _run(cond_full_static, vision_tokens_full, vision_timesteps) - if needs_control_cfg and needs_text_cfg: - cond_no_control = _run(cond_no_control_static, vision_tokens_target, vision_timesteps) - uncond_full = _run(uncond_full_static, vision_tokens_full, vision_timesteps) - control_cond = cond_no_control + step_control * (cond_full - cond_no_control) - velocity = uncond_full + step_guidance * (control_cond - uncond_full) - elif needs_control_cfg: - cond_no_control = _run(cond_no_control_static, vision_tokens_target, vision_timesteps) - velocity = cond_no_control + step_control * (cond_full - cond_no_control) - elif needs_text_cfg: - uncond_full = _run(uncond_full_static, vision_tokens_full, vision_timesteps) - velocity = uncond_full + step_guidance * (cond_full - uncond_full) - else: - velocity = cond_full - - velocity = velocity * velocity_mask - latents = self.scheduler.step(velocity.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False)[ - 0 - ].squeeze(0) - latents = velocity_mask * latents + (1.0 - velocity_mask) * condition_latents - - output_video = _decode_to_pixel(latents) - previous_output = output_video - # Chunks after the first overlap the previous chunk by the conditioning frames; drop them when stitching. - output_chunks.append(output_video if chunk_id == 0 else output_video[:, :, current_conditional_frames:]) - - self._current_timestep = None - decoded = torch.cat(output_chunks, dim=2)[:, :, :total_frames] - video_out = self.video_processor.postprocess_video(decoded, output_type=output_type)[0] - if enable_safety_check and isinstance(self.safety_checker, CosmosSafetyChecker): - video_out = self._apply_video_safety_check(video_out, output_type=output_type, device=device) + if 0 < current_conditional_frames < chunk_frames: + fill = target[:, :, current_conditional_frames - 1 : current_conditional_frames] + target[:, :, current_conditional_frames:] = fill.expand( + -1, -1, chunk_frames - current_conditional_frames, -1, -1 + ) - self.maybe_free_model_hooks() - if not return_dict: - return (video_out, None) - return Cosmos3OmniPipelineOutput(video=video_out, sound=None, action=None) + # Encode controls as clean latents and build the noisy target latents + conditioning mask. + control_latents = [self._encode_video(ctrl).contiguous().float() for ctrl in chunk_controls] + target_x0 = self._encode_video(target).contiguous().float() + latent_t = target_x0.shape[2] + condition_mask = torch.zeros((latent_t, 1, 1), device=device, dtype=dtype) + latent_condition_frames = 0 + if current_conditional_frames > 0: + latent_condition_frames = (current_conditional_frames - 1) // tcf + 1 + condition_mask[:latent_condition_frames] = 1.0 + noise = randn_tensor(tuple(target_x0.shape), generator=generator, device=device, dtype=dtype) + latents = condition_mask * target_x0 + (1.0 - condition_mask) * noise + velocity_mask = 1.0 - condition_mask + condition_latents = condition_mask * target_x0 + target_condition_indexes = list(range(latent_condition_frames)) + return ( + latents, + control_latents, + velocity_mask, + condition_latents, + target_condition_indexes, + current_conditional_frames, + ) @property def current_timestep(self): @@ -1796,8 +1589,8 @@ def __call__( height = 720 if width is None: width = 1280 - # For transfer, num_frames defaults to the control video length (resolved in _generate_transfer); for the - # other modes it falls back to the standard ~7.9s clip. + # For transfer, num_frames defaults to the control video length (resolved in the transfer pre-setup + # below); for the other modes it falls back to the standard ~7.9s clip. if num_frames is None and control_videos is None: num_frames = 189 @@ -1860,40 +1653,83 @@ def __call__( self.safety_checker.to("cpu") # Transfer is a distinct mode (autoregressive multi-chunk + nested control/text CFG over multiple vision - # items), so it runs through its own self-contained routine rather than the shared single-clip path below. - if control_videos is not None: - return self._generate_transfer( - prompt=prompt, - negative_prompt=negative_prompt, - control_videos=control_videos, - video=video, - num_frames=num_frames, - height=height, - width=width, - fps=fps, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - control_guidance=control_guidance, - control_guidance_interval=control_guidance_interval, - guidance_interval=guidance_interval, - num_conditional_frames=num_conditional_frames, - num_first_chunk_conditional_frames=num_first_chunk_conditional_frames, - num_video_frames_per_chunk=num_video_frames_per_chunk, - share_vision_temporal_positions=share_vision_temporal_positions, - generator=generator, - output_type=output_type, - return_dict=return_dict, - use_system_prompt=use_system_prompt, - enable_safety_check=enable_safety_check, - device=device, - dtype=dtype, - ) + # items). It threads through the same numbered steps below, taking the `if transfer:` branch wherever its + # control flow differs and reusing the shared steps everywhere else. + transfer = control_videos is not None + num_chunks = 1 + + # Transfer pre-setup: preprocess the control maps, resolve the autoregressive chunk geometry, and temporally + # pad the controls / optional input video so every chunk window is well-defined. Chunk-invariant, so done once. + if transfer: + if output_type == "latent": + raise ValueError( + "Transfer decodes and stitches chunks in pixel space; `output_type='latent'` is unsupported." + ) + tcf = int(self.vae.config.scale_factor_temporal) + sf = int(self.vae.config.scale_factor_spatial) + if height % sf != 0 or width % sf != 0: + raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") + + def _pad_temporal(frames: torch.Tensor, target_t: int) -> torch.Tensor: + # frames: [1, 3, T, H, W]. Reflect-pad along time up to target_t, falling back to repeating the last + # frame once the clip is too short to keep reflecting. No truncation (callers slice). + if frames.shape[2] >= target_t: + return frames + while frames.shape[2] < target_t: + pad_len = min(frames.shape[2] - 1, target_t - frames.shape[2]) + if pad_len <= 0: + pad_frame = frames[:, :, -1:].repeat(1, 1, target_t - frames.shape[2], 1, 1) + frames = torch.cat([frames, pad_frame], dim=2) + break + frames = torch.cat([frames, frames.flip(dims=[2])[:, :, :pad_len]], dim=2) + return frames + + def _active_at(t: torch.Tensor, interval: tuple[float, float] | None) -> bool: + if interval is None: + return True + lo, hi = float(interval[0]), float(interval[1]) + return lo <= float(t.item()) <= hi + + # Canonical hint order, then preprocess every control map to [1, 3, T, H, W] in [-1, 1] at target geometry. + hint_keys = [k for k in ["edge", "blur", "depth", "seg", "wsm"] if k in control_videos] + control_frames = { + key: self.video_processor.preprocess_video(control_videos[key], height=height, width=width).to( + device=device, dtype=dtype + ) + for key in hint_keys + } + + # Output frame count / chunking come from the (first) control video, optionally capped by num_frames. + total_frames = next(iter(control_frames.values())).shape[2] + if num_frames is not None: + total_frames = min(total_frames, num_frames) + total_frames = max(1, total_frames) - # 2. Tokenize prompt (applies metadata templates and selects mode-specific default negative prompt) + per_chunk = num_video_frames_per_chunk if num_video_frames_per_chunk is not None else total_frames + chunk_frames = 1 if total_frames == 1 else per_chunk + chunk_frames = math.ceil((chunk_frames - 1) / tcf) * tcf + 1 + + if total_frames <= chunk_frames: + num_chunks, stride = 1, chunk_frames + else: + stride = chunk_frames - num_conditional_frames + if stride <= 0: + raise ValueError("`num_conditional_frames` must be smaller than `num_video_frames_per_chunk`.") + remaining = total_frames - chunk_frames + num_chunks = 1 + (remaining // stride + (1 if remaining % stride else 0)) + + padded = max(total_frames, chunk_frames) + control_frames = {key: _pad_temporal(frames, padded) for key, frames in control_frames.items()} + num_hints = len(hint_keys) + output_chunks: list[torch.Tensor] = [] + previous_output: torch.Tensor | None = None + + # 2. Tokenize prompt (applies metadata templates and selects mode-specific default negative prompt). Transfer + # passes the per-chunk frame count and its JSON-passthrough system prompt. cond_input_ids, uncond_input_ids = self.tokenize_prompt( prompt, negative_prompt, - num_frames=num_frames, + num_frames=chunk_frames if transfer else num_frames, height=height, width=width, fps=fps, @@ -1902,328 +1738,517 @@ def __call__( add_duration_template=add_duration_template, action_mode=action_mode, action_view_point=action.view_point if action is not None else None, + transfer_mode=transfer, ) - # 3. Pre-pack the text segment for each prompt — text packing is invariant - # across denoising steps, so we do it once here and reuse inside the loop. + # 3. Pre-pack the text segment for each prompt — text packing is invariant across chunks and denoising + # steps, so we do it once here and reuse inside the loop. cond_text_segment = self._prepare_text_segment(cond_input_ids, device=device) uncond_text_segment = self._prepare_text_segment(uncond_input_ids, device=device) - # 4. Prepare latents (initial noise per modality + pack metadata) - ( - latents, - sound_latents, - action_latents, - fps_vision, - fps_sound, - vision_condition_mask, - sound_condition_mask, - action_condition_mask, - action_domain_id, - action_image_size, - raw_action_dim_resolved, - action_condition_frame_indexes, - ) = self.prepare_latents( - image=image, - video=video, - condition_frame_indexes_vision=condition_frame_indexes_vision, - condition_video_keep=condition_video_keep, - num_frames=num_frames, - height=height, - width=width, - fps=fps, - latents=latents, - sound_latents=sound_latents, - action_latents=action_latents, - generator=generator, - device=device, - dtype=dtype, - enable_sound=enable_sound, - action=action, - ) - vision_condition_indexes_for_pack = torch.nonzero(vision_condition_mask[:, 0, 0] > 0, as_tuple=False).flatten() - vision_condition_indexes_for_pack = [int(idx.item()) for idx in vision_condition_indexes_for_pack] - has_image_condition = bool(vision_condition_indexes_for_pack) - - # 5. Pre-pack the static per-prompt vision / sound sequence segments. The only - # fields that vary across denoising steps are the modality token tensors and the - # per-modality timestep tensors; everything else only depends on prompt length - # and modality shape, so we hoist it out of the loop and splice the step-varying - # fields back in below. - cond_vision_segment = self._prepare_vision_segment( - input_vision_tokens=latents, - has_image_condition=has_image_condition, - mrope_offset=cond_text_segment["vision_start_temporal_offset"], - vision_fps=fps_vision, - curr=cond_text_segment["und_len"], - device=device, - condition_frame_indexes=vision_condition_indexes_for_pack, - ) - cond_sound_segment: dict[str, Any] = {} - if sound_latents is not None: - cond_sound_segment = self._prepare_sound_segment( - input_sound_tokens=sound_latents, - mrope_offset=cond_text_segment["vision_start_temporal_offset"], - sound_fps=fps_sound, - curr=cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"], - device=device, - ) - cond_action_segment: dict[str, Any] = {} - if action_latents is not None: - cond_action_segment = self._prepare_action_segment( - input_action_tokens=action_latents, - condition_frame_indexes=action_condition_frame_indexes, - mrope_offset=cond_text_segment["vision_start_temporal_offset"], - action_fps=fps_vision, - curr=cond_text_segment["und_len"] - + cond_vision_segment["num_vision_tokens"] - + cond_sound_segment.get("sound_len", 0), - device=device, - ) - cond_mrope_segments = [cond_text_segment["text_mrope_ids"], cond_vision_segment["vision_mrope_ids"]] - if cond_sound_segment: - cond_mrope_segments.append(cond_sound_segment["sound_mrope_ids"]) - if cond_action_segment: - cond_mrope_segments.append(cond_action_segment["action_mrope_ids"]) - cond_packed_static = { - **cond_text_segment, - **cond_vision_segment, - **cond_sound_segment, - **cond_action_segment, - "position_ids": torch.cat(cond_mrope_segments, dim=1), - "sequence_length": cond_text_segment["und_len"] - + cond_vision_segment["num_vision_tokens"] - + cond_sound_segment.get("sound_len", 0) - + cond_action_segment.get("action_len", 0), - } - - uncond_vision_segment = self._prepare_vision_segment( - input_vision_tokens=latents, - has_image_condition=has_image_condition, - mrope_offset=uncond_text_segment["vision_start_temporal_offset"], - vision_fps=fps_vision, - curr=uncond_text_segment["und_len"], - device=device, - condition_frame_indexes=vision_condition_indexes_for_pack, - ) - uncond_sound_segment: dict[str, Any] = {} - if sound_latents is not None: - uncond_sound_segment = self._prepare_sound_segment( - input_sound_tokens=sound_latents, - mrope_offset=uncond_text_segment["vision_start_temporal_offset"], - sound_fps=fps_sound, - curr=uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"], - device=device, - ) - uncond_action_segment: dict[str, Any] = {} - if action_latents is not None: - uncond_action_segment = self._prepare_action_segment( - input_action_tokens=action_latents, - condition_frame_indexes=action_condition_frame_indexes, - mrope_offset=uncond_text_segment["vision_start_temporal_offset"], - action_fps=fps_vision, - curr=uncond_text_segment["und_len"] - + uncond_vision_segment["num_vision_tokens"] - + uncond_sound_segment.get("sound_len", 0), - device=device, - ) - uncond_mrope_segments = [uncond_text_segment["text_mrope_ids"], uncond_vision_segment["vision_mrope_ids"]] - if uncond_sound_segment: - uncond_mrope_segments.append(uncond_sound_segment["sound_mrope_ids"]) - if uncond_action_segment: - uncond_mrope_segments.append(uncond_action_segment["action_mrope_ids"]) - uncond_packed_static = { - **uncond_text_segment, - **uncond_vision_segment, - **uncond_sound_segment, - **uncond_action_segment, - "position_ids": torch.cat(uncond_mrope_segments, dim=1), - "sequence_length": uncond_text_segment["und_len"] - + uncond_vision_segment["num_vision_tokens"] - + uncond_sound_segment.get("sound_len", 0) - + uncond_action_segment.get("action_len", 0), - } - num_noisy_vision_tokens = cond_vision_segment["num_noisy_vision_tokens"] - sound_len = cond_sound_segment.get("sound_len") - action_noisy_len = cond_action_segment.get("num_noisy_action_tokens") - - # 6. Set timesteps. UniPCMultistepScheduler keeps per-step state (_step_index, - # model_outputs history) on the instance, so sound/action each get their own copy. - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - sound_scheduler = copy.deepcopy(self.scheduler) if sound_latents is not None else None - action_scheduler = copy.deepcopy(self.scheduler) if action_latents is not None else None - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - timestep = t.item() - - # The transformer projections (proj_in / audio_proj_in) are bf16; cast the per-step - # noisy tokens before packing so the modality tokens enter the model in the right dtype. - vision_tokens = latents.to(device=device, dtype=dtype) - sound_tokens = sound_latents.to(device=device, dtype=dtype) if sound_latents is not None else None - action_tokens = action_latents.to(device=device, dtype=dtype) if action_latents is not None else None - # The static packs both report the same num_noisy_vision_tokens / sound_len, so a - # single per-step timestep tensor per modality is shared by the cond / uncond passes. - vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) - sound_timesteps = ( - torch.full((sound_len,), timestep, device=device) if sound_tokens is not None else None - ) - action_timesteps = ( - torch.full((action_noisy_len,), timestep, device=device) if action_tokens is not None else None + sound = None + action_output = None + # Outer chunk loop. Non-transfer modes always run exactly one iteration (num_chunks == 1). + for chunk_id in range(num_chunks): + # 4. Prepare latents (initial noise per modality + pack metadata). + if transfer: + start_frame = chunk_id * stride + end_frame = min(start_frame + chunk_frames, total_frames) + chunk_controls = [ + _pad_temporal(control_frames[key][:, :, start_frame:end_frame], chunk_frames) for key in hint_keys + ] + ( + latents, + control_latents, + velocity_mask, + condition_latents, + target_condition_indexes, + current_conditional_frames, + ) = self._prepare_transfer_latents( + chunk_id=chunk_id, + chunk_frames=chunk_frames, + height=height, + width=width, + chunk_controls=chunk_controls, + video=video, + previous_output=previous_output, + num_first_chunk_conditional_frames=num_first_chunk_conditional_frames, + num_conditional_frames=num_conditional_frames, + tcf=tcf, + generator=generator, + device=device, + dtype=dtype, ) - - # --- Conditional pass --- - preds_vision, preds_sound, preds_action = self.transformer( - input_ids=cond_packed_static["input_ids"], - text_indexes=cond_packed_static["text_indexes"], - position_ids=cond_packed_static["position_ids"], - und_len=cond_packed_static["und_len"], - sequence_length=cond_packed_static["sequence_length"], - vision_tokens=[vision_tokens], - vision_token_shapes=cond_packed_static["vision_token_shapes"], - vision_sequence_indexes=cond_packed_static["vision_sequence_indexes"], - vision_mse_loss_indexes=cond_packed_static["vision_mse_loss_indexes"], - vision_timesteps=vision_timesteps, - vision_noisy_frame_indexes=cond_packed_static["vision_noisy_frame_indexes"], - sound_tokens=[sound_tokens] if sound_tokens is not None else None, - sound_token_shapes=cond_packed_static.get("sound_token_shapes"), - sound_sequence_indexes=cond_packed_static.get("sound_sequence_indexes"), - sound_mse_loss_indexes=cond_packed_static.get("sound_mse_loss_indexes"), - sound_timesteps=sound_timesteps, - sound_noisy_frame_indexes=cond_packed_static.get("sound_noisy_frame_indexes"), - action_tokens=[action_tokens] if action_tokens is not None else None, - action_token_shapes=cond_packed_static.get("action_token_shapes"), - action_sequence_indexes=cond_packed_static.get("action_sequence_indexes"), - action_mse_loss_indexes=cond_packed_static.get("action_mse_loss_indexes"), - action_timesteps=action_timesteps, - action_noisy_frame_indexes=cond_packed_static.get("action_noisy_frame_indexes"), - action_domain_ids=[action_domain_id] if action_domain_id is not None else None, + else: + ( + latents, + sound_latents, + action_latents, + fps_vision, + fps_sound, + vision_condition_mask, + sound_condition_mask, + action_condition_mask, + action_domain_id, + action_image_size, + raw_action_dim_resolved, + action_condition_frame_indexes, + ) = self.prepare_latents( + image=image, + video=video, + condition_frame_indexes_vision=condition_frame_indexes_vision, + condition_video_keep=condition_video_keep, + num_frames=num_frames, + height=height, + width=width, + fps=fps, + latents=latents, + sound_latents=sound_latents, + action_latents=action_latents, + generator=generator, + device=device, + dtype=dtype, + enable_sound=enable_sound, + action=action, ) - cond_v_vision, cond_v_sound, cond_v_action = self._mask_velocity_predictions( - preds_vision, - preds_sound, - vision_condition_mask=[vision_condition_mask], - sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, - preds_action=preds_action, - action_condition_mask=[action_condition_mask] if action_condition_mask is not None else None, - raw_action_dim=raw_action_dim_resolved, + vision_condition_indexes_for_pack = torch.nonzero( + vision_condition_mask[:, 0, 0] > 0, as_tuple=False + ).flatten() + vision_condition_indexes_for_pack = [int(idx.item()) for idx in vision_condition_indexes_for_pack] + has_image_condition = bool(vision_condition_indexes_for_pack) + + # 5. Pre-pack the static per-prompt sequence segments. The only fields that vary across denoising steps + # are the modality token tensors and the per-modality timestep tensors; everything else only depends on + # prompt length and modality shape, so we hoist it out of the step loop and splice those back in below. + if transfer: + # Pre-pack the three CFG sequence variants. cond_full / uncond_full carry every control item; the + # no-control branch drops them (only [text, target]) so the control axis can be amplified. + def _vision_pack(text_segment: dict[str, Any], include_controls: bool) -> dict[str, Any]: + if include_controls: + vision_items = [*control_latents, latents] + condition_indexes = [None] * num_hints + [target_condition_indexes] + clean_flags = [True] * num_hints + [False] + else: + vision_items = [latents] + condition_indexes = [target_condition_indexes] + clean_flags = [False] + vision_segment = self._prepare_vision_segment( + input_vision_tokens=vision_items, + has_image_condition=False, + mrope_offset=text_segment["vision_start_temporal_offset"], + vision_fps=fps, + curr=text_segment["und_len"], + device=device, + condition_frame_indexes=condition_indexes, + clean_item_flags=clean_flags, + share_vision_temporal_positions=share_vision_temporal_positions, + ) + return { + **text_segment, + **vision_segment, + "position_ids": torch.cat( + [text_segment["text_mrope_ids"], vision_segment["vision_mrope_ids"]], dim=1 + ), + "sequence_length": text_segment["und_len"] + vision_segment["num_vision_tokens"], + } + + cond_full_static = _vision_pack(cond_text_segment, include_controls=True) + cond_no_control_static = _vision_pack(cond_text_segment, include_controls=False) + uncond_full_static = _vision_pack(uncond_text_segment, include_controls=True) + num_noisy_vision_tokens = cond_full_static["num_noisy_vision_tokens"] + else: + cond_vision_segment = self._prepare_vision_segment( + input_vision_tokens=latents, + has_image_condition=has_image_condition, + mrope_offset=cond_text_segment["vision_start_temporal_offset"], + vision_fps=fps_vision, + curr=cond_text_segment["und_len"], + device=device, + condition_frame_indexes=vision_condition_indexes_for_pack, ) - - # --- Unconditional pass (Skip if not using CFG) --- - uncond_v_vision = uncond_v_sound = uncond_v_action = None - if self.do_classifier_free_guidance: - preds_vision, preds_sound, preds_action = self.transformer( - input_ids=uncond_packed_static["input_ids"], - text_indexes=uncond_packed_static["text_indexes"], - position_ids=uncond_packed_static["position_ids"], - und_len=uncond_packed_static["und_len"], - sequence_length=uncond_packed_static["sequence_length"], - vision_tokens=[vision_tokens], - vision_token_shapes=uncond_packed_static["vision_token_shapes"], - vision_sequence_indexes=uncond_packed_static["vision_sequence_indexes"], - vision_mse_loss_indexes=uncond_packed_static["vision_mse_loss_indexes"], - vision_timesteps=vision_timesteps, - vision_noisy_frame_indexes=uncond_packed_static["vision_noisy_frame_indexes"], - sound_tokens=[sound_tokens] if sound_tokens is not None else None, - sound_token_shapes=uncond_packed_static.get("sound_token_shapes"), - sound_sequence_indexes=uncond_packed_static.get("sound_sequence_indexes"), - sound_mse_loss_indexes=uncond_packed_static.get("sound_mse_loss_indexes"), - sound_timesteps=sound_timesteps, - sound_noisy_frame_indexes=uncond_packed_static.get("sound_noisy_frame_indexes"), - action_tokens=[action_tokens] if action_tokens is not None else None, - action_token_shapes=uncond_packed_static.get("action_token_shapes"), - action_sequence_indexes=uncond_packed_static.get("action_sequence_indexes"), - action_mse_loss_indexes=uncond_packed_static.get("action_mse_loss_indexes"), - action_timesteps=action_timesteps, - action_noisy_frame_indexes=uncond_packed_static.get("action_noisy_frame_indexes"), - action_domain_ids=[action_domain_id] if action_domain_id is not None else None, + cond_sound_segment: dict[str, Any] = {} + if sound_latents is not None: + cond_sound_segment = self._prepare_sound_segment( + input_sound_tokens=sound_latents, + mrope_offset=cond_text_segment["vision_start_temporal_offset"], + sound_fps=fps_sound, + curr=cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"], + device=device, ) - uncond_v_vision, uncond_v_sound, uncond_v_action = self._mask_velocity_predictions( - preds_vision, - preds_sound, - vision_condition_mask=[vision_condition_mask], - sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, - preds_action=preds_action, - action_condition_mask=[action_condition_mask] if action_condition_mask is not None else None, - raw_action_dim=raw_action_dim_resolved, + cond_action_segment: dict[str, Any] = {} + if action_latents is not None: + cond_action_segment = self._prepare_action_segment( + input_action_tokens=action_latents, + condition_frame_indexes=action_condition_frame_indexes, + mrope_offset=cond_text_segment["vision_start_temporal_offset"], + action_fps=fps_vision, + curr=cond_text_segment["und_len"] + + cond_vision_segment["num_vision_tokens"] + + cond_sound_segment.get("sound_len", 0), + device=device, ) + cond_mrope_segments = [cond_text_segment["text_mrope_ids"], cond_vision_segment["vision_mrope_ids"]] + if cond_sound_segment: + cond_mrope_segments.append(cond_sound_segment["sound_mrope_ids"]) + if cond_action_segment: + cond_mrope_segments.append(cond_action_segment["action_mrope_ids"]) + cond_packed_static = { + **cond_text_segment, + **cond_vision_segment, + **cond_sound_segment, + **cond_action_segment, + "position_ids": torch.cat(cond_mrope_segments, dim=1), + "sequence_length": cond_text_segment["und_len"] + + cond_vision_segment["num_vision_tokens"] + + cond_sound_segment.get("sound_len", 0) + + cond_action_segment.get("action_len", 0), + } - # --- CFG combine + per-modality scheduler step --- - # UniPC's multistep_uni_p_bh_update einsum ("k,bkc...->bc...") requires sample - # to carry a batch dim; per-modality latents have no batch axis, so wrap for the step. - - # Skip CFG for 1.0 guidance scale - if self.do_classifier_free_guidance: - velocity_vision = uncond_v_vision + guidance_scale * (cond_v_vision - uncond_v_vision) - else: - velocity_vision = cond_v_vision - - latents = self.scheduler.step( - velocity_vision.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False - )[0].squeeze(0) + uncond_vision_segment = self._prepare_vision_segment( + input_vision_tokens=latents, + has_image_condition=has_image_condition, + mrope_offset=uncond_text_segment["vision_start_temporal_offset"], + vision_fps=fps_vision, + curr=uncond_text_segment["und_len"], + device=device, + condition_frame_indexes=vision_condition_indexes_for_pack, + ) + uncond_sound_segment: dict[str, Any] = {} + if sound_latents is not None: + uncond_sound_segment = self._prepare_sound_segment( + input_sound_tokens=sound_latents, + mrope_offset=uncond_text_segment["vision_start_temporal_offset"], + sound_fps=fps_sound, + curr=uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"], + device=device, + ) + uncond_action_segment: dict[str, Any] = {} + if action_latents is not None: + uncond_action_segment = self._prepare_action_segment( + input_action_tokens=action_latents, + condition_frame_indexes=action_condition_frame_indexes, + mrope_offset=uncond_text_segment["vision_start_temporal_offset"], + action_fps=fps_vision, + curr=uncond_text_segment["und_len"] + + uncond_vision_segment["num_vision_tokens"] + + uncond_sound_segment.get("sound_len", 0), + device=device, + ) + uncond_mrope_segments = [ + uncond_text_segment["text_mrope_ids"], + uncond_vision_segment["vision_mrope_ids"], + ] + if uncond_sound_segment: + uncond_mrope_segments.append(uncond_sound_segment["sound_mrope_ids"]) + if uncond_action_segment: + uncond_mrope_segments.append(uncond_action_segment["action_mrope_ids"]) + uncond_packed_static = { + **uncond_text_segment, + **uncond_vision_segment, + **uncond_sound_segment, + **uncond_action_segment, + "position_ids": torch.cat(uncond_mrope_segments, dim=1), + "sequence_length": uncond_text_segment["und_len"] + + uncond_vision_segment["num_vision_tokens"] + + uncond_sound_segment.get("sound_len", 0) + + uncond_action_segment.get("action_len", 0), + } + num_noisy_vision_tokens = cond_vision_segment["num_noisy_vision_tokens"] + sound_len = cond_sound_segment.get("sound_len") + action_noisy_len = cond_action_segment.get("num_noisy_action_tokens") - if sound_scheduler is not None and cond_v_sound is not None: - # Skip CFG for 1.0 guidance scale - if self.do_classifier_free_guidance: - velocity_sound = uncond_v_sound + guidance_scale * (cond_v_sound - uncond_v_sound) + # 6. Set timesteps. UniPCMultistepScheduler keeps per-step state (_step_index, model_outputs history) on + # the instance, so it is reset per chunk and sound/action each get their own copy. + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + sound_scheduler = copy.deepcopy(self.scheduler) if (not transfer and sound_latents is not None) else None + action_scheduler = copy.deepcopy(self.scheduler) if (not transfer and action_latents is not None) else None + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.item() + + if transfer: + # Nested control/text CFG over [ctrl_1, ..., ctrl_N, target]. Each branch is gated by its + # interval so out-of-window steps skip the extra forward(s). + vision_tokens_full = [c.to(device=device, dtype=dtype) for c in control_latents] + [ + latents.to(device=device, dtype=dtype) + ] + vision_tokens_target = [latents.to(device=device, dtype=dtype)] + vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) + + step_guidance = guidance_scale if _active_at(t, guidance_interval) else 1.0 + step_control = control_guidance if _active_at(t, control_guidance_interval) else 1.0 + needs_text_cfg = step_guidance > 1.0 + needs_control_cfg = step_control != 1.0 + + # The target is the last vision item; control items return zeros (no MSE positions). Each + # branch is run at most once: cond_full always, no-control only when control CFG is active, + # uncond only when text CFG is active. + preds_vision, _, _ = self.transformer( + input_ids=cond_full_static["input_ids"], + text_indexes=cond_full_static["text_indexes"], + position_ids=cond_full_static["position_ids"], + und_len=cond_full_static["und_len"], + sequence_length=cond_full_static["sequence_length"], + vision_tokens=vision_tokens_full, + vision_token_shapes=cond_full_static["vision_token_shapes"], + vision_sequence_indexes=cond_full_static["vision_sequence_indexes"], + vision_mse_loss_indexes=cond_full_static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=cond_full_static["vision_noisy_frame_indexes"], + ) + cond_full = preds_vision[-1] + + cond_no_control = None + if needs_control_cfg: + preds_vision, _, _ = self.transformer( + input_ids=cond_no_control_static["input_ids"], + text_indexes=cond_no_control_static["text_indexes"], + position_ids=cond_no_control_static["position_ids"], + und_len=cond_no_control_static["und_len"], + sequence_length=cond_no_control_static["sequence_length"], + vision_tokens=vision_tokens_target, + vision_token_shapes=cond_no_control_static["vision_token_shapes"], + vision_sequence_indexes=cond_no_control_static["vision_sequence_indexes"], + vision_mse_loss_indexes=cond_no_control_static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=cond_no_control_static["vision_noisy_frame_indexes"], + ) + cond_no_control = preds_vision[-1] + + uncond_full = None + if needs_text_cfg: + preds_vision, _, _ = self.transformer( + input_ids=uncond_full_static["input_ids"], + text_indexes=uncond_full_static["text_indexes"], + position_ids=uncond_full_static["position_ids"], + und_len=uncond_full_static["und_len"], + sequence_length=uncond_full_static["sequence_length"], + vision_tokens=vision_tokens_full, + vision_token_shapes=uncond_full_static["vision_token_shapes"], + vision_sequence_indexes=uncond_full_static["vision_sequence_indexes"], + vision_mse_loss_indexes=uncond_full_static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=uncond_full_static["vision_noisy_frame_indexes"], + ) + uncond_full = preds_vision[-1] + + if needs_control_cfg and needs_text_cfg: + control_cond = cond_no_control + step_control * (cond_full - cond_no_control) + velocity = uncond_full + step_guidance * (control_cond - uncond_full) + elif needs_control_cfg: + velocity = cond_no_control + step_control * (cond_full - cond_no_control) + elif needs_text_cfg: + velocity = uncond_full + step_guidance * (cond_full - uncond_full) + else: + velocity = cond_full + + velocity = velocity * velocity_mask + latents = self.scheduler.step( + velocity.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + # Re-pin conditioned frames exactly (the autoregressive seed), guarding multistep drift. + latents = velocity_mask * latents + (1.0 - velocity_mask) * condition_latents else: - velocity_sound = cond_v_sound - sound_latents = sound_scheduler.step( - velocity_sound.unsqueeze(0), t, sound_latents.unsqueeze(0), return_dict=False - )[0].squeeze(0) - - has_noisy_action = ( - action_condition_mask is not None and action_condition_mask.sum() < action_condition_mask.numel() + # The transformer projections (proj_in / audio_proj_in) are bf16; cast the per-step + # noisy tokens before packing so the modality tokens enter the model in the right dtype. + vision_tokens = latents.to(device=device, dtype=dtype) + sound_tokens = ( + sound_latents.to(device=device, dtype=dtype) if sound_latents is not None else None + ) + action_tokens = ( + action_latents.to(device=device, dtype=dtype) if action_latents is not None else None + ) + # The static packs both report the same num_noisy_vision_tokens / sound_len, so a + # single per-step timestep tensor per modality is shared by the cond / uncond passes. + vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) + sound_timesteps = ( + torch.full((sound_len,), timestep, device=device) if sound_tokens is not None else None + ) + action_timesteps = ( + torch.full((action_noisy_len,), timestep, device=device) + if action_tokens is not None + else None + ) + + # --- Conditional pass --- + preds_vision, preds_sound, preds_action = self.transformer( + input_ids=cond_packed_static["input_ids"], + text_indexes=cond_packed_static["text_indexes"], + position_ids=cond_packed_static["position_ids"], + und_len=cond_packed_static["und_len"], + sequence_length=cond_packed_static["sequence_length"], + vision_tokens=[vision_tokens], + vision_token_shapes=cond_packed_static["vision_token_shapes"], + vision_sequence_indexes=cond_packed_static["vision_sequence_indexes"], + vision_mse_loss_indexes=cond_packed_static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=cond_packed_static["vision_noisy_frame_indexes"], + sound_tokens=[sound_tokens] if sound_tokens is not None else None, + sound_token_shapes=cond_packed_static.get("sound_token_shapes"), + sound_sequence_indexes=cond_packed_static.get("sound_sequence_indexes"), + sound_mse_loss_indexes=cond_packed_static.get("sound_mse_loss_indexes"), + sound_timesteps=sound_timesteps, + sound_noisy_frame_indexes=cond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=cond_packed_static.get("action_token_shapes"), + action_sequence_indexes=cond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=cond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=cond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[action_domain_id] if action_domain_id is not None else None, + ) + cond_v_vision, cond_v_sound, cond_v_action = self._mask_velocity_predictions( + preds_vision, + preds_sound, + vision_condition_mask=[vision_condition_mask], + sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, + preds_action=preds_action, + action_condition_mask=[action_condition_mask] + if action_condition_mask is not None + else None, + raw_action_dim=raw_action_dim_resolved, + ) + + # --- Unconditional pass (Skip if not using CFG) --- + uncond_v_vision = uncond_v_sound = uncond_v_action = None + if self.do_classifier_free_guidance: + preds_vision, preds_sound, preds_action = self.transformer( + input_ids=uncond_packed_static["input_ids"], + text_indexes=uncond_packed_static["text_indexes"], + position_ids=uncond_packed_static["position_ids"], + und_len=uncond_packed_static["und_len"], + sequence_length=uncond_packed_static["sequence_length"], + vision_tokens=[vision_tokens], + vision_token_shapes=uncond_packed_static["vision_token_shapes"], + vision_sequence_indexes=uncond_packed_static["vision_sequence_indexes"], + vision_mse_loss_indexes=uncond_packed_static["vision_mse_loss_indexes"], + vision_timesteps=vision_timesteps, + vision_noisy_frame_indexes=uncond_packed_static["vision_noisy_frame_indexes"], + sound_tokens=[sound_tokens] if sound_tokens is not None else None, + sound_token_shapes=uncond_packed_static.get("sound_token_shapes"), + sound_sequence_indexes=uncond_packed_static.get("sound_sequence_indexes"), + sound_mse_loss_indexes=uncond_packed_static.get("sound_mse_loss_indexes"), + sound_timesteps=sound_timesteps, + sound_noisy_frame_indexes=uncond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=uncond_packed_static.get("action_token_shapes"), + action_sequence_indexes=uncond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=uncond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=uncond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[action_domain_id] if action_domain_id is not None else None, + ) + uncond_v_vision, uncond_v_sound, uncond_v_action = self._mask_velocity_predictions( + preds_vision, + preds_sound, + vision_condition_mask=[vision_condition_mask], + sound_condition_mask=( + [sound_condition_mask] if sound_condition_mask is not None else None + ), + preds_action=preds_action, + action_condition_mask=( + [action_condition_mask] if action_condition_mask is not None else None + ), + raw_action_dim=raw_action_dim_resolved, + ) + + # --- CFG combine + per-modality scheduler step --- + # UniPC's multistep_uni_p_bh_update einsum ("k,bkc...->bc...") requires sample + # to carry a batch dim; per-modality latents have no batch axis, so wrap for the step. + + # Skip CFG for 1.0 guidance scale + if self.do_classifier_free_guidance: + velocity_vision = uncond_v_vision + guidance_scale * (cond_v_vision - uncond_v_vision) + else: + velocity_vision = cond_v_vision + + latents = self.scheduler.step( + velocity_vision.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + + if sound_scheduler is not None and cond_v_sound is not None: + # Skip CFG for 1.0 guidance scale + if self.do_classifier_free_guidance: + velocity_sound = uncond_v_sound + guidance_scale * (cond_v_sound - uncond_v_sound) + else: + velocity_sound = cond_v_sound + sound_latents = sound_scheduler.step( + velocity_sound.unsqueeze(0), t, sound_latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + + has_noisy_action = ( + action_condition_mask is not None + and action_condition_mask.sum() < action_condition_mask.numel() + ) + if action_scheduler is not None and has_noisy_action and cond_v_action is not None: + if self.do_classifier_free_guidance: + velocity_action = uncond_v_action + guidance_scale * (cond_v_action - uncond_v_action) + else: + velocity_action = cond_v_action + action_latents = action_scheduler.step( + velocity_action.unsqueeze(0), t, action_latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + if raw_action_dim_resolved is not None: + action_latents[:, raw_action_dim_resolved:] = 0 + + if callback_on_step_end is not None: + callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + # 8. Postprocess + decode (per chunk for transfer; once for the other modes). + if transfer: + vae_dtype = self.vae.dtype + mean = self._vae_latents_mean.to(device=latents.device, dtype=vae_dtype) + inv_std = self._vae_latents_inv_std.to(device=latents.device, dtype=vae_dtype) + z_raw = latents.to(vae_dtype) / inv_std.view(1, -1, 1, 1, 1) + mean.view(1, -1, 1, 1, 1) + output_video = self.vae.decode(z_raw).sample.to(torch.float32).clamp(-1, 1) + previous_output = output_video + # Chunks after the first overlap the previous chunk by the conditioning frames; drop on stitch. + output_chunks.append( + output_video if chunk_id == 0 else output_video[:, :, current_conditional_frames:] ) - if action_scheduler is not None and has_noisy_action and cond_v_action is not None: - if self.do_classifier_free_guidance: - velocity_action = uncond_v_action + guidance_scale * (cond_v_action - uncond_v_action) - else: - velocity_action = cond_v_action - action_latents = action_scheduler.step( - velocity_action.unsqueeze(0), t, action_latents.unsqueeze(0), return_dict=False - )[0].squeeze(0) + else: + sound = self.decode_sound(sound_latents) if sound_latents is not None else None + if action_mode in {"inverse_dynamics", "policy"} and action_latents is not None: + action_output = action_latents if raw_action_dim_resolved is not None: - action_latents[:, raw_action_dim_resolved:] = 0 - - if callback_on_step_end is not None: - callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - latents = callback_outputs.pop("latents", latents) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - self._current_timestep = None - - # 8. Postprocess + decode - sound = self.decode_sound(sound_latents) if sound_latents is not None else None - action_output = None - if action_mode in {"inverse_dynamics", "policy"} and action_latents is not None: - action_output = action_latents - if raw_action_dim_resolved is not None: - action_output = action_output[:, :raw_action_dim_resolved] - action_output = [action_output.detach().cpu()] - if output_type == "latent": - video = latents - else: - in_dtype = latents.dtype - dtype = self.vae.dtype - mean = self._vae_latents_mean.to(device=latents.device, dtype=dtype) - inv_std = self._vae_latents_inv_std.to(device=latents.device, dtype=dtype) - z_raw = latents.to(dtype) / inv_std.view(1, -1, 1, 1, 1) + mean.view(1, -1, 1, 1, 1) - decoded = self.vae.decode(z_raw).sample.to(in_dtype) + action_output = action_output[:, :raw_action_dim_resolved] + action_output = [action_output.detach().cpu()] + if output_type == "latent": + video = latents + else: + in_dtype = latents.dtype + vae_dtype = self.vae.dtype + mean = self._vae_latents_mean.to(device=latents.device, dtype=vae_dtype) + inv_std = self._vae_latents_inv_std.to(device=latents.device, dtype=vae_dtype) + z_raw = latents.to(vae_dtype) / inv_std.view(1, -1, 1, 1, 1) + mean.view(1, -1, 1, 1, 1) + decoded = self.vae.decode(z_raw).sample.to(in_dtype) + video = self.video_processor.postprocess_video(decoded, output_type=output_type)[0] + + if transfer: + decoded = torch.cat(output_chunks, dim=2)[:, :, :total_frames] video = self.video_processor.postprocess_video(decoded, output_type=output_type)[0] - if enable_safety_check and isinstance(self.safety_checker, CosmosSafetyChecker) and output_type != "latent": + if ( + enable_safety_check + and isinstance(self.safety_checker, CosmosSafetyChecker) + and (transfer or output_type != "latent") + ): video = self._apply_video_safety_check(video, output_type=output_type, device=device) self.maybe_free_model_hooks()