``` ├── LICENSE ├── README.md ├── SkyworkVL_RM.pdf ├── Skywork_R1V.pdf ├── Skywork_R1V2.pdf ├── imgs/ ├── Chemistry_cn.gif ├── comparsion.png ├── math_r1v.gif ├── multi_reasoning_osm.png ├── multi_reasoning_pm.png ├── r1v_comp.png ├── text_reasoning.png ├── inference/ ├── __init__.py ├── inference_with_transformers.py ├── inference_with_vllm.py ├── setup.sh ├── utils.py ``` ## /LICENSE ``` path="/LICENSE" MIT License Copyright (c) 2025 Skywork Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ``` ## /README.md # Skywork-R1V: Pioneering Multimodal Reasoning with CoT
[[🤗 Skywork-R1V2-38B](https://huggingface.co/Skywork/Skywork-R1V2-38B)] [[🤖 R1V2 ModelScope](https://modelscope.cn/models/Skywork/Skywork-R1V2-38B)] [[📖 R1V2 Report](https://arxiv.org/abs/2504.16656)]

[[🤗 Skywork-R1V-38B](https://huggingface.co/Skywork/Skywork-R1V-38B)] [[📖 R1V1 Report](https://arxiv.org/abs/2504.05599)]
math_r1v chemistry_1


Welcome to the Skywork-R1V repository! Here, you'll find the model weights and inference code for our state-of-the-art open-sourced multimodal reasoning model, enabling advanced visual and text thinking. ## 🔥News **April 24, 2025**: We released **Skywork-R1V2**, a state-of-the-art, open-source multimodal reasoning model that achieves leading performance across multiple vision-language benchmarks.[[🤗 Skywork-R1V2-38B](https://huggingface.co/Skywork/Skywork-R1V2-38B)][[📖R1V2 Report](https://arxiv.org/abs/2504.16656)] **April 9, 2025**: Our technical report is currently available on arxiv: [[Skywork-R1V: Pioneering Multimodal Reasoning with CoT](https://arxiv.org/abs/2504.05599)]. **April 1, 2025**: Skywork-R1V supports inference with [[vLLM](https://github.com/vllm-project/vllm)], On 4×L20Y GPUs, vLLM generates 1k tokens in ~12.3s, at least 5× faster than transformers. **Mar 26, 2025**: We released awq quantized version of Skywork R1V[[🤗 Skywork-R1V-38B-AWQ](https://huggingface.co/Skywork/Skywork-R1V-38B-AWQ)], supporting single-card (above 30GB) inference. **Mar 18, 2025**: We are thrilled to introduce Skywork R1V, the first industry open-sourced multimodal reasoning model with advanced visual chain-of-thought capabilities, pushing the boundaries of AI-driven vision and logical inference! 🚀 ## R1V2-38B Evaluation Skywork-R1V2-38B demonstrates state-of-the-art performance on both text and multimodal reasoning tasks.
Comparison of Skywork-R1V2 with Multimodal Open-Source and Proprietary Models
Model Text Reasoning (pass@1 or %) Multimodal Reasoning (%)
AIME24 LiveCodebench liveBench IFEVAL BFCL MMMU(val) MathVista(mini) MathVision(mini) OlympiadBench mmmu-pro
Skywork-R1V2-38B 78.9 63.6 73.2 82.9 66.3 73.6 74.0 49.0 62.6 52.0
OpenAI-4o 74.6 9.3 49.9 69.1 63.8 58.0
Claude 3.5 Sonnet 16.0 65.0 66.4 65.3
Kimi k1.5 77.5 70.0 74.9
Qwen2.5-VL-72B 70.2 74.8 38.1 40.4
InternVL3-38B 70.1 75.1 34.2 -


Text Reasoning Performance
text_reasoning


Multimodal Reasoning vs Proprietary Models
multi_reasoning_pm


Multimodal Reasoning vs Open-Source Models
multi_reasoning_osm
## How to Run Locally ### 1. Clone the Repository ```shell git clone https://github.com/SkyworkAI/Skywork-R1V.git cd skywork-r1v/inference ``` ### 2. Set Up the Environment ```shell # For Transformers conda create -n r1-v python=3.10 && conda activate r1-v bash setup.sh # For vLLM conda create -n r1v-vllm python=3.10 && conda activate r1v-vllm pip install -U vllm ``` ### 3. Run the Inference Script #### Using Transformers ```shell CUDA_VISIBLE_DEVICES="0,1" python inference_with_transformers.py \ --model_path path \ --image_paths image1_path \ --question "your question" ``` #### Using vLLM ```shell python inference_with_vllm.py \ --model_path path \ --image_paths image1_path image2_path \ --question "your question" \ --tensor_parallel_size 4 ``` ## License This code repository is licensed under [the MIT License](https://github.com/SkyworkAI/Skywork-R1V/blob/main/LICENSE). ✅ Commercial use permitted ✅ Modification allowed ✅ Distribution allowed ❌ No liability ## Citation If you use Skywork-R1V in your research, please cite: ``` @misc{chris2025skyworkr1v2multimodalhybrid, title={Skywork R1V2: Multimodal Hybrid Reinforcement Learning for Reasoning}, author={Chris and Yichen Wei and Yi Peng and Xiaokun Wang and Weijie Qiu and Wei Shen and Tianyidan Xie and Jiangbo Pei and Jianhao Zhang and Yunzhuo Hao and Xuchen Song and Yang Liu and Yahui Zhou}, year={2025}, eprint={2504.16656}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2504.16656}, } ``` ``` @misc{peng2025skyworkr1vpioneeringmultimodal, title={Skywork R1V: Pioneering Multimodal Reasoning with Chain-of-Thought}, author={Yi Peng and Chris and Xiaokun Wang and Yichen Wei and Jiangbo Pei and Weijie Qiu and Ai Jian and Yunzhuo Hao and Jiachun Pan and Tianyidan Xie and Li Ge and Rongxian Zhuang and Xuchen Song and Yang Liu and Yahui Zhou}, year={2025}, eprint={2504.05599}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2504.05599}, } ``` ## /SkyworkVL_RM.pdf Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/SkyworkVL_RM.pdf ## /Skywork_R1V.pdf Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/Skywork_R1V.pdf ## /Skywork_R1V2.pdf Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/Skywork_R1V2.pdf ## /imgs/Chemistry_cn.gif Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/Chemistry_cn.gif ## /imgs/comparsion.png Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/comparsion.png ## /imgs/math_r1v.gif Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/math_r1v.gif ## /imgs/multi_reasoning_osm.png Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/multi_reasoning_osm.png ## /imgs/multi_reasoning_pm.png Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/multi_reasoning_pm.png ## /imgs/r1v_comp.png Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/r1v_comp.png ## /imgs/text_reasoning.png Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/text_reasoning.png ## /inference/__init__.py ```py path="/inference/__init__.py" ``` ## /inference/inference_with_transformers.py ```py path="/inference/inference_with_transformers.py" import torch from transformers import AutoModel, AutoTokenizer from utils import load_image, split_model import argparse def main(): # 解析命令行参数 parser = argparse.ArgumentParser(description="Run inference with Skywork-R1V model.") parser.add_argument('--model_path', type=str, default='Skywork/Skywork-R1V-38B', help="Path to the model.") parser.add_argument('--image_paths', type=str, nargs='+', required=True, help="Path(s) to the image(s).") parser.add_argument('--question', type=str, required=True, help="Question to ask the model.") args = parser.parse_args() device_map = split_model(args.model_path) model = AutoModel.from_pretrained( args.model_path, torch_dtype=torch.bfloat16, load_in_8bit=False, low_cpu_mem_usage=True, use_flash_attn=True, trust_remote_code=True, device_map=device_map ).eval() tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, use_fast=False) pixel_values = [load_image(img_path, max_num=12).to(torch.bfloat16).cuda() for img_path in args.image_paths] if len(pixel_values) > 1: num_patches_list = [img.size(0) for img in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) else: pixel_values = pixel_values[0] num_patches_list = None prompt = "\n"*len(args.image_paths) + args.question generation_config = dict(max_new_tokens=64000, do_sample=True, temperature=0.6, top_p=0.95, repetition_penalty=1.05) response = model.chat(tokenizer, pixel_values, prompt, generation_config, num_patches_list=num_patches_list) print(f'User: {args.question}\nAssistant: {response}') if __name__ == '__main__': main() ``` ## /inference/inference_with_vllm.py ```py path="/inference/inference_with_vllm.py" import argparse from typing import List, Union from PIL import Image from transformers import AutoTokenizer from vllm import LLM, SamplingParams def parse_arguments() -> argparse.Namespace: """Parse command line arguments for model inference. Returns: argparse.Namespace: Parsed command line arguments """ parser = argparse.ArgumentParser( description="Run inference with Skywork-R1V series model using vLLM." ) # Model configuration parser.add_argument( "--model_path", type=str, default="Skywork/Skywork-R1V2-38B", help="Path to the model" ) parser.add_argument( "--tensor_parallel_size", type=int, default=4, help="Number of GPUs for tensor parallelism" ) # Input parameters parser.add_argument( "--image_paths", type=str, nargs="+", required=True, help="Path(s) to the input image(s)" ) parser.add_argument( "--question", type=str, required=True, help="Question to ask the model" ) # Generation parameters parser.add_argument( "--temperature", type=float, default=0.0, help="Temperature for sampling (higher = more creative)" ) parser.add_argument( "--max_tokens", type=int, default=8000, help="Maximum number of tokens to generate" ) parser.add_argument( "--repetition_penalty", type=float, default=1.05, help="Penalty for repeated tokens (1.0 = no penalty)" ) parser.add_argument( "--top_p", type=float, default=0.95, help="Top-p (nucleus) sampling probability" ) return parser.parse_args() def load_images(image_paths: List[str]) -> Union[Image.Image, List[Image.Image]]: """Load images from given paths. Args: image_paths: List of image file paths Returns: Single image if one path provided, else list of images """ images = [Image.open(img_path) for img_path in image_paths] return images[0] if len(images) == 1 else images def prepare_question(question: str, num_images: int) -> str: """Format the question with appropriate image tags. Args: question: Original question string num_images: Number of images being processed Returns: Formatted question string """ if not question.startswith("\n"): return "\n" * num_images + question return question def initialize_model(args: argparse.Namespace) -> tuple[LLM, AutoTokenizer]: """Initialize the LLM model and tokenizer. Args: args: Parsed command line arguments Returns: Tuple of (LLM instance, tokenizer) """ tokenizer = AutoTokenizer.from_pretrained(args.model_path) llm = LLM( model=args.model_path, tensor_parallel_size=args.tensor_parallel_size, trust_remote_code=True, limit_mm_per_prompt={"image": 20}, gpu_memory_utilization=0.7, ) return llm, tokenizer def generate_response( llm: LLM, tokenizer: AutoTokenizer, question: str, images: Union[Image.Image, List[Image.Image]], sampling_params: SamplingParams ) -> str: """Generate response from the model. Args: llm: Initialized LLM instance tokenizer: Initialized tokenizer question: Formatted question string images: Input image(s) sampling_params: Generation parameters Returns: Generated response text """ messages = [{"role": "user", "content": question}] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) outputs = llm.generate( { "prompt": prompt, "multi_modal_data": {"image": images}, }, sampling_params=sampling_params ) return outputs[0].outputs[0].text def main() -> None: """Main execution function.""" args = parse_arguments() sampling_params = SamplingParams( temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens, repetition_penalty=args.repetition_penalty, ) llm, tokenizer = initialize_model(args) images = load_images(args.image_paths) question = prepare_question(args.question, len(args.image_paths)) response = generate_response(llm, tokenizer, question, images, sampling_params) print(f"User: {args.question}\nAssistant: {response}") if __name__ == "__main__": main() ``` ## /inference/setup.sh ```sh path="/inference/setup.sh" pip install torch==2.6.0 pip install pillow==11.1.0 pip install torchvision==0.21.0 pip install transformers==4.37.2 pip install einops==0.6.1 pip install einops-exts==0.0.4 pip install timm==0.9.12 pip install bitsandbytes==0.42.0 pip install accelerate==1.5.2 pip install flash-attn --no-build-isolation ``` ## /inference/utils.py ```py path="/inference/utils.py" import math import torch import torchvision.transforms as T from PIL import Image from torchvision.transforms.functional import InterpolationMode from transformers import AutoConfig IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def build_transform(input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def load_image(image_file, input_size=448, max_num=12): image = Image.open(image_file).convert('RGB') transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values def split_model(model_path): device_map = {} world_size = torch.cuda.device_count() config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) num_layers = config.llm_config.num_hidden_layers num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5)) num_layers_per_gpu = [num_layers_per_gpu] * world_size num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5) layer_cnt = 0 for i, num_layer in enumerate(num_layers_per_gpu): for j in range(num_layer): device_map[f'language_model.model.layers.{layer_cnt}'] = i layer_cnt += 1 device_map['vision_model'] = 0 device_map['mlp1'] = 0 device_map['language_model.model.tok_embeddings'] = 0 device_map['language_model.model.embed_tokens'] = 0 device_map['language_model.output'] = 0 device_map['language_model.model.norm'] = 0 device_map['language_model.model.rotary_emb'] = 0 device_map['language_model.lm_head'] = 0 device_map[f'language_model.model.layers.{num_layers - 1}'] = 0 return device_map ``` The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.