在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
我的环境:Python 3.6, Tensorflow 1.3.0
虽然有很多解决方案,但大多数都是基于tf.train.Saver。当我们加载由Saver保存的.ckpt文件时,我们必须要么重新定义张量流网络,要么使用一些奇怪且难以记住的名称,例如:“placehold_0:0”,“密集/亚当/重量:0”。这里我推荐使用tf。saved_model,下面给出的一个最简单的例子,你可以从为TensorFlow模型服务中学到更多:
保存模型:
import tensorflow as tf
# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")
h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# save the model
export_path = './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'x_input': tensor_info_x},
outputs={'y_output': tensor_info_y},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
prediction_signature
},
)
builder.save()
加载模型:
import tensorflow as tf
sess=tf.Session()
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'
export_path = './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
export_path)
signature = meta_graph_def.signature_def
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
y_out = sess.run(y, {x: 3.0})
其他回答
在TensorFlow 0.11.0RC1版本中,你可以通过调用tf.train直接保存和恢复你的模型。Export_meta_graph和tf.train。根据https://www.tensorflow.org/programmers_guide/meta_graph的Import_meta_graph。
保存模型
w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta
恢复模型
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
v_ = sess.run(v)
print(v_)
您可以保存网络中的变量使用
saver = tf.train.Saver()
saver.save(sess, 'path of save/fileName.ckpt')
要恢复网络以供以后或在另一个脚本中重用,请使用:
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....)
重要的几点:
第一次运行和以后运行之间的Sess必须相同(一致的结构)。 储蓄者。还原需要保存文件的文件夹路径,而不是单个文件路径。
我正在改进我的回答,以添加更多关于保存和恢复模型的细节。
在Tensorflow 0.11版本中(及之后):
保存模型:
import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)
恢复模型:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated
这里已经很好地解释了这一点和一些更高级的用例。
一个快速完整的教程,保存和恢复Tensorflow模型
对于张量流2.0,它非常简单
#保存模型 model.save(“path_to_my_model.h5”)
恢复:
new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')
正如Yaroslav所说,您可以通过导入图、手动创建变量,然后使用Saver来从graph_def和检查点进行恢复。
我实现这个是为了我个人使用,所以我想在这里分享一下代码。
链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(当然,这是一种hack,并且不能保证以这种方式保存的模型在TensorFlow的未来版本中仍然是可读的。)