为什么加载之前保存的Keras模型得出不一样的结果:经验和教训

译文
人工智能 机器学习
现在,机器学习模型在生产环境中的使用比以往都要广泛。Keras就是这样一种流行的库,用于创建强大的机器学习和深度学习模型。

现在,机器学习模型在生产环境中的使用比以往都要广泛。Keras就是这样一种流行的库,用于创建强大的机器学习和深度学习模型。然而,这些模型的训练过程常常计算开销大,还费时,具体取决于实际处理的数据和模型架构。一些模型需要数周到数月的时间来训练。因此,能够在本地存储模型、需要进行预测时再次检索它们变得至关重要。但如果由于某种原因保存的模型没有正确加载,该怎么办?我会根据本人的经验试着给出答案。

我不会详细介绍如何使用和保存Keras模型,只是假设读者熟悉该过程,直接介绍如何处理加载时意外的模型行为。也就是说,在训练存储在Model变量中的Keras模型之后,我们希望将其保存为原样,那样下次加载时我们可以跳过训练,就进行预测。

我首选的方法是保存模型的权重,权重在模型创建开始时是随机的,随着模型的训练而加以更新。于是我点击了model.save_weights(“model.h5”)。创建了“model.h5”文件,含有模型学习到的权重。接下来,在另一个会话中,我使用与以前相同的架构重建模型,并使用 new_model.load_weights(“model.h5”)加载我保存的训练权重。一切似乎都很好。只是我点击 new_model.predict(test_data)后,得到的准确性为零,不知道为什么。

事实证明,模型无法做出正确的预测有诸多原因。我在本文试着总结最常见的原因,并介绍如何解决。

1. 先仔细检查数据。

我知道这似乎很明显,但是从磁盘重新加载模型时,一有疏忽就会导致性能下降。比如说,如果您在构建语言模型,应确保在每个新会话中,您执行以下操作:

  • 重新检查类标签的顺序。如果您将它们映射到数字,重新检查在每个会话中每个类标签都有相同的数字。如果您使用list(set())函数来检索,可能会发生这种情况,该函数每次都会以不同的顺序返回您的标签。这最终可能会搞乱您的标签预测。
  • 检查数据集。如果您的测试数据还没有在另一个文件中,检查训练-测试拆分不是随机的,以便每次进行预测时,您根据不同的数据进行预测,因此您的预测准确性最终会不一致。

当然,您可能会遇到其他与数据相关的问题,具体取决于您从事的领域。然而,请始终检查数据表示的一致性。

2. 度量指标问题

导致错误或结果不一致的另一个原因是,准确性度量指标的选择。在构建模型并保存其权重时,我们通常执行以下操作:

def build_model(max_len, n_tags): 
input_layer = Input(shape=(max_len, ))
output_layer = Dense(n_tags, activation = 'softmax')(input_layer)
model = Model(input_layer, output_layer)

return model

model = build_model()
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy", metrics=["accuracy"])

model.fit(..)
model.save_weights("model.h5")

如果我们需要在新的会话/脚本中打开它,需要执行以下操作:

def build_model(max_len, n_tags): 
input_layer = Input(shape=(max_len, ))
output_layer = Dense(n_tags, activation = 'softmax')(input_layer)
model = Model(input_layer, output_layer)

return model
model = build_model()
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.load_weights("model.h5")
model.evaluate()

这可能抛出错误,具体视所使用的特定的Keras/Tensorflow版本而定。编译模型并选择“准确性”作为指标时,会出现问题。Keras识别准确性的各种定义:“稀疏分类准确性”、“分类准确性”等;视您使用的数据而定,不同的定义是优选的解决方案。这是由于如果我们将度量指标设为“准确性”,Keras将试着分配其中一种特定的准确性类型,具体取决于它认为哪一种最适合数据分布。它可能会在不同的运行中推断出不同的准确性指标。这里最好的解决方法是,始终明确设置准确性指标,而不是让Keras自行选择。比如说,把

model.compile(optimizer="adam", 
loss="sparse_categorical_crossentropy", metrics=["accuracy"])

换成:

model.compile(optimizer="adam", 
loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"])

3. 随机性

在与以前相同的数据上重新训练Keras神经网络时,您很少两次获得同样的结果。这是由于Keras中的神经网络在初始化权重时使用随机性,因此每次运行时权重的初始化方式都不同,因此在学习过程中这些权重会以不同方式更新,于是在进行预测时不太可能获得相同的准确性结果。

如果出于某种原因,您需要在训练之前使权重相等,可以在代码前面设置随机数生成器:

from numpy.random import seed
seed(42)
from tensorflow import set_random_seed
set_random_seed(42)

numpy随机种子用于Keras,而至于Tensorflow后端,我们需要将其自己的随机数生成器设置为相等的种子。该代码片段将确保每次运行代码时,您的神经网络权重都会被同等地初始化。

4. 留意自定义层的使用

Keras提供了众多层(Dense、LSTM、Dropout和BatchNormalizaton等),但有时我们希望对模型中的数据采取某种特定的操作,但又没有为它定义的特定层。一般来说,Keras提供了两种类型的层:Lambda和基础层类。但对这两种层要很小心,如果您将模型架构保存为json格式更要小心。Lambda层的棘手地方在于序列化限制。由于它与Python字节码的序列化一同保存,它只能加载到保存它的同一个环境中,即它不可移植。遇到该问题时,通常建议覆盖keras.layers.Layer层,或者只保存其权重,从头开始重建模型,而不是保存整个模型。

5. 自定义对象

很多时候,您会想要使用自定义函数应用于数据,或计算损失/准确性等指标的函数。

Keras允许这种使用,为此让我们可以在保存/加载模型时指定额外的参数。假设我们想要将我们自行创建的特殊的损失函数与之前保存的模型一并加载:

model = load_model("model.h5", custom_objects=
{"custom_loss":custom_loss})

如果我们在新环境中加载该模型,必须在新环境中小心定义custom_loss函数,因为默认情况下,保存模型时不会记住这些函数。即使我们保存了模型的整个架构,它也会保存该自定义函数的名称,但函数体是我们需要额外提供的东西。

6. 全局变量初始化器

如果您使用Tensorflow 1.x作为后端——您可能仍然需要该后端用于许多应用程序,这点尤为重要。运行tf 1.x会话时,您需要运行tf.global_variables_initializer(),它随机初始化所有变量。这么做的副作用是,当您尝试保存模型时,它可能重新初始化所有权重。您可以手动停止该行为,只需运行:

from keras.backend import manual_variable_initialization manual_variable_initialization(True)

结语

本文列出了最常导致您的Keras模型无法在新环境中正确加载的几个因素。有时这些问题导致不可预测的结果,而在其他情况下,它们只会抛出错误。它们何时发生、如何发生,在很大程度上也取决于您使用的Python版本以及Tensorflow和Keras版本,因为其中一些版本不相兼容,从而导致意外的行为。但愿读完本文后,下次遇到此类问题时您知道从何处入手。

原文标题:Why Loading a Previously Saved Keras Model Gives Different Results: Lessons Learned,作者:Kristina Popova

责任编辑:华轩 来源: 51CTO
相关推荐

2021-07-12 23:53:22

Python交换变量

2012-03-07 17:24:10

戴尔咨询

2012-12-20 10:17:32

IT运维

2021-01-11 14:02:22

dudf运维

2017-05-25 15:02:46

联宇益通SD-WAN

2011-09-02 10:12:36

网速测试结果网速测试网速测试方法

2015-10-19 12:33:01

华三/新IT

2016-05-09 18:40:26

VIP客户缉拿

2020-02-14 14:36:23

DevOps落地认知

2021-12-23 15:11:46

Web 3.0元宇宙Metaverse

2018-05-09 15:42:24

新零售

2011-03-14 16:51:24

2009-12-01 16:42:27

Gentoo Linu

2009-02-04 15:43:45

敏捷开发PHPFleaPHP

2012-07-18 02:05:02

函数语言编程语言

2011-02-28 10:38:13

Windows 8

2009-06-12 15:26:02

2016-03-24 18:51:40

2023-03-20 08:19:23

GPT-4OpenAI

2015-08-25 09:52:36

云计算云计算产业云计算政策
点赞
收藏

51CTO技术栈公众号