在机器学习领域,确定图像之间的相似度在各种应用中至关重要,从检测重复项到面部识别。解决这个问题的一个强大方法是使用暹罗网络结合三元组损失函数。在本文中,我们将探索如何构建和训练暹罗网络以估计图像相似度,并通过一个来自GitHub仓库的实际示例进行说明。
什么是暹罗网络?
暹罗网络是一种包含两个或更多相同子网络的神经网络架构。这些子网络旨在为每个输入生成特征向量,然后可以比较这些向量以估计相似度。关键思想是使用相同的网络处理每个输入,确保输出一致且可比较。
这种架构特别适合于检测重复项、寻找异常和面部识别等任务。在我们将要探索的实现中,网络设置有三个相同的子网络。每个网络处理三张图像中的一张:锚点图像、正样本(与锚点相似)和负样本(与锚点无关)。
什么是三元组损失?
为了有效地训练暹罗网络,我们使用三元组损失函数。这种损失函数鼓励网络在特征空间中拉近锚点和正样本的距离,同时将锚点和负样本推得更远。损失函数定义如下:
L(A, P, N) = max(‖f(A) — f(P)‖² — ‖f(A) — f(N)‖² + margin, 0)
这里,A是锚点图像,P是正图像,N是负图像。函数f(x)代表网络生成的embedding,而margin是一个小的正值,有助于确保网络不会将所有嵌入压缩到同一点。
设置暹罗网络
在这次实现中,我们首先加载Totally Looks Like数据集,其中包含我们用来创建训练网络的三元组图像。
1. 数据准备
使用TensorFlow的tf.data API处理数据集以创建图像三元组。这涉及到设置一个数据管道,其中每个三元组由锚点、正样本和负样本图像组成。通过调整图像大小到目标形状并归一化像素值来预处理图像。
def preprocess_image(filename):
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image
def preprocess_triplets(anchor, positive, negative):
return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)
以下是从数据集中生成的三元组示例,每行的前两张图像相似(锚点和正样本),第三张不同(负样本):
图1:在数据准备期间生成的三元组。每行的前两张图像相似(锚点和正样本),第三张不同(负样本)
2.构建 embedding 生成器
我们暹罗网络的核心是嵌入生成器,它使用在ImageNet上预训练的ResNet50模型构建。通过冻结ResNet50中的大部分层的权重,并且仅微调最后几层,我们可以利用迁移学习来减少训练时间并提高性能。
base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
# Freeze all layers until the layer conv5_block1_out
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
3.构建暹罗网络
暹罗网络设置为一次输入三张图像(锚点、正样本和负样本)。自定义的DistanceLayer计算锚点-正样本对和锚点-负样本对之间的距离。然后训练模型以最小化相似图像之间的距离,并最大化不相似图像之间的距离。
class DistanceLayer(layers.Layer):
def call(self, anchor, positive, negative):
ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
return (ap_distance, an_distance)
anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))
distances = DistanceLayer()(
embedding(resnet.preprocess_input(anchor_input)),
embedding(resnet.preprocess_input(positive_input)),
embedding(resnet.preprocess_input(negative_input)),
)
siamese_network = Model(
inputs=[anchor_input, positive_input, negative_input], outputs=distances
)
4.训练和评估
模型使用自定义训练循环进行训练,其中计算三元组损失并用于更新网络的权重。仔细监控训练过程,并通过对学习到的嵌入进行检查来评估模型的性能。
class SiameseModel(Model):
def __init__(self, siamese_network, margin=0.5):
super(SiameseModel, self).__init__()
self.siamese_network = siamese_network
self.margin = margin
self.loss_tracker = metrics.Mean(name="loss")
def train_step(self, data):
with tf.GradientTape() as tape:
loss = self._compute_loss(data)
gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
self.optimizer.apply_gradients(
zip(gradients, self.siamese_network.trainable_weights)
)
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
def _compute_loss(self, data):
ap_distance, an_distance = self.siamese_network(data)
loss = ap_distance - an_distance
loss = tf.maximum(loss + self.margin, 0.0)
return loss
5.检查结果
训练完成后,我们可以通过比较锚点-正样本对和锚点-负样本对的嵌入之间的余弦相似度来评估网络学习分离相似和不相似图像的能力。
cosine_similarity = metrics.CosineSimilarity()
positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
print("Positive similarity:", positive_similarity.numpy())
negative_similarity = cosine_similarity(anchor_embedding, negative_embedding)
print("Negative similarity:", negative_similarity.numpy())
以下是经过训练的模型评估的三元组示例。网络成功识别出图像之间的相似性和差异:
图2:经过训练的暹罗网络的输出,其中每行的前两张图像被模型识别为相似,第三张为不同
结论
本文展示了使用三元组损失的暹罗网络如何有效地估计图像相似度。通过使用预训练的ResNet50模型并微调其层,我们可以创建一个可以应用于需要相似度估计的各种任务。
完整代码和解释,参考:https://github.com/elcaiseri/Siamese-Network