在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
在Tensorflow中训练一个模型后:
如何拯救训练过的模型? 您以后如何恢复这个保存的模型?
当前回答
我的环境:Python 3.6, Tensorflow 1.3.0
虽然有很多解决方案,但大多数都是基于tf.train.Saver。当我们加载由Saver保存的.ckpt文件时,我们必须要么重新定义张量流网络,要么使用一些奇怪且难以记住的名称,例如:“placehold_0:0”,“密集/亚当/重量:0”。这里我推荐使用tf。saved_model,下面给出的一个最简单的例子,你可以从为TensorFlow模型服务中学到更多:
保存模型:
import tensorflow as tf
# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")
h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# save the model
export_path = './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'x_input': tensor_info_x},
outputs={'y_output': tensor_info_y},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
prediction_signature
},
)
builder.save()
加载模型:
import tensorflow as tf
sess=tf.Session()
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'
export_path = './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
export_path)
signature = meta_graph_def.signature_def
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)
y_out = sess.run(y, {x: 3.0})
其他回答
模型有两个部分,模型定义,由Supervisor保存为图。模型目录中的PBTXT和张量的数值,保存到检查点文件,如model.ckpt-1003418。
可以使用tf恢复模型定义。import_graph_def,并且使用Saver恢复权重。
然而,Saver使用特殊的集合保存附加到模型Graph的变量列表,并且这个集合没有使用import_graph_def初始化,所以您目前不能同时使用这两者(这在我们的路线图中进行修复)。现在,您必须使用Ryan Sepassi的方法——手动构造具有相同节点名称的图,并使用Saver将权重加载到其中。
(或者,您可以通过使用import_graph_def,手动创建变量和使用tf.add_to_collection(tf.GraphKeys. collection)来破解它。变量,变量)为每个变量,然后使用Saver)
tensorflow - 2.0
这很简单。
import tensorflow as tf
SAVE
model.save("model_name")
恢复
model = tf.keras.models.load_model('model_name')
你也可以在TensorFlow/skflow中查看例子,它提供了保存和恢复方法,可以帮助你轻松地管理模型。它具有一些参数,您还可以控制备份模型的频率。
在@Vishnuvardhan Janapati的回答之后,这里是另一种在TensorFlow 2.0.0下保存和重载自定义层/度量/损失模型的方法
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects
# custom loss (for example)
def custom_loss(y_true,y_pred):
return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss})
# custom loss (for example)
class CustomLayer(Layer):
def __init__(self, ...):
...
# define custom layer and all necessary custom operations inside custom layer
get_custom_objects().update({'CustomLayer': CustomLayer})
通过这种方式,一旦您执行了这些代码,并使用tf.keras.models保存了您的模型。Save_model或model。save或ModelCheckpoint回调,您可以重新加载您的模型,而不需要精确的自定义对象,就像这样简单
new_model = tf.keras.models.load_model("./model.h5"})
对于张量流2.0,它非常简单
#保存模型 model.save(“path_to_my_model.h5”)
恢复:
new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')