``` ├── .gitignore ├── LICENSE ├── README.md ├── assets/ ├── demo.jpg ├── gradio_demo.py ├── hi_diffusers/ ├── __init__.py ├── models/ ├── attention.py ├── attention_processor.py ├── embeddings.py ├── moe.py ├── transformers/ ├── transformer_hidream_image.py ├── pipelines/ ├── hidream_image/ ├── pipeline_hidream_image.py ├── pipeline_output.py ├── schedulers/ ├── flash_flow_match.py ├── fm_solvers_unipc.py ├── inference.py ├── requirements.txt ``` ## /.gitignore ```gitignore path="/.gitignore" __pycache__ tmp *_local.py *.jpg *.png *.tar *.txt ``` ## /LICENSE ``` path="/LICENSE" MIT License Copyright (c) 2025 HiDream.ai Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ``` ## /README.md # HiDream-I1 ![HiDream-I1 Demo](assets/demo.jpg) `HiDream-I1` is a new open-source image generative foundation model with 17B parameters that achieves state-of-the-art image generation quality within seconds. For more features and to experience the full capabilities of our product, please visit [https://vivago.ai/](https://vivago.ai/). ## Project Updates - 🤗 **April 11, 2025**: HiDream is now officially supported in the `diffusers` library. Check out the docs [here](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream). - 🤗 **April 8, 2025**: We've launched a Hugging Face Space for **HiDream-I1-Dev**. Experience our model firsthand at [https://huggingface.co/spaces/HiDream-ai/HiDream-I1-Dev](https://huggingface.co/spaces/HiDream-ai/HiDream-I1-Dev)! - 🚀 **April 7, 2025**: We've open-sourced the text-to-image model **HiDream-I1**. ## Models We offer both the full version and distilled models. For more information about the models, please refer to the link under Usage. | Name | Script | Inference Steps | HuggingFace repo | | --------------- | -------------------------------------------------- | --------------- | ---------------------- | | HiDream-I1-Full | [inference.py](./inference.py) | 50 | 🤗 [HiDream-I1-Full](https://huggingface.co/HiDream-ai/HiDream-I1-Full) | | HiDream-I1-Dev | [inference.py](./inference.py) | 28 | 🤗 [HiDream-I1-Dev](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) | | HiDream-I1-Fast | [inference.py](./inference.py) | 16 | 🤗 [HiDream-I1-Fast](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) | ## Quick Start Please make sure you have installed [Flash Attention](https://github.com/Dao-AILab/flash-attention). We recommend CUDA versions 12.4 for the manual installation. ```sh pip install -r requirements.txt pip install -U flash-attn --no-build-isolation ``` Then you can run the inference scripts to generate images: ``` python # For full model inference python ./inference.py --model_type full # For distilled dev model inference python ./inference.py --model_type dev # For distilled fast model inference python ./inference.py --model_type fast ``` > [!NOTE] > The inference script will try to automatically download `meta-llama/Llama-3.1-8B-Instruct` model files. You need to [agree to the license of the Llama model](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) on your HuggingFace account and login using `huggingface-cli login` in order to use the automatic downloader. ## Gradio Demo We also provide a Gradio demo for interactive image generation. You can run the demo with: ``` python python gradio_demo.py ``` ## Inference with Diffusers We recommend install Diffusers from source for better compatibility. ```shell pip install git+https://github.com/huggingface/diffusers.git ``` Then you can inference **HiDream-I1** with the following command: ```python import torch from transformers import PreTrainedTokenizerFast, LlamaForCausalLM from diffusers import HiDreamImagePipeline tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") text_encoder_4 = LlamaForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3.1-8B-Instruct", output_hidden_states=True, output_attentions=True, torch_dtype=torch.bfloat16, ) pipe = HiDreamImagePipeline.from_pretrained( "HiDream-ai/HiDream-I1-Full", # "HiDream-ai/HiDream-I1-Dev" | "HiDream-ai/HiDream-I1-Fast" tokenizer_4=tokenizer_4, text_encoder_4=text_encoder_4, torch_dtype=torch.bfloat16, ) pipe = pipe.to('cuda') image = pipe( 'A cat holding a sign that says "HiDream.ai".', height=1024, width=1024, guidance_scale=5.0, # 0.0 for Dev&Fast num_inference_steps=50, # 28 for Dev and 16 for Fast generator=torch.Generator("cuda").manual_seed(0), ).images[0] image.save("output.png") ``` ## Evaluation Metrics ### DPG-Bench | Model | Overall | Global | Entity | Attribute | Relation | Other | | -------------- | --------- | ------ | ------ | --------- | -------- | ----- | | PixArt-alpha | 71.11 | 74.97 | 79.32 | 78.60 | 82.57 | 76.96 | | SDXL | 74.65 | 83.27 | 82.43 | 80.91 | 86.76 | 80.41 | | DALL-E 3 | 83.50 | 90.97 | 89.61 | 88.39 | 90.58 | 89.83 | | Flux.1-dev | 83.79 | 85.80 | 86.79 | 89.98 | 90.04 | 89.90 | | SD3-Medium | 84.08 | 87.90 | 91.01 | 88.83 | 80.70 | 88.68 | | Janus-Pro-7B | 84.19 | 86.90 | 88.90 | 89.40 | 89.32 | 89.48 | | CogView4-6B | 85.13 | 83.85 | 90.35 | 91.17 | 91.14 | 87.29 | | **HiDream-I1** | **85.89** | 76.44 | 90.22 | 89.48 | 93.74 | 91.83 | ### GenEval | Model | Overall | Single Obj. | Two Obj. | Counting | Colors | Position | Color attribution | | -------------- | -------- | ----------- | -------- | -------- | ------ | -------- | ----------------- | | SDXL | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 | | PixArt-alpha | 0.48 | 0.98 | 0.50 | 0.44 | 0.80 | 0.08 | 0.07 | | Flux.1-dev | 0.66 | 0.98 | 0.79 | 0.73 | 0.77 | 0.22 | 0.45 | | DALL-E 3 | 0.67 | 0.96 | 0.87 | 0.47 | 0.83 | 0.43 | 0.45 | | CogView4-6B | 0.73 | 0.99 | 0.86 | 0.66 | 0.79 | 0.48 | 0.58 | | SD3-Medium | 0.74 | 0.99 | 0.94 | 0.72 | 0.89 | 0.33 | 0.60 | | Janus-Pro-7B | 0.80 | 0.99 | 0.89 | 0.59 | 0.90 | 0.79 | 0.66 | | **HiDream-I1** | **0.83** | 1.00 | 0.98 | 0.79 | 0.91 | 0.60 | 0.72 | ### HPSv2.1 benchmark | Model | Averaged | Animation | Concept-art | Painting | Photo | | --------------------- | --------- | --------- | ----------- | -------- | ----- | | Stable Diffusion v2.0 | 26.38 | 27.09 | 26.02 | 25.68 | 26.73 | | Midjourney V6 | 30.29 | 32.02 | 30.29 | 29.74 | 29.10 | | SDXL | 30.64 | 32.84 | 31.36 | 30.86 | 27.48 | | Dall-E3 | 31.44 | 32.39 | 31.09 | 31.18 | 31.09 | | SD3 | 31.53 | 32.60 | 31.82 | 32.06 | 29.62 | | Midjourney V5 | 32.33 | 34.05 | 32.47 | 32.24 | 30.56 | | CogView4-6B | 32.31 | 33.23 | 32.60 | 32.89 | 30.52 | | Flux.1-dev | 32.47 | 33.87 | 32.27 | 32.62 | 31.11 | | stable cascade | 32.95 | 34.58 | 33.13 | 33.29 | 30.78 | | **HiDream-I1** | **33.82** | 35.05 | 33.74 | 33.88 | 32.61 | ## License The code in this repository and the HiDream-I1 models are licensed under [MIT License](./LICENSE). ## /assets/demo.jpg Binary file available at https://raw.githubusercontent.com/HiDream-ai/HiDream-I1/refs/heads/main/assets/demo.jpg ## /gradio_demo.py ```py path="/gradio_demo.py" import torch import gradio as gr from hi_diffusers import HiDreamImagePipeline from hi_diffusers import HiDreamImageTransformer2DModel from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler from transformers import LlamaForCausalLM, PreTrainedTokenizerFast MODEL_PREFIX = "HiDream-ai" LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Model configurations MODEL_CONFIGS = { "dev": { "path": f"{MODEL_PREFIX}/HiDream-I1-Dev", "guidance_scale": 0.0, "num_inference_steps": 28, "shift": 6.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler }, "full": { "path": f"{MODEL_PREFIX}/HiDream-I1-Full", "guidance_scale": 5.0, "num_inference_steps": 50, "shift": 3.0, "scheduler": FlowUniPCMultistepScheduler }, "fast": { "path": f"{MODEL_PREFIX}/HiDream-I1-Fast", "guidance_scale": 0.0, "num_inference_steps": 16, "shift": 3.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler } } # Resolution options RESOLUTION_OPTIONS = [ "1024 × 1024 (Square)", "768 × 1360 (Portrait)", "1360 × 768 (Landscape)", "880 × 1168 (Portrait)", "1168 × 880 (Landscape)", "1248 × 832 (Landscape)", "832 × 1248 (Portrait)" ] # Load models def load_models(model_type): config = MODEL_CONFIGS[model_type] pretrained_model_name_or_path = config["path"] scheduler = MODEL_CONFIGS[model_type]["scheduler"](num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False) tokenizer_4 = PreTrainedTokenizerFast.from_pretrained( LLAMA_MODEL_NAME, use_fast=False) text_encoder_4 = LlamaForCausalLM.from_pretrained( LLAMA_MODEL_NAME, output_hidden_states=True, output_attentions=True, torch_dtype=torch.bfloat16).to("cuda") transformer = HiDreamImageTransformer2DModel.from_pretrained( pretrained_model_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") pipe = HiDreamImagePipeline.from_pretrained( pretrained_model_name_or_path, scheduler=scheduler, tokenizer_4=tokenizer_4, text_encoder_4=text_encoder_4, torch_dtype=torch.bfloat16 ).to("cuda", torch.bfloat16) pipe.transformer = transformer return pipe, config # Parse resolution string to get height and width def parse_resolution(resolution_str): if "1024 × 1024" in resolution_str: return 1024, 1024 elif "768 × 1360" in resolution_str: return 768, 1360 elif "1360 × 768" in resolution_str: return 1360, 768 elif "880 × 1168" in resolution_str: return 880, 1168 elif "1168 × 880" in resolution_str: return 1168, 880 elif "1248 × 832" in resolution_str: return 1248, 832 elif "832 × 1248" in resolution_str: return 832, 1248 else: return 1024, 1024 # Default fallback # Generate image function def generate_image(model_type, prompt, resolution, seed): global pipe, current_model # Reload model if needed if model_type != current_model: del pipe torch.cuda.empty_cache() print(f"Loading {model_type} model...") pipe, config = load_models(model_type) current_model = model_type print(f"{model_type} model loaded successfully!") # Get configuration for current model config = MODEL_CONFIGS[model_type] guidance_scale = config["guidance_scale"] num_inference_steps = config["num_inference_steps"] # Parse resolution height, width = parse_resolution(resolution) # Handle seed if seed == -1: seed = torch.randint(0, 1000000, (1,)).item() generator = torch.Generator("cuda").manual_seed(seed) images = pipe( prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, num_images_per_prompt=1, generator=generator ).images return images[0], seed # Initialize with default model print("Loading default model (full)...") current_model = "full" pipe, _ = load_models(current_model) print("Model loaded successfully!") # Create Gradio interface with gr.Blocks(title="HiDream Image Generator") as demo: gr.Markdown("# HiDream Image Generator") with gr.Row(): with gr.Column(): model_type = gr.Radio( choices=list(MODEL_CONFIGS.keys()), value="full", label="Model Type", info="Select model variant" ) prompt = gr.Textbox( label="Prompt", placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".", lines=3 ) resolution = gr.Radio( choices=RESOLUTION_OPTIONS, value=RESOLUTION_OPTIONS[0], label="Resolution", info="Select image resolution" ) seed = gr.Number( label="Seed (use -1 for random)", value=-1, precision=0 ) generate_btn = gr.Button("Generate Image") seed_used = gr.Number(label="Seed Used", interactive=False) with gr.Column(): output_image = gr.Image(label="Generated Image", type="pil") generate_btn.click( fn=generate_image, inputs=[model_type, prompt, resolution, seed], outputs=[output_image, seed_used] ) # Launch app if __name__ == "__main__": demo.launch() ``` ## /hi_diffusers/__init__.py ```py path="/hi_diffusers/__init__.py" from .models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel from .pipelines.hidream_image.pipeline_hidream_image import HiDreamImagePipeline ``` ## /hi_diffusers/models/attention.py ```py path="/hi_diffusers/models/attention.py" import torch from torch import nn from typing import Optional from diffusers.models.attention_processor import Attention from diffusers.utils.torch_utils import maybe_allow_in_graph @maybe_allow_in_graph class HiDreamAttention(Attention): def __init__( self, query_dim: int, heads: int = 8, dim_head: int = 64, upcast_attention: bool = False, upcast_softmax: bool = False, scale_qk: bool = True, eps: float = 1e-5, processor = None, out_dim: int = None, single: bool = False ): super(Attention, self).__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.out_dim = out_dim if out_dim is not None else query_dim self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = out_dim // dim_head if out_dim is not None else heads self.sliceable_head_dim = heads self.single = single linear_cls = nn.Linear self.linear_cls = linear_cls self.to_q = linear_cls(query_dim, self.inner_dim) self.to_k = linear_cls(self.inner_dim, self.inner_dim) self.to_v = linear_cls(self.inner_dim, self.inner_dim) self.to_out = linear_cls(self.inner_dim, self.out_dim) self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps) self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps) if not single: self.to_q_t = linear_cls(query_dim, self.inner_dim) self.to_k_t = linear_cls(self.inner_dim, self.inner_dim) self.to_v_t = linear_cls(self.inner_dim, self.inner_dim) self.to_out_t = linear_cls(self.inner_dim, self.out_dim) self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps) self.set_processor(processor) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward( self, norm_image_tokens: torch.FloatTensor, image_tokens_masks: torch.FloatTensor = None, norm_text_tokens: torch.FloatTensor = None, rope: torch.FloatTensor = None, ) -> torch.Tensor: return self.processor( self, image_tokens = norm_image_tokens, image_tokens_masks = image_tokens_masks, text_tokens = norm_text_tokens, rope = rope, ) class FeedForwardSwiGLU(nn.Module): def __init__( self, dim: int, hidden_dim: int, multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None, ): super().__init__() hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ( (hidden_dim + multiple_of - 1) // multiple_of ) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) ``` ## /hi_diffusers/models/attention_processor.py ```py path="/hi_diffusers/models/attention_processor.py" from typing import Optional import torch from .attention import HiDreamAttention ATTN_FUNC_BACKEND = None import einops try: try: from flash_attn_interface import flash_attn_func ATTN_FUNC_BACKEND = "FLASH_ATTN_3" except: from flash_attn import flash_attn_func ATTN_FUNC_BACKEND = "FLASH_ATTN_2" except: import torch.nn.functional as F ATTN_FUNC_BACKEND = "VANILLA" # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): if ATTN_FUNC_BACKEND == "FLASH_ATTN_3": hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] elif ATTN_FUNC_BACKEND == "FLASH_ATTN_2": hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) elif ATTN_FUNC_BACKEND == "VANILLA": # Use einops for transpose: b s, h d -> b h s d query = einops.rearrange(query, 'b s h d -> b h s d') key = einops.rearrange(key, 'b s h d -> b h s d') value = einops.rearrange(value, 'b s h d -> b h s d') hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) # Use einops for transpose: b h s d -> b s, h d hidden_states = einops.rearrange(hidden_states, 'b h s d -> b s h d') else: raise RuntimeError(f"Unknown attention backend: {ATTN_FUNC_BACKEND}") hidden_states = hidden_states.flatten(-2) hidden_states = hidden_states.to(query.dtype) return hidden_states class HiDreamAttnProcessor_flashattn: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __call__( self, attn: HiDreamAttention, image_tokens: torch.FloatTensor, image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, *args, **kwargs, ) -> torch.FloatTensor: dtype = image_tokens.dtype batch_size = image_tokens.shape[0] query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) value_i = attn.to_v(image_tokens) inner_dim = key_i.shape[-1] head_dim = inner_dim // attn.heads query_i = query_i.view(batch_size, -1, attn.heads, head_dim) key_i = key_i.view(batch_size, -1, attn.heads, head_dim) value_i = value_i.view(batch_size, -1, attn.heads, head_dim) if image_tokens_masks is not None: key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) if not attn.single: query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) value_t = attn.to_v_t(text_tokens) query_t = query_t.view(batch_size, -1, attn.heads, head_dim) key_t = key_t.view(batch_size, -1, attn.heads, head_dim) value_t = value_t.view(batch_size, -1, attn.heads, head_dim) num_image_tokens = query_i.shape[1] num_text_tokens = query_t.shape[1] query = torch.cat([query_i, query_t], dim=1) key = torch.cat([key_i, key_t], dim=1) value = torch.cat([value_i, value_t], dim=1) else: query = query_i key = key_i value = value_i if query.shape[-1] == rope.shape[-3] * 2: query, key = apply_rope(query, key, rope) else: query_1, query_2 = query.chunk(2, dim=-1) key_1, key_2 = key.chunk(2, dim=-1) query_1, key_1 = apply_rope(query_1, key_1, rope) query = torch.cat([query_1, query_2], dim=-1) key = torch.cat([key_1, key_2], dim=-1) hidden_states = attention(query, key, value) if not attn.single: hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) hidden_states_i = attn.to_out(hidden_states_i) hidden_states_t = attn.to_out_t(hidden_states_t) return hidden_states_i, hidden_states_t else: hidden_states = attn.to_out(hidden_states) return hidden_states ``` ## /hi_diffusers/models/embeddings.py ```py path="/hi_diffusers/models/embeddings.py" import torch from torch import nn from typing import List from diffusers.models.embeddings import Timesteps, TimestepEmbedding # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0, "The dimension must be even." scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta**scale) batch_size, seq_length = pos.shape out = torch.einsum("...n,d->...nd", pos, omega) cos_out = torch.cos(out) sin_out = torch.sin(out) stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) return out.float() # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class EmbedND(nn.Module): def __init__(self, theta: int, axes_dim: List[int]): super().__init__() self.theta = theta self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] emb = torch.cat( [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3, ) return emb.unsqueeze(2) class PatchEmbed(nn.Module): def __init__( self, patch_size=2, in_channels=4, out_channels=1024, ): super().__init__() self.patch_size = patch_size self.out_channels = out_channels self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, latent): latent = self.proj(latent) return latent class PooledEmbed(nn.Module): def __init__(self, text_emb_dim, hidden_size): super().__init__() self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, pooled_embed): return self.pooled_embedder(pooled_embed) class TimestepEmbed(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, timesteps, wdtype): t_emb = self.time_proj(timesteps).to(dtype=wdtype) t_emb = self.timestep_embedder(t_emb) return t_emb class OutEmbed(nn.Module): def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.zeros_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x, adaln_input): shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) x = self.linear(x) return x ``` ## /hi_diffusers/models/moe.py ```py path="/hi_diffusers/models/moe.py" import math import torch from torch import nn import torch.nn.functional as F from .attention import FeedForwardSwiGLU from torch.distributed.nn.functional import all_gather _LOAD_BALANCING_LOSS = [] def save_load_balancing_loss(loss): global _LOAD_BALANCING_LOSS _LOAD_BALANCING_LOSS.append(loss) def clear_load_balancing_loss(): global _LOAD_BALANCING_LOSS _LOAD_BALANCING_LOSS.clear() def get_load_balancing_loss(): global _LOAD_BALANCING_LOSS return _LOAD_BALANCING_LOSS def batched_load_balancing_loss(): aux_losses_arr = get_load_balancing_loss() alpha = aux_losses_arr[0][-1] Pi = torch.stack([ent[1] for ent in aux_losses_arr], dim=0) fi = torch.stack([ent[2] for ent in aux_losses_arr], dim=0) fi_list = all_gather(fi) fi = torch.stack(fi_list, 0).mean(0) aux_loss = (Pi * fi).sum(-1).mean() * alpha return aux_loss # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py class MoEGate(nn.Module): def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01): super().__init__() self.top_k = num_activated_experts self.n_routed_experts = num_routed_experts self.scoring_func = 'softmax' self.alpha = aux_loss_alpha self.seq_aux = False # topk selection algorithm self.norm_topk_prob = False self.gating_dim = embed_dim self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape # print(bsz, seq_len, h) ### compute gating score hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states, self.weight, None) if self.scoring_func == 'softmax': scores = logits.softmax(dim=-1) else: raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') ### select top-k experts topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator ### expert-level computation auxiliary loss if self.training and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k # always compute aux loss based on the naive greedy topk method topk_idx_for_aux_loss = topk_idx.view(bsz, -1) if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts) aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha else: mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha save_load_balancing_loss((aux_loss, Pi, fi, self.alpha)) else: aux_loss = None return topk_idx, topk_weight, aux_loss # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py class MOEFeedForwardSwiGLU(nn.Module): def __init__( self, dim: int, hidden_dim: int, num_routed_experts: int, num_activated_experts: int, ): super().__init__() self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2) self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]) self.gate = MoEGate( embed_dim = dim, num_routed_experts = num_routed_experts, num_activated_experts = num_activated_experts ) self.num_activated_experts = num_activated_experts def forward(self, x): wtype = x.dtype identity = x orig_shape = x.shape topk_idx, topk_weight, aux_loss = self.gate(x) x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1) if self.training: x = x.repeat_interleave(self.num_activated_experts, dim=0) y = torch.empty_like(x, dtype=wtype) for i, expert in enumerate(self.experts): y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape).to(dtype=wtype) #y = AddAuxiliaryLoss.apply(y, aux_loss) else: y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) y = y + self.shared_experts(identity) return y @torch.no_grad() def moe_infer(self, x, flat_expert_indices, flat_expert_weights): expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) token_idxs = idxs // self.num_activated_experts for i, end_idx in enumerate(tokens_per_expert): start_idx = 0 if i == 0 else tokens_per_expert[i-1] if start_idx == end_idx: continue expert = self.experts[i] exp_token_idx = token_idxs[start_idx:end_idx] expert_tokens = x[exp_token_idx] expert_out = expert(expert_tokens) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) # for fp16 and other dtype expert_cache = expert_cache.to(expert_out.dtype) expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') return expert_cache ``` ## /hi_diffusers/models/transformers/transformer_hidream_image.py ```py path="/hi_diffusers/models/transformers/transformer_hidream_image.py" from typing import Any, Dict, Optional, Tuple, List import torch import torch.nn as nn import einops from einops import repeat from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.modeling_outputs import Transformer2DModelOutput from ..embeddings import PatchEmbed, PooledEmbed, TimestepEmbed, EmbedND, OutEmbed from ..attention import HiDreamAttention, FeedForwardSwiGLU from ..attention_processor import HiDreamAttnProcessor_flashattn from ..moe import MOEFeedForwardSwiGLU logger = logging.get_logger(__name__) # pylint: disable=invalid-name class TextProjection(nn.Module): def __init__(self, in_features, hidden_size): super().__init__() self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False) def forward(self, caption): hidden_states = self.linear(caption) return hidden_states class BlockType: TransformerBlock = 1 SingleTransformerBlock = 2 @maybe_allow_in_graph class HiDreamImageSingleTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, num_routed_experts: int = 4, num_activated_experts: int = 2 ): super().__init__() self.num_attention_heads = num_attention_heads self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True) ) nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) # 1. Attention self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) self.attn1 = HiDreamAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, processor = HiDreamAttnProcessor_flashattn(), single = True ) # 3. Feed-forward self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) if num_routed_experts > 0: self.ff_i = MOEFeedForwardSwiGLU( dim = dim, hidden_dim = 4 * dim, num_routed_experts = num_routed_experts, num_activated_experts = num_activated_experts, ) else: self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) def forward( self, image_tokens: torch.FloatTensor, image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) # 1. MM-Attention norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i attn_output_i = self.attn1( norm_image_tokens, image_tokens_masks, rope = rope, ) image_tokens = gate_msa_i * attn_output_i + image_tokens # 2. Feed-forward norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) image_tokens = ff_output_i + image_tokens return image_tokens @maybe_allow_in_graph class HiDreamImageTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, num_routed_experts: int = 4, num_activated_experts: int = 2 ): super().__init__() self.num_attention_heads = num_attention_heads self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(dim, 12 * dim, bias=True) ) nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) # 1. Attention self.norm1_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) self.norm1_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) self.attn1 = HiDreamAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, processor = HiDreamAttnProcessor_flashattn(), single = False ) # 3. Feed-forward self.norm3_i = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) if num_routed_experts > 0: self.ff_i = MOEFeedForwardSwiGLU( dim = dim, hidden_dim = 4 * dim, num_routed_experts = num_routed_experts, num_activated_experts = num_activated_experts, ) else: self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) self.norm3_t = nn.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim) def forward( self, image_tokens: torch.FloatTensor, image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, adaln_input: Optional[torch.FloatTensor] = None, rope: torch.FloatTensor = None, ) -> torch.FloatTensor: wtype = image_tokens.dtype shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) # 1. MM-Attention norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t attn_output_i, attn_output_t = self.attn1( norm_image_tokens, image_tokens_masks, norm_text_tokens, rope = rope, ) image_tokens = gate_msa_i * attn_output_i + image_tokens text_tokens = gate_msa_t * attn_output_t + text_tokens # 2. Feed-forward norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) image_tokens = ff_output_i + image_tokens text_tokens = ff_output_t + text_tokens return image_tokens, text_tokens @maybe_allow_in_graph class HiDreamImageBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, num_routed_experts: int = 4, num_activated_experts: int = 2, block_type: BlockType = BlockType.TransformerBlock, ): super().__init__() block_classes = { BlockType.TransformerBlock: HiDreamImageTransformerBlock, BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, } self.block = block_classes[block_type]( dim, num_attention_heads, attention_head_dim, num_routed_experts, num_activated_experts ) def forward( self, image_tokens: torch.FloatTensor, image_tokens_masks: Optional[torch.FloatTensor] = None, text_tokens: Optional[torch.FloatTensor] = None, adaln_input: torch.FloatTensor = None, rope: torch.FloatTensor = None, ) -> torch.FloatTensor: return self.block( image_tokens, image_tokens_masks, text_tokens, adaln_input, rope, ) class HiDreamImageTransformer2DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin ): _supports_gradient_checkpointing = True _no_split_modules = ["HiDreamImageBlock"] @register_to_config def __init__( self, patch_size: Optional[int] = None, in_channels: int = 64, out_channels: Optional[int] = None, num_layers: int = 16, num_single_layers: int = 32, attention_head_dim: int = 128, num_attention_heads: int = 20, caption_channels: List[int] = None, text_emb_dim: int = 2048, num_routed_experts: int = 4, num_activated_experts: int = 2, axes_dims_rope: Tuple[int, int] = (32, 32), max_resolution: Tuple[int, int] = (128, 128), llama_layers: List[int] = None, ): super().__init__() self.out_channels = out_channels or in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.llama_layers = llama_layers self.t_embedder = TimestepEmbed(self.inner_dim) self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim) self.x_embedder = PatchEmbed( patch_size = patch_size, in_channels = in_channels, out_channels = self.inner_dim, ) self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) self.double_stream_blocks = nn.ModuleList( [ HiDreamImageBlock( dim = self.inner_dim, num_attention_heads = self.config.num_attention_heads, attention_head_dim = self.config.attention_head_dim, num_routed_experts = num_routed_experts, num_activated_experts = num_activated_experts, block_type = BlockType.TransformerBlock ) for i in range(self.config.num_layers) ] ) self.single_stream_blocks = nn.ModuleList( [ HiDreamImageBlock( dim = self.inner_dim, num_attention_heads = self.config.num_attention_heads, attention_head_dim = self.config.attention_head_dim, num_routed_experts = num_routed_experts, num_activated_experts = num_activated_experts, block_type = BlockType.SingleTransformerBlock ) for i in range(self.config.num_single_layers) ] ) self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels) caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] caption_projection = [] for caption_channel in caption_channels: caption_projection.append(TextProjection(in_features = caption_channel, hidden_size = self.inner_dim)) self.caption_projection = nn.ModuleList(caption_projection) self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def expand_timesteps(self, timesteps, batch_size, device): if not torch.is_tensor(timesteps): is_mps = device.type == "mps" if isinstance(timesteps, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(batch_size) return timesteps def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: if is_training: x = einops.rearrange(x, 'B S (p1 p2 C) -> B C S (p1 p2)', p1=self.config.patch_size, p2=self.config.patch_size) else: x_arr = [] for i, img_size in enumerate(img_sizes): pH, pW = img_size x_arr.append( einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', p1=self.config.patch_size, p2=self.config.patch_size) ) x = torch.cat(x_arr, dim=0) return x def patchify(self, x, max_seq, img_sizes=None): pz2 = self.config.patch_size * self.config.patch_size if isinstance(x, torch.Tensor): B, C = x.shape[0], x.shape[1] device = x.device dtype = x.dtype else: B, C = len(x), x[0].shape[0] device = x[0].device dtype = x[0].dtype x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) if img_sizes is not None: for i, img_size in enumerate(img_sizes): x_masks[i, 0:img_size[0] * img_size[1]] = 1 x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) elif isinstance(x, torch.Tensor): pH, pW = x.shape[-2] // self.config.patch_size, x.shape[-1] // self.config.patch_size x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.config.patch_size, p2=self.config.patch_size) img_sizes = [[pH, pW]] * B x_masks = None else: raise NotImplementedError return x, x_masks, img_sizes def forward( self, hidden_states: torch.Tensor, timesteps: torch.LongTensor = None, encoder_hidden_states: torch.Tensor = None, pooled_embeds: torch.Tensor = None, img_sizes: Optional[List[Tuple[int, int]]] = None, img_ids: Optional[torch.Tensor] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) # spatial forward batch_size = hidden_states.shape[0] hidden_states_type = hidden_states.dtype # 0. time timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) timesteps = self.t_embedder(timesteps, hidden_states_type) p_embedder = self.p_embedder(pooled_embeds) adaln_input = timesteps + p_embedder hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) if image_tokens_masks is None: pH, pW = img_sizes[0] img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) hidden_states = self.x_embedder(hidden_states) T5_encoder_hidden_states = encoder_hidden_states[0] encoder_hidden_states = encoder_hidden_states[-1] encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] if self.caption_projection is not None: new_encoder_hidden_states = [] for i, enc_hidden_state in enumerate(encoder_hidden_states): enc_hidden_state = self.caption_projection[i](enc_hidden_state) enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) new_encoder_hidden_states.append(enc_hidden_state) encoder_hidden_states = new_encoder_hidden_states T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) encoder_hidden_states.append(T5_encoder_hidden_states) txt_ids = torch.zeros( batch_size, encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], 3, device=img_ids.device, dtype=img_ids.dtype ) ids = torch.cat((img_ids, txt_ids), dim=1) rope = self.pe_embedder(ids) # 2. Blocks block_id = 0 initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] for bid, block in enumerate(self.double_stream_blocks): cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states, initial_encoder_hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, image_tokens_masks, cur_encoder_hidden_states, adaln_input, rope, **ckpt_kwargs, ) else: hidden_states, initial_encoder_hidden_states = block( image_tokens = hidden_states, image_tokens_masks = image_tokens_masks, text_tokens = cur_encoder_hidden_states, adaln_input = adaln_input, rope = rope, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] block_id += 1 image_tokens_seq_len = hidden_states.shape[1] hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) hidden_states_seq_len = hidden_states.shape[1] if image_tokens_masks is not None: encoder_attention_mask_ones = torch.ones( (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), device=image_tokens_masks.device, dtype=image_tokens_masks.dtype ) image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) for bid, block in enumerate(self.single_stream_blocks): cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, image_tokens_masks, None, adaln_input, rope, **ckpt_kwargs, ) else: hidden_states = block( image_tokens = hidden_states, image_tokens_masks = image_tokens_masks, text_tokens = None, adaln_input = adaln_input, rope = rope, ) hidden_states = hidden_states[:, :hidden_states_seq_len] block_id += 1 hidden_states = hidden_states[:, :image_tokens_seq_len, ...] output = self.final_layer(hidden_states, adaln_input) output = self.unpatchify(output, img_sizes, self.training) if image_tokens_masks is not None: image_tokens_masks = image_tokens_masks[:, :image_tokens_seq_len] if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output, image_tokens_masks) return Transformer2DModelOutput(sample=output, mask=image_tokens_masks) ``` ## /hi_diffusers/pipelines/hidream_image/pipeline_hidream_image.py ```py path="/hi_diffusers/pipelines/hidream_image/pipeline_hidream_image.py" import inspect from typing import Any, Callable, Dict, List, Optional, Union import math import einops import torch from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer, LlamaForCausalLM, PreTrainedTokenizerFast ) from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FromSingleFileMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from .pipeline_output import HiDreamImagePipelineOutput from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin): model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_encoder_2: CLIPTextModelWithProjection, tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5Tokenizer, text_encoder_4: LlamaForCausalLM, tokenizer_4: PreTrainedTokenizerFast, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, text_encoder_4=text_encoder_4, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, tokenizer_4=tokenizer_4, scheduler=scheduler, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.default_sample_size = 128 self.tokenizer_4.pad_token = self.tokenizer_4.eos_token def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder_3.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_3( prompt, padding="max_length", max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _get_clip_prompt_embeds( self, tokenizer, text_encoder, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( prompt, padding="max_length", max_length=min(max_sequence_length, 218), truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {218} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def _get_llama3_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 128, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder_4.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_4( prompt, padding="max_length", max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" ) outputs = self.text_encoder_4( text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True, output_attentions=True ) prompt_embeds = outputs.hidden_states[1:] prompt_embeds = torch.stack(prompt_embeds, dim=0) _, _, seq_len, dim = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], prompt_4: Union[str, List[str]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, negative_prompt_4: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[List[torch.FloatTensor]] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 128, lora_scale: Optional[float] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] prompt_embeds, pooled_prompt_embeds = self._encode_prompt( prompt = prompt, prompt_2 = prompt_2, prompt_3 = prompt_3, prompt_4 = prompt_4, device = device, dtype = dtype, num_images_per_prompt = num_images_per_prompt, prompt_embeds = prompt_embeds, pooled_prompt_embeds = pooled_prompt_embeds, max_sequence_length = max_sequence_length, ) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_3 = negative_prompt_3 or negative_prompt negative_prompt_4 = negative_prompt_4 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) negative_prompt_4 = ( batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( prompt = negative_prompt, prompt_2 = negative_prompt_2, prompt_3 = negative_prompt_3, prompt_4 = negative_prompt_4, device = device, dtype = dtype, num_images_per_prompt = num_images_per_prompt, prompt_embeds = negative_prompt_embeds, pooled_prompt_embeds = negative_pooled_prompt_embeds, max_sequence_length = max_sequence_length, ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def _encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], prompt_3: Union[str, List[str]], prompt_4: Union[str, List[str]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[List[torch.FloatTensor]] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 128, ): device = device or self._execution_device if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 prompt_3 = prompt_3 or prompt prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 prompt_4 = prompt_4 or prompt prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( self.tokenizer, self.text_encoder, prompt = prompt, num_images_per_prompt = num_images_per_prompt, max_sequence_length = max_sequence_length, device = device, dtype = dtype, ) pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( self.tokenizer_2, self.text_encoder_2, prompt = prompt_2, num_images_per_prompt = num_images_per_prompt, max_sequence_length = max_sequence_length, device = device, dtype = dtype, ) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) t5_prompt_embeds = self._get_t5_prompt_embeds( prompt = prompt_3, num_images_per_prompt = num_images_per_prompt, max_sequence_length = max_sequence_length, device = device, dtype = dtype ) llama3_prompt_embeds = self._get_llama3_prompt_embeds( prompt = prompt_4, num_images_per_prompt = num_images_per_prompt, max_sequence_length = max_sequence_length, device = device, dtype = dtype ) prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] return prompt_embeds, pooled_prompt_embeds def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) return latents @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, prompt_4: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_3: Optional[Union[str, List[str]]] = None, negative_prompt_4: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 128, ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor division = self.vae_scale_factor * 2 S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 scale = S_max / (width * height) scale = math.sqrt(scale) width, height = int(width * scale // division * division), int(height * scale // division * division) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, prompt_4=prompt_4, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, negative_prompt_4=negative_prompt_4, do_classifier_free_guidance=self.do_classifier_free_guidance, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) if self.do_classifier_free_guidance: prompt_embeds_arr = [] for n, p in zip(negative_prompt_embeds, prompt_embeds): if len(n.shape) == 3: prompt_embeds_arr.append(torch.cat([n, p], dim=0)) else: prompt_embeds_arr.append(torch.cat([n, p], dim=1)) prompt_embeds = prompt_embeds_arr pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, pooled_prompt_embeds.dtype, device, generator, latents, ) if latents.shape[-2] != latents.shape[-1]: B, C, H, W = latents.shape pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) img_ids = torch.zeros(pH, pW, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] img_ids = img_ids.reshape(pH * pW, -1) img_ids_pad = torch.zeros(self.transformer.max_seq, 3) img_ids_pad[:pH*pW, :] = img_ids img_sizes = img_sizes.unsqueeze(0).to(latents.device) img_ids = img_ids_pad.unsqueeze(0).to(latents.device) if self.do_classifier_free_guidance: img_sizes = img_sizes.repeat(2 * B, 1) img_ids = img_ids.repeat(2 * B, 1, 1) else: img_sizes = img_ids = None # 5. Prepare timesteps mu = calculate_shift(self.transformer.max_seq) scheduler_kwargs = {"mu": mu} if isinstance(self.scheduler, FlowUniPCMultistepScheduler): self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu)) timesteps = self.scheduler.timesteps else: timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) if latent_model_input.shape[-2] != latent_model_input.shape[-1]: B, C, H, W = latent_model_input.shape patch_size = self.transformer.config.patch_size pH, pW = H // patch_size, W // patch_size out = torch.zeros( (B, C, self.transformer.max_seq, patch_size * patch_size), dtype=latent_model_input.dtype, device=latent_model_input.device ) latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) out[:, :, 0:pH*pW] = latent_model_input latent_model_input = out noise_pred = self.transformer( hidden_states = latent_model_input, timesteps = timestep, encoder_hidden_states = prompt_embeds, pooled_embeds = pooled_prompt_embeds, img_sizes = img_sizes, img_ids = img_ids, return_dict = False, )[0] noise_pred = -noise_pred # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return HiDreamImagePipelineOutput(images=image) ``` ## /hi_diffusers/pipelines/hidream_image/pipeline_output.py ```py path="/hi_diffusers/pipelines/hidream_image/pipeline_output.py" from dataclasses import dataclass from typing import List, Union import numpy as np import PIL.Image from diffusers.utils import BaseOutput @dataclass class HiDreamImagePipelineOutput(BaseOutput): """ Output class for HiDreamImage pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] ``` ## /hi_diffusers/schedulers/flash_flow_match.py ```py path="/hi_diffusers/schedulers/flash_flow_match.py" # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import BaseOutput, is_scipy_available, logging from diffusers.utils.torch_utils import randn_tensor if is_scipy_available(): import scipy.stats logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class FlashFlowMatchEulerDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.FloatTensor class FlashFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. shift (`float`, defaults to 1.0): The shift value for the timestep schedule. """ _compatibles = [] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, shift: float = 1.0, use_dynamic_shifting=False, base_shift: Optional[float] = 0.5, max_shift: Optional[float] = 1.15, base_image_seq_len: Optional[int] = 256, max_image_seq_len: Optional[int] = 4096, invert_sigmas: bool = False, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: raise ValueError( "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." ) timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps if not use_dynamic_shifting: # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def scale_noise( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ Forward process in flow-matching Args: sample (`torch.FloatTensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.FloatTensor`: A scaled input sample. """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) if sample.device.type == "mps" and torch.is_floating_point(timestep): # mps does not support float64 schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) timestep = timestep.to(sample.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(sample.device) timestep = timestep.to(sample.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timestep.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timestep.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(sample.shape): sigma = sigma.unsqueeze(-1) sample = sigma * noise + (1.0 - sigma) * sample return sample def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps def time_shift(self, mu: float, sigma: float, t: torch.Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, sigmas: Optional[List[float]] = None, mu: Optional[float] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ if self.config.use_dynamic_shifting and mu is None: raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") if sigmas is None: timesteps = np.linspace( self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps ) sigmas = timesteps / self.config.num_train_timesteps else: sigmas = np.array(sigmas).astype(np.float32) num_inference_steps = len(sigmas) self.num_inference_steps = num_inference_steps if self.config.use_dynamic_shifting: sigmas = self.time_shift(mu, 1.0, sigmas) else: sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) elif self.config.use_beta_sigmas: sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps if self.config.invert_sigmas: sigmas = 1.0 - sigmas timesteps = sigmas * self.config.num_train_timesteps sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) else: sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.timesteps = timesteps.to(device=device) self.sigmas = sigmas self._step_index = None self._begin_index = None def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[FlashFlowMatchEulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sigma = self.sigmas[self.step_index] # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) denoised = sample - model_output * sigma if self.step_index < self.num_inference_steps - 1: sigma_next = self.sigmas[self.step_index + 1] noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype, ) sample = sigma_next * noise + (1.0 - sigma_next) * denoised self._step_index += 1 sample = sample.to(model_output.dtype) if not return_dict: return (sample,) return FlashFlowMatchEulerDiscreteSchedulerOutput(prev_sample=sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: """Constructs the noise schedule of Karras et al. (2022).""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: """Constructs an exponential noise schedule.""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta def _convert_to_beta( self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 ) -> torch.Tensor: """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" # Hack to make sure that other schedulers which copy this function don't break # TODO: Add this logic to the other schedulers if hasattr(self.config, "sigma_min"): sigma_min = self.config.sigma_min else: sigma_min = None if hasattr(self.config, "sigma_max"): sigma_max = self.config.sigma_max else: sigma_max = None sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() sigmas = np.array( [ sigma_min + (ppf * (sigma_max - sigma_min)) for ppf in [ scipy.stats.beta.ppf(timestep, alpha, beta) for timestep in 1 - np.linspace(0, 1, num_inference_steps) ] ] ) return sigmas def __len__(self): return self.config.num_train_timesteps ``` ## /hi_diffusers/schedulers/fm_solvers_unipc.py ```py path="/hi_diffusers/schedulers/fm_solvers_unipc.py" # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py # Convert unipc for flow matching # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput) from diffusers.utils import deprecate, is_scipy_available if is_scipy_available(): import scipy.stats class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. solver_order (`int`, default `2`): The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, defaults to "flow_prediction"): Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts the flow of the diffusion process. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. predict_x0 (`bool`, defaults to `True`): Whether to use the updating algorithm on the predicted x0. solver_type (`str`, default `bh2`): Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` otherwise. lower_order_final (`bool`, default `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. disable_corrector (`list`, default `[]`): Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is usually disabled during the first few steps. solver_p (`SchedulerMixin`, default `None`): Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, solver_order: int = 2, prediction_type: str = "flow_prediction", shift: Optional[float] = 1.0, use_dynamic_shifting=False, thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, predict_x0: bool = True, solver_type: str = "bh2", lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, timestep_spacing: str = "linspace", steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" ): if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: self.register_to_config(solver_type="bh2") else: raise NotImplementedError( f"{solver_type} is not implemented for {self.__class__}") self.predict_x0 = predict_x0 # setable values self.num_inference_steps = None alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() sigmas = 1.0 - alphas sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) if not use_dynamic_shifting: # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore self.sigmas = sigmas self.timesteps = sigmas * num_train_timesteps self.model_outputs = [None] * solver_order self.timestep_list = [None] * solver_order self.lower_order_nums = 0 self.disable_corrector = disable_corrector self.solver_p = solver_p self.last_sample = None self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to( "cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() @property def step_index(self): """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps def set_timesteps( self, num_inference_steps: Union[int, None] = None, device: Union[str, torch.device] = None, sigmas: Optional[List[float]] = None, mu: Optional[Union[float, None]] = None, shift: Optional[Union[float, None]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): Total number of the spacing of the time steps. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ if self.config.use_dynamic_shifting and mu is None: raise ValueError( " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" ) if sigmas is None: sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore if self.config.use_dynamic_shifting: sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore else: if shift is None: shift = self.config.shift sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0])**0.5 elif self.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" ) timesteps = sigmas * self.config.num_train_timesteps sigmas = np.concatenate([sigmas, [sigma_last] ]).astype(np.float32) # pyright: ignore self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to( device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 self.last_sample = None if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) # add an index counter for schedulers that allow duplicated timesteps self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to( "cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float( ) # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile( abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze( 1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp( sample, -s, s ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps def _sigma_to_alpha_sigma_t(self, sigma): return 1 - sigma, sigma # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps def time_shift(self, mu: float, sigma: float, t: torch.Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: r""" Convert the model output to the corresponding type the UniPC algorithm needs. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. Returns: `torch.Tensor`: The converted model output. """ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError( "missing `sample` as a required keyward argument") if timestep is not None: deprecate( "timesteps", "1.0.0", "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) if self.predict_x0: if self.config.prediction_type == "flow_prediction": sigma_t = self.sigmas[self.step_index] x0_pred = sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred else: if self.config.prediction_type == "flow_prediction": sigma_t = self.sigmas[self.step_index] epsilon = sample - (1 - sigma_t) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." ) if self.config.thresholding: sigma_t = self.sigmas[self.step_index] x0_pred = sample - sigma_t * model_output x0_pred = self._threshold_sample(x0_pred) epsilon = model_output + x0_pred return epsilon def multistep_uni_p_bh_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, order: int = None, # pyright: ignore **kwargs, ) -> torch.Tensor: """ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model at the current timestep. prev_timestep (`int`): The previous discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. order (`int`): The order of UniP at this timestep (corresponds to the *p* in UniPC-p). Returns: `torch.Tensor`: The sample tensor at the previous timestep. """ prev_timestep = args[0] if len(args) > 0 else kwargs.pop( "prev_timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError( " missing `sample` as a required keyward argument") if order is None: if len(args) > 2: order = args[2] else: raise ValueError( " missing `order` as a required keyward argument") if prev_timestep is not None: deprecate( "prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) model_output_list = self.model_outputs s0 = self.timestep_list[-1] m0 = model_output_list[-1] x = sample if self.solver_p: x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ self.step_index] # pyright: ignore alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = sample.device rks = [] D1s = [] for i in range(1, order): si = self.step_index - i # pyright: ignore mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) # pyright: ignore rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) else: D1s = None if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res x_t = x_t.to(x.dtype) return x_t def multistep_uni_c_bh_update( self, this_model_output: torch.Tensor, *args, last_sample: torch.Tensor = None, this_sample: torch.Tensor = None, order: int = None, # pyright: ignore **kwargs, ) -> torch.Tensor: """ One step for the UniC (B(h) version). Args: this_model_output (`torch.Tensor`): The model outputs at `x_t`. this_timestep (`int`): The current timestep `t`. last_sample (`torch.Tensor`): The generated sample before the last predictor `x_{t-1}`. this_sample (`torch.Tensor`): The generated sample after the last predictor `x_{t}`. order (`int`): The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. Returns: `torch.Tensor`: The corrected sample tensor at the current timestep. """ this_timestep = args[0] if len(args) > 0 else kwargs.pop( "this_timestep", None) if last_sample is None: if len(args) > 1: last_sample = args[1] else: raise ValueError( " missing`last_sample` as a required keyward argument") if this_sample is None: if len(args) > 2: this_sample = args[2] else: raise ValueError( " missing`this_sample` as a required keyward argument") if order is None: if len(args) > 3: order = args[3] else: raise ValueError( " missing`order` as a required keyward argument") if this_timestep is not None: deprecate( "this_timestep", "1.0.0", "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) model_output_list = self.model_outputs m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ self.step_index - 1] # pyright: ignore alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = this_sample.device rks = [] D1s = [] for i in range(1, order): si = self.step_index - (i + 1) # pyright: ignore mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) # pyright: ignore rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) else: D1s = None # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) x_t = x_t.to(x.dtype) return x_t def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. """ if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step(self, model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, return_dict: bool = True, generator=None) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with the multistep UniPC. Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.step_index is None: self._init_step_index(timestep) use_corrector = ( self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None # pyright: ignore ) model_output_convert = self.convert_model_output( model_output, sample=sample) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep # pyright: ignore if self.config.lower_order_final: this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore else: this_order = self.config.solver_order self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep assert self.this_order > 0 self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, # pass the original non-converted model output, in case solver-p is used sample=sample, order=self.this_order, ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1 # pyright: ignore if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. Returns: `torch.Tensor`: A scaled input sample. """ return sample # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to( device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point( timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to( original_samples.device, dtype=torch.float32) timesteps = timesteps.to( original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [ self.index_for_timestep(t, schedule_timesteps) for t in timesteps ] elif self.step_index is not None: # add_noise is called after first denoising step (for inpainting) step_indices = [self.step_index] * timesteps.shape[0] else: # add noise is called before first denoising step to create initial latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ``` ## /inference.py ```py path="/inference.py" import torch import argparse from hi_diffusers import HiDreamImagePipeline from hi_diffusers import HiDreamImageTransformer2DModel from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler from transformers import LlamaForCausalLM, PreTrainedTokenizerFast parser = argparse.ArgumentParser() parser.add_argument("--model_type", type=str, default="dev") args = parser.parse_args() model_type = args.model_type MODEL_PREFIX = "HiDream-ai" LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Model configurations MODEL_CONFIGS = { "dev": { "path": f"{MODEL_PREFIX}/HiDream-I1-Dev", "guidance_scale": 0.0, "num_inference_steps": 28, "shift": 6.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler }, "full": { "path": f"{MODEL_PREFIX}/HiDream-I1-Full", "guidance_scale": 5.0, "num_inference_steps": 50, "shift": 3.0, "scheduler": FlowUniPCMultistepScheduler }, "fast": { "path": f"{MODEL_PREFIX}/HiDream-I1-Fast", "guidance_scale": 0.0, "num_inference_steps": 16, "shift": 3.0, "scheduler": FlashFlowMatchEulerDiscreteScheduler } } # Resolution options RESOLUTION_OPTIONS = [ "1024 × 1024 (Square)", "768 × 1360 (Portrait)", "1360 × 768 (Landscape)", "880 × 1168 (Portrait)", "1168 × 880 (Landscape)", "1248 × 832 (Landscape)", "832 × 1248 (Portrait)" ] # Load models def load_models(model_type): config = MODEL_CONFIGS[model_type] pretrained_model_name_or_path = config["path"] scheduler = MODEL_CONFIGS[model_type]["scheduler"](num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False) tokenizer_4 = PreTrainedTokenizerFast.from_pretrained( LLAMA_MODEL_NAME, use_fast=False) text_encoder_4 = LlamaForCausalLM.from_pretrained( LLAMA_MODEL_NAME, output_hidden_states=True, output_attentions=True, torch_dtype=torch.bfloat16).to("cuda") transformer = HiDreamImageTransformer2DModel.from_pretrained( pretrained_model_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") pipe = HiDreamImagePipeline.from_pretrained( pretrained_model_name_or_path, scheduler=scheduler, tokenizer_4=tokenizer_4, text_encoder_4=text_encoder_4, torch_dtype=torch.bfloat16 ).to("cuda", torch.bfloat16) pipe.transformer = transformer return pipe, config # Parse resolution string to get height and width def parse_resolution(resolution_str): if "1024 × 1024" in resolution_str: return 1024, 1024 elif "768 × 1360" in resolution_str: return 768, 1360 elif "1360 × 768" in resolution_str: return 1360, 768 elif "880 × 1168" in resolution_str: return 880, 1168 elif "1168 × 880" in resolution_str: return 1168, 880 elif "1248 × 832" in resolution_str: return 1248, 832 elif "832 × 1248" in resolution_str: return 832, 1248 else: return 1024, 1024 # Default fallback # Generate image function def generate_image(pipe, model_type, prompt, resolution, seed): # Get configuration for current model config = MODEL_CONFIGS[model_type] guidance_scale = config["guidance_scale"] num_inference_steps = config["num_inference_steps"] # Parse resolution height, width = parse_resolution(resolution) # Handle seed if seed == -1: seed = torch.randint(0, 1000000, (1,)).item() generator = torch.Generator("cuda").manual_seed(seed) images = pipe( prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, num_images_per_prompt=1, generator=generator ).images return images[0], seed # Initialize with default model print("Loading default model (full)...") pipe, _ = load_models(model_type) print("Model loaded successfully!") prompt = "A cat holding a sign that says \"Hi-Dreams.ai\"." resolution = "1024 × 1024 (Square)" seed = -1 image, seed = generate_image(pipe, model_type, prompt, resolution, seed) image.save("output.png") ``` ## /requirements.txt torch>=2.5.1 torchvision>=0.20.1 diffusers>=0.32.1 transformers>=4.47.1 einops>=0.7.0 accelerate>=1.2.1 The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.