``` ├── .gitattributes (omitted) ├── .gitignore (700 tokens) ├── LICENSE (omitted) ├── README.md (200 tokens) ├── __init__.py ├── diffusers_helper/ ├── bucket_tools.py (200 tokens) ├── dit_common.py (300 tokens) ├── k_diffusion/ ├── uni_pc_fm.py (900 tokens) ├── wrapper.py (400 tokens) ├── memory.py (900 tokens) ├── models/ ├── hunyuan_video_packed.py (8.3k tokens) ├── pipelines/ ├── k_diffusion_hunyuan.py (800 tokens) ├── utils.py (3.5k tokens) ├── example_workflows/ ├── framepack_hv_example.json (6.4k tokens) ├── fp8_optimization.py (300 tokens) ├── nodes.py (6.1k tokens) ├── requirements.txt ├── transformer_config.json (100 tokens) ├── utils.py (900 tokens) ``` ## /.gitignore ```gitignore path="/.gitignore" hf_download/ outputs/ repo/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # UV # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. #uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ # Ruff stuff: .ruff_cache/ # PyPI configuration file .pypirc demo_gradio.py ``` ## /README.md # ComfyUI Wrapper for [FramePack by lllyasviel](https://lllyasviel.github.io/frame_pack_gitpage/) # WORK IN PROGRESS Mostly working, took some liberties to make it run faster. Uses all the native models for text encoders, VAE and sigclip: https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files https://huggingface.co/Comfy-Org/sigclip_vision_384/tree/main And the transformer model itself is either autodownloaded from here: https://huggingface.co/lllyasviel/FramePackI2V_HY/tree/main to `ComfyUI\models\diffusers\lllyasviel\FramePackI2V_HY` Or from single file, in `ComfyUI\models\diffusion_models`: https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_fp8_e4m3fn.safetensors https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors ## /__init__.py ```py path="/__init__.py" from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] ``` ## /diffusers_helper/bucket_tools.py ```py path="/diffusers_helper/bucket_tools.py" bucket_options = { (416, 960), (448, 864), (480, 832), (512, 768), (544, 704), (576, 672), (608, 640), (640, 608), (672, 576), (704, 544), (768, 512), (832, 480), (864, 448), (960, 416), } def find_nearest_bucket(h, w, resolution=640): min_metric = float('inf') best_bucket = None for (bucket_h, bucket_w) in bucket_options: metric = abs(h * bucket_w - w * bucket_h) if metric <= min_metric: min_metric = metric best_bucket = (bucket_h, bucket_w) if resolution != 640: scale_factor = resolution / 640.0 scaled_height = round(best_bucket[0] * scale_factor / 16) * 16 scaled_width = round(best_bucket[1] * scale_factor / 16) * 16 best_bucket = (scaled_height, scaled_width) print(f'Resolution: {best_bucket[1]} x {best_bucket[0]}') return best_bucket ``` ## /diffusers_helper/dit_common.py ```py path="/diffusers_helper/dit_common.py" import torch import accelerate.accelerator from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x def LayerNorm_forward(self, x): return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x) LayerNorm.forward = LayerNorm_forward torch.nn.LayerNorm.forward = LayerNorm_forward def FP32LayerNorm_forward(self, x): origin_dtype = x.dtype return torch.nn.functional.layer_norm( x.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps, ).to(origin_dtype) FP32LayerNorm.forward = FP32LayerNorm_forward def RMSNorm_forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) if self.weight is None: return hidden_states.to(input_dtype) return hidden_states.to(input_dtype) * self.weight.to(input_dtype) RMSNorm.forward = RMSNorm_forward def AdaLayerNormContinuous_forward(self, x, conditioning_embedding): emb = self.linear(self.silu(conditioning_embedding)) scale, shift = emb.chunk(2, dim=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward ``` ## /diffusers_helper/k_diffusion/uni_pc_fm.py ```py path="/diffusers_helper/k_diffusion/uni_pc_fm.py" # Better Flow Matching UniPC by Lvmin Zhang # (c) 2025 # CC BY-SA 4.0 # Attribution-ShareAlike 4.0 International Licence import torch from comfy.utils import ProgressBar from tqdm.auto import trange def expand_dims(v, dims): return v[(...,) + (None,) * (dims - 1)] class FlowMatchUniPC: def __init__(self, model, extra_args, variant='bh1'): self.model = model self.variant = variant self.extra_args = extra_args def model_fn(self, x, t): return self.model(x, t, **self.extra_args) def update_fn(self, x, model_prev_list, t_prev_list, t, order): assert order <= len(model_prev_list) dims = x.dim() t_prev_0 = t_prev_list[-1] lambda_prev_0 = - torch.log(t_prev_0) lambda_t = - torch.log(t) model_prev_0 = model_prev_list[-1] h = lambda_t - lambda_prev_0 rks = [] D1s = [] for i in range(1, order): t_prev_i = t_prev_list[-(i + 1)] model_prev_i = model_prev_list[-(i + 1)] lambda_prev_i = - torch.log(t_prev_i) rk = ((lambda_prev_i - lambda_prev_0) / h)[0] rks.append(rk) D1s.append((model_prev_i - model_prev_0) / rk) rks.append(1.) rks = torch.tensor(rks, device=x.device) R = [] b = [] hh = -h[0] h_phi_1 = torch.expm1(hh) h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.variant == 'bh1': B_h = hh elif self.variant == 'bh2': B_h = torch.expm1(hh) else: raise NotImplementedError('Bad variant!') for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= (i + 1) h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=x.device) use_predictor = len(D1s) > 0 if use_predictor: D1s = torch.stack(D1s, dim=1) if order == 2: rhos_p = torch.tensor([0.5], device=b.device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) else: D1s = None rhos_p = None if order == 1: rhos_c = torch.tensor([0.5], device=b.device) else: rhos_c = torch.linalg.solve(R, b) x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0 if use_predictor: pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) else: pred_res = 0 x_t = x_t_ - expand_dims(B_h, dims) * pred_res model_t = self.model_fn(x_t, t) if D1s is not None: corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) else: corr_res = 0 D1_t = (model_t - model_prev_0) x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t) return x_t, model_t def sample(self, x, sigmas, callback=None, disable_pbar=False): order = min(3, len(sigmas) - 2) model_prev_list, t_prev_list = [], [] comfy_pbar = ProgressBar(len(sigmas)-1) for i in trange(len(sigmas) - 1, disable=disable_pbar): vec_t = sigmas[i].expand(x.shape[0]) if i == 0: model_prev_list = [self.model_fn(x, vec_t)] t_prev_list = [vec_t] elif i < order: init_order = i x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order) model_prev_list.append(model_x) t_prev_list.append(vec_t) else: x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order) model_prev_list.append(model_x) t_prev_list.append(vec_t) model_prev_list = model_prev_list[-order:] t_prev_list = t_prev_list[-order:] if callback is not None: callback_latent = model_prev_list[-1].detach()[0].permute(1,0,2,3) callback( i, callback_latent, None, len(sigmas) - 1 ) comfy_pbar.update(1) return model_prev_list[-1] def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): assert variant in ['bh1', 'bh2'] return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable) ``` ## /diffusers_helper/k_diffusion/wrapper.py ```py path="/diffusers_helper/k_diffusion/wrapper.py" import torch def append_dims(x, target_dims): return x[(...,) + (None,) * (target_dims - x.ndim)] def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0): if guidance_rescale == 0: return noise_cfg std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg return noise_cfg def fm_wrapper(transformer, t_scale=1000.0): def k_model(x, sigma, **extra_args): dtype = extra_args['dtype'] cfg_scale = extra_args['cfg_scale'] cfg_rescale = extra_args['cfg_rescale'] concat_latent = extra_args['concat_latent'] original_dtype = x.dtype sigma = sigma.float() x = x.to(dtype) timestep = (sigma * t_scale).to(dtype) if concat_latent is None: hidden_states = x else: hidden_states = torch.cat([x, concat_latent.to(x)], dim=1) pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float() if cfg_scale == 1.0: pred_negative = torch.zeros_like(pred_positive) else: pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float() pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative) pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale) x0 = x.float() - pred.float() * append_dims(sigma, x.ndim) return x0.to(dtype=original_dtype) return k_model ``` ## /diffusers_helper/memory.py ```py path="/diffusers_helper/memory.py" # By lllyasviel import torch cpu = torch.device('cpu') gpu = torch.device(f'cuda:{torch.cuda.current_device()}') gpu_complete_modules = [] class DynamicSwapInstaller: @staticmethod def _install_module(module: torch.nn.Module, **kwargs): original_class = module.__class__ module.__dict__['forge_backup_original_class'] = original_class def hacked_get_attr(self, name: str): if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] if name in _parameters: p = _parameters[name] if p is None: return None if p.__class__ == torch.nn.Parameter: return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad) else: return p.to(**kwargs) if '_buffers' in self.__dict__: _buffers = self.__dict__['_buffers'] if name in _buffers: return _buffers[name].to(**kwargs) return super(original_class, self).__getattr__(name) module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), { '__getattr__': hacked_get_attr, }) return @staticmethod def _uninstall_module(module: torch.nn.Module): if 'forge_backup_original_class' in module.__dict__: module.__class__ = module.__dict__.pop('forge_backup_original_class') return @staticmethod def install_model(model: torch.nn.Module, **kwargs): for m in model.modules(): DynamicSwapInstaller._install_module(m, **kwargs) return @staticmethod def uninstall_model(model: torch.nn.Module): for m in model.modules(): DynamicSwapInstaller._uninstall_module(m) return def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device): if hasattr(model, 'scale_shift_table'): model.scale_shift_table.data = model.scale_shift_table.data.to(target_device) return for k, p in model.named_modules(): if hasattr(p, 'weight'): p.to(target_device) return def get_cuda_free_memory_gb(device=None): if device is None: device = gpu memory_stats = torch.cuda.memory_stats(device) bytes_active = memory_stats['active_bytes.all.current'] bytes_reserved = memory_stats['reserved_bytes.all.current'] bytes_free_cuda, _ = torch.cuda.mem_get_info(device) bytes_inactive_reserved = bytes_reserved - bytes_active bytes_total_available = bytes_free_cuda + bytes_inactive_reserved return bytes_total_available / (1024 ** 3) def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0): print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB') for m in model.modules(): if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb: torch.cuda.empty_cache() return if hasattr(m, 'weight'): m.to(device=target_device) model.to(device=target_device) torch.cuda.empty_cache() return def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0): print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB') for m in model.modules(): if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb: torch.cuda.empty_cache() return if hasattr(m, 'weight'): m.to(device=cpu) model.to(device=cpu) torch.cuda.empty_cache() return def unload_complete_models(*args): for m in gpu_complete_modules + list(args): m.to(device=cpu) print(f'Unloaded {m.__class__.__name__} as complete.') gpu_complete_modules.clear() torch.cuda.empty_cache() return def load_model_as_complete(model, target_device, unload=True): if unload: unload_complete_models() model.to(device=target_device) print(f'Loaded {model.__class__.__name__} to {target_device} as complete.') gpu_complete_modules.append(model) return ``` ## /diffusers_helper/models/hunyuan_video_packed.py ```py path="/diffusers_helper/models/hunyuan_video_packed.py" from typing import Any, Dict, List, Optional, Tuple, Union import torch import einops import torch.nn as nn import numpy as np from diffusers.loaders import FromOriginalModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.utils import logging from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import Attention from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from ...diffusers_helper.dit_common import LayerNorm enabled_backends = [] if torch.backends.cuda.flash_sdp_enabled(): enabled_backends.append("flash") if torch.backends.cuda.math_sdp_enabled(): enabled_backends.append("math") if torch.backends.cuda.mem_efficient_sdp_enabled(): enabled_backends.append("mem_efficient") if torch.backends.cuda.cudnn_sdp_enabled(): enabled_backends.append("cudnn") try: # raise NotImplementedError from flash_attn import flash_attn_varlen_func, flash_attn_func except: flash_attn_varlen_func = None flash_attn_func = None try: # raise NotImplementedError from sageattention import sageattn_varlen, sageattn except: sageattn_varlen = None sageattn = None logger = logging.get_logger(__name__) # pylint: disable=invalid-name def pad_for_3d_conv(x, kernel_size): b, c, t, h, w = x.shape pt, ph, pw = kernel_size pad_t = (pt - (t % pt)) % pt pad_h = (ph - (h % ph)) % ph pad_w = (pw - (w % pw)) % pw return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate') def center_down_sample_3d(x, kernel_size): # pt, ph, pw = kernel_size # cp = (pt * ph * pw) // 2 # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw) # xc = xp[cp] # return xc return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) def get_cu_seqlens(text_mask, img_len): batch_size = text_mask.shape[0] text_len = text_mask.sum(dim=1) max_len = text_mask.shape[1] + img_len cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") for i in range(batch_size): s = text_len[i] + img_len s1 = i * max_len + s s2 = (i + 1) * max_len cu_seqlens[2 * i + 1] = s1 cu_seqlens[2 * i + 2] = s2 return cu_seqlens def apply_rotary_emb_transposed(x, freqs_cis): cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) out = x.float() * cos + x_rotated.float() * sin out = out.to(x) return out def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attention_mode='sdpa'): if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None: if attention_mode == "sageattn": x = sageattn(q, k, v, tensor_layout='NHD') if attention_mode == "flash_attn": x = flash_attn_func(q, k, v) elif attention_mode == "sdpa": x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) return x # batch_size = q.shape[0] # q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # if sageattn_varlen is not None: # x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) # elif flash_attn_varlen_func is not None: # x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) # else: # raise NotImplementedError('No Attn Installed!') # x = x.view(batch_size, max_seqlen_q, *x.shape[2:]) # return x class HunyuanAttnProcessorFlashAttnDouble: def __init__(self, attention_mode): self.attention_mode = attention_mode def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) query = apply_rotary_emb_transposed(query, image_rotary_emb) key = apply_rotary_emb_transposed(key, image_rotary_emb) encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) query = torch.cat([query, encoder_query], dim=1) key = torch.cat([key, encoder_key], dim=1) value = torch.cat([value, encoder_value], dim=1) hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, self.attention_mode) hidden_states = hidden_states.flatten(-2) txt_length = encoder_hidden_states.shape[1] hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states class HunyuanAttnProcessorFlashAttnSingle: def __init__(self, attention_mode): self.attention_mode = attention_mode def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) txt_length = encoder_hidden_states.shape[1] query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1) key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1) hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, self.attention_mode) hidden_states = hidden_states.flatten(-2) hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] return hidden_states, encoder_hidden_states class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, guidance, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) guidance_proj = self.time_proj(guidance) guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) time_guidance_emb = timesteps_emb + guidance_emb pooled_projections = self.text_embedder(pooled_projection) conditioning = time_guidance_emb + pooled_projections return conditioning class CombinedTimestepTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) pooled_projections = self.text_embedder(pooled_projection) conditioning = timesteps_emb + pooled_projections return conditioning class HunyuanVideoAdaNorm(nn.Module): def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: super().__init__() out_features = out_features or 2 * in_features self.linear = nn.Linear(in_features, out_features) self.nonlinearity = nn.SiLU() def forward( self, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: temb = self.linear(self.nonlinearity(temb)) gate_msa, gate_mlp = temb.chunk(2, dim=-1) gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) return gate_msa, gate_mlp class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, heads=num_attention_heads, dim_head=attention_head_dim, bias=attention_bias, ) self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=None, attention_mask=attention_mask, ) gate_msa, gate_mlp = self.norm_out(temb) hidden_states = hidden_states + attn_output * gate_msa ff_output = self.ff(self.norm2(hidden_states)) hidden_states = hidden_states + ff_output * gate_mlp return hidden_states class HunyuanVideoIndividualTokenRefiner(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, num_layers: int, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() self.refiner_blocks = nn.ModuleList( [ HunyuanVideoIndividualTokenRefinerBlock( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, attention_bias=attention_bias, ) for _ in range(num_layers) ] ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> None: self_attn_mask = None if attention_mask is not None: batch_size = attention_mask.shape[0] seq_len = attention_mask.shape[1] attention_mask = attention_mask.to(hidden_states.device).bool() self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() self_attn_mask[:, :, :, 0] = True for block in self.refiner_blocks: hidden_states = block(hidden_states, temb, self_attn_mask) return hidden_states class HunyuanVideoTokenRefiner(nn.Module): def __init__( self, in_channels: int, num_attention_heads: int, attention_head_dim: int, num_layers: int, mlp_ratio: float = 4.0, mlp_drop_rate: float = 0.0, attention_bias: bool = True, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=hidden_size, pooled_projection_dim=in_channels ) self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) self.token_refiner = HunyuanVideoIndividualTokenRefiner( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, num_layers=num_layers, mlp_width_ratio=mlp_ratio, mlp_drop_rate=mlp_drop_rate, attention_bias=attention_bias, ) def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, ) -> torch.Tensor: if attention_mask is None: pooled_projections = hidden_states.mean(dim=1) else: original_dtype = hidden_states.dtype mask_float = attention_mask.float().unsqueeze(-1) pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) pooled_projections = pooled_projections.to(original_dtype) temb = self.time_text_embed(timestep, pooled_projections) hidden_states = self.proj_in(hidden_states) hidden_states = self.token_refiner(hidden_states, temb, attention_mask) return hidden_states class HunyuanVideoRotaryPosEmbed(nn.Module): def __init__(self, rope_dim, theta): super().__init__() self.DT, self.DY, self.DX = rope_dim self.theta = theta @torch.no_grad() def get_frequency(self, dim, pos): T, H, W = pos.shape freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim)) freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) return freqs.cos(), freqs.sin() @torch.no_grad() def forward_inner(self, frame_indices, height, width, device): GT, GY, GX = torch.meshgrid( frame_indices.to(device=device, dtype=torch.float32), torch.arange(0, height, device=device, dtype=torch.float32), torch.arange(0, width, device=device, dtype=torch.float32), indexing="ij" ) FCT, FST = self.get_frequency(self.DT, GT) FCY, FSY = self.get_frequency(self.DY, GY) FCX, FSX = self.get_frequency(self.DX, GX) result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) return result.to(device) @torch.no_grad() def forward(self, frame_indices, height, width, device): frame_indices = frame_indices.unbind(0) results = [self.forward_inner(f, height, width, device) for f in frame_indices] results = torch.stack(results, dim=0) return results class AdaLayerNormZero(nn.Module): def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) else: raise ValueError(f"unknown norm_type {norm_type}") def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) x = self.norm(x) * (1 + scale_msa) + shift_msa return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaLayerNormZeroSingle(nn.Module): def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) else: raise ValueError(f"unknown norm_type {norm_type}") def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) x = self.norm(x) * (1 + scale_msa) + shift_msa return x, gate_msa class AdaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) else: raise ValueError(f"unknown norm_type {norm_type}") def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: emb = emb.unsqueeze(-2) emb = self.linear(self.silu(emb)) scale, shift = emb.chunk(2, dim=-1) x = self.norm(x) * (1 + scale) + shift return x class HunyuanVideoSingleTransformerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0, qk_norm: str = "rms_norm", attention_mode: str = "sdpa", ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim mlp_dim = int(hidden_size * mlp_ratio) self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=hidden_size, bias=True, processor=HunyuanAttnProcessorFlashAttnSingle(attention_mode), qk_norm=qk_norm, eps=1e-6, pre_only=True, ) self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") self.proj_mlp = nn.Linear(hidden_size, mlp_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.shape[1] hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) residual = hidden_states # 1. Input normalization norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) norm_hidden_states, norm_encoder_hidden_states = ( norm_hidden_states[:, :-text_seq_length, :], norm_hidden_states[:, -text_seq_length:, :], ) # 2. Attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) attn_output = torch.cat([attn_output, context_attn_output], dim=1) # 3. Modulation and residual connection hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states = gate * self.proj_out(hidden_states) hidden_states = hidden_states + residual hidden_states, encoder_hidden_states = ( hidden_states[:, :-text_seq_length, :], hidden_states[:, -text_seq_length:, :], ) return hidden_states, encoder_hidden_states class HunyuanVideoTransformerBlock(nn.Module): def __init__( self, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float, qk_norm: str = "rms_norm", attention_mode: str = "sdpa", ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") self.attn = Attention( query_dim=hidden_size, cross_attention_dim=None, added_kv_proj_dim=hidden_size, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=hidden_size, context_pre_only=False, bias=True, processor=HunyuanAttnProcessorFlashAttnDouble(attention_mode), qk_norm=qk_norm, eps=1e-6, ) self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb) # 2. Joint attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, attention_mask=attention_mask, image_rotary_emb=freqs_cis, ) # 3. Modulation and residual connection hidden_states = hidden_states + attn_output * gate_msa encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa norm_hidden_states = self.norm2(hidden_states) norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp # 4. Feed-forward ff_output = self.ff(norm_hidden_states) context_ff_output = self.ff_context(norm_encoder_hidden_states) hidden_states = hidden_states + gate_mlp * ff_output encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output return hidden_states, encoder_hidden_states class ClipVisionProjection(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.Linear(in_channels, out_channels * 3) self.down = nn.Linear(out_channels * 3, out_channels) def forward(self, x): projected_x = self.down(nn.functional.silu(self.up(x))) return projected_x class HunyuanVideoPatchEmbed(nn.Module): def __init__(self, patch_size, in_chans, embed_dim): super().__init__() self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): def __init__(self, inner_dim): super().__init__() self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) @torch.no_grad() def initialize_weight_from_another_conv3d(self, another_layer): weight = another_layer.weight.detach().clone() bias = another_layer.bias.detach().clone() sd = { 'proj.weight': weight.clone(), 'proj.bias': bias.clone(), 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0, 'proj_2x.bias': bias.clone(), 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0, 'proj_4x.bias': bias.clone(), } sd = {k: v.clone() for k, v in sd.items()} self.load_state_dict(sd) return class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): @register_to_config def __init__( self, in_channels: int = 16, out_channels: int = 16, num_attention_heads: int = 24, attention_head_dim: int = 128, num_layers: int = 20, num_single_layers: int = 40, num_refiner_layers: int = 2, mlp_ratio: float = 4.0, patch_size: int = 2, patch_size_t: int = 1, qk_norm: str = "rms_norm", guidance_embeds: bool = True, text_embed_dim: int = 4096, pooled_projection_dim: int = 768, rope_theta: float = 256.0, rope_axes_dim: Tuple[int] = (16, 56, 56), has_image_proj=False, image_proj_dim=1152, has_clean_x_embedder=False, attention_mode="sdpa", ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels # 1. Latent and condition embedders self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) self.context_embedder = HunyuanVideoTokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) self.clean_x_embedder = None self.image_projection = None # 2. RoPE self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) # 3. Dual stream transformer blocks self.transformer_blocks = nn.ModuleList( [ HunyuanVideoTransformerBlock( num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm, attention_mode=attention_mode ) for _ in range(num_layers) ] ) # 4. Single stream transformer blocks self.single_transformer_blocks = nn.ModuleList( [ HunyuanVideoSingleTransformerBlock( num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm, attention_mode=attention_mode ) for _ in range(num_single_layers) ] ) # 5. Output projection self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) self.inner_dim = inner_dim self.use_gradient_checkpointing = False self.enable_teacache = False if has_image_proj: self.install_image_projection(image_proj_dim) if has_clean_x_embedder: self.install_clean_x_embedder() self.high_quality_fp32_output_for_inference = False def install_image_projection(self, in_channels): self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim) self.config['has_image_proj'] = True self.config['image_proj_dim'] = in_channels def install_clean_x_embedder(self): self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim) self.config['has_clean_x_embedder'] = True def enable_gradient_checkpointing(self): self.use_gradient_checkpointing = True print('self.use_gradient_checkpointing = True') def disable_gradient_checkpointing(self): self.use_gradient_checkpointing = False print('self.use_gradient_checkpointing = False') def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15): self.enable_teacache = enable_teacache self.cnt = 0 self.num_steps = num_steps self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None self.previous_residual = None self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]) def gradient_checkpointing_method(self, block, *args): if self.use_gradient_checkpointing: result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False) else: result = block(*args) return result def process_input_hidden_states( self, latents, latent_indices=None, clean_latents=None, clean_latent_indices=None, clean_latents_2x=None, clean_latent_2x_indices=None, clean_latents_4x=None, clean_latent_4x_indices=None ): hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents) B, C, T, H, W = hidden_states.shape if latent_indices is None: latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1) hidden_states = hidden_states.flatten(2).transpose(1, 2) rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device) rope_freqs = rope_freqs.flatten(2).transpose(1, 2) if clean_latents is not None and clean_latent_indices is not None: clean_latents = clean_latents.to(hidden_states) clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents) clean_latents = clean_latents.flatten(2).transpose(1, 2) clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device) clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1) if clean_latents_2x is not None and clean_latent_2x_indices is not None: clean_latents_2x = clean_latents_2x.to(hidden_states) clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x) clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device) clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2)) clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2)) clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1) if clean_latents_4x is not None and clean_latent_4x_indices is not None: clean_latents_4x = clean_latents_4x.to(hidden_states) clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x) clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device) clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4)) clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4)) clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2) hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1) return hidden_states, rope_freqs def forward( self, hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance, latent_indices=None, clean_latents=None, clean_latent_indices=None, clean_latents_2x=None, clean_latent_2x_indices=None, clean_latents_4x=None, clean_latent_4x_indices=None, image_embeddings=None, attention_kwargs=None, return_dict=True ): if attention_kwargs is None: attention_kwargs = {} batch_size, num_channels, num_frames, height, width = hidden_states.shape p, p_t = self.config['patch_size'], self.config['patch_size_t'] post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p original_context_length = post_patch_num_frames * post_patch_height * post_patch_width hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices) temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections) encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask) if self.image_projection is not None and image_embeddings is not None: extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings) extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device) # must cat before (not after) encoder_hidden_states, due to attn masking encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1) encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1) with torch.no_grad(): if batch_size == 1: # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want # If they are not same, then their impls are wrong. Ours are always the correct one. text_len = encoder_attention_mask.sum().item() encoder_hidden_states = encoder_hidden_states[:, :text_len] attention_mask = None, None, None, None else: img_seq_len = hidden_states.shape[1] txt_seq_len = encoder_hidden_states.shape[1] cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) cu_seqlens_kv = cu_seqlens_q max_seqlen_q = img_seq_len + txt_seq_len max_seqlen_kv = max_seqlen_q attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv if self.enable_teacache: modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] if self.cnt == 0 or self.cnt == self.num_steps-1: should_calc = True self.accumulated_rel_l1_distance = 0 else: curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item() self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1) should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh if should_calc: self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp self.cnt += 1 if self.cnt == self.num_steps: self.cnt = 0 if not should_calc: hidden_states = hidden_states + self.previous_residual else: ori_hidden_states = hidden_states.clone() for block_id, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) for block_id, block in enumerate(self.single_transformer_blocks): hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) self.previous_residual = hidden_states - ori_hidden_states else: for block_id, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) for block_id, block in enumerate(self.single_transformer_blocks): hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs ) hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb) hidden_states = hidden_states[:, -original_context_length:, :] if self.high_quality_fp32_output_for_inference: hidden_states = hidden_states.to(dtype=torch.float32) if self.proj_out.weight.dtype != torch.float32: self.proj_out.to(dtype=torch.float32) hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states) hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)', t=post_patch_num_frames, h=post_patch_height, w=post_patch_width, pt=p_t, ph=p, pw=p) if return_dict: return Transformer2DModelOutput(sample=hidden_states) return hidden_states, ``` ## /diffusers_helper/pipelines/k_diffusion_hunyuan.py ```py path="/diffusers_helper/pipelines/k_diffusion_hunyuan.py" import torch import math from ..k_diffusion.uni_pc_fm import sample_unipc from ..k_diffusion.wrapper import fm_wrapper from ..utils import repeat_to_batch_size def flux_time_shift(t, mu=1.15, sigma=1.0): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): k = (y2 - y1) / (x2 - x1) b = y1 - k * x1 mu = k * context_length + b mu = min(mu, math.log(exp_max)) return mu def get_flux_sigmas_from_mu(n, mu): sigmas = torch.linspace(1, 0, steps=n + 1) sigmas = flux_time_shift(sigmas, mu=mu) return sigmas @torch.inference_mode() def sample_hunyuan( transformer, sampler='unipc', initial_latent=None, concat_latent=None, strength=1.0, width=512, height=512, frames=16, real_guidance_scale=1.0, distilled_guidance_scale=6.0, guidance_rescale=0.0, shift=None, num_inference_steps=25, batch_size=None, generator=None, prompt_embeds=None, prompt_embeds_mask=None, prompt_poolers=None, negative_prompt_embeds=None, negative_prompt_embeds_mask=None, negative_prompt_poolers=None, dtype=torch.bfloat16, device=None, negative_kwargs=None, callback=None, **kwargs, ): device = device or transformer.device if batch_size is None: batch_size = int(prompt_embeds.shape[0]) latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32) B, C, T, H, W = latents.shape seq_length = T * H * W // 4 if shift is None: mu = calculate_flux_mu(seq_length, exp_max=7.0) else: mu = math.log(shift) sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) k_model = fm_wrapper(transformer) if initial_latent is not None: sigmas = sigmas * strength first_sigma = sigmas[0].to(device=device, dtype=torch.float32) initial_latent = initial_latent.to(device=device, dtype=torch.float32) latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma if concat_latent is not None: concat_latent = concat_latent.to(latents) distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) concat_latent = repeat_to_batch_size(concat_latent, batch_size) sampler_kwargs = dict( dtype=dtype, cfg_scale=real_guidance_scale, cfg_rescale=guidance_rescale, concat_latent=concat_latent, positive=dict( pooled_projections=prompt_poolers, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_embeds_mask, guidance=distilled_guidance, **kwargs, ), negative=dict( pooled_projections=negative_prompt_poolers, encoder_hidden_states=negative_prompt_embeds, encoder_attention_mask=negative_prompt_embeds_mask, guidance=distilled_guidance, **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), ) ) if sampler == 'unipc_bh1': variant = 'bh1' elif sampler == 'unipc_bh2': variant = 'bh2' results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, variant=variant, callback=callback) return results ``` ## /diffusers_helper/utils.py ```py path="/diffusers_helper/utils.py" import os #import cv2 import json import random import glob import torch import einops import numpy as np import datetime import torchvision import safetensors.torch as sf from PIL import Image # def min_resize(x, m): # if x.shape[0] < x.shape[1]: # s0 = m # s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) # else: # s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) # s1 = m # new_max = max(s1, s0) # raw_max = max(x.shape[0], x.shape[1]) # if new_max < raw_max: # interpolation = cv2.INTER_AREA # else: # interpolation = cv2.INTER_LANCZOS4 # y = cv2.resize(x, (s1, s0), interpolation=interpolation) # return y # def d_resize(x, y): # H, W, C = y.shape # new_min = min(H, W) # raw_min = min(x.shape[0], x.shape[1]) # if new_min < raw_min: # interpolation = cv2.INTER_AREA # else: # interpolation = cv2.INTER_LANCZOS4 # y = cv2.resize(x, (W, H), interpolation=interpolation) # return y def resize_and_center_crop(image, target_width, target_height): if target_height == image.shape[0] and target_width == image.shape[1]: return image pil_image = Image.fromarray(image) original_width, original_height = pil_image.size scale_factor = max(target_width / original_width, target_height / original_height) resized_width = int(round(original_width * scale_factor)) resized_height = int(round(original_height * scale_factor)) resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) left = (resized_width - target_width) / 2 top = (resized_height - target_height) / 2 right = (resized_width + target_width) / 2 bottom = (resized_height + target_height) / 2 cropped_image = resized_image.crop((left, top, right, bottom)) return np.array(cropped_image) def resize_and_center_crop_pytorch(image, target_width, target_height): B, C, H, W = image.shape if H == target_height and W == target_width: return image scale_factor = max(target_width / W, target_height / H) resized_width = int(round(W * scale_factor)) resized_height = int(round(H * scale_factor)) resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False) top = (resized_height - target_height) // 2 left = (resized_width - target_width) // 2 cropped = resized[:, :, top:top + target_height, left:left + target_width] return cropped def resize_without_crop(image, target_width, target_height): if target_height == image.shape[0] and target_width == image.shape[1]: return image pil_image = Image.fromarray(image) resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) return np.array(resized_image) def just_crop(image, w, h): if h == image.shape[0] and w == image.shape[1]: return image original_height, original_width = image.shape[:2] k = min(original_height / h, original_width / w) new_width = int(round(w * k)) new_height = int(round(h * k)) x_start = (original_width - new_width) // 2 y_start = (original_height - new_height) // 2 cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width] return cropped_image def write_to_json(data, file_path): temp_file_path = file_path + ".tmp" with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: json.dump(data, temp_file, indent=4) os.replace(temp_file_path, file_path) return def read_from_json(file_path): with open(file_path, 'rt', encoding='utf-8') as file: data = json.load(file) return data def get_active_parameters(m): return {k: v for k, v in m.named_parameters() if v.requires_grad} def cast_training_params(m, dtype=torch.float32): result = {} for n, param in m.named_parameters(): if param.requires_grad: param.data = param.to(dtype) result[n] = param return result def separate_lora_AB(parameters, B_patterns=None): parameters_normal = {} parameters_B = {} if B_patterns is None: B_patterns = ['.lora_B.', '__zero__'] for k, v in parameters.items(): if any(B_pattern in k for B_pattern in B_patterns): parameters_B[k] = v else: parameters_normal[k] = v return parameters_normal, parameters_B def set_attr_recursive(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) setattr(obj, attrs[-1], value) return def print_tensor_list_size(tensors): total_size = 0 total_elements = 0 if isinstance(tensors, dict): tensors = tensors.values() for tensor in tensors: total_size += tensor.nelement() * tensor.element_size() total_elements += tensor.nelement() total_size_MB = total_size / (1024 ** 2) total_elements_B = total_elements / 1e9 print(f"Total number of tensors: {len(tensors)}") print(f"Total size of tensors: {total_size_MB:.2f} MB") print(f"Total number of parameters: {total_elements_B:.3f} billion") return @torch.no_grad() def batch_mixture(a, b=None, probability_a=0.5, mask_a=None): batch_size = a.size(0) if b is None: b = torch.zeros_like(a) if mask_a is None: mask_a = torch.rand(batch_size) < probability_a mask_a = mask_a.to(a.device) mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) result = torch.where(mask_a, a, b) return result @torch.no_grad() def zero_module(module): for p in module.parameters(): p.detach().zero_() return module @torch.no_grad() def supress_lower_channels(m, k, alpha=0.01): data = m.weight.data.clone() assert int(data.shape[1]) >= k data[:, :k] = data[:, :k] * alpha m.weight.data = data.contiguous().clone() return m def freeze_module(m): if not hasattr(m, '_forward_inside_frozen_module'): m._forward_inside_frozen_module = m.forward m.requires_grad_(False) m.forward = torch.no_grad()(m.forward) return m def get_latest_safetensors(folder_path): safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors')) if not safetensors_files: raise ValueError('No file to resume!') latest_file = max(safetensors_files, key=os.path.getmtime) latest_file = os.path.abspath(os.path.realpath(latest_file)) return latest_file def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): tags = tags_str.split(', ') tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) prompt = ', '.join(tags) return prompt def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0): numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma) if round_to_int: numbers = np.round(numbers).astype(int) return numbers.tolist() def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False): edges = np.linspace(0, 1, n + 1) points = np.random.uniform(edges[:-1], edges[1:]) numbers = inclusive + (exclusive - inclusive) * points if round_to_int: numbers = np.round(numbers).astype(int) return numbers.tolist() def soft_append_bcthw(history, current, overlap=0): if overlap <= 0: return torch.cat([history, current], dim=2) assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) return output.to(history) def save_bcthw_as_mp4(x, output_filename, fps=10): b, c, t, h, w = x.shape per_row = b for p in [6, 5, 4, 3, 2]: if b % p == 0: per_row = p break os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 x = x.detach().cpu().to(torch.uint8) x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'}) return x def save_bcthw_as_png(x, output_filename): os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 x = x.detach().cpu().to(torch.uint8) x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') torchvision.io.write_png(x, output_filename) return output_filename def save_bchw_as_png(x, output_filename): os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 x = x.detach().cpu().to(torch.uint8) x = einops.rearrange(x, 'b c h w -> c h (b w)') torchvision.io.write_png(x, output_filename) return output_filename def add_tensors_with_padding(tensor1, tensor2): if tensor1.shape == tensor2.shape: return tensor1 + tensor2 shape1 = tensor1.shape shape2 = tensor2.shape new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) padded_tensor1 = torch.zeros(new_shape) padded_tensor2 = torch.zeros(new_shape) padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 result = padded_tensor1 + padded_tensor2 return result def print_free_mem(): torch.cuda.empty_cache() free_mem, total_mem = torch.cuda.mem_get_info(0) free_mem_mb = free_mem / (1024 ** 2) total_mem_mb = total_mem / (1024 ** 2) print(f"Free memory: {free_mem_mb:.2f} MB") print(f"Total memory: {total_mem_mb:.2f} MB") return def print_gpu_parameters(device, state_dict, log_count=1): summary = {"device": device, "keys_count": len(state_dict)} logged_params = {} for i, (key, tensor) in enumerate(state_dict.items()): if i >= log_count: break logged_params[key] = tensor.flatten()[:3].tolist() summary["params"] = logged_params print(str(summary)) return def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18): from PIL import Image, ImageDraw, ImageFont txt = Image.new("RGB", (width, height), color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype(font_path, size=size) if text == '': return np.array(txt) # Split text into lines that fit within the image width lines = [] words = text.split() current_line = words[0] for word in words[1:]: line_with_word = f"{current_line} {word}" if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width: current_line = line_with_word else: lines.append(current_line) current_line = word lines.append(current_line) # Draw the text line by line y = 0 line_height = draw.textbbox((0, 0), "A", font=font)[3] for line in lines: if y + line_height > height: break # stop drawing if the next line will be outside the image draw.text((0, y), line, fill="black", font=font) y += line_height return np.array(txt) # def blue_mark(x): # x = x.copy() # c = x[:, :, 2] # b = cv2.blur(c, (9, 9)) # x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1) # return x # def green_mark(x): # x = x.copy() # x[:, :, 2] = -1 # x[:, :, 0] = -1 # return x # def frame_mark(x): # x = x.copy() # x[:64] = -1 # x[-64:] = -1 # x[:, :8] = 1 # x[:, -8:] = 1 # return x @torch.inference_mode() def pytorch2numpy(imgs): results = [] for x in imgs: y = x.movedim(0, -1) y = y * 127.5 + 127.5 y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) results.append(y) return results @torch.inference_mode() def numpy2pytorch(imgs): h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 h = h.movedim(-1, 1) return h @torch.no_grad() def duplicate_prefix_to_suffix(x, count, zero_out=False): if zero_out: return torch.cat([x, torch.zeros_like(x[:count])], dim=0) else: return torch.cat([x, x[:count]], dim=0) def weighted_mse(a, b, weight): return torch.mean(weight.float() * (a.float() - b.float()) ** 2) def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0): x = (x - x_min) / (x_max - x_min) x = max(0.0, min(x, 1.0)) x = x ** sigma return y_min + x * (y_max - y_min) def expand_to_dims(x, target_dims): return x.view(*x.shape, *([1] * max(0, target_dims - x.dim()))) def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int): if tensor is None: return None first_dim = tensor.shape[0] if first_dim == batch_size: return tensor if batch_size % first_dim != 0: raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.") repeat_times = batch_size // first_dim return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1)) def dim5(x): return expand_to_dims(x, 5) def dim4(x): return expand_to_dims(x, 4) def dim3(x): return expand_to_dims(x, 3) def crop_or_pad_yield_mask(x, length): B, F, C = x.shape device = x.device dtype = x.dtype if F < length: y = torch.zeros((B, length, C), dtype=dtype, device=device) mask = torch.zeros((B, length), dtype=torch.bool, device=device) y[:, :F, :] = x mask[:, :F] = True return y, mask return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device) def extend_dim(x, dim, minimal_length, zero_pad=False): original_length = int(x.shape[dim]) if original_length >= minimal_length: return x if zero_pad: padding_shape = list(x.shape) padding_shape[dim] = minimal_length - original_length padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device) else: idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1) last_element = x[idx] padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim) return torch.cat([x, padding], dim=dim) def lazy_positional_encoding(t, repeats=None): if not isinstance(t, list): t = [t] from diffusers.models.embeddings import get_timestep_embedding te = torch.tensor(t) te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0) if repeats is None: return te te = te[:, None, :].expand(-1, repeats, -1) return te def state_dict_offset_merge(A, B, C=None): result = {} keys = A.keys() for key in keys: A_value = A[key] B_value = B[key].to(A_value) if C is None: result[key] = A_value + B_value else: C_value = C[key].to(A_value) result[key] = A_value + B_value - C_value return result def state_dict_weighted_merge(state_dicts, weights): if len(state_dicts) != len(weights): raise ValueError("Number of state dictionaries must match number of weights") if not state_dicts: return {} total_weight = sum(weights) if total_weight == 0: raise ValueError("Sum of weights cannot be zero") normalized_weights = [w / total_weight for w in weights] keys = state_dicts[0].keys() result = {} for key in keys: result[key] = state_dicts[0][key] * normalized_weights[0] for i in range(1, len(state_dicts)): state_dict_value = state_dicts[i][key].to(result[key]) result[key] += state_dict_value * normalized_weights[i] return result def group_files_by_folder(all_files): grouped_files = {} for file in all_files: folder_name = os.path.basename(os.path.dirname(file)) if folder_name not in grouped_files: grouped_files[folder_name] = [] grouped_files[folder_name].append(file) list_of_lists = list(grouped_files.values()) return list_of_lists def generate_timestamp(): now = datetime.datetime.now() timestamp = now.strftime('%y%m%d_%H%M%S') milliseconds = f"{int(now.microsecond / 1000):03d}" random_number = random.randint(0, 9999) return f"{timestamp}_{milliseconds}_{random_number}" def write_PIL_image_with_png_info(image, metadata, path): from PIL.PngImagePlugin import PngInfo png_info = PngInfo() for key, value in metadata.items(): png_info.add_text(key, value) image.save(path, "PNG", pnginfo=png_info) return image def torch_safe_save(content, path): torch.save(content, path + '_tmp') os.replace(path + '_tmp', path) return path def move_optimizer_to_device(optimizer, device): for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) ``` ## /example_workflows/framepack_hv_example.json ```json path="/example_workflows/framepack_hv_example.json" { "id": "ce2cb810-7775-4564-8928-dd5bed1053cd", "revision": 0, "last_node_id": 69, "last_link_id": 158, "nodes": [ { "id": 15, "type": "ConditioningZeroOut", "pos": [ 1346.0872802734375, 263.21856689453125 ], "size": [ 317.4000244140625, 26 ], "flags": { "collapsed": true }, "order": 18, "mode": 0, "inputs": [ { "name": "conditioning", "type": "CONDITIONING", "link": 118 } ], "outputs": [ { "name": "CONDITIONING", "type": "CONDITIONING", "links": [ 108 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "ConditioningZeroOut" }, "widgets_values": [], "color": "#332922", "bgcolor": "#593930" }, { "id": 13, "type": "DualCLIPLoader", "pos": [ 320.9956359863281, 166.8336181640625 ], "size": [ 340.2243957519531, 130 ], "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [ { "name": "CLIP", "type": "CLIP", "links": [ 102 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "DualCLIPLoader" }, "widgets_values": [ "clip_l.safetensors", "llava_llama3_fp16.safetensors", "hunyuan_video", "default" ], "color": "#432", "bgcolor": "#653" }, { "id": 54, "type": "DownloadAndLoadFramePackModel", "pos": [ 1256.5235595703125, -277.76226806640625 ], "size": [ 315, 130 ], "flags": {}, "order": 1, "mode": 4, "inputs": [ { "name": "compile_args", "shape": 7, "type": "FRAMEPACKCOMPILEARGS", "link": null } ], "outputs": [ { "name": "model", "type": "FramePackMODEL", "links": null } ], "properties": { "aux_id": "kijai/ComfyUI-FramePackWrapper", "ver": "49fe507eca8246cc9d08a8093892f40c1180e88f", "Node name for S&R": "DownloadAndLoadFramePackModel" }, "widgets_values": [ "lllyasviel/FramePackI2V_HY", "bf16", "disabled", "sdpa" ] }, { "id": 55, "type": "MarkdownNote", "pos": [ 567.05908203125, -628.8865966796875 ], "size": [ 459.8609619140625, 285.9714660644531 ], "flags": {}, "order": 2, "mode": 0, "inputs": [], "outputs": [], "properties": {}, "widgets_values": [ "Model links:\n\n[https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_fp8_e4m3fn.safetensors](https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_fp8_e4m3fn.safetensors)\n\n[https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors](https://huggingface.co/Kijai/HunyuanVideo_comfy/blob/main/FramePackI2V_HY_bf16.safetensors)\n\nsigclip:\n\n[https://huggingface.co/Comfy-Org/sigclip_vision_384/tree/main](https://huggingface.co/Comfy-Org/sigclip_vision_384/tree/main)\n\ntext encoder and VAE:\n\n[https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files](https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged/tree/main/split_files)" ], "color": "#432", "bgcolor": "#653" }, { "id": 17, "type": "CLIPVisionEncode", "pos": [ 1545.9541015625, 359.1331481933594 ], "size": [ 380.4000244140625, 78 ], "flags": {}, "order": 23, "mode": 0, "inputs": [ { "name": "clip_vision", "type": "CLIP_VISION", "link": 149 }, { "name": "image", "type": "IMAGE", "link": 116 } ], "outputs": [ { "name": "CLIP_VISION_OUTPUT", "type": "CLIP_VISION_OUTPUT", "links": [ 141 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "CLIPVisionEncode" }, "widgets_values": [ "center" ], "color": "#233", "bgcolor": "#355" }, { "id": 64, "type": "GetNode", "pos": [ 1554.2071533203125, 486.79547119140625 ], "size": [ 210, 60 ], "flags": { "collapsed": true }, "order": 3, "mode": 0, "inputs": [], "outputs": [ { "name": "CLIP_VISION", "type": "CLIP_VISION", "links": [ 149 ] } ], "title": "Get_ClipVisionModle", "properties": {}, "widgets_values": [ "ClipVisionModle" ], "color": "#233", "bgcolor": "#355" }, { "id": 48, "type": "GetImageSizeAndCount", "pos": [ 1259.2060546875, 626.8657836914062 ], "size": [ 277.20001220703125, 86 ], "flags": {}, "order": 21, "mode": 0, "inputs": [ { "name": "image", "type": "IMAGE", "link": 125 } ], "outputs": [ { "name": "image", "type": "IMAGE", "links": [ 116, 156 ] }, { "label": "704 width", "name": "width", "type": "INT", "links": null }, { "label": "544 height", "name": "height", "type": "INT", "links": null }, { "label": "1 count", "name": "count", "type": "INT", "links": null } ], "properties": { "cnr_id": "comfyui-kjnodes", "ver": "8ecf5cd05e0a1012087b0da90eea9a13674668db", "Node name for S&R": "GetImageSizeAndCount" }, "widgets_values": [] }, { "id": 60, "type": "GetImageSizeAndCount", "pos": [ 1279.781494140625, 1060.245361328125 ], "size": [ 277.20001220703125, 86 ], "flags": {}, "order": 22, "mode": 0, "inputs": [ { "name": "image", "type": "IMAGE", "link": 139 } ], "outputs": [ { "name": "image", "type": "IMAGE", "links": [ 151, 152 ] }, { "label": "704 width", "name": "width", "type": "INT", "links": null }, { "label": "544 height", "name": "height", "type": "INT", "links": null }, { "label": "1 count", "name": "count", "type": "INT", "links": null } ], "properties": { "cnr_id": "comfyui-kjnodes", "ver": "8ecf5cd05e0a1012087b0da90eea9a13674668db", "Node name for S&R": "GetImageSizeAndCount" }, "widgets_values": [] }, { "id": 12, "type": "VAELoader", "pos": [ 570.5363159179688, -282.70068359375 ], "size": [ 469.0488586425781, 58 ], "flags": {}, "order": 4, "mode": 0, "inputs": [], "outputs": [ { "name": "VAE", "type": "VAE", "links": [ 153 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "VAELoader" }, "widgets_values": [ "hyvid\\hunyuan_video_vae_bf16_repack.safetensors" ], "color": "#322", "bgcolor": "#533" }, { "id": 66, "type": "SetNode", "pos": [ 1083.503173828125, -358.4913330078125 ], "size": [ 210, 60 ], "flags": { "collapsed": true }, "order": 15, "mode": 0, "inputs": [ { "name": "VAE", "type": "VAE", "link": 153 } ], "outputs": [ { "name": "*", "type": "*", "links": null } ], "title": "Set_VAE", "properties": { "previousName": "VAE" }, "widgets_values": [ "VAE" ], "color": "#322", "bgcolor": "#533" }, { "id": 20, "type": "VAEEncode", "pos": [ 1733.111083984375, 633.30419921875 ], "size": [ 210, 46 ], "flags": {}, "order": 24, "mode": 0, "inputs": [ { "name": "pixels", "type": "IMAGE", "link": 156 }, { "name": "vae", "type": "VAE", "link": 155 } ], "outputs": [ { "name": "LATENT", "type": "LATENT", "links": [ 86 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "VAEEncode" }, "widgets_values": [], "color": "#322", "bgcolor": "#533" }, { "id": 68, "type": "GetNode", "pos": [ 1729.60693359375, 734.5352172851562 ], "size": [ 210, 34 ], "flags": { "collapsed": true }, "order": 5, "mode": 0, "inputs": [], "outputs": [ { "name": "VAE", "type": "VAE", "links": [ 155 ] } ], "title": "Get_VAE", "properties": {}, "widgets_values": [ "VAE" ], "color": "#322", "bgcolor": "#533" }, { "id": 62, "type": "VAEEncode", "pos": [ 1612.563232421875, 1048.6236572265625 ], "size": [ 210, 46 ], "flags": {}, "order": 26, "mode": 0, "inputs": [ { "name": "pixels", "type": "IMAGE", "link": 152 }, { "name": "vae", "type": "VAE", "link": 158 } ], "outputs": [ { "name": "LATENT", "type": "LATENT", "links": [ 147 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "VAEEncode" }, "widgets_values": [], "color": "#322", "bgcolor": "#533" }, { "id": 57, "type": "CLIPVisionEncode", "pos": [ 1600.4202880859375, 1181.36767578125 ], "size": [ 380.4000244140625, 78 ], "flags": {}, "order": 25, "mode": 0, "inputs": [ { "name": "clip_vision", "type": "CLIP_VISION", "link": 150 }, { "name": "image", "type": "IMAGE", "link": 151 } ], "outputs": [ { "name": "CLIP_VISION_OUTPUT", "type": "CLIP_VISION_OUTPUT", "links": [ 132 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.29", "Node name for S&R": "CLIPVisionEncode" }, "widgets_values": [ "center" ], "color": "#233", "bgcolor": "#355" }, { "id": 69, "type": "GetNode", "pos": [ 1619.6104736328125, 1137.854736328125 ], "size": [ 210, 34 ], "flags": { "collapsed": true }, "order": 6, "mode": 0, "inputs": [], "outputs": [ { "name": "VAE", "type": "VAE", "links": [ 158 ] } ], "title": "Get_VAE", "properties": {}, "widgets_values": [ "VAE" ], "color": "#322", "bgcolor": "#533" }, { "id": 65, "type": "GetNode", "pos": [ 1604.746337890625, 1306.3175048828125 ], "size": [ 210, 34 ], "flags": { "collapsed": true }, "order": 7, "mode": 0, "inputs": [], "outputs": [ { "name": "CLIP_VISION", "type": "CLIP_VISION", "links": [ 150 ] } ], "title": "Get_ClipVisionModle", "properties": {}, "widgets_values": [ "ClipVisionModle" ], "color": "#233", "bgcolor": "#355" }, { "id": 59, "type": "ImageResize+", "pos": [ 908.9832763671875, 1062.01123046875 ], "size": [ 315, 218 ], "flags": {}, "order": 20, "mode": 0, "inputs": [ { "name": "image", "type": "IMAGE", "link": 138 }, { "name": "width", "type": "INT", "widget": { "name": "width" }, "link": 136 }, { "name": "height", "type": "INT", "widget": { "name": "height" }, "link": 137 } ], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 139 ] }, { "name": "width", "type": "INT", "links": null }, { "name": "height", "type": "INT", "links": null } ], "properties": { "aux_id": "kijai/ComfyUI_essentials", "ver": "76e9d1e4399bd025ce8b12c290753d58f9f53e93", "Node name for S&R": "ImageResize+" }, "widgets_values": [ 512, 512, "lanczos", "stretch", "always", 0 ] }, { "id": 50, "type": "ImageResize+", "pos": [ 907.2653198242188, 593.743896484375 ], "size": [ 315, 218 ], "flags": {}, "order": 19, "mode": 0, "inputs": [ { "name": "image", "type": "IMAGE", "link": 122 }, { "name": "width", "type": "INT", "widget": { "name": "width" }, "link": 128 }, { "name": "height", "type": "INT", "widget": { "name": "height" }, "link": 127 } ], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 125 ] }, { "name": "width", "type": "INT", "links": null }, { "name": "height", "type": "INT", "links": null } ], "properties": { "aux_id": "kijai/ComfyUI_essentials", "ver": "76e9d1e4399bd025ce8b12c290753d58f9f53e93", "Node name for S&R": "ImageResize+" }, "widgets_values": [ 512, 512, "lanczos", "stretch", "always", 0 ] }, { "id": 58, "type": "LoadImage", "pos": [ 190.07057189941406, 1060.399169921875 ], "size": [ 315, 314 ], "flags": {}, "order": 8, "mode": 0, "inputs": [], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 138 ] }, { "name": "MASK", "type": "MASK", "links": null } ], "title": "Load Image: End", "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "LoadImage" }, "widgets_values": [ "sd3stag.png", "image" ] }, { "id": 51, "type": "FramePackFindNearestBucket", "pos": [ 550.0997314453125, 887.411376953125 ], "size": [ 315, 78 ], "flags": {}, "order": 16, "mode": 0, "inputs": [ { "name": "image", "type": "IMAGE", "link": 126 } ], "outputs": [ { "name": "width", "type": "INT", "links": [ 128, 136 ] }, { "name": "height", "type": "INT", "links": [ 127, 137 ] } ], "properties": { "aux_id": "kijai/ComfyUI-FramePackWrapper", "ver": "4f9030a9f4c0bd67d86adf3d3dc07e37118c40bd", "Node name for S&R": "FramePackFindNearestBucket" }, "widgets_values": [ 640 ] }, { "id": 19, "type": "LoadImage", "pos": [ 184.2612762451172, 591.6886596679688 ], "size": [ 315, 314 ], "flags": {}, "order": 9, "mode": 0, "inputs": [], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 122, 126 ] }, { "name": "MASK", "type": "MASK", "links": null } ], "title": "Load Image: Start", "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "LoadImage" }, "widgets_values": [ "sd3stag.png", "image" ] }, { "id": 18, "type": "CLIPVisionLoader", "pos": [ 33.149566650390625, 23.595293045043945 ], "size": [ 388.87139892578125, 58 ], "flags": {}, "order": 10, "mode": 0, "inputs": [], "outputs": [ { "name": "CLIP_VISION", "type": "CLIP_VISION", "links": [ 148 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "CLIPVisionLoader" }, "widgets_values": [ "sigclip_vision_patch14_384.safetensors" ], "color": "#2a363b", "bgcolor": "#3f5159" }, { "id": 63, "type": "SetNode", "pos": [ 247.1346435546875, -28.502397537231445 ], "size": [ 210, 60 ], "flags": { "collapsed": true }, "order": 17, "mode": 0, "inputs": [ { "name": "CLIP_VISION", "type": "CLIP_VISION", "link": 148 } ], "outputs": [ { "name": "*", "type": "*", "links": null } ], "title": "Set_ClipVisionModle", "properties": { "previousName": "ClipVisionModle" }, "widgets_values": [ "ClipVisionModle" ], "color": "#233", "bgcolor": "#355" }, { "id": 27, "type": "FramePackTorchCompileSettings", "pos": [ 623.3660278320312, -140.94215393066406 ], "size": [ 531.5999755859375, 202 ], "flags": {}, "order": 11, "mode": 0, "inputs": [], "outputs": [ { "name": "torch_compile_args", "type": "FRAMEPACKCOMPILEARGS", "links": [] } ], "properties": { "aux_id": "lllyasviel/FramePack", "ver": "0e5fe5d7ca13c76fb8e13708f4b92e7c7a34f20c", "Node name for S&R": "FramePackTorchCompileSettings" }, "widgets_values": [ "inductor", false, "default", false, 64, true, true ] }, { "id": 33, "type": "VAEDecodeTiled", "pos": [ 2328.923828125, -22.08228874206543 ], "size": [ 315, 150 ], "flags": {}, "order": 28, "mode": 0, "inputs": [ { "name": "samples", "type": "LATENT", "link": 85 }, { "name": "vae", "type": "VAE", "link": 154 } ], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 96 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "VAEDecodeTiled" }, "widgets_values": [ 256, 64, 64, 8 ], "color": "#322", "bgcolor": "#533" }, { "id": 67, "type": "GetNode", "pos": [ 2342.01806640625, -76.06847381591797 ], "size": [ 210, 60 ], "flags": { "collapsed": true }, "order": 12, "mode": 0, "inputs": [], "outputs": [ { "name": "VAE", "type": "VAE", "links": [ 154 ] } ], "title": "Get_VAE", "properties": {}, "widgets_values": [ "VAE" ], "color": "#322", "bgcolor": "#533" }, { "id": 23, "type": "VHS_VideoCombine", "pos": [ 2726.849853515625, -29.90264129638672 ], "size": [ 908.428955078125, 334 ], "flags": {}, "order": 30, "mode": 0, "inputs": [ { "name": "images", "type": "IMAGE", "link": 97 }, { "name": "audio", "shape": 7, "type": "AUDIO", "link": null }, { "name": "meta_batch", "shape": 7, "type": "VHS_BatchManager", "link": null }, { "name": "vae", "shape": 7, "type": "VAE", "link": null } ], "outputs": [ { "name": "Filenames", "type": "VHS_FILENAMES", "links": null } ], "properties": { "cnr_id": "comfyui-videohelpersuite", "ver": "0a75c7958fe320efcb052f1d9f8451fd20c730a8", "Node name for S&R": "VHS_VideoCombine" }, "widgets_values": { "frame_rate": 30, "loop_count": 0, "filename_prefix": "FramePack", "format": "video/h264-mp4", "pix_fmt": "yuv420p", "crf": 19, "save_metadata": true, "trim_to_audio": false, "pingpong": false, "save_output": false, "videopreview": { "hidden": false, "paused": false, "params": { "filename": "FramePack_00001.mp4", "subfolder": "", "type": "temp", "format": "video/h264-mp4", "frame_rate": 30, "workflow": "FramePack_00001.png", "fullpath": "N:\\AI\\ComfyUI\\temp\\FramePack_00001.mp4" } } } }, { "id": 44, "type": "GetImageSizeAndCount", "pos": [ 2501.023193359375, -178.70773315429688 ], "size": [ 277.20001220703125, 86 ], "flags": {}, "order": 29, "mode": 0, "inputs": [ { "name": "image", "type": "IMAGE", "link": 96 } ], "outputs": [ { "name": "image", "type": "IMAGE", "links": [ 97 ] }, { "label": "704 width", "name": "width", "type": "INT", "links": null }, { "label": "544 height", "name": "height", "type": "INT", "links": null }, { "label": "145 count", "name": "count", "type": "INT", "links": null } ], "properties": { "cnr_id": "comfyui-kjnodes", "ver": "8ecf5cd05e0a1012087b0da90eea9a13674668db", "Node name for S&R": "GetImageSizeAndCount" }, "widgets_values": [] }, { "id": 47, "type": "CLIPTextEncode", "pos": [ 715.3054809570312, 127.73457336425781 ], "size": [ 400, 200 ], "flags": {}, "order": 14, "mode": 0, "inputs": [ { "name": "clip", "type": "CLIP", "link": 102 } ], "outputs": [ { "name": "CONDITIONING", "type": "CONDITIONING", "links": [ 114, 118 ] } ], "properties": { "cnr_id": "comfy-core", "ver": "0.3.28", "Node name for S&R": "CLIPTextEncode" }, "widgets_values": [ "majestig stag in a forest" ], "color": "#232", "bgcolor": "#353" }, { "id": 52, "type": "LoadFramePackModel", "pos": [ 1253.046630859375, -82.57657623291016 ], "size": [ 480.7601013183594, 174 ], "flags": {}, "order": 13, "mode": 0, "inputs": [ { "name": "compile_args", "shape": 7, "type": "FRAMEPACKCOMPILEARGS", "link": null }, { "name": "lora", "shape": 7, "type": "FPLORA", "link": null } ], "outputs": [ { "name": "model", "type": "FramePackMODEL", "links": [ 129 ] } ], "properties": { "aux_id": "kijai/ComfyUI-FramePackWrapper", "ver": "49fe507eca8246cc9d08a8093892f40c1180e88f", "Node name for S&R": "LoadFramePackModel" }, "widgets_values": [ "Hyvid\\FramePackI2V_HY_fp8_e4m3fn.safetensors", "bf16", "fp8_e4m3fn", "offload_device", "sdpa" ] }, { "id": 39, "type": "FramePackSampler", "pos": [ 2292.58837890625, 194.90232849121094 ], "size": [ 365.07305908203125, 814.6473388671875 ], "flags": {}, "order": 27, "mode": 0, "inputs": [ { "name": "model", "type": "FramePackMODEL", "link": 129 }, { "name": "positive", "type": "CONDITIONING", "link": 114 }, { "name": "negative", "type": "CONDITIONING", "link": 108 }, { "name": "start_latent", "type": "LATENT", "link": 86 }, { "name": "image_embeds", "shape": 7, "type": "CLIP_VISION_OUTPUT", "link": 141 }, { "name": "end_latent", "shape": 7, "type": "LATENT", "link": 147 }, { "name": "end_image_embeds", "shape": 7, "type": "CLIP_VISION_OUTPUT", "link": 132 }, { "name": "initial_samples", "shape": 7, "type": "LATENT", "link": null } ], "outputs": [ { "name": "samples", "type": "LATENT", "links": [ 85 ] } ], "properties": { "aux_id": "kijai/ComfyUI-FramePackWrapper", "ver": "8e5ec6b7f3acf88255c5d93d062079f18b43aa2b", "Node name for S&R": "FramePackSampler" }, "widgets_values": [ 30, true, 0.15, 1, 10, 0, 47, "fixed", 9, 5, 6, "unipc_bh1", "weighted_average", 0.5, 1 ] } ], "links": [ [ 85, 39, 0, 33, 0, "LATENT" ], [ 86, 20, 0, 39, 3, "LATENT" ], [ 96, 33, 0, 44, 0, "IMAGE" ], [ 97, 44, 0, 23, 0, "IMAGE" ], [ 102, 13, 0, 47, 0, "CLIP" ], [ 108, 15, 0, 39, 2, "CONDITIONING" ], [ 114, 47, 0, 39, 1, "CONDITIONING" ], [ 116, 48, 0, 17, 1, "IMAGE" ], [ 118, 47, 0, 15, 0, "CONDITIONING" ], [ 122, 19, 0, 50, 0, "IMAGE" ], [ 125, 50, 0, 48, 0, "IMAGE" ], [ 126, 19, 0, 51, 0, "IMAGE" ], [ 127, 51, 1, 50, 2, "INT" ], [ 128, 51, 0, 50, 1, "INT" ], [ 129, 52, 0, 39, 0, "FramePackMODEL" ], [ 132, 57, 0, 39, 6, "CLIP_VISION_OUTPUT" ], [ 136, 51, 0, 59, 1, "INT" ], [ 137, 51, 1, 59, 2, "INT" ], [ 138, 58, 0, 59, 0, "IMAGE" ], [ 139, 59, 0, 60, 0, "IMAGE" ], [ 141, 17, 0, 39, 4, "CLIP_VISION_OUTPUT" ], [ 147, 62, 0, 39, 5, "LATENT" ], [ 148, 18, 0, 63, 0, "*" ], [ 149, 64, 0, 17, 0, "CLIP_VISION" ], [ 150, 65, 0, 57, 0, "CLIP_VISION" ], [ 151, 60, 0, 57, 1, "IMAGE" ], [ 152, 60, 0, 62, 0, "IMAGE" ], [ 153, 12, 0, 66, 0, "*" ], [ 154, 67, 0, 33, 1, "VAE" ], [ 155, 68, 0, 20, 1, "VAE" ], [ 156, 48, 0, 20, 0, "IMAGE" ], [ 158, 69, 0, 62, 1, "VAE" ] ], "groups": [ { "id": 1, "title": "End Image", "bounding": [ 12.77297592163086, 999.1203002929688, 2038.674560546875, 412.9618225097656 ], "color": "#3f789e", "font_size": 24, "flags": {} }, { "id": 2, "title": "Start Image", "bounding": [ 11.781991958618164, 531.3884887695312, 2032.7288818359375, 442.6904602050781 ], "color": "#3f789e", "font_size": 24, "flags": {} } ], "config": {}, "extra": { "ds": { "scale": 0.6115909044841659, "offset": [ 21.57747102795121, 375.7674957811538 ] }, "frontendVersion": "1.18.3", "VHS_latentpreview": true, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true }, "version": 0.4 } ``` ## /fp8_optimization.py ```py path="/fp8_optimization.py" #based on ComfyUI's and MinusZoneAI's fp8_linear optimization import torch import torch.nn as nn def fp8_linear_forward(cls, original_dtype, input): weight_dtype = cls.weight.dtype if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: if len(input.shape) == 3: target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn inn = input.reshape(-1, input.shape[2]).to(target_dtype) w = cls.weight.t() scale = torch.ones((1), device=input.device, dtype=torch.float32) bias = cls.bias.to(original_dtype) if cls.bias is not None else None if bias is not None: o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale) else: o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale) if isinstance(o, tuple): o = o[0] return o.reshape((-1, input.shape[1], cls.weight.shape[0])) else: return cls.original_forward(input.to(original_dtype)) else: return cls.original_forward(input) def convert_fp8_linear(module, original_dtype, params_to_keep={}): setattr(module, "fp8_matmul_enabled", True) for name, module in module.named_modules(): if not any(keyword in name for keyword in params_to_keep): if isinstance(module, nn.Linear): original_forward = module.forward setattr(module, "original_forward", original_forward) setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) ``` ## /nodes.py ```py path="/nodes.py" import os import torch import math from tqdm import tqdm from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device import folder_paths import comfy.model_management as mm from comfy.utils import load_torch_file, ProgressBar, common_upscale import comfy.model_base import comfy.latent_formats from comfy.cli_args import args, LatentPreviewMethod from .utils import log script_directory = os.path.dirname(os.path.abspath(__file__)) vae_scaling_factor = 0.476986 from .diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModel from .diffusers_helper.memory import DynamicSwapInstaller, move_model_to_device_with_memory_preservation from .diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan from .diffusers_helper.utils import crop_or_pad_yield_mask from .diffusers_helper.bucket_tools import find_nearest_bucket from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers class HyVideoModel(comfy.model_base.BaseModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pipeline = {} self.load_device = mm.get_torch_device() def __getitem__(self, k): return self.pipeline[k] def __setitem__(self, k, v): self.pipeline[k] = v class HyVideoModelConfig: def __init__(self, dtype): self.unet_config = {} self.unet_extra_config = {} self.latent_format = comfy.latent_formats.HunyuanVideo self.latent_format.latent_channels = 16 self.manual_cast_dtype = dtype self.sampling_settings = {"multiplier": 1.0} self.memory_usage_factor = 2.0 self.unet_config["disable_unet_model_creation"] = True class FramePackTorchCompileSettings: @classmethod def INPUT_TYPES(s): return { "required": { "backend": (["inductor","cudagraphs"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), "dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}), "compile_single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Enable single block compilation"}), "compile_double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Enable double block compilation"}), }, } RETURN_TYPES = ("FRAMEPACKCOMPILEARGS",) RETURN_NAMES = ("torch_compile_args",) FUNCTION = "loadmodel" CATEGORY = "HunyuanVideoWrapper" DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_single_blocks, compile_double_blocks): compile_args = { "backend": backend, "fullgraph": fullgraph, "mode": mode, "dynamic": dynamic, "dynamo_cache_size_limit": dynamo_cache_size_limit, "compile_single_blocks": compile_single_blocks, "compile_double_blocks": compile_double_blocks } return (compile_args, ) #region Model loading class DownloadAndLoadFramePackModel: @classmethod def INPUT_TYPES(s): return { "required": { "model": (["lllyasviel/FramePackI2V_HY"],), "base_precision": (["fp32", "bf16", "fp16"], {"default": "bf16"}), "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2'], {"default": 'disabled', "tooltip": "optional quantization method"}), }, "optional": { "attention_mode": ([ "sdpa", "flash_attn", "sageattn", ], {"default": "sdpa"}), "compile_args": ("FRAMEPACKCOMPILEARGS", ), } } RETURN_TYPES = ("FramePackMODEL",) RETURN_NAMES = ("model", ) FUNCTION = "loadmodel" CATEGORY = "FramePackWrapper" def loadmodel(self, model, base_precision, quantization, compile_args=None, attention_mode="sdpa"): base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] device = mm.get_torch_device() model_path = os.path.join(folder_paths.models_dir, "diffusers", "lllyasviel", "FramePackI2V_HY") if not os.path.exists(model_path): print(f"Downloading clip model to: {model_path}") from huggingface_hub import snapshot_download snapshot_download( repo_id=model, local_dir=model_path, local_dir_use_symlinks=False, ) transformer = HunyuanVideoTransformer3DModel.from_pretrained(model_path, torch_dtype=base_dtype, attention_mode=attention_mode).cpu() params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} if quantization == 'fp8_e4m3fn' or quantization == 'fp8_e4m3fn_fast': transformer = transformer.to(torch.float8_e4m3fn) if quantization == "fp8_e4m3fn_fast": from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, base_dtype, params_to_keep=params_to_keep) elif quantization == 'fp8_e5m2': transformer = transformer.to(torch.float8_e5m2) else: transformer = transformer.to(base_dtype) DynamicSwapInstaller.install_model(transformer, device=device) if compile_args is not None: if compile_args["compile_single_blocks"]: for i, block in enumerate(transformer.single_transformer_blocks): transformer.single_transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if compile_args["compile_double_blocks"]: for i, block in enumerate(transformer.transformer_blocks): transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) #transformer = torch.compile(transformer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) pipe = { "transformer": transformer.eval(), "dtype": base_dtype, } return (pipe, ) class FramePackLoraSelect: @classmethod def INPUT_TYPES(s): return { "required": { "lora": (folder_paths.get_filename_list("loras"), {"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}), "fuse_lora": ("BOOLEAN", {"default": True, "tooltip": "Fuse the LORA model with the base model. This is recommended for better performance."}), }, "optional": { "prev_lora":("FPLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}), } } RETURN_TYPES = ("FPLORA",) RETURN_NAMES = ("lora", ) FUNCTION = "getlorapath" CATEGORY = "FramePackWrapper" DESCRIPTION = "Select a LoRA model from ComfyUI/models/loras" def getlorapath(self, lora, strength, prev_lora=None, fuse_lora=True): loras_list = [] lora = { "path": folder_paths.get_full_path("loras", lora), "strength": strength, "name": lora.split(".")[0], "fuse_lora": fuse_lora, } if prev_lora is not None: loras_list.extend(prev_lora) loras_list.append(lora) return (loras_list,) class LoadFramePackModel: @classmethod def INPUT_TYPES(s): return { "required": { "model": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), "base_precision": (["fp32", "bf16", "fp16"], {"default": "bf16"}), "quantization": (['disabled', 'fp8_e4m3fn', 'fp8_e4m3fn_fast', 'fp8_e5m2'], {"default": 'disabled', "tooltip": "optional quantization method"}), "load_device": (["main_device", "offload_device"], {"default": "cuda", "tooltip": "Initialize the model on the main device or offload device"}), }, "optional": { "attention_mode": ([ "sdpa", "flash_attn", "sageattn", ], {"default": "sdpa"}), "compile_args": ("FRAMEPACKCOMPILEARGS", ), "lora": ("FPLORA", {"default": None, "tooltip": "LORA model to load"}), } } RETURN_TYPES = ("FramePackMODEL",) RETURN_NAMES = ("model", ) FUNCTION = "loadmodel" CATEGORY = "FramePackWrapper" def loadmodel(self, model, base_precision, quantization, compile_args=None, attention_mode="sdpa", lora=None, load_device="main_device"): base_dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e4m3fn_fast": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[base_precision] device = mm.get_torch_device() offload_device = mm.unet_offload_device() if load_device == "main_device": transformer_load_device = device else: transformer_load_device = offload_device model_path = folder_paths.get_full_path_or_raise("diffusion_models", model) model_config_path = os.path.join(script_directory, "transformer_config.json") import json with open(model_config_path, "r") as f: config = json.load(f) sd = load_torch_file(model_path, device=offload_device, safe_load=True) model_weight_dtype = sd['single_transformer_blocks.0.attn.to_k.weight'].dtype with init_empty_weights(): transformer = HunyuanVideoTransformer3DModel(**config, attention_mode=attention_mode) params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast" or quantization == "fp8_scaled": dtype = torch.float8_e4m3fn elif quantization == "fp8_e5m2": dtype = torch.float8_e5m2 else: dtype = base_dtype if lora is not None: after_lora_dtype = dtype dtype = base_dtype print("Using accelerate to load and assign model weights to device...") param_count = sum(1 for _ in transformer.named_parameters()) for name, param in tqdm(transformer.named_parameters(), desc=f"Loading transformer parameters to {transformer_load_device}", total=param_count, leave=True): dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype set_module_tensor_to_device(transformer, name, device=transformer_load_device, dtype=dtype_to_use, value=sd[name]) if lora is not None: adapter_list = [] adapter_weights = [] for l in lora: fuse = True if l["fuse_lora"] else False lora_sd = load_torch_file(l["path"]) if "lora_unet_single_transformer_blocks_0_attn_to_k.lora_up.weight" in lora_sd: from .utils import convert_to_diffusers lora_sd = convert_to_diffusers("lora_unet_", lora_sd) if not "transformer.single_transformer_blocks.0.attn_to.k.lora_A.weight" in lora_sd: log.info(f"Converting LoRA weights from {l['path']} to diffusers format...") lora_sd = _convert_hunyuan_video_lora_to_diffusers(lora_sd) lora_rank = None for key, val in lora_sd.items(): if "lora_B" in key or "lora_up" in key: lora_rank = val.shape[1] break if lora_rank is not None: log.info(f"Merging rank {lora_rank} LoRA weights from {l['path']} with strength {l['strength']}") adapter_name = l['path'].split("/")[-1].split(".")[0] adapter_weight = l['strength'] transformer.load_lora_adapter(lora_sd, weight_name=l['path'].split("/")[-1], lora_rank=lora_rank, adapter_name=adapter_name) adapter_list.append(adapter_name) adapter_weights.append(adapter_weight) del lora_sd mm.soft_empty_cache() if adapter_list: transformer.set_adapters(adapter_list, weights=adapter_weights) if fuse: if model_weight_dtype not in [torch.float32, torch.float16, torch.bfloat16]: raise ValueError("Fusing LoRA doesn't work well with fp8 model weights. Please use a bf16 model file, or disable LoRA fusing.") lora_scale = 1 transformer.fuse_lora(lora_scale=lora_scale) transformer.delete_adapters(adapter_list) if quantization == "fp8_e4m3fn" or quantization == "fp8_e4m3fn_fast" or quantization == "fp8_e5m2": params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"} for name, param in transformer.named_parameters(): # Make sure to not cast the LoRA weights to fp8. if not any(keyword in name for keyword in params_to_keep) and not 'lora' in name: param.data = param.data.to(after_lora_dtype) if quantization == "fp8_e4m3fn_fast": from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, base_dtype, params_to_keep=params_to_keep) DynamicSwapInstaller.install_model(transformer, device=device) if compile_args is not None: if compile_args["compile_single_blocks"]: for i, block in enumerate(transformer.single_transformer_blocks): transformer.single_transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) if compile_args["compile_double_blocks"]: for i, block in enumerate(transformer.transformer_blocks): transformer.transformer_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) #transformer = torch.compile(transformer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"]) pipe = { "transformer": transformer.eval(), "dtype": base_dtype, } return (pipe, ) class FramePackFindNearestBucket: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE", {"tooltip": "Image to resize"}), "base_resolution": ("INT", {"default": 640, "min": 64, "max": 2048, "step": 16, "tooltip": "Width of the image to encode"}), }, } RETURN_TYPES = ("INT", "INT", ) RETURN_NAMES = ("width","height",) FUNCTION = "process" CATEGORY = "FramePackWrapper" DESCRIPTION = "Finds the closes resolution bucket as defined in the orignal code" def process(self, image, base_resolution): H, W = image.shape[1], image.shape[2] new_height, new_width = find_nearest_bucket(H, W, resolution=base_resolution) return (new_width, new_height, ) class FramePackSampler: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("FramePackMODEL",), "positive": ("CONDITIONING",), "negative": ("CONDITIONING",), "start_latent": ("LATENT", {"tooltip": "init Latents to use for image2video"} ), "steps": ("INT", {"default": 30, "min": 1}), "use_teacache": ("BOOLEAN", {"default": True, "tooltip": "Use teacache for faster sampling."}), "teacache_rel_l1_thresh": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The threshold for the relative L1 loss."}), "cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 30.0, "step": 0.01}), "guidance_scale": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 32.0, "step": 0.01}), "shift": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "latent_window_size": ("INT", {"default": 9, "min": 1, "max": 33, "step": 1, "tooltip": "The size of the latent window to use for sampling."}), "total_second_length": ("FLOAT", {"default": 5, "min": 1, "max": 120, "step": 0.1, "tooltip": "The total length of the video in seconds."}), "gpu_memory_preservation": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 128.0, "step": 0.1, "tooltip": "The amount of GPU memory to preserve."}), "sampler": (["unipc_bh1", "unipc_bh2"], { "default": 'unipc_bh1' }), }, "optional": { "image_embeds": ("CLIP_VISION_OUTPUT", ), "end_latent": ("LATENT", {"tooltip": "end Latents to use for image2video"} ), "end_image_embeds": ("CLIP_VISION_OUTPUT", {"tooltip": "end Image's clip embeds"} ), "embed_interpolation": (["disabled", "weighted_average", "linear"], {"default": 'disabled', "tooltip": "Image embedding interpolation type. If linear, will smoothly interpolate with time, else it'll be weighted average with the specified weight."}), "start_embed_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Weighted average constant for image embed interpolation. If end image is not set, the embed's strength won't be affected"}), "initial_samples": ("LATENT", {"tooltip": "init Latents to use for video2video"} ), "denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("LATENT", ) RETURN_NAMES = ("samples",) FUNCTION = "process" CATEGORY = "FramePackWrapper" def process(self, model, shift, positive, negative, latent_window_size, use_teacache, total_second_length, teacache_rel_l1_thresh, steps, cfg, guidance_scale, seed, sampler, gpu_memory_preservation, start_latent=None, image_embeds=None, end_latent=None, end_image_embeds=None, embed_interpolation="linear", start_embed_strength=1.0, initial_samples=None, denoise_strength=1.0): total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) total_latent_sections = int(max(round(total_latent_sections), 1)) print("total_latent_sections: ", total_latent_sections) transformer = model["transformer"] base_dtype = model["dtype"] device = mm.get_torch_device() offload_device = mm.unet_offload_device() mm.unload_all_models() mm.cleanup_models() mm.soft_empty_cache() if start_latent is not None: start_latent = start_latent["samples"] * vae_scaling_factor if initial_samples is not None: initial_samples = initial_samples["samples"] * vae_scaling_factor if end_latent is not None: end_latent = end_latent["samples"] * vae_scaling_factor has_end_image = end_latent is not None print("start_latent", start_latent.shape) B, C, T, H, W = start_latent.shape if image_embeds is not None: start_image_encoder_last_hidden_state = image_embeds["last_hidden_state"].to(device, base_dtype) if has_end_image: assert end_image_embeds is not None end_image_encoder_last_hidden_state = end_image_embeds["last_hidden_state"].to(device, base_dtype) else: if image_embeds is not None: end_image_encoder_last_hidden_state = torch.zeros_like(start_image_encoder_last_hidden_state) llama_vec = positive[0][0].to(device, base_dtype) clip_l_pooler = positive[0][1]["pooled_output"].to(device, base_dtype) if not math.isclose(cfg, 1.0): llama_vec_n = negative[0][0].to(device, base_dtype) clip_l_pooler_n = negative[0][1]["pooled_output"].to(device, base_dtype) else: llama_vec_n = torch.zeros_like(llama_vec, device=device) clip_l_pooler_n = torch.zeros_like(clip_l_pooler, device=device) llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512) # Sampling rnd = torch.Generator("cpu").manual_seed(seed) num_frames = latent_window_size * 4 - 3 history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, H, W), dtype=torch.float32).cpu() total_generated_latent_frames = 0 latent_paddings_list = list(reversed(range(total_latent_sections))) latent_paddings = latent_paddings_list.copy() # Create a copy for iteration comfy_model = HyVideoModel( HyVideoModelConfig(base_dtype), model_type=comfy.model_base.ModelType.FLOW, device=device, ) patcher = comfy.model_patcher.ModelPatcher(comfy_model, device, torch.device("cpu")) from latent_preview import prepare_callback callback = prepare_callback(patcher, steps) move_model_to_device_with_memory_preservation(transformer, target_device=device, preserved_memory_gb=gpu_memory_preservation) if total_latent_sections > 4: # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some # items looks better than expanding it when total_latent_sections > 4 # One can try to remove below trick and just # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0] latent_paddings_list = latent_paddings.copy() for i, latent_padding in enumerate(latent_paddings): print(f"latent_padding: {latent_padding}") is_last_section = latent_padding == 0 is_first_section = latent_padding == latent_paddings[0] latent_padding_size = latent_padding * latent_window_size if image_embeds is not None: if embed_interpolation != "disabled": if embed_interpolation == "linear": if total_latent_sections <= 1: frac = 1.0 # Handle case with only one section else: frac = 1 - i / (total_latent_sections - 1) # going backwards else: frac = start_embed_strength if has_end_image else 1.0 image_encoder_last_hidden_state = start_image_encoder_last_hidden_state * frac + (1 - frac) * end_image_encoder_last_hidden_state else: image_encoder_last_hidden_state = start_image_encoder_last_hidden_state * start_embed_strength else: image_encoder_last_hidden_state = None print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, is_first_section = {is_first_section}') start_latent_frames = T # 0 or 1 indices = torch.arange(0, sum([start_latent_frames, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([start_latent_frames, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) clean_latents_pre = start_latent.to(history_latents) clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2) clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Use end image latent for the first section if provided if has_end_image and is_first_section: clean_latents_post = end_latent.to(history_latents) clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) #vid2vid WIP if initial_samples is not None: total_length = initial_samples.shape[2] # Get the max padding value for normalization max_padding = max(latent_paddings_list) if is_last_section: # Last section should capture the end of the sequence start_idx = max(0, total_length - latent_window_size) else: # Calculate windows that distribute more evenly across the sequence # This normalizes the padding values to create appropriate spacing if max_padding > 0: # Avoid division by zero progress = (max_padding - latent_padding) / max_padding start_idx = int(progress * max(0, total_length - latent_window_size)) else: start_idx = 0 end_idx = min(start_idx + latent_window_size, total_length) print(f"start_idx: {start_idx}, end_idx: {end_idx}, total_length: {total_length}") input_init_latents = initial_samples[:, :, start_idx:end_idx, :, :].to(device) if use_teacache: transformer.initialize_teacache(enable_teacache=True, num_steps=steps, rel_l1_thresh=teacache_rel_l1_thresh) else: transformer.initialize_teacache(enable_teacache=False) with torch.autocast(device_type=mm.get_autocast_device(device), dtype=base_dtype, enabled=True): generated_latents = sample_hunyuan( transformer=transformer, sampler=sampler, initial_latent=input_init_latents if initial_samples is not None else None, strength=denoise_strength, width=W * 8, height=H * 8, frames=num_frames, real_guidance_scale=cfg, distilled_guidance_scale=guidance_scale, guidance_rescale=0, shift=shift if shift != 0 else None, num_inference_steps=steps, generator=rnd, prompt_embeds=llama_vec, prompt_embeds_mask=llama_attention_mask, prompt_poolers=clip_l_pooler, negative_prompt_embeds=llama_vec_n, negative_prompt_embeds_mask=llama_attention_mask_n, negative_prompt_poolers=clip_l_pooler_n, device=device, dtype=base_dtype, image_embeddings=image_encoder_last_hidden_state, latent_indices=latent_indices, clean_latents=clean_latents, clean_latent_indices=clean_latent_indices, clean_latents_2x=clean_latents_2x, clean_latent_2x_indices=clean_latent_2x_indices, clean_latents_4x=clean_latents_4x, clean_latent_4x_indices=clean_latent_4x_indices, callback=callback, ) if is_last_section: generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2) total_generated_latent_frames += int(generated_latents.shape[2]) history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2) real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :] if is_last_section: break transformer.to(offload_device) mm.soft_empty_cache() return {"samples": real_history_latents / vae_scaling_factor}, NODE_CLASS_MAPPINGS = { "DownloadAndLoadFramePackModel": DownloadAndLoadFramePackModel, "FramePackSampler": FramePackSampler, "FramePackTorchCompileSettings": FramePackTorchCompileSettings, "FramePackFindNearestBucket": FramePackFindNearestBucket, "LoadFramePackModel": LoadFramePackModel, "FramePackLoraSelect": FramePackLoraSelect, } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadFramePackModel": "(Down)Load FramePackModel", "FramePackSampler": "FramePackSampler", "FramePackTorchCompileSettings": "Torch Compile Settings", "FramePackFindNearestBucket": "Find Nearest Bucket", "LoadFramePackModel": "Load FramePackModel", "FramePackLoraSelect": "Select Lora", } ``` ## /requirements.txt accelerate>=1.6.0 diffusers>=0.33.1 transformers>=4.46.2 einops safetensors ## /transformer_config.json ```json path="/transformer_config.json" { "_class_name": "HunyuanVideoTransformer3DModelPacked", "_diffusers_version": "0.33.0.dev0", "_name_or_path": "hunyuanvideo-community/HunyuanVideo", "attention_head_dim": 128, "guidance_embeds": true, "has_clean_x_embedder": true, "has_image_proj": true, "image_proj_dim": 1152, "in_channels": 16, "mlp_ratio": 4.0, "num_attention_heads": 24, "num_layers": 20, "num_refiner_layers": 2, "num_single_layers": 40, "out_channels": 16, "patch_size": 2, "patch_size_t": 1, "pooled_projection_dim": 768, "qk_norm": "rms_norm", "rope_axes_dim": [ 16, 56, 56 ], "rope_theta": 256.0, "text_embed_dim": 4096 } ``` ## /utils.py ```py path="/utils.py" import importlib.metadata import torch import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) def check_diffusers_version(): try: version = importlib.metadata.version('diffusers') required_version = '0.31.0' if version < required_version: raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.") except importlib.metadata.PackageNotFoundError: raise AssertionError("diffusers is not installed.") def print_memory(device): memory = torch.cuda.memory_allocated(device) / 1024**3 max_memory = torch.cuda.max_memory_allocated(device) / 1024**3 max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 log.info(f"-------------------------------") log.info(f"Allocated memory: {memory=:.3f} GB") log.info(f"Max allocated memory: {max_memory=:.3f} GB") log.info(f"Max reserved memory: {max_reserved=:.3f} GB") log.info(f"-------------------------------") #memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False) #log.info(f"Memory Summary:\n{memory_summary}") def convert_to_diffusers(prefix, weights_sd): # convert from default LoRA to diffusers # https://github.com/kohya-ss/musubi-tuner/blob/main/convert_lora.py # get alphas lora_alphas = {} for key, weight in weights_sd.items(): if key.startswith(prefix): lora_name = key.split(".", 1)[0] # before first dot if lora_name not in lora_alphas and "alpha" in key: lora_alphas[lora_name] = weight new_weights_sd = {} for key, weight in weights_sd.items(): if key.startswith(prefix): if "alpha" in key: continue lora_name = key.split(".", 1)[0] # before first dot module_name = lora_name[len(prefix) :] # remove "lora_unet_" module_name = module_name.replace("_", ".") # replace "_" with "." # HunyuanVideo lora name to module name: ugly but works #module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks module_name = module_name.replace("single.transformer.blocks.", "single_transformer_blocks.") # fix single blocks module_name = module_name.replace("transformer.blocks.", "transformer_blocks.") # fix double blocks module_name = module_name.replace("img.", "img_") # fix img module_name = module_name.replace("txt.", "txt_") # fix txt module_name = module_name.replace("to.q", "to_q") # fix attn module_name = module_name.replace("to.k", "to_k") module_name = module_name.replace("to.v", "to_v") module_name = module_name.replace("to.add.out", "to_add_out") module_name = module_name.replace("add.k.proj", "add_k_proj") module_name = module_name.replace("add.q.proj", "add_q_proj") module_name = module_name.replace("add.v.proj", "add_v_proj") module_name = module_name.replace("add.out.proj", "add_out_proj") module_name = module_name.replace("proj.", "proj_") # fix proj module_name = module_name.replace("to.out", "to_out") # fix to_out module_name = module_name.replace("ff.context", "ff_context") # fix ff context diffusers_prefix = "transformer" if "lora_down" in key: new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight" dim = weight.shape[0] elif "lora_up" in key: new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight" dim = weight.shape[1] else: log.warning(f"unexpected key: {key} in default LoRA format") continue # scale weight by alpha if lora_name in lora_alphas: # we scale both down and up, so scale is sqrt scale = lora_alphas[lora_name] / dim scale = scale.sqrt() weight = weight * scale else: log.warning(f"missing alpha for {lora_name}") new_weights_sd[new_key] = weight return new_weights_sd ``` The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.