在Tensorflow中训练一个模型后:

如何拯救训练过的模型? 您以后如何恢复这个保存的模型?


当前回答

在@Vishnuvardhan Janapati的回答之后,这里是另一种在TensorFlow 2.0.0下保存和重载自定义层/度量/损失模型的方法

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects

# custom loss (for example)  
def custom_loss(y_true,y_pred):
  return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss}) 

# custom loss (for example) 
class CustomLayer(Layer):
  def __init__(self, ...):
      ...
  # define custom layer and all necessary custom operations inside custom layer

get_custom_objects().update({'CustomLayer': CustomLayer})  

通过这种方式,一旦您执行了这些代码,并使用tf.keras.models保存了您的模型。Save_model或model。save或ModelCheckpoint回调,您可以重新加载您的模型,而不需要精确的自定义对象,就像这样简单

new_model = tf.keras.models.load_model("./model.h5"})

其他回答

我的环境:Python 3.6, Tensorflow 1.3.0

虽然有很多解决方案,但大多数都是基于tf.train.Saver。当我们加载由Saver保存的.ckpt文件时,我们必须要么重新定义张量流网络,要么使用一些奇怪且难以记住的名称,例如:“placehold_0:0”,“密集/亚当/重量:0”。这里我推荐使用tf。saved_model,下面给出的一个最简单的例子,你可以从为TensorFlow模型服务中学到更多:

保存模型:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

加载模型:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})

如果您使用tf.train.MonitoredTrainingSession作为默认会话,则不需要添加额外的代码来执行保存/恢复操作。只需将检查点目录名称传递给MonitoredTrainingSession的构造函数,它将使用会话挂钩来处理这些。

无论你想把模型保存在哪里,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

确保你所有的任务。变量有名称,因为您可能希望稍后使用它们的名称来恢复它们。 在你想预测的地方,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

确保该保护程序在相应的会话中运行。 请记住,如果使用tf.train.latest_checkpoint('./'),那么将只使用最新的检查点。

如果它是一个内部保存的模型,您只需为所有变量指定一个恢复器为

restorer = tf.train.Saver(tf.all_variables())

并使用它来恢复当前会话中的变量:

restorer.restore(self._sess, model_file)

对于外部模型,您需要指定从它的变量名到您的变量名的映射。您可以使用该命令查看模型变量名

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

inspect_checkpoint.py脚本可以在`。tensorflow源码的/tensorflow/python/tools文件夹。

为了指定映射,你可以使用我的Tensorflow-Worklab,它包含一组类和脚本来训练和再训练不同的模型。它包括一个再训练ResNet模型的例子,位于这里

tensorflow - 2.0

这很简单。

import tensorflow as tf

SAVE

model.save("model_name")

恢复

model = tf.keras.models.load_model('model_name')