在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
使用tf.train.Saver保存模型。记住,如果想要减小模型大小,就需要指定var_list。val_list可以是:
特遣部队。trainable_variables或 tf.global_variables。
其他回答
你也可以用更简单的方法。
步骤1:初始化所有变量
W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")
Similarly, W2, B2, W3, .....
步骤2:在模型Saver中保存会话并保存它
model_saver = tf.train.Saver()
# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")
步骤3:恢复模型
with tf.Session(graph=graph_cnn) as session:
model_saver.restore(session, "saved_models/CNN_New.ckpt")
print("Model restored.")
print('Initialized')
步骤4:检查变量
W1 = session.run(W1)
print(W1)
在不同的python实例中运行时,使用
with tf.Session() as sess:
# Restore latest checkpoint
saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))
# Initalize the variables
sess.run(tf.global_variables_initializer())
# Get default graph (supply your custom graph if you have one)
graph = tf.get_default_graph()
# It will give tensor object
W1 = graph.get_tensor_by_name('W1:0')
# To get the value (numpy array)
W1_value = session.run(W1)
这里所有的答案都很棒,但我想补充两点。
首先,详细说明@user7505159的答案,“。添加到要恢复的文件名的开头可能很重要。
例如,您可以保存没有“的图形。/"在文件名中如下所示:
# Some graph defined up here with specific names
saver = tf.train.Saver()
save_file = 'model.ckpt'
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, save_file)
但是为了恢复图形,您可能需要在前面加上一个"。/"到file_name:
# Same graph defined up here
saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_file)
你并不总是需要“。/”,但是它会根据你的环境和TensorFlow版本而导致问题。
它还想提到sess.run(tf.global_variables_initializer())在恢复会话之前可能很重要。
如果在尝试恢复保存的会话时收到关于未初始化变量的错误,请确保在保存程序之前包含sess.run(tf.global_variables_initializer())。恢复(sess, save_file)行。这样你就不用头疼了。
在TensorFlow 0.11.0RC1版本中,你可以通过调用tf.train直接保存和恢复你的模型。Export_meta_graph和tf.train。根据https://www.tensorflow.org/programmers_guide/meta_graph的Import_meta_graph。
保存模型
w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta
恢复模型
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
v_ = sess.run(v)
print(v_)
最简单的方法是使用keras api,在线保存模型和一行加载模型
from keras.models import load_model
my_model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del my_model # deletes the existing model
my_model = load_model('my_model.h5') # returns a compiled model identical to the previous one
下面是我对这两种基本情况的简单解决方案,这两种情况的不同之处在于您是想从文件加载图形还是在运行时构建它。
这个答案适用于Tensorflow 0.12+(包括1.0)。
在代码中重建图形
储蓄
graph = ... # build the graph
saver = tf.train.Saver() # create the saver after the graph
with ... as sess: # your session object
saver.save(sess, 'my-model')
加载
graph = ... # build the graph
saver = tf.train.Saver() # create the saver after the graph
with ... as sess: # your session object
saver.restore(sess, tf.train.latest_checkpoint('./'))
# now you can use the graph, continue training or whatever
还从文件中加载图形
当使用这种技术时,确保所有的层/变量都显式地设置了唯一的名称。否则Tensorflow将使名称本身是唯一的,因此它们将不同于存储在文件中的名称。在前一种技术中,这不是问题,因为名称在加载和保存时都以相同的方式“损坏”。
储蓄
graph = ... # build the graph
for op in [ ... ]: # operators you want to use after restoring the model
tf.add_to_collection('ops_to_restore', op)
saver = tf.train.Saver() # create the saver after the graph
with ... as sess: # your session object
saver.save(sess, 'my-model')
加载
with ... as sess: # your session object
saver = tf.train.import_meta_graph('my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
ops = tf.get_collection('ops_to_restore') # here are your operators in the same order in which you saved them to the collection