在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
无论你想把模型保存在哪里,
self.saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
...
self.saver.save(sess, filename)
确保你所有的任务。变量有名称,因为您可能希望稍后使用它们的名称来恢复它们。 在你想预测的地方,
saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file'
with tf.Session() as sess:
saver.restore(sess, name)
print(sess.run('W1:0')) #example to retrieve by variable name
确保该保护程序在相应的会话中运行。 请记住,如果使用tf.train.latest_checkpoint('./'),那么将只使用最新的检查点。
其他回答
在@Vishnuvardhan Janapati的回答之后,这里是另一种在TensorFlow 2.0.0下保存和重载自定义层/度量/损失模型的方法
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects
# custom loss (for example)
def custom_loss(y_true,y_pred):
return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss})
# custom loss (for example)
class CustomLayer(Layer):
def __init__(self, ...):
...
# define custom layer and all necessary custom operations inside custom layer
get_custom_objects().update({'CustomLayer': CustomLayer})
通过这种方式,一旦您执行了这些代码,并使用tf.keras.models保存了您的模型。Save_model或model。save或ModelCheckpoint回调,您可以重新加载您的模型,而不需要精确的自定义对象,就像这样简单
new_model = tf.keras.models.load_model("./model.h5"})
无论你想把模型保存在哪里,
self.saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
...
self.saver.save(sess, filename)
确保你所有的任务。变量有名称,因为您可能希望稍后使用它们的名称来恢复它们。 在你想预测的地方,
saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file'
with tf.Session() as sess:
saver.restore(sess, name)
print(sess.run('W1:0')) #example to retrieve by variable name
确保该保护程序在相应的会话中运行。 请记住,如果使用tf.train.latest_checkpoint('./'),那么将只使用最新的检查点。
如第6255期所述:
use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')
而不是
saver.restore('my_model_final.ckpt')
我在版本:
tensorflow (1.13.1)
tensorflow-gpu (1.13.1)
简单的方法是
拯救策略:
model.save("model.h5")
恢复:
model = tf.keras.models.load_model("model.h5")
在TensorFlow 0.11.0RC1版本中,你可以通过调用tf.train直接保存和恢复你的模型。Export_meta_graph和tf.train。根据https://www.tensorflow.org/programmers_guide/meta_graph的Import_meta_graph。
保存模型
w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta
恢复模型
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
v_ = sess.run(v)
print(v_)