回复
多模态大模型Reyes增加batch推理方式,提升推理速度 原创
笔者在前面预训练了一个多模态大模型Reyes,详情见《【多模态&LLM】Reyes:一个从0到1开始训练的多模态大模型(技术报告)》。本文将为Reyes增加一个batch推理方式,提高Reyes的推理速度。
Reyes-8B开源地址:
- modelscope权重下载地址:https://modelscope.cn/models/yujunhuinlp/Reyes-8B
- github:https://github.com/yujunhuics/Reyes
使用方式
将本仓库中的modeling_reyes.py
文件替换modelscrope下载的modeling_reyes.py
运行即可。 batch推理详细见github:batch_inference.ipynb
.
modeling_reyes.py
增项:
def chat_batch(
self,
tokenizer,
pixel_values_list,
questions,
generation_config,
histories=None,
return_histories=False,
num_patches_lists=None,
IMG_START_TOKEN='<|vision_start|>',
IMG_END_TOKEN='<|vision_end|>',
IMG_CONTEXT_TOKEN='<|vision_pad|>',
verbose=False,
visual_features_list=None
):
if histories isNone:
histories = [[] for _ in questions]
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
# Get eos_token_id from the template
template = get_conv_template(self.template)
template.system_message = self.system_message
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
generation_config['eos_token_id'] = eos_token_id
queries = []
input_ids_list = []
attention_mask_list = []
for idx in range(len(questions)):
question = questions[idx]
history = histories[idx]
pixel_values = pixel_values_list[idx] if pixel_values_list[idx] isnotNoneelseNone
num_patches_list = [pixel_values.shape[0]] if pixel_values isnotNoneelse []
ifnot history and pixel_values isnotNoneand'<image>'notin question:
question = '<image>\n' + question
template_i = get_conv_template(self.template)
template_i.system_message = self.system_message
for (old_question, old_answer) in history:
template_i.append_message(template_i.roles[0], old_question)
template_i.append_message(template_i.roles[1], old_answer)
template_i.append_message(template_i.roles[0], question)
template_i.append_message(template_i.roles[1], None)
query = template_i.get_prompt()
# Handle image tokens
if pixel_values isnotNone:
for num_patches in num_patches_list:
tile_pos_identifiers = [f"<tile_{i}>"for i in range(1, num_patches)] + ["<tile_global_thumbnail>"]
image_tokens = ''
for tile_pos_identifier in tile_pos_identifiers:
image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token
image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
model_inputs = tokenizer(
query,
return_tensors='pt',
padding=True,
truncation=True
)
input_ids = model_inputs['input_ids'].cuda()
attention_mask = model_inputs['attention_mask'].cuda()
input_ids_list.append(input_ids)
attention_mask_list.append(attention_mask)
# Call the generate function
generation_output = self.generate_batch(
pixel_values_list=pixel_values_list,
input_ids_list=input_ids_list,
attention_mask_list=attention_mask_list,
**generation_config
)
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
outputs = []
for idx, response in enumerate(responses):
response = response.split(template.sep)[0].strip()
histories[idx].append((questions[idx], response))
outputs.append(response)
if return_histories:
return outputs, histories
else:
if verbose:
for idx, query in enumerate(queries):
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
print(query_to_print, outputs[idx])
return outputs
@torch.no_grad()
def generate_batch(
self,
pixel_values_list: Optional[List[torch.FloatTensor]] = None,
input_ids_list: Optional[List[torch.FloatTensor]] = None,
attention_mask_list: Optional[List[torch.LongTensor]] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
input_embeds_list = []
attention_mask_padded_list = []
max_seq_length = max(input_ids.shape[1] for input_ids in input_ids_list)
for pixel_values, input_ids, attention_mask in zip(pixel_values_list, input_ids_list, attention_mask_list):
if pixel_values isnotNone:
if visual_features isnotNone:
vit_embeds = visual_features.cuda()
vit_embeds = self.mlp1(vit_embeds)
else:
vit_embeds = self.extract_feature(pixel_values)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.img_context_token_id)
assert selected.sum() != 0, "No valid image context token IDs found."
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
seq_length = input_embeds.shape[1]
if seq_length < max_seq_length:
pad_size = max_seq_length - seq_length
input_embeds = F.pad(input_embeds, (0, 0, 0, pad_size))
attention_mask = F.pad(attention_mask, (0, pad_size))
input_embeds_list.append(input_embeds)
attention_mask_padded_list.append(attention_mask)
input_embeds = torch.cat(input_embeds_list, dim=0)
attention_mask = torch.cat(attention_mask_padded_list, dim=0)
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=True,
**generate_kwargs,
)
return outputs
batch推理:
import torch
from modelscope import AutoTokenizer, AutoModel
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB'else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def preprocess_image(file_path, dynamic=True, max_num=6, image_size=448):
try:
if dynamic:
return load_image(file_path, max_num=max_num).to(torch.bfloat16).cuda()
else:
img = Image.open(file_path).convert('RGB')
transform = build_transform(image_size)
pixel_values = transform(img)
return torch.stack([pixel_values]).to(torch.bfloat16).cuda()
except Exception as e:
raise RuntimeError(f"Error processing image: {e}")
path = "Reyes-8B"
model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).eval().cuda()
# print(model)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
generation_config = dict(max_new_tokens=2048, do_sample=False)
questions = [
"<image>\nDescribe this image.",
"<image>\nDescribe this image.",
"<image>\nDescribe this image.",
]
images_path = ["t6.png","t6.png","t6.png"]
def conversation(model, tokenizer, questions, images_path,generation_config,histories):
pixel_values_list=[]
for i in range(len(questions)):
if images_path[i] isnotNone:
pixel_values = preprocess_image(file_path, dynamic=True)
pixel_values_list.append(pixel_values)
return model.chat_batch(tokenizer, pixel_values_list, questions, generation_config, histories, return_histories=False)
responses= conversation(model, tokenizer, questions, images_path,generation_config,histories=None)
for question, response in zip(questions, responses):
print(f"User: {question}\nAssistant: {response}\n")
本文转载自公众号大模型自然语言处理 作者:余俊晖
原文链接:https://mp.weixin.qq.com/s/IeDUGzTOnOEONrFoLvXFcg
©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
赞
收藏
回复
相关推荐