在Tensorflow中训练一个模型后:

如何拯救训练过的模型? 您以后如何恢复这个保存的模型?


当前回答

我正在改进我的回答,以添加更多关于保存和恢复模型的细节。

在Tensorflow 0.11版本中(及之后):

保存模型:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

恢复模型:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

这里已经很好地解释了这一点和一些更高级的用例。

一个快速完整的教程,保存和恢复Tensorflow模型

其他回答

无论你想把模型保存在哪里,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

确保你所有的任务。变量有名称,因为您可能希望稍后使用它们的名称来恢复它们。 在你想预测的地方,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

确保该保护程序在相应的会话中运行。 请记住,如果使用tf.train.latest_checkpoint('./'),那么将只使用最新的检查点。

正如Yaroslav所说,您可以通过导入图、手动创建变量,然后使用Saver来从graph_def和检查点进行恢复。

我实现这个是为了我个人使用,所以我想在这里分享一下代码。

链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(当然,这是一种hack,并且不能保证以这种方式保存的模型在TensorFlow的未来版本中仍然是可读的。)

你也可以在TensorFlow/skflow中查看例子,它提供了保存和恢复方法,可以帮助你轻松地管理模型。它具有一些参数,您还可以控制备份模型的频率。

在大多数情况下,使用tf.train.Saver从磁盘保存和恢复是最好的选择:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

您还可以保存/恢复图结构本身(详细信息请参阅MetaGraph文档)。默认情况下,保存程序将图形结构保存到.meta文件中。您可以调用import_meta_graph()来恢复它。它恢复图形结构并返回一个你可以用来恢复模型状态的保护程序:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

然而,在某些情况下,您需要更快的方法。例如,如果您实现了早期停止,那么您希望在训练期间每次模型改进时都保存检查点(在验证集上测量),然后如果一段时间内没有进展,则希望回滚到最佳模型。如果每次模型改进时都将其保存到磁盘,则会极大地降低训练速度。诀窍是将变量状态保存到内存中,然后稍后恢复它们:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

A quick explanation: when you create a variable X, TensorFlow automatically creates an assignment operation X/Assign to set the variable's initial value. Instead of creating placeholders and extra assignment ops (which would just make the graph messy), we just use these existing assignment ops. The first input of each assignment op is a reference to the variable it is supposed to initialize, and the second input (assign_op.inputs[1]) is the initial value. So in order to set any value we want (instead of the initial value), we need to use a feed_dict and replace the initial value. Yes, TensorFlow lets you feed a value for any op, not just for placeholders, so this works fine.

这里所有的答案都很棒,但我想补充两点。

首先,详细说明@user7505159的答案,“。添加到要恢复的文件名的开头可能很重要。

例如,您可以保存没有“的图形。/"在文件名中如下所示:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

但是为了恢复图形,您可能需要在前面加上一个"。/"到file_name:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

你并不总是需要“。/”,但是它会根据你的环境和TensorFlow版本而导致问题。

它还想提到sess.run(tf.global_variables_initializer())在恢复会话之前可能很重要。

如果在尝试恢复保存的会话时收到关于未初始化变量的错误,请确保在保存程序之前包含sess.run(tf.global_variables_initializer())。恢复(sess, save_file)行。这样你就不用头疼了。