```
├── LICENSE
├── README.md
├── SkyworkVL_RM.pdf
├── Skywork_R1V.pdf
├── Skywork_R1V2.pdf
├── imgs/
├── Chemistry_cn.gif
├── comparsion.png
├── math_r1v.gif
├── multi_reasoning_osm.png
├── multi_reasoning_pm.png
├── r1v_comp.png
├── text_reasoning.png
├── inference/
├── __init__.py
├── inference_with_transformers.py
├── inference_with_vllm.py
├── setup.sh
├── utils.py
```
## /LICENSE
``` path="/LICENSE"
MIT License
Copyright (c) 2025 Skywork
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
## /README.md
# Skywork-R1V: Pioneering Multimodal Reasoning with CoT
[[🤗 Skywork-R1V2-38B](https://huggingface.co/Skywork/Skywork-R1V2-38B)] [[🤖 R1V2 ModelScope](https://modelscope.cn/models/Skywork/Skywork-R1V2-38B)] [[📖 R1V2 Report](https://arxiv.org/abs/2504.16656)]
[[🤗 Skywork-R1V-38B](https://huggingface.co/Skywork/Skywork-R1V-38B)] [[📖 R1V1 Report](https://arxiv.org/abs/2504.05599)]
Welcome to the Skywork-R1V repository! Here, you'll find the model weights and inference code for our state-of-the-art open-sourced multimodal reasoning model, enabling advanced visual and text thinking.
## 🔥News
**April 24, 2025**: We released **Skywork-R1V2**, a state-of-the-art, open-source multimodal reasoning model that achieves leading performance across multiple vision-language benchmarks.[[🤗 Skywork-R1V2-38B](https://huggingface.co/Skywork/Skywork-R1V2-38B)][[📖R1V2 Report](https://arxiv.org/abs/2504.16656)]
**April 9, 2025**: Our technical report is currently available on arxiv: [[Skywork-R1V: Pioneering Multimodal Reasoning with CoT](https://arxiv.org/abs/2504.05599)].
**April 1, 2025**: Skywork-R1V supports inference with [[vLLM](https://github.com/vllm-project/vllm)], On 4×L20Y GPUs, vLLM generates 1k tokens in ~12.3s, at least 5× faster than transformers.
**Mar 26, 2025**: We released awq quantized version of Skywork R1V[[🤗 Skywork-R1V-38B-AWQ](https://huggingface.co/Skywork/Skywork-R1V-38B-AWQ)], supporting single-card (above 30GB) inference.
**Mar 18, 2025**: We are thrilled to introduce Skywork R1V, the first industry open-sourced multimodal reasoning model with advanced visual chain-of-thought capabilities, pushing the boundaries of AI-driven vision and logical inference! 🚀
## R1V2-38B Evaluation
Skywork-R1V2-38B demonstrates state-of-the-art performance on both text and multimodal reasoning tasks.
Comparison of Skywork-R1V2 with Multimodal Open-Source and Proprietary Models
Model |
Text Reasoning (pass@1 or %) |
Multimodal Reasoning (%) |
AIME24 |
LiveCodebench |
liveBench |
IFEVAL |
BFCL |
MMMU(val) |
MathVista(mini) |
MathVision(mini) |
OlympiadBench |
mmmu-pro |
Skywork-R1V2-38B |
78.9 |
63.6 |
73.2 |
82.9 |
66.3 |
73.6 |
74.0 |
49.0 |
62.6 |
52.0 |
OpenAI-4o |
74.6 |
9.3 |
49.9 |
— |
— |
69.1 |
63.8 |
58.0 |
— |
— |
Claude 3.5 Sonnet |
16.0 |
— |
65.0 |
— |
— |
66.4 |
65.3 |
— |
— |
— |
Kimi k1.5 |
77.5 |
— |
— |
— |
— |
70.0 |
74.9 |
— |
— |
— |
Qwen2.5-VL-72B |
— |
— |
— |
— |
— |
70.2 |
74.8 |
38.1 |
40.4 |
— |
InternVL3-38B |
— |
— |
— |
— |
— |
70.1 |
75.1 |
34.2 |
- |
— |
Text Reasoning Performance
Multimodal Reasoning vs Proprietary Models
Multimodal Reasoning vs Open-Source Models
## How to Run Locally
### 1. Clone the Repository
```shell
git clone https://github.com/SkyworkAI/Skywork-R1V.git
cd skywork-r1v/inference
```
### 2. Set Up the Environment
```shell
# For Transformers
conda create -n r1-v python=3.10 && conda activate r1-v
bash setup.sh
# For vLLM
conda create -n r1v-vllm python=3.10 && conda activate r1v-vllm
pip install -U vllm
```
### 3. Run the Inference Script
#### Using Transformers
```shell
CUDA_VISIBLE_DEVICES="0,1" python inference_with_transformers.py \
--model_path path \
--image_paths image1_path \
--question "your question"
```
#### Using vLLM
```shell
python inference_with_vllm.py \
--model_path path \
--image_paths image1_path image2_path \
--question "your question" \
--tensor_parallel_size 4
```
## License
This code repository is licensed under [the MIT License](https://github.com/SkyworkAI/Skywork-R1V/blob/main/LICENSE).
✅ Commercial use permitted
✅ Modification allowed
✅ Distribution allowed
❌ No liability
## Citation
If you use Skywork-R1V in your research, please cite:
```
@misc{chris2025skyworkr1v2multimodalhybrid,
title={Skywork R1V2: Multimodal Hybrid Reinforcement Learning for Reasoning},
author={Chris and Yichen Wei and Yi Peng and Xiaokun Wang and Weijie Qiu and Wei Shen and Tianyidan Xie and Jiangbo Pei and Jianhao Zhang and Yunzhuo Hao and Xuchen Song and Yang Liu and Yahui Zhou},
year={2025},
eprint={2504.16656},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.16656},
}
```
```
@misc{peng2025skyworkr1vpioneeringmultimodal,
title={Skywork R1V: Pioneering Multimodal Reasoning with Chain-of-Thought},
author={Yi Peng and Chris and Xiaokun Wang and Yichen Wei and Jiangbo Pei and Weijie Qiu and Ai Jian and Yunzhuo Hao and Jiachun Pan and Tianyidan Xie and Li Ge and Rongxian Zhuang and Xuchen Song and Yang Liu and Yahui Zhou},
year={2025},
eprint={2504.05599},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.05599},
}
```
## /SkyworkVL_RM.pdf
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/SkyworkVL_RM.pdf
## /Skywork_R1V.pdf
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/Skywork_R1V.pdf
## /Skywork_R1V2.pdf
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/Skywork_R1V2.pdf
## /imgs/Chemistry_cn.gif
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/Chemistry_cn.gif
## /imgs/comparsion.png
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/comparsion.png
## /imgs/math_r1v.gif
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/math_r1v.gif
## /imgs/multi_reasoning_osm.png
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/multi_reasoning_osm.png
## /imgs/multi_reasoning_pm.png
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/multi_reasoning_pm.png
## /imgs/r1v_comp.png
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/r1v_comp.png
## /imgs/text_reasoning.png
Binary file available at https://raw.githubusercontent.com/SkyworkAI/Skywork-R1V/refs/heads/main/imgs/text_reasoning.png
## /inference/__init__.py
```py path="/inference/__init__.py"
```
## /inference/inference_with_transformers.py
```py path="/inference/inference_with_transformers.py"
import torch
from transformers import AutoModel, AutoTokenizer
from utils import load_image, split_model
import argparse
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description="Run inference with Skywork-R1V model.")
parser.add_argument('--model_path', type=str, default='Skywork/Skywork-R1V-38B', help="Path to the model.")
parser.add_argument('--image_paths', type=str, nargs='+', required=True, help="Path(s) to the image(s).")
parser.add_argument('--question', type=str, required=True, help="Question to ask the model.")
args = parser.parse_args()
device_map = split_model(args.model_path)
model = AutoModel.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
load_in_8bit=False,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
device_map=device_map
).eval()
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, use_fast=False)
pixel_values = [load_image(img_path, max_num=12).to(torch.bfloat16).cuda() for img_path in args.image_paths]
if len(pixel_values) > 1:
num_patches_list = [img.size(0) for img in pixel_values]
pixel_values = torch.cat(pixel_values, dim=0)
else:
pixel_values = pixel_values[0]
num_patches_list = None
prompt = "\n"*len(args.image_paths) + args.question
generation_config = dict(max_new_tokens=64000, do_sample=True, temperature=0.6, top_p=0.95, repetition_penalty=1.05)
response = model.chat(tokenizer, pixel_values, prompt, generation_config, num_patches_list=num_patches_list)
print(f'User: {args.question}\nAssistant: {response}')
if __name__ == '__main__':
main()
```
## /inference/inference_with_vllm.py
```py path="/inference/inference_with_vllm.py"
import argparse
from typing import List, Union
from PIL import Image
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments for model inference.
Returns:
argparse.Namespace: Parsed command line arguments
"""
parser = argparse.ArgumentParser(
description="Run inference with Skywork-R1V series model using vLLM."
)
# Model configuration
parser.add_argument(
"--model_path",
type=str,
default="Skywork/Skywork-R1V2-38B",
help="Path to the model"
)
parser.add_argument(
"--tensor_parallel_size",
type=int,
default=4,
help="Number of GPUs for tensor parallelism"
)
# Input parameters
parser.add_argument(
"--image_paths",
type=str,
nargs="+",
required=True,
help="Path(s) to the input image(s)"
)
parser.add_argument(
"--question",
type=str,
required=True,
help="Question to ask the model"
)
# Generation parameters
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Temperature for sampling (higher = more creative)"
)
parser.add_argument(
"--max_tokens",
type=int,
default=8000,
help="Maximum number of tokens to generate"
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.05,
help="Penalty for repeated tokens (1.0 = no penalty)"
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="Top-p (nucleus) sampling probability"
)
return parser.parse_args()
def load_images(image_paths: List[str]) -> Union[Image.Image, List[Image.Image]]:
"""Load images from given paths.
Args:
image_paths: List of image file paths
Returns:
Single image if one path provided, else list of images
"""
images = [Image.open(img_path) for img_path in image_paths]
return images[0] if len(images) == 1 else images
def prepare_question(question: str, num_images: int) -> str:
"""Format the question with appropriate image tags.
Args:
question: Original question string
num_images: Number of images being processed
Returns:
Formatted question string
"""
if not question.startswith("\n"):
return "\n" * num_images + question
return question
def initialize_model(args: argparse.Namespace) -> tuple[LLM, AutoTokenizer]:
"""Initialize the LLM model and tokenizer.
Args:
args: Parsed command line arguments
Returns:
Tuple of (LLM instance, tokenizer)
"""
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
llm = LLM(
model=args.model_path,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=True,
limit_mm_per_prompt={"image": 20},
gpu_memory_utilization=0.7,
)
return llm, tokenizer
def generate_response(
llm: LLM,
tokenizer: AutoTokenizer,
question: str,
images: Union[Image.Image, List[Image.Image]],
sampling_params: SamplingParams
) -> str:
"""Generate response from the model.
Args:
llm: Initialized LLM instance
tokenizer: Initialized tokenizer
question: Formatted question string
images: Input image(s)
sampling_params: Generation parameters
Returns:
Generated response text
"""
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params
)
return outputs[0].outputs[0].text
def main() -> None:
"""Main execution function."""
args = parse_arguments()
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
repetition_penalty=args.repetition_penalty,
)
llm, tokenizer = initialize_model(args)
images = load_images(args.image_paths)
question = prepare_question(args.question, len(args.image_paths))
response = generate_response(llm, tokenizer, question, images, sampling_params)
print(f"User: {args.question}\nAssistant: {response}")
if __name__ == "__main__":
main()
```
## /inference/setup.sh
```sh path="/inference/setup.sh"
pip install torch==2.6.0
pip install pillow==11.1.0
pip install torchvision==0.21.0
pip install transformers==4.37.2
pip install einops==0.6.1
pip install einops-exts==0.0.4
pip install timm==0.9.12
pip install bitsandbytes==0.42.0
pip install accelerate==1.5.2
pip install flash-attn --no-build-isolation
```
## /inference/utils.py
```py path="/inference/utils.py"
import math
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoConfig
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def split_model(model_path):
device_map = {}
world_size = torch.cuda.device_count()
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
num_layers = config.llm_config.num_hidden_layers
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
num_layers_per_gpu = [num_layers_per_gpu] * world_size
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language_model.model.layers.{layer_cnt}'] = i
layer_cnt += 1
device_map['vision_model'] = 0
device_map['mlp1'] = 0
device_map['language_model.model.tok_embeddings'] = 0
device_map['language_model.model.embed_tokens'] = 0
device_map['language_model.output'] = 0
device_map['language_model.model.norm'] = 0
device_map['language_model.model.rotary_emb'] = 0
device_map['language_model.lm_head'] = 0
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
return device_map
```
The better and more specific the context, the better the LLM can follow instructions. If the context seems verbose, the user can refine the filter using uithub. Thank you for using https://uithub.com - Perfect LLM context for any GitHub repo.