如何在PyTorch中保存训练好的模型?我读到过:
Torch.save ()/torch.load()用于保存/加载可序列化对象。 model.state_dict()/model.load_state_dict()用于保存/加载模型状态。
如何在PyTorch中保存训练好的模型?我读到过:
Torch.save ()/torch.load()用于保存/加载可序列化对象。 model.state_dict()/model.load_state_dict()用于保存/加载模型状态。
当前回答
现在所有内容都写在官方教程中: https://pytorch.org/tutorials/beginner/saving_loading_models.html
关于如何保存和保存什么,您有几个选项,所有这些都在本教程中进行了解释。
其他回答
在他们的github回购上找到了这个页面:
Recommended approach for saving a model There are two main approaches for serializing and restoring a model. The first (recommended) saves and loads only the model parameters: torch.save(the_model.state_dict(), PATH) Then later: the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH)) The second saves and loads the entire model: torch.save(the_model, PATH) Then later: the_model = torch.load(PATH) However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
请参见官方PyTorch教程中的保存和加载模型部分。
一个常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。
保存/加载整个模型
拯救策略:
path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)
负载:
(模型类必须在某处定义)
model.load_state_dict(torch.load(PATH))
model.eval()
PIP安装火炬闪电
确保你的父模型使用pl.LightningModule而不是nn。模块
使用pytorch闪电保存和加载检查点
import pytorch_lightning as pl
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
这取决于你想做什么。
案例# 1:保存模型以便自己使用它进行推理:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为你通常有BatchNorm和Dropout图层,默认情况下在构造时处于火车模式:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
案例2:保存模型以稍后恢复训练:如果您需要继续训练即将保存的模型,那么您需要保存的不仅仅是模型。您还需要保存优化器的状态、epoch、分数等。你会这样做:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
为了恢复训练,你可以这样做:state = torch.load(filepath),然后,恢复每个单独对象的状态,就像这样:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
由于您正在恢复训练,因此在加载时恢复状态后不要调用model.eval()。
案例# 3:模型被其他人使用,而不能访问您的代码: 在Tensorflow中,你可以创建一个.pb文件来定义模型的架构和权重。这非常方便,特别是在使用Tensorflow服务时。在Pytorch中这样做的等效方法是:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
这种方式仍然不是防弹的,因为pytorch仍然在经历很多变化,我不推荐它。
保存在本地
如何保存模型取决于将来如何访问它。如果你可以调用模型类的一个新实例,那么你所需要做的就是用model.state_dict()保存/加载模型的权重:
# Save:
torch.save(old_model.state_dict(), PATH)
# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))
如果你因为任何原因(或者更喜欢简单的语法)不能,那么你可以使用torch.save()保存整个模型(实际上是对定义模型的文件的引用,以及它的state_dict):
# Save:
torch.save(old_model, PATH)
# Load:
new_model = torch.load(PATH)
但是由于这是对定义模型类的文件位置的引用,所以这段代码是不可移植的,除非这些文件也移植到相同的目录结构中。
保存到云- TorchHub
如果您希望您的模型是可移植的,您可以轻松地使用torch.hub导入它。如果你在github repo中添加了一个适当定义的hubconf.py文件,这可以很容易地在PyTorch中调用,使用户能够加载你的模型,带/不带权重:
hubconf.py (github.com/repo_owner/repo_name)
dependencies = ['torch']
from my_module import mymodel as _mymodel
def mymodel(pretrained=False, **kwargs):
return _mymodel(pretrained=pretrained, **kwargs)
加载模型:
new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)