使用 OCR 识别手写文本

人工智能
GNHK手写笔记数据集由GoodNotes提供,包含来自世界各地学生的数百份英文手写笔记。本文实现了基于微调TrOCR模型进行手写文本识别。

本文实现了基于微调TrOCR模型进行手写文本识别。

1.GNHK手写笔记数据集

GNHK(GoodNotes Handwriting Kollection)手写笔记数据集由GoodNotes提供,包含来自世界各地学生的数百份英文手写笔记。

下载数据集

访问GNHK数据集官方网站:

(https://www.goodnotes.com/gnhk),滚动到底部,同意使用条款和条件;点击第二个链接下载数据集。

下载后会得到两个文件:train_data.zip 和 test_data.zip。解压这两个文件后,数据集的目录结构如下:

├── test_data
│   └── test
│       ├── eng_AF_004.jpg
│       ├── eng_AF_004.json
│       ├── eng_AF_007.jpg
│       ├── eng_AF_007.json
│       ...
│       ├── eng_NA_142.jpg
│       └── eng_NA_142.json
├── train_data
    └── train
        ├── eng_AF_001.jpg
        ├── eng_AF_001.json
        ├── eng_AF_002.jpg
        ├── eng_AF_002.json
        ...
        ├── eng_NA_146.jpg
        └── eng_NA_146.json
4 directories, 1375 files
  • 训练集:包含515个样本
  • 测试集:包含172个样本
  • 图像文件:从1080p到4K的高分辨率图像
  • 标注文件:每个图像文件对应一个JSON文件,包含图像中每个单词的标注信息

以下是数据集中的一些手写笔记图像样本。

每个图像文件对应一个JSON文件,文件内容格式如下:

[
    {
        "text": "%math%",
        "polygon": {
            "x0": 112, "y0": 556,
            "x1": 285, "y1": 563,
            "x2": 245, "y2": 776,
            "x3": 112, "y3": 783
        },
        "line_idx": 1,
        "type": "H"
    },
    {
        "text": "%math%",
        "polygon": {
            "x0": 2365, "y0": 202,
            "x1": 2350, "y1": 509,
            "x2": 2588, "y2": 527,
            "x3": 2632, "y3": 195
        },
        "line_idx": 0,
        "type": "H"
    },
    ...
    {
        "text": "ownership",
        "polygon": {
            "x0": 1347, "y0": 1606,
            "x1": 2238, "y1": 1574,
            "x2": 2170, "y2": 1884,
            "x3": 1300, "y3": 1747
        },
        "line_idx": 4,
        "type": "H"
    }
]

其中:

  • text:表示单词的内容。如果单词是数学符号、特殊字符或不可理解的内容(例如划线),则用%%符号包裹的特殊词表示。否则,text键包含实际的单词。
  • polygon:表示单词的多边形坐标,用于精确标注单词的位置。
  • line_idx:表示单词所在的行索引。
  • type:表示单词的类型,通常为"H"(手写)。

2.项目目录结构

├── input
│   └── gnhk_dataset
│       ├── test_data
│       ├── test_processed
│       ├── train_data
│       ├── train_processed
│       ├── test_processed.csv
│       └── train_processed.csv
├── pretrained_model_inference  [10066 entries exceeds filelimit, not opening dir]
├── trocr_handwritten
│   ├── checkpoint-6093
│   │   ├── config.json
│   │   ├── generation_config.json
│   │   ├── model.safetensors
│   │   ├── optimizer.pt
│   │   ├── preprocessor_config.json
│   │   ├── rng_state.pth
│   │   ├── scheduler.pt
│   │   ├── trainer_state.json
│   │   └── training_args.bin
│   ├── checkpoint-6770
│   │   ├── config.json
│   │   ├── generation_config.json
│   │   ├── model.safetensors
│   │   ├── optimizer.pt
│   │   ├── preprocessor_config.json
│   │   ├── rng_state.pth
│   │   ├── scheduler.pt
│   │   ├── trainer_state.json
│   │   └── training_args.bin
│   └── runs
│       └── Aug27_11-30-05_f57a2dab37c7
├── Fine_Tune_TrOCR_Handwritten.ipynb
├── preprocess_gnhk_dataset.py
└── Pretrained_Model_Inference.ipynb

目录说明:

  • input/gnhk_dataset:包含下载并解压的数据集
  • pretrained_model_inference:包含使用预训练的TrOCR手写模型对验证数据集进行推理的结果。
  • trocr_handwritten:包含微调TrOCR模型后的结果。
  • Fine_Tune_TrOCR_Handwritten.ipynb:用于微调TrOCR模型的Jupyter Notebook
  • preprocess_gnhk_dataset.py:包含预处理GNHK数据集的Python脚本
  • Pretrained_Model_Inference.ipynb:用于使用预训练模型进行推理的Jupyter Notebook

3.安装依赖项

在继续进行数据预处理、推理和训练之前,我们需要安装以下依赖项。

pip install transformers
pip install sentencepiece
pip install jiwer
pip install datasets
pip install evaluate
pip install -U accelerate

pip install matplotlib
pip install protobuf==3.20.1
pip install tensorboard

4.GNHK数据集预处理

预训练的TrOCR模型只能识别单个单词或单行句子,而GNHK数据集中的图像是整个文档的图像。因此需要对数据集进行预处理,以便模型能够更好地处理这些图像。

数据集预处理的关键步骤如下:

  • 转换多边形坐标为四点边界框坐标。
  • 裁剪每个单词并存储在单独的目录中。
  • 创建两个 CSV 文件,一个用于训练集,一个用于测试集。这些文件将包含裁剪后的图像名称和标签文本。

代码实现:

import os
import json
import csv
import cv2
import numpy as np
from tqdm import tqdm
 
def create_directories():
   """
   创建必要的目录
   """
   dirs = [
       'input/gnhk_dataset/train_processed/images',
       'input/gnhk_dataset/test_processed/images',
   ]
   for dir_path in dirs:
       os.makedirs(dir_path, exist_ok=True)
 
def polygon_to_bbox(polygon):
    """
    将多边形坐标转换为四点边界框坐标
    """
   points = np.array([(polygon[f'x{i}'], polygon[f'y{i}']) for i in range(4)], dtype=np.int32)
   x, y, w, h = cv2.boundingRect(points)
   return x, y, w, h
 
def process_dataset(input_folder, output_folder, csv_path):
    """
    处理数据集,裁剪图像并生成 CSV 文件
    """
   with open(csv_path, 'w', newline='') as csvfile:
       csv_writer = csv.writer(csvfile)
       csv_writer.writerow(['image_filename', 'text'])
       
       for filename in tqdm(os.listdir(input_folder), desc=f"Processing {os.path.basename(input_folder)}"):
           if filename.endswith('.json'):
               json_path = os.path.join(input_folder, filename)
               img_path = os.path.join(input_folder, filename.replace('.json', '.jpg'))
               
               with open(json_path, 'r') as f:
                   data = json.load(f)
               
               img = cv2.imread(img_path)
               
               for idx, item in enumerate(data):
                   text = item['text']
                   if text.startswith('%') and text.endswith('%'):
                       text = 'SPECIAL_CHARACTER'
                   
                   x, y, w, h = polygon_to_bbox(item['polygon'])
                   
                   cropped_img = img[y:y+h, x:x+w]
                   
                   output_filename = f"{filename.replace('.json', '')}_{idx}.jpg"
                   output_path = os.path.join(output_folder, output_filename)
                   cv2.imwrite(output_path, cropped_img)
                   
                   csv_writer.writerow([output_filename, text])
                   
def main():
    """
    主函数,创建目录并处理数据集
    """
   create_directories()

   process_dataset(
       'input/gnhk_dataset/train_data/train',
       'input/gnhk_dataset/train_processed/images',
       'input/gnhk_dataset/train_processed.csv'
   )
   process_dataset(
       'input/gnhk_dataset/test_data/test',
       'input/gnhk_dataset/test_processed/images',
       'input/gnhk_dataset/test_processed.csv'
   )

if __name__ == '__main__':
   main()

将上述代码保存为preprocess_gnhk_dataset.py文件。在终端中运行脚本。

python preprocess_gnhk_dataset.py

运行脚本后,将会在 input/gnhk_dataset 目录下创建以下子目录和文件:

子目录:

  • train_processed/images:存储训练集的裁剪图像。
  • test_processed/images:存储测试集的裁剪图像。

CSV 文件:

  • train_processed.csv:包含训练集的图像文件名和对应的标签文本。
  • test_processed.csv:包含测试集的图像文件名和对应的标签文本。

以下是一些经过处理后的裁剪图像示例:

csv文件示例如下图所示:

每个csv文件包括裁剪后的图像文件名和对应图像的标签文本。每一行表示一个裁剪后的图像及其对应的标签文本。

处理后的数据集包括:

  • 训练集:32495张裁剪图像
  • 测试集:10066张裁剪图像

5.微调TrOCR模型

首先,导入必要的库,并定义一些全局设置。

import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
 
from PIL import Image
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import (
   VisionEncoderDecoderModel,
   TrOCRProcessor,
   Seq2SeqTrainer,
   Seq2SeqTrainingArguments,
   default_data_collator
)
 
block_plot = False
plt.rcParams['figure.figsize'] = (12, 9)
 
os.environ["TOKENIZERS_PARALLELISM"] = 'false'

接着,为确保实验的可重复性,设置随机种子,并初始化计算设备。

def seed_everything(seed_value):
   np.random.seed(seed_value)
   torch.manual_seed(seed_value)
   torch.cuda.manual_seed_all(seed_value)
   torch.backends.cudnn.deterministic = True
   torch.backends.cudnn.benchmark = False

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

定义一些重要的配置项,包括训练和数据集的路径。这里设置批次大小batch size为48,训练轮数10,基础学习率0.00005。

@dataclass(frozen=True)
class TrainingConfig:
   BATCH_SIZE:    int = 48
   EPOCHS:        int = 10
   LEARNING_RATE: float = 0.00005
 
@dataclass(frozen=True)
class DatasetConfig:
   DATA_ROOT:     str = 'input/gnhk_dataset'
 
@dataclass(frozen=True)
class ModelConfig:
   MODEL_NAME: str = 'microsoft/trocr-small-handwritten'

可视化训练样本,以帮助我们验证路径、CSV文件准备和标签是否正确。

def visualize(dataset_path, df):
   all_images = df.image_filename
   all_labels = df.text
   
   plt.figure(figsize=(15, 3))
   for i in range(15):
       plt.subplot(3, 5, i+1)
       image = plt.imread(f"{dataset_path}/test_processed/images/{all_images[i]}")
       label = all_labels[i]
       plt.imshow(image)
       plt.axis('off')
       plt.title(label)
   plt.show()
sample_df = pd.read_csv(
   os.path.join(DatasetConfig.DATA_ROOT, 'test_processed.csv'),
   header=None,
   skiprows=1,
   names=['image_filename', 'text'],
   nrows=50
)
 
visualize(DatasetConfig.DATA_ROOT, sample_df)

GNHK手写文本识别数据集具有自定义的目录结构和CSV文件,我们需要编写自定义的数据集准备代码。

  • 读取csv文件
train_df = pd.read_csv(
   os.path.join(DatasetConfig.DATA_ROOT, 'train_processed.csv'),
   header=None,
   skiprows=1,
   names=['image_filename', 'text']
)
 
 
test_df = pd.read_csv(
   os.path.join(DatasetConfig.DATA_ROOT, 'test_processed.csv'),
   header=None,
   skiprows=1,
   names=['image_filename', 'text']
)
  • 为了减少过拟合,应用一些轻微的数据增强,主要包括颜色抖动和高斯模糊。
# 定义数据增强
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=0.5, hue=0.3),
    transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 5)),
])
  • 需要创建一个自定义的PyTorch数据集类。
class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

        # 填充空值
        self.df['text'] = self.df['text'].fillna('')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 图像文件名
        file_name = self.df['image_filename'][idx]
        # 文本(标签)
        text = self.df['text'][idx]

        # 读取图像,应用数据增强,并获取转换后的像素值
        image = Image.open(os.path.join(self.root_dir, file_name)).convert('RGB')
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values

        # 通过分词器对文本进行分词,并获取标签
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids

        # 使用 -100 作为填充标记
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels)
        }
        return encoding
  • 初始化TrOCR处理器,并准备训练和验证数据集。
# 初始化处理器
processor = TrOCRProcessor.from_pretrained(ModelConfig['MODEL_NAME'])

# 准备训练数据集
train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig['DATA_ROOT'], 'train_processed/images/'),
    df=train_df,
    processor=processor
)

# 准备验证数据集
valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig['DATA_ROOT'], 'test_processed/images/'),
    df=test_df,
    processor=processor
)

初始化和配置模型,并统计模型的参数数量。

  • 加载模型
# 初始化模型
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig['MODEL_NAME'])
model.to(device)
print(model)

# 统计总参数和可训练参数
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")

total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad
)
print(f"{total_trainable_params:,} training parameters.")
  • 手动设置一些配置。
# 设置特殊 token 用于从标签创建 decoder_input_ids
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# 设置正确的词汇表大小
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

# 设置最大输出长度
model.config.max_length = 64
# 启用提前停止
model.config.early_stopping = True
# 设置不重复 n-gram 的大小
model.config.no_repeat_ngram_size = 3
# 设置长度惩罚
model.config.length_penalty = 2.0
# 设置 beam search 的束宽
model.config.num_beams = 4

# 打印模型配置
print(model.config)
  • 定义AdamW优化器,并配置学习率和权重衰减。
# 定义 AdamW 优化器
optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig['LEARNING_RATE'], weight_decay=0.0005
)
  • 使用字符错误率CER对模型进行评估。
cer_metric = evaluate.load('cer')


def compute_cer(pred):
   # 提取标签的 ID
   labels_ids = pred.label_ids
   # 提取预测的 ID
   pred_ids = pred.predictions

   # 将预测的 ID 解码为字符串,跳过特殊 token
   pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
   # 将标签中的 -100 转换为 pad_token_id,以避免影响评估结果
   labels_ids[labels_ids == -100] = processor.tokenizer.
   # 将标签的 ID 解码为字符串,跳过特殊 token
   label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

   # 使用 cer_metric 计算 CER
   cer = cer_metric.compute(predictions=pred_str, references=label_str)

   return {"cer": cer}

训练和验证模型。在开始训练之前,需要初始化训练参数和 Trainer API。

  • 定义 Seq2SeqTrainingArguments 对象,设置训练和验证的相关参数。
# 初始化训练参数
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig['BATCH_SIZE'],
    per_device_eval_batch_size=TrainingConfig['BATCH_SIZE'],
    fp16=True,
    output_dir='trocr_handwritten/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    report_to='tensorboard',
    num_train_epochs=TrainingConfig['EPOCHS'],
    dataloader_num_workers=8
)
  • 使用 Seq2SeqTrainer API 初始化训练器。Seq2SeqTrainer 接受模型、处理器、训练参数、数据集和数据收集器作为参数。
# 初始化训练器
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_cer,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator
)
  • 开始微调模型。
# 开始训练
trainer.train()

以下是训练10个epoch后的日志示例:

在训练完成后,我们得到了最佳的验证 CER 值。接下来,我们将使用最后一个epoch的检查点对验证集进行推理。

如图所示,验证CER图表在整个训练过程中持续下降,直到最后一个 epoch。这表明模型仍在学习,并且可能通过适当的学习率调度进一步训练几个 epoch 以获得更好的性能。

6.使用训练好的TrOCR模型推理

接下来,将使用训练好的trOCR模型对一组图像进行推理。

  • 加载处理器和训练好的模型检查点。
# 定义模型和处理器
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
trained_model = VisionEncoderDecoderModel.from_pretrained('trocr_handwritten/checkpoint-'+str(res.global_step)).to(device)
  • 定义一些辅助函数,用于读取图像、通过模型进行前向传播以及绘制结果。
def read_and_show(image_path):
    """
    :param image_path: String, path to the input image.
    
    Returns:
        image: PIL Image.
    """
    image = Image.open(image_path).convert('RGB')
    return image
def ocr(image, processor, model):
    """
    :param image: PIL Image.
    :param processor: Huggingface OCR processor.
    :param model: Huggingface OCR model.
    
    Returns:
        generated_text: the OCR'd text string.
    """
    pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text
def eval_new_data(data_path=None, num_samples=50, df=None):
    all_images = df.image_filename
    all_labels = df.text
    
    plt.figure(figsize=(15, 3))
    for i in range(num_samples):
        plt.subplot(3, 5, i+1)
        image = read_and_show(os.path.join(data_path, all_images[i]))
        text = ocr(image, processor, trained_model)
        plt.imshow(image)
        plt.title(text)
        plt.axis('off')
    plt.show()
  • 运行推理并可视化结果
# 运行推理并可视化结果
eval_new_data(
    data_path=data_path,
    num_samples=num_samples,
    df=sample_df
)

推理结果如下图所示。

由此可以看出,模型成功地正确预测了所有单词。这表明经过微调后,模型在验证集上的表现非常出色。

附录(完整代码)

链接:https://pan.baidu.com/s/1R5-JB7zKTeb1pJ0kS2Tmnw

提取码:d388

责任编辑:赵宁宁 来源: 小喵学AI
相关推荐

2023-09-07 10:37:43

OCR项目字符串

2023-09-29 08:45:38

截图工具Windows

2015-07-09 13:58:28

tesseract教程OCR教程

2023-12-25 19:21:55

ocr人工智能

2021-04-09 20:49:44

PythonOCR图像

2009-03-28 09:13:11

AndroidGoogle移动OS

2023-09-12 14:46:24

人工智能自然语言

2023-06-25 07:37:54

谷歌Chrome

2020-03-27 20:22:53

数据集装箱网络

2023-07-06 08:41:20

TTS​Mac​系统

2018-04-02 10:45:11

深度学习PaddlePaddl手写数字识别

2022-10-08 08:36:02

UbuntuLinux语音识别

2017-09-21 15:43:02

深度序列学习

2023-03-16 17:19:50

开源OCR识别项目

2023-05-17 15:22:45

识别开源工具

2023-04-17 08:59:14

OCRChatGPT识别食品

2009-09-22 12:16:29

ibmdwLotus

2023-09-27 08:51:52

PythonOCR技术

2023-11-12 23:01:44

PaddleOCR深度学习

2014-11-12 10:16:43

人工智能灵云
点赞
收藏

51CTO技术栈公众号