在Tensorflow中训练一个模型后:

如何拯救训练过的模型? 您以后如何恢复这个保存的模型?


当前回答

你也可以在TensorFlow/skflow中查看例子,它提供了保存和恢复方法,可以帮助你轻松地管理模型。它具有一些参数,您还可以控制备份模型的频率。

其他回答

根据新的Tensorflow版本,tf.train.Checkpoint是保存和恢复模型的最佳方式:

Checkpoint.save and Checkpoint.restore write and read object-based checkpoints, in contrast to tf.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables when executing eagerly. Prefer tf.train.Checkpoint over tf.train.Saver for new code.

这里有一个例子:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

这里有更多信息和示例。

对于TensorFlow版本< 0.11.0RC1:

保存的检查点包含模型中的变量值,而不是模型/图本身,这意味着当您恢复检查点时,图应该是相同的。

这里有一个线性回归的例子,其中有一个训练循环,保存变量检查点,还有一个评估部分,将恢复之前运行中保存的变量并计算预测。当然,如果你愿意,你也可以恢复变量并继续训练。

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

下面是变量文档,涵盖了保存和恢复。这是保存程序的文档。

在@Vishnuvardhan Janapati的回答之后,这里是另一种在TensorFlow 2.0.0下保存和重载自定义层/度量/损失模型的方法

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects

# custom loss (for example)  
def custom_loss(y_true,y_pred):
  return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss}) 

# custom loss (for example) 
class CustomLayer(Layer):
  def __init__(self, ...):
      ...
  # define custom layer and all necessary custom operations inside custom layer

get_custom_objects().update({'CustomLayer': CustomLayer})  

通过这种方式,一旦您执行了这些代码,并使用tf.keras.models保存了您的模型。Save_model或model。save或ModelCheckpoint回调,您可以重新加载您的模型,而不需要精确的自定义对象,就像这样简单

new_model = tf.keras.models.load_model("./model.h5"})

在新版本的tensorflow 2.0中,保存/加载模型的过程要容易得多。因为Keras API的实现,一个TensorFlow的高级API。

保存一个模型: 请查阅相关文档以作参考: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

tf.keras.models.save_model(model_name, filepath, save_format)

加载一个模型:

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

model = tf.keras.models.load_model(filepath)

如果它是一个内部保存的模型,您只需为所有变量指定一个恢复器为

restorer = tf.train.Saver(tf.all_variables())

并使用它来恢复当前会话中的变量:

restorer.restore(self._sess, model_file)

对于外部模型,您需要指定从它的变量名到您的变量名的映射。您可以使用该命令查看模型变量名

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

inspect_checkpoint.py脚本可以在`。tensorflow源码的/tensorflow/python/tools文件夹。

为了指定映射,你可以使用我的Tensorflow-Worklab,它包含一组类和脚本来训练和再训练不同的模型。它包括一个再训练ResNet模型的例子,位于这里