多模态大模型Reyes增加batch推理方式,提升推理速度 原创

发布于 2025-1-16 15:35
浏览
0收藏

笔者在前面预训练了一个多模态大模型Reyes,详情见《​​【多模态&LLM】Reyes:一个从0到1开始训练的多模态大模型(技术报告)​​》。本文将为Reyes增加一个batch推理方式,提高Reyes的推理速度。

Reyes-8B开源地址:

  • modelscope权重下载地址:https://modelscope.cn/models/yujunhuinlp/Reyes-8B
  • github:https://github.com/yujunhuics/Reyes

使用方式

将本仓库中的​​modeling_reyes.py​​​文件替换modelscrope下载的​​modeling_reyes.py​​​运行即可。 batch推理详细见github:​​batch_inference.ipynb​​.

​modeling_reyes.py​​增项:

def chat_batch(
            self,
            tokenizer,
            pixel_values_list,
            questions,
            generation_config,
            histories=None,
            return_histories=False,
            num_patches_lists=None,
            IMG_START_TOKEN='<|vision_start|>',
            IMG_END_TOKEN='<|vision_end|>',
            IMG_CONTEXT_TOKEN='<|vision_pad|>',
            verbose=False,
            visual_features_list=None
    ):

        if histories isNone:
            histories = [[] for _ in questions]

        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_context_token_id = img_context_token_id
        # Get eos_token_id from the template
        template = get_conv_template(self.template)
        template.system_message = self.system_message
        eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
        generation_config['eos_token_id'] = eos_token_id

        queries = []
        input_ids_list = []
        attention_mask_list = []

        for idx in range(len(questions)):
            question = questions[idx]
            history = histories[idx]
            pixel_values = pixel_values_list[idx] if pixel_values_list[idx] isnotNoneelseNone
            num_patches_list = [pixel_values.shape[0]] if pixel_values isnotNoneelse []

            ifnot history and pixel_values isnotNoneand'<image>'notin question:
                question = '<image>\n' + question

            template_i = get_conv_template(self.template)
            template_i.system_message = self.system_message
            for (old_question, old_answer) in history:
                template_i.append_message(template_i.roles[0], old_question)
                template_i.append_message(template_i.roles[1], old_answer)
            template_i.append_message(template_i.roles[0], question)
            template_i.append_message(template_i.roles[1], None)
            query = template_i.get_prompt()
            # Handle image tokens
            if pixel_values isnotNone:
                for num_patches in num_patches_list:
                    tile_pos_identifiers = [f"<tile_{i}>"for i in range(1, num_patches)] + ["<tile_global_thumbnail>"]
                    image_tokens = ''
                    for tile_pos_identifier in tile_pos_identifiers:
                        image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token
                    image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN
                    query = query.replace('<image>', image_tokens, 1)

            model_inputs = tokenizer(
                query,
                return_tensors='pt',
                padding=True,
                truncation=True
            )
            input_ids = model_inputs['input_ids'].cuda()
            attention_mask = model_inputs['attention_mask'].cuda()
            input_ids_list.append(input_ids)
            attention_mask_list.append(attention_mask)

        # Call the generate function
        generation_output = self.generate_batch(
            pixel_values_list=pixel_values_list,
            input_ids_list=input_ids_list,
            attention_mask_list=attention_mask_list,
            **generation_config
        )
        responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)

        outputs = []
        for idx, response in enumerate(responses):
            response = response.split(template.sep)[0].strip()
            histories[idx].append((questions[idx], response))
            outputs.append(response)

        if return_histories:
            return outputs, histories
        else:
            if verbose:
                for idx, query in enumerate(queries):
                    query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
                    query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
                    print(query_to_print, outputs[idx])
            return outputs

    @torch.no_grad()
    def generate_batch(
            self,
            pixel_values_list: Optional[List[torch.FloatTensor]] = None,
            input_ids_list: Optional[List[torch.FloatTensor]] = None,
            attention_mask_list: Optional[List[torch.LongTensor]] = None,
            visual_features: Optional[torch.FloatTensor] = None,
            generation_config: Optional[GenerationConfig] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            **generate_kwargs,
    ) -> torch.LongTensor:
        input_embeds_list = []
        attention_mask_padded_list = []

        max_seq_length = max(input_ids.shape[1] for input_ids in input_ids_list)

        for pixel_values, input_ids, attention_mask in zip(pixel_values_list, input_ids_list, attention_mask_list):
            if pixel_values isnotNone:
                if visual_features isnotNone:
                    vit_embeds = visual_features.cuda()
                    vit_embeds = self.mlp1(vit_embeds)
                else:
                    vit_embeds = self.extract_feature(pixel_values)

                input_embeds = self.language_model.get_input_embeddings()(input_ids)
                B, N, C = input_embeds.shape
                input_embeds = input_embeds.reshape(B * N, C)

                input_ids = input_ids.reshape(B * N)
                selected = (input_ids == self.img_context_token_id)
                assert selected.sum() != 0, "No valid image context token IDs found."
                input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)

                input_embeds = input_embeds.reshape(B, N, C)
            else:
                input_embeds = self.language_model.get_input_embeddings()(input_ids)

            seq_length = input_embeds.shape[1]
            if seq_length < max_seq_length:
                pad_size = max_seq_length - seq_length
                input_embeds = F.pad(input_embeds, (0, 0, 0, pad_size))
                attention_mask = F.pad(attention_mask, (0, pad_size))

            input_embeds_list.append(input_embeds)
            attention_mask_padded_list.append(attention_mask)

        input_embeds = torch.cat(input_embeds_list, dim=0)
        attention_mask = torch.cat(attention_mask_padded_list, dim=0)

        outputs = self.language_model.generate(
            inputs_embeds=input_embeds,
            attention_mask=attention_mask,
            generation_config=generation_config,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            use_cache=True,
            **generate_kwargs,
        )

        return outputs

batch推理:

import torch
from modelscope import AutoTokenizer, AutoModel
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB'else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


def preprocess_image(file_path, dynamic=True, max_num=6, image_size=448):
    try:
        if dynamic:
            return load_image(file_path, max_num=max_num).to(torch.bfloat16).cuda()
        else:
            img = Image.open(file_path).convert('RGB')
            transform = build_transform(image_size)
            pixel_values = transform(img)
            return torch.stack([pixel_values]).to(torch.bfloat16).cuda()
    except Exception as e:
        raise RuntimeError(f"Error processing image: {e}")


path = "Reyes-8B"

model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).eval().cuda()

# print(model)

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
generation_config = dict(max_new_tokens=2048, do_sample=False)
questions = [
    "<image>\nDescribe this image.",
    "<image>\nDescribe this image.",
    "<image>\nDescribe this image.",
]

images_path = ["t6.png","t6.png","t6.png"]


def conversation(model, tokenizer, questions, images_path,generation_config,histories):
    pixel_values_list=[]

    for i in range(len(questions)):
        if images_path[i] isnotNone:
            pixel_values = preprocess_image(file_path, dynamic=True)
            pixel_values_list.append(pixel_values)


    return model.chat_batch(tokenizer, pixel_values_list, questions, generation_config, histories, return_histories=False)

responses= conversation(model, tokenizer, questions, images_path,generation_config,histories=None)
for question, response in zip(questions, responses):
    print(f"User: {question}\nAssistant: {response}\n")



本文转载自公众号大模型自然语言处理  作者:余俊晖

原文链接:​​https://mp.weixin.qq.com/s/IeDUGzTOnOEONrFoLvXFcg​


©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
收藏
回复
举报
回复
相关推荐