在Tensorflow中训练一个模型后:

如何拯救训练过的模型? 您以后如何恢复这个保存的模型?


当前回答

最简单的方法是使用keras api,在线保存模型和一行加载模型

from keras.models import load_model

my_model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'

del my_model  # deletes the existing model


my_model = load_model('my_model.h5') # returns a compiled model identical to the previous one

其他回答

正如Yaroslav所说,您可以通过导入图、手动创建变量,然后使用Saver来从graph_def和检查点进行恢复。

我实现这个是为了我个人使用,所以我想在这里分享一下代码。

链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(当然,这是一种hack,并且不能保证以这种方式保存的模型在TensorFlow的未来版本中仍然是可读的。)

下面是我对这两种基本情况的简单解决方案,这两种情况的不同之处在于您是想从文件加载图形还是在运行时构建它。

这个答案适用于Tensorflow 0.12+(包括1.0)。

在代码中重建图形

储蓄

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

加载

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

还从文件中加载图形

当使用这种技术时,确保所有的层/变量都显式地设置了唯一的名称。否则Tensorflow将使名称本身是唯一的,因此它们将不同于存储在文件中的名称。在前一种技术中,这不是问题,因为名称在加载和保存时都以相同的方式“损坏”。

储蓄

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

加载

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

在新版本的tensorflow 2.0中,保存/加载模型的过程要容易得多。因为Keras API的实现,一个TensorFlow的高级API。

保存一个模型: 请查阅相关文档以作参考: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

tf.keras.models.save_model(model_name, filepath, save_format)

加载一个模型:

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

model = tf.keras.models.load_model(filepath)

tensorflow - 2.0

这很简单。

import tensorflow as tf

SAVE

model.save("model_name")

恢复

model = tf.keras.models.load_model('model_name')

这里所有的答案都很棒,但我想补充两点。

首先,详细说明@user7505159的答案,“。添加到要恢复的文件名的开头可能很重要。

例如,您可以保存没有“的图形。/"在文件名中如下所示:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

但是为了恢复图形,您可能需要在前面加上一个"。/"到file_name:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

你并不总是需要“。/”,但是它会根据你的环境和TensorFlow版本而导致问题。

它还想提到sess.run(tf.global_variables_initializer())在恢复会话之前可能很重要。

如果在尝试恢复保存的会话时收到关于未初始化变量的错误,请确保在保存程序之前包含sess.run(tf.global_variables_initializer())。恢复(sess, save_file)行。这样你就不用头疼了。