在Tensorflow中训练一个模型后:

如何拯救训练过的模型? 您以后如何恢复这个保存的模型?


当前回答

你也可以在TensorFlow/skflow中查看例子,它提供了保存和恢复方法,可以帮助你轻松地管理模型。它具有一些参数,您还可以控制备份模型的频率。

其他回答

你可以使用Tensorflow中的saver对象来保存你训练过的模型。该对象提供保存和恢复模型的方法。

在TensorFlow中保存一个训练好的模型:

tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None,
                    meta_graph_suffix='meta', write_meta_graph=True,
                    write_state=True, strip_default_attrs=False,
                    save_debug_info=False)

在TensorFlow中恢复已保存的模型:

tf.train.Saver.restore(sess, save_path, latest_filename=None,
                       meta_graph_suffix='meta', clear_devices=False,
                       import_scope=None)

这里所有的答案都很棒,但我想补充两点。

首先,详细说明@user7505159的答案,“。添加到要恢复的文件名的开头可能很重要。

例如,您可以保存没有“的图形。/"在文件名中如下所示:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

但是为了恢复图形,您可能需要在前面加上一个"。/"到file_name:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

你并不总是需要“。/”,但是它会根据你的环境和TensorFlow版本而导致问题。

它还想提到sess.run(tf.global_variables_initializer())在恢复会话之前可能很重要。

如果在尝试恢复保存的会话时收到关于未初始化变量的错误,请确保在保存程序之前包含sess.run(tf.global_variables_initializer())。恢复(sess, save_file)行。这样你就不用头疼了。

你也可以用更简单的方法。

步骤1:初始化所有变量

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

步骤2:在模型Saver中保存会话并保存它

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

步骤3:恢复模型

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

步骤4:检查变量

W1 = session.run(W1)
print(W1)

在不同的python实例中运行时,使用

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

根据新的Tensorflow版本,tf.train.Checkpoint是保存和恢复模型的最佳方式:

Checkpoint.save and Checkpoint.restore write and read object-based checkpoints, in contrast to tf.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables when executing eagerly. Prefer tf.train.Checkpoint over tf.train.Saver for new code.

这里有一个例子:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

这里有更多信息和示例。

最简单的方法是使用keras api,在线保存模型和一行加载模型

from keras.models import load_model

my_model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'

del my_model  # deletes the existing model


my_model = load_model('my_model.h5') # returns a compiled model identical to the previous one