-
Notifications
You must be signed in to change notification settings - Fork 226
/
Copy pathinference_with_vllm.py
184 lines (149 loc) · 4.62 KB
/
inference_with_vllm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import argparse
from typing import List, Union
from PIL import Image
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments for model inference.
Returns:
argparse.Namespace: Parsed command line arguments
"""
parser = argparse.ArgumentParser(
description="Run inference with Skywork-R1V series model using vLLM."
)
# Model configuration
parser.add_argument(
"--model_path",
type=str,
default="Skywork/Skywork-R1V2-38B",
help="Path to the model"
)
parser.add_argument(
"--tensor_parallel_size",
type=int,
default=4,
help="Number of GPUs for tensor parallelism"
)
# Input parameters
parser.add_argument(
"--image_paths",
type=str,
nargs="+",
required=True,
help="Path(s) to the input image(s)"
)
parser.add_argument(
"--question",
type=str,
required=True,
help="Question to ask the model"
)
# Generation parameters
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Temperature for sampling (higher = more creative)"
)
parser.add_argument(
"--max_tokens",
type=int,
default=8000,
help="Maximum number of tokens to generate"
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.05,
help="Penalty for repeated tokens (1.0 = no penalty)"
)
parser.add_argument(
"--top_p",
type=float,
default=0.95,
help="Top-p (nucleus) sampling probability"
)
return parser.parse_args()
def load_images(image_paths: List[str]) -> Union[Image.Image, List[Image.Image]]:
"""Load images from given paths.
Args:
image_paths: List of image file paths
Returns:
Single image if one path provided, else list of images
"""
images = [Image.open(img_path) for img_path in image_paths]
return images[0] if len(images) == 1 else images
def prepare_question(question: str, num_images: int) -> str:
"""Format the question with appropriate image tags.
Args:
question: Original question string
num_images: Number of images being processed
Returns:
Formatted question string
"""
if not question.startswith("<image>\n"):
return "<image>\n" * num_images + question
return question
def initialize_model(args: argparse.Namespace) -> tuple[LLM, AutoTokenizer]:
"""Initialize the LLM model and tokenizer.
Args:
args: Parsed command line arguments
Returns:
Tuple of (LLM instance, tokenizer)
"""
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
llm = LLM(
model=args.model_path,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=True,
limit_mm_per_prompt={"image": 20},
gpu_memory_utilization=0.7,
)
return llm, tokenizer
def generate_response(
llm: LLM,
tokenizer: AutoTokenizer,
question: str,
images: Union[Image.Image, List[Image.Image]],
sampling_params: SamplingParams
) -> str:
"""Generate response from the model.
Args:
llm: Initialized LLM instance
tokenizer: Initialized tokenizer
question: Formatted question string
images: Input image(s)
sampling_params: Generation parameters
Returns:
Generated response text
"""
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params
)
return outputs[0].outputs[0].text
def main() -> None:
"""Main execution function."""
args = parse_arguments()
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
repetition_penalty=args.repetition_penalty,
)
llm, tokenizer = initialize_model(args)
images = load_images(args.image_paths)
question = prepare_question(args.question, len(args.image_paths))
response = generate_response(llm, tokenizer, question, images, sampling_params)
print(f"User: {args.question}\nAssistant: {response}")
if __name__ == "__main__":
main()