在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
这里所有的答案都很棒,但我想补充两点。
首先,详细说明@user7505159的答案,“。添加到要恢复的文件名的开头可能很重要。
例如,您可以保存没有“的图形。/"在文件名中如下所示:
# Some graph defined up here with specific names
saver = tf.train.Saver()
save_file = 'model.ckpt'
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, save_file)
但是为了恢复图形,您可能需要在前面加上一个"。/"到file_name:
# Same graph defined up here
saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, save_file)
你并不总是需要“。/”,但是它会根据你的环境和TensorFlow版本而导致问题。
它还想提到sess.run(tf.global_variables_initializer())在恢复会话之前可能很重要。
如果在尝试恢复保存的会话时收到关于未初始化变量的错误,请确保在保存程序之前包含sess.run(tf.global_variables_initializer())。恢复(sess, save_file)行。这样你就不用头疼了。
其他回答
最简单的方法是使用keras api,在线保存模型和一行加载模型
from keras.models import load_model
my_model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
del my_model # deletes the existing model
my_model = load_model('my_model.h5') # returns a compiled model identical to the previous one
下面是一个使用Tensorflow 2.0 SavedModel格式(根据文档,这是推荐的格式)的简单MNIST数据集分类器的简单示例,使用Keras函数式API,没有太多的花哨操作:
# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()
# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)
# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train
model.fit(x_train, y_train, epochs=3)
# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)
# ... possibly another python program
# Reload model
loaded_model = tf.keras.models.load_model(export_path)
# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step
# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))
# Show results
print(np.argmax(prediction['graph_output'])) # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary) # prints the image
serving_default是什么?
它是所选标记的签名定义的名称(在本例中,选择了默认的服务标记)。此外,本文还解释了如何使用saved_model_cli查找模型的标记和签名。
免责声明
这只是一个基本的例子,如果你只是想让它运行起来,但这绝不是一个完整的答案-也许我可以在未来更新它。我只是想给出一个在TF 2.0中使用SavedModel的简单示例,因为我在任何地方都没有见过这样简单的SavedModel。
@Tom的回答是一个SavedModel的例子,但它在Tensorflow 2.0上不起作用,因为不幸的是有一些突破性的变化。
@Vishnuvardhan Janapati的回答是TF 2.0,但它不适合SavedModel格式。
在大多数情况下,使用tf.train.Saver从磁盘保存和恢复是最好的选择:
... # build your model
saver = tf.train.Saver()
with tf.Session() as sess:
... # train the model
saver.save(sess, "/tmp/my_great_model")
with tf.Session() as sess:
saver.restore(sess, "/tmp/my_great_model")
... # use the model
您还可以保存/恢复图结构本身(详细信息请参阅MetaGraph文档)。默认情况下,保存程序将图形结构保存到.meta文件中。您可以调用import_meta_graph()来恢复它。它恢复图形结构并返回一个你可以用来恢复模型状态的保护程序:
saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")
with tf.Session() as sess:
saver.restore(sess, "/tmp/my_great_model")
... # use the model
然而,在某些情况下,您需要更快的方法。例如,如果您实现了早期停止,那么您希望在训练期间每次模型改进时都保存检查点(在验证集上测量),然后如果一段时间内没有进展,则希望回滚到最佳模型。如果每次模型改进时都将其保存到磁盘,则会极大地降低训练速度。诀窍是将变量状态保存到内存中,然后稍后恢复它们:
... # build your model
# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]
with tf.Session() as sess:
... # train the model
# when needed, save the model state to memory
gvars_state = sess.run(gvars)
# when needed, restore the model state
feed_dict = {init_value: val
for init_value, val in zip(init_values, gvars_state)}
sess.run(assign_ops, feed_dict=feed_dict)
A quick explanation: when you create a variable X, TensorFlow automatically creates an assignment operation X/Assign to set the variable's initial value. Instead of creating placeholders and extra assignment ops (which would just make the graph messy), we just use these existing assignment ops. The first input of each assignment op is a reference to the variable it is supposed to initialize, and the second input (assign_op.inputs[1]) is the initial value. So in order to set any value we want (instead of the initial value), we need to use a feed_dict and replace the initial value. Yes, TensorFlow lets you feed a value for any op, not just for placeholders, so this works fine.
你也可以用更简单的方法。
步骤1:初始化所有变量
W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")
Similarly, W2, B2, W3, .....
步骤2:在模型Saver中保存会话并保存它
model_saver = tf.train.Saver()
# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")
步骤3:恢复模型
with tf.Session(graph=graph_cnn) as session:
model_saver.restore(session, "saved_models/CNN_New.ckpt")
print("Model restored.")
print('Initialized')
步骤4:检查变量
W1 = session.run(W1)
print(W1)
在不同的python实例中运行时,使用
with tf.Session() as sess:
# Restore latest checkpoint
saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))
# Initalize the variables
sess.run(tf.global_variables_initializer())
# Get default graph (supply your custom graph if you have one)
graph = tf.get_default_graph()
# It will give tensor object
W1 = graph.get_tensor_by_name('W1:0')
# To get the value (numpy array)
W1_value = session.run(W1)
正如Yaroslav所说,您可以通过导入图、手动创建变量,然后使用Saver来从graph_def和检查点进行恢复。
我实现这个是为了我个人使用,所以我想在这里分享一下代码。
链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(当然,这是一种hack,并且不能保证以这种方式保存的模型在TensorFlow的未来版本中仍然是可读的。)