在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
如果您使用tf.train.MonitoredTrainingSession作为默认会话,则不需要添加额外的代码来执行保存/恢复操作。只需将检查点目录名称传递给MonitoredTrainingSession的构造函数,它将使用会话挂钩来处理这些。
其他回答
对于TensorFlow版本< 0.11.0RC1:
保存的检查点包含模型中的变量值,而不是模型/图本身,这意味着当您恢复检查点时,图应该是相同的。
这里有一个线性回归的例子,其中有一个训练循环,保存变量检查点,还有一个评估部分,将恢复之前运行中保存的变量并计算预测。当然,如果你愿意,你也可以恢复变量并继续训练。
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))
...more setup for optimization and what not...
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if FLAGS.train:
for i in xrange(FLAGS.training_steps):
...training loop...
if (i + 1) % FLAGS.checkpoint_steps == 0:
saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
global_step=i+1)
else:
# Here's where you're restoring the variables w and b.
# Note that the graph is exactly as it was when the variables were
# saved in a prior training run.
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
...no checkpoint found...
# Now you can run the model to get predictions
batch_x = ...load some data...
predictions = sess.run(y_hat, feed_dict={x: batch_x})
下面是变量文档,涵盖了保存和恢复。这是保存程序的文档。
模型有两个部分,模型定义,由Supervisor保存为图。模型目录中的PBTXT和张量的数值,保存到检查点文件,如model.ckpt-1003418。
可以使用tf恢复模型定义。import_graph_def,并且使用Saver恢复权重。
然而,Saver使用特殊的集合保存附加到模型Graph的变量列表,并且这个集合没有使用import_graph_def初始化,所以您目前不能同时使用这两者(这在我们的路线图中进行修复)。现在,您必须使用Ryan Sepassi的方法——手动构造具有相同节点名称的图,并使用Saver将权重加载到其中。
(或者,您可以通过使用import_graph_def,手动创建变量和使用tf.add_to_collection(tf.GraphKeys. collection)来破解它。变量,变量)为每个变量,然后使用Saver)
正如Yaroslav所说,您可以通过导入图、手动创建变量,然后使用Saver来从graph_def和检查点进行恢复。
我实现这个是为了我个人使用,所以我想在这里分享一下代码。
链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(当然,这是一种hack,并且不能保证以这种方式保存的模型在TensorFlow的未来版本中仍然是可读的。)
在TensorFlow 0.11.0RC1版本中,你可以通过调用tf.train直接保存和恢复你的模型。Export_meta_graph和tf.train。根据https://www.tensorflow.org/programmers_guide/meta_graph的Import_meta_graph。
保存模型
w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta
恢复模型
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
v_ = sess.run(v)
print(v_)
对于张量流2.0,它非常简单
#保存模型 model.save(“path_to_my_model.h5”)
恢复:
new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')