如何在PyTorch中保存训练好的模型?我读到过:

Torch.save ()/torch.load()用于保存/加载可序列化对象。 model.state_dict()/model.load_state_dict()用于保存/加载模型状态。


当前回答

PIP安装火炬闪电

确保你的父模型使用pl.LightningModule而不是nn。模块

使用pytorch闪电保存和加载检查点

import pytorch_lightning as pl

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

其他回答

PIP安装火炬闪电

确保你的父模型使用pl.LightningModule而不是nn。模块

使用pytorch闪电保存和加载检查点

import pytorch_lightning as pl

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

在他们的github回购上找到了这个页面:

Recommended approach for saving a model There are two main approaches for serializing and restoring a model. The first (recommended) saves and loads only the model parameters: torch.save(the_model.state_dict(), PATH) Then later: the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH)) The second saves and loads the entire model: torch.save(the_model, PATH) Then later: the_model = torch.load(PATH) However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.


请参见官方PyTorch教程中的保存和加载模型部分。

现在所有内容都写在官方教程中: https://pytorch.org/tutorials/beginner/saving_loading_models.html

关于如何保存和保存什么,您有几个选项,所有这些都在本教程中进行了解释。

一个常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。

保存/加载整个模型

拯救策略:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

负载:

(模型类必须在某处定义)

model.load_state_dict(torch.load(PATH))
model.eval()

我用这个方法,希望对大家有用。

num_labels = len(test_label_cols)
robertaclassificationtrain = '/dbfs/FileStore/tables/PM/TC/roberta_model'
robertaclassificationpath = "/dbfs/FileStore/tables/PM/TC/ROBERTACLASSIFICATION"

model = RobertaForSequenceClassification.from_pretrained(robertaclassificationpath, 
num_labels=num_labels)
model.cuda()

model.load_state_dict(torch.load(robertaclassificationtrain))
model.eval()

我保存我的火车模型已经在“roberta_model”路径。保存一个火车模型。

torch.save(model.state_dict(), '/dbfs/FileStore/tables/PM/TC/roberta_model')