Seq2Seq(Sequence-to-Sequence)模型是一种用于处理序列数据的神经网络架构,广泛应用于自然语言处理(NLP)任务,如机器翻译、文本生成、对话系统等。
它通过编码器-解码器架构将输入序列(如一个句子)映射到输出序列(另一个句子或序列)。
图片
模型结构
Seq2Seq 模型由两个主要部分组成。
编码器(Encoder)
编码器是一个循环神经网络(RNN)或其变体,如LSTM或GRU,用于接收输入序列并将其转换为一个固定大小的上下文向量。
编码器逐步处理输入序列的每个时间步,通过隐藏层状态不断更新输入信息的表示,直到编码到达输入序列的结尾。
这一过程的最后一个隐藏状态通常被认为是整个输入序列的摘要,传递给解码器。
图片
class Encoder(nn.Module):
def __init__(self,input_dim,embedding_dim,hidden_size,num_layers,dropout):
super(Encoder,self).__init__()
#note hidden size and num layers
self.hidden_size = hidden_size
self.num_layers = num_layers
#create a dropout layer
self.dropout = nn.Dropout(dropout)
#embedding to convert input token into dense vectors
self.embedding = nn.Embedding(input_dim,embedding_dim)
#bilstm layer
self.lstm = nn.LSTM(embedding_dim,hidden_size,num_layers=num_layers,bidirectinotallow=True,dropout=dropout)
def forward(self,src):
embedded = self.dropout(self.embedding(src))
out,(hidden,cell) = self.lstm(embedded)
return hidden,cell
解码器(Decoder)
解码器也是一个RNN网络,接受编码器输出的上下文向量,并生成目标序列。
解码器在每一步会生成一个输出,并将上一步的输出作为下一步的输入,直到产生特定的终止符。
解码器的初始状态来自编码器的最后一个隐藏状态,因此可以理解为解码器是基于编码器生成的全局信息来预测输出序列。
class Decoder(nn.Module):
def __init__(self,output_dim,embedding_dim,hidden_size,num_layers,dropout):
super(Decoder,self).__init__()
self.output_dim = output_dim
#note hidden size and num layers for seq2seq class
self.hidden_size = hidden_size
self.num_layers = num_layers
self.dropout = nn.Dropout(dropout)
#note inputs of embedding layer
self.embedding = nn.Embedding(output_dim,embedding_dim)
self.lstm = nn.LSTM(embedding_dim,hidden_size,num_layers=num_layers,bidirectinotallow=True,dropout=dropout)
#we apply softmax over target vocab size
self.fc = nn.Linear(hidden_size*2,output_dim)
def forward(self,input_token,hidden,cell):
#adjust dimensions of input token
input_token = input_token.unsqueeze(0)
emb = self.embedding(input_token)
emb = self.dropout(emb)
#note hidden and cell along with output
out,(hidden,cell) = self.lstm(emb,(hidden,cell))
out = out.squeeze(0)
pred = self.fc(out)
return pred,hidden,cell
工作流程
Seq2Seq 模型的基本工作流程如下
- 输入处理
将输入序列(如源语言句子)逐步传入编码器的 RNN 层,编码器的最后一层的隐藏状态会保留输入序列的上下文信息。 - 生成上下文向量
编码器输出的隐藏状态向量(通常是最后一个隐藏状态)称为上下文向量,它包含了输入序列的信息。 - 解码过程
解码器接收上下文向量作为初始状态,然后通过自身的RNN结构逐步生成目标序列。
每一步解码器生成一个输出token,并将其作为下一步的输入,直到生成结束token。 - 序列生成
解码器生成的序列作为模型的最终输出。
优缺点
优点
- 通用性强
Seq2Seq 模型可以处理可变长度的输入和输出序列,适用于许多任务,例如机器翻译、文本摘要、对话生成、语音识别等。
它的编码器-解码器结构使得输入和输出不必同长,具有高度的灵活性。 - 适应复杂序列任务
Seq2Seq 模型通过编码器-解码器的分离,能够更好地学习序列映射关系。
编码器负责捕获输入序列的信息,而解码器则生成符合输出序列特征的内容。
缺点
- 信息压缩损失
传统 Seq2Seq 模型通过编码器最后一个隐藏状态来表示整个输入序列信息,当输入序列较长时,这种单一的上下文向量难以全面表示输入内容,导致信息丢失。这会导致模型在长序列任务上表现欠佳。 - 对长序列敏感
在没有注意力机制的情况下,Seq2Seq模型难以有效处理长序列,因为解码器需要依赖于编码器的固定向量,而这个向量可能无法完全涵盖长序列的细节。 - 训练难度大
Seq2Seq 模型在训练时面临梯度消失和爆炸的问题,尤其是在长序列的情况下。
案例分享
下面是一个使用 Seq2Seq 进行机器翻译的示例代码。
首先,我们从 HuggingFace 导入了数据集,并将其分为训练集和测试集
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
import tqdm,datasets
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence
import spacy
dataset = datasets.load_dataset('bentrevett/multi30k')
train_data,val_data,test_data = dataset['train'],dataset['validation'],dataset['test']
加载源语言和目标语言的 spaCy 模型。
spaCy 是一个功能强大、可用于生产的 Python 高级自然语言处理库。
与许多其他 NLP 库不同,spaCy 专为实际使用而设计,而非研究实验。
它擅长使用预先训练的模型进行高效的文本处理,可完成标记化、词性标记、命名实体识别和依赖性解析等任务。
en_nlp = spacy.load('en_core_web_sm')
de_nlp = spacy.load('de_core_news_sm')
#tokenizer
def sample_tokenizer(sample,en_nlp,de_nlp,lower,max_length,sos_token,eos_token):
en_tokens = [token.text for token in en_nlp.tokenizer(sample["en"])][:max_length]
de_tokens = [token.text for token in de_nlp.tokenizer(sample["de"])][:max_length]
if lower == True:
en_tokens = [token.lower() for token in en_tokens]
de_tokens = [token.lower() for token in de_tokens]
en_tokens = [sos_token] + en_tokens + [eos_token]
de_tokens = [sos_token] + de_tokens + [eos_token]
return {"en_tokens":en_tokens,"de_tokens":de_tokens}
fn_kwargs = {
"en_nlp":en_nlp,
"de_nlp":de_nlp,
"lower":True,
"max_length":1000,
"sos_token":'<sos>',
"eos_token":'<eos>'
}
train_data = train_data.map(sample_tokenizer,fn_kwargs=fn_kwargs)
val_data = val_data.map(sample_tokenizer,fn_kwargs=fn_kwargs)
test_data = test_data.map(sample_tokenizer,fn_kwargs=fn_kwargs)
min_freq = 2
specials = ['<unk>','<pad>','<sos>','<eos>']
en_vocab = build_vocab_from_iterator(train_data['en_tokens'],specials=specials,min_freq=min_freq)
de_vocab = build_vocab_from_iterator(train_data['de_tokens'],specials=specials,min_freq=min_freq)
assert en_vocab['<unk>'] == de_vocab['<unk>']
assert en_vocab['<pad>'] == de_vocab['<pad>']
unk_index = en_vocab['<unk>']
pad_index = en_vocab['<pad>']
en_vocab.set_default_index(unk_index)
de_vocab.set_default_index(unk_index)
def sample_num(sample,en_vocab,de_vocab):
en_ids = en_vocab.lookup_indices(sample["en_tokens"])
de_ids = de_vocab.lookup_indices(sample["de_tokens"])
return {"en_ids":en_ids,"de_ids":de_ids}
fn_kwargs = {"en_vocab":en_vocab,"de_vocab":de_vocab}
train_data = train_data.map(sample_num,fn_kwargs=fn_kwargs)
val_data = val_data.map(sample_num,fn_kwargs=fn_kwargs)
test_data = test_data.map(sample_num,fn_kwargs=fn_kwargs)
train_data = train_data.with_format(type="torch",columns=['en_ids','de_ids'],output_all_columns=True)
val_data = val_data.with_format(type="torch",columns=['en_ids','de_ids'],output_all_columns=True)
test_data = test_data.with_format(type="torch",columns=['en_ids','de_ids'],output_all_columns=True)
def get_collate_fn(pad_index):
def collate_fn(batch):
batch_en_ids = [sample["en_ids"] for sample in batch]
batch_de_ids = [sample["de_ids"] for sample in batch]
batch_en_ids = pad_sequence(batch_en_ids,padding_value=pad_index)
batch_de_ids = pad_sequence(batch_de_ids,padding_value=pad_index)
batch = {"en_ids":batch_en_ids,"de_ids":batch_de_ids}
return batch
return collate_fn
def get_dataloader(dataset,batch_size,shuffle,pad_index):
collate_fn = get_collate_fn(pad_index)
dataloader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn)
return dataloader
train_loader = get_dataloader(train_data,batch_size=512,shuffle=True,pad_index=pad_index)
val_loader = get_dataloader(val_data,batch_size=512,shuffle=True,pad_index=pad_index)
test_loader = get_dataloader(test_data,batch_size=512,shuffle=True,pad_index=pad_index)
接下来构建 seq2seq 模型。
class Seq2Seq(nn.Module):
def __init__(self,encoder,decoder,device):
super(Seq2Seq,self).__init__()
self.encoder = encoder
self.decoder = decoder
assert encoder.num_layers == decoder.num_layers
assert encoder.hidden_size == decoder.hidden_size
def forward(self,src,trg,teacher_forcing_ratio):
#exctract dim for out vector
trg_len = trg.shape[0]
batch_size = trg.shape[1]
vocab_size = self.decoder.output_dim
outputs = torch.zeros(trg_len,batch_size,vocab_size).to(device)
#get input and hidden
input_token = trg[0,:]
hidden,cell = self.encoder(src)
for t in range(1,trg_len):
out,hidden,cell = self.decoder(input_token,hidden,cell)
outputs[t] = out
#decide what passes as input
top1 = out.argmax(1)
teacher_force = np.random.randn()<teacher_forcing_ratio
input_token = trg[t] if teacher_force else top1
return outputs
input_dim = len(de_vocab)
output_dim = len(en_vocab)
encoder_embedding_dim = 256
decoder_embedding_dim = 256
hidden_size = 512
num_layers = 3
encoder_dropout = 0.5
decoder_dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(
input_dim,
encoder_embedding_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=encoder_dropout,
)
decoder = Decoder(
output_dim,
decoder_embedding_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=decoder_dropout,
)
model = Seq2Seq(encoder, decoder, device).to(device)
模型训练
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=pad_index)
def train_fn(
model, data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device):
model.train()
epoch_loss = 0
for i, batch in enumerate(data_loader):
src = batch["de_ids"].to(device)
trg = batch["en_ids"].to(device)
# src = [src length, batch size]
# trg = [trg length, batch size]
optimizer.zero_grad()
output = model(src, trg, teacher_forcing_ratio)
# output = [trg length, batch size, trg vocab size]
output_dim = output.shape[-1]
output = output[1:].view(-1, output_dim)
# output = [(trg length - 1) * batch size, trg vocab size]
trg = trg[1:].view(-1)
# trg = [(trg length - 1) * batch size]
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(data_loader)
def evaluate_fn(model, data_loader, criterion, device):
model.eval()
epoch_loss = 0
with torch.no_grad():
for i, batch in enumerate(data_loader):
src = batch["de_ids"].to(device)
trg = batch["en_ids"].to(device)
# src = [src length, batch size]
# trg = [trg length, batch size]
output = model(src, trg, 0) # turn off teacher forcing
# output = [trg length, batch size, trg vocab size]
output_dim = output.shape[-1]
output = output[1:].view(-1, output_dim)
# output = [(trg length - 1) * batch size, trg vocab size]
trg = trg[1:].view(-1)
# trg = [(trg length - 1) * batch size]
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(data_loader)
n_epochs = 10
clip = 1.0
teacher_forcing_ratio = 1
best_valid_loss = float("inf")
for epoch in tqdm.tqdm(range(n_epochs)):
train_loss = train_fn(
model,
train_loader,
optimizer,
criterion,
clip,
teacher_forcing_ratio,
device,
)
valid_loss = evaluate_fn(
model,
val_loader,
criterion,
device,
)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), "tut1-model.pt")
print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")
model.load_state_dict(torch.load("tut1-model.pt"))
test_loss = evaluate_fn(model, test_loader, criterion, device)
print(f"| Test Loss: {test_loss:.3f} | Test PPL: {np.exp(test_loss):7.3f} |")
接下来,我们看一下最终的效果
def translate_sentence(
sentence,
model,
en_nlp,
de_nlp,
en_vocab,
de_vocab,
lower,
sos_token,
eos_token,
device,
max_output_length=25,
):
model.eval()
with torch.no_grad():
if isinstance(sentence, str):
tokens = [token.text for token in de_nlp.tokenizer(sentence)]
else:
tokens = [token for token in sentence]
if lower:
tokens = [token.lower() for token in tokens]
tokens = [sos_token] + tokens + [eos_token]
ids = de_vocab.lookup_indices(tokens)
tensor = torch.LongTensor(ids).unsqueeze(-1).to(device)
hidden, cell = model.encoder(tensor)
inputs = en_vocab.lookup_indices([sos_token])
for _ in range(max_output_length):
inputs_tensor = torch.LongTensor([inputs[-1]]).to(device)
output, hidden, cell = model.decoder(inputs_tensor, hidden, cell)
predicted_token = output.argmax(-1).item()
inputs.append(predicted_token)
if predicted_token == en_vocab[eos_token]:
break
tokens = en_vocab.lookup_tokens(inputs)
return tokens
sentence ='Der Mann ist am Weisheitsspross'
sos_token='<sos>'
eos_token='<eos>'
lower=True
translation = translate_sentence(
sentence,
model,
en_nlp,
de_nlp,
en_vocab,
de_vocab,
lower,
sos_token,
eos_token,
device,
)
print(translation)
#['<sos>', 'the', 'woman', 'is', 'looking', 'at', 'the', 'camera', '.', '<eos>']