在Tensorflow中训练一个模型后:

如何拯救训练过的模型? 您以后如何恢复这个保存的模型?


当前回答

无论你想把模型保存在哪里,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

确保你所有的任务。变量有名称,因为您可能希望稍后使用它们的名称来恢复它们。 在你想预测的地方,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

确保该保护程序在相应的会话中运行。 请记住,如果使用tf.train.latest_checkpoint('./'),那么将只使用最新的检查点。

其他回答

你也可以用更简单的方法。

步骤1:初始化所有变量

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

步骤2:在模型Saver中保存会话并保存它

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

步骤3:恢复模型

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

步骤4:检查变量

W1 = session.run(W1)
print(W1)

在不同的python实例中运行时,使用

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

你可以使用Tensorflow中的saver对象来保存你训练过的模型。该对象提供保存和恢复模型的方法。

在TensorFlow中保存一个训练好的模型:

tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None,
                    meta_graph_suffix='meta', write_meta_graph=True,
                    write_state=True, strip_default_attrs=False,
                    save_debug_info=False)

在TensorFlow中恢复已保存的模型:

tf.train.Saver.restore(sess, save_path, latest_filename=None,
                       meta_graph_suffix='meta', clear_devices=False,
                       import_scope=None)

对于TensorFlow版本< 0.11.0RC1:

保存的检查点包含模型中的变量值,而不是模型/图本身,这意味着当您恢复检查点时,图应该是相同的。

这里有一个线性回归的例子,其中有一个训练循环,保存变量检查点,还有一个评估部分,将恢复之前运行中保存的变量并计算预测。当然,如果你愿意,你也可以恢复变量并继续训练。

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

下面是变量文档,涵盖了保存和恢复。这是保存程序的文档。

使用tf.train.Saver保存模型。记住,如果想要减小模型大小,就需要指定var_list。val_list可以是:

特遣部队。trainable_variables或 tf.global_variables。

如第6255期所述:

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

而不是

saver.restore('my_model_final.ckpt')