在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
我在版本:
tensorflow (1.13.1)
tensorflow-gpu (1.13.1)
简单的方法是
拯救策略:
model.save("model.h5")
恢复:
model = tf.keras.models.load_model("model.h5")
其他回答
我正在改进我的回答,以添加更多关于保存和恢复模型的细节。
在Tensorflow 0.11版本中(及之后):
保存模型:
import tensorflow as tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#Create a saver object which will save all the variables
saver = tf.train.Saver()
#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)
恢复模型:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated
这里已经很好地解释了这一点和一些更高级的用例。
一个快速完整的教程,保存和恢复Tensorflow模型
对于张量流2.0,它非常简单
#保存模型 model.save(“path_to_my_model.h5”)
恢复:
new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')
如第6255期所述:
use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')
而不是
saver.restore('my_model_final.ckpt')
我在版本:
tensorflow (1.13.1)
tensorflow-gpu (1.13.1)
简单的方法是
拯救策略:
model.save("model.h5")
恢复:
model = tf.keras.models.load_model("model.h5")
你也可以用更简单的方法。
步骤1:初始化所有变量
W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")
Similarly, W2, B2, W3, .....
步骤2:在模型Saver中保存会话并保存它
model_saver = tf.train.Saver()
# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")
步骤3:恢复模型
with tf.Session(graph=graph_cnn) as session:
model_saver.restore(session, "saved_models/CNN_New.ckpt")
print("Model restored.")
print('Initialized')
步骤4:检查变量
W1 = session.run(W1)
print(W1)
在不同的python实例中运行时,使用
with tf.Session() as sess:
# Restore latest checkpoint
saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))
# Initalize the variables
sess.run(tf.global_variables_initializer())
# Get default graph (supply your custom graph if you have one)
graph = tf.get_default_graph()
# It will give tensor object
W1 = graph.get_tensor_by_name('W1:0')
# To get the value (numpy array)
W1_value = session.run(W1)