在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
在@Vishnuvardhan Janapati的回答之后,这里是另一种在TensorFlow 2.0.0下保存和重载自定义层/度量/损失模型的方法
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects
# custom loss (for example)
def custom_loss(y_true,y_pred):
return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss})
# custom loss (for example)
class CustomLayer(Layer):
def __init__(self, ...):
...
# define custom layer and all necessary custom operations inside custom layer
get_custom_objects().update({'CustomLayer': CustomLayer})
通过这种方式,一旦您执行了这些代码,并使用tf.keras.models保存了您的模型。Save_model或model。save或ModelCheckpoint回调,您可以重新加载您的模型,而不需要精确的自定义对象,就像这样简单
new_model = tf.keras.models.load_model("./model.h5"})
其他回答
如果它是一个内部保存的模型,您只需为所有变量指定一个恢复器为
restorer = tf.train.Saver(tf.all_variables())
并使用它来恢复当前会话中的变量:
restorer.restore(self._sess, model_file)
对于外部模型,您需要指定从它的变量名到您的变量名的映射。您可以使用该命令查看模型变量名
python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt
inspect_checkpoint.py脚本可以在`。tensorflow源码的/tensorflow/python/tools文件夹。
为了指定映射,你可以使用我的Tensorflow-Worklab,它包含一组类和脚本来训练和再训练不同的模型。它包括一个再训练ResNet模型的例子,位于这里
我在版本:
tensorflow (1.13.1)
tensorflow-gpu (1.13.1)
简单的方法是
拯救策略:
model.save("model.h5")
恢复:
model = tf.keras.models.load_model("model.h5")
对于TensorFlow版本< 0.11.0RC1:
保存的检查点包含模型中的变量值,而不是模型/图本身,这意味着当您恢复检查点时,图应该是相同的。
这里有一个线性回归的例子,其中有一个训练循环,保存变量检查点,还有一个评估部分,将恢复之前运行中保存的变量并计算预测。当然,如果你愿意,你也可以恢复变量并继续训练。
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))
...more setup for optimization and what not...
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if FLAGS.train:
for i in xrange(FLAGS.training_steps):
...training loop...
if (i + 1) % FLAGS.checkpoint_steps == 0:
saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
global_step=i+1)
else:
# Here's where you're restoring the variables w and b.
# Note that the graph is exactly as it was when the variables were
# saved in a prior training run.
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
...no checkpoint found...
# Now you can run the model to get predictions
batch_x = ...load some data...
predictions = sess.run(y_hat, feed_dict={x: batch_x})
下面是变量文档,涵盖了保存和恢复。这是保存程序的文档。
如果您使用tf.train.MonitoredTrainingSession作为默认会话,则不需要添加额外的代码来执行保存/恢复操作。只需将检查点目录名称传递给MonitoredTrainingSession的构造函数,它将使用会话挂钩来处理这些。
你可以使用Tensorflow中的saver对象来保存你训练过的模型。该对象提供保存和恢复模型的方法。
在TensorFlow中保存一个训练好的模型:
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None,
meta_graph_suffix='meta', write_meta_graph=True,
write_state=True, strip_default_attrs=False,
save_debug_info=False)
在TensorFlow中恢复已保存的模型:
tf.train.Saver.restore(sess, save_path, latest_filename=None,
meta_graph_suffix='meta', clear_devices=False,
import_scope=None)