如何在PyTorch中保存训练好的模型?我读到过:
Torch.save ()/torch.load()用于保存/加载可序列化对象。 model.state_dict()/model.load_state_dict()用于保存/加载模型状态。
如何在PyTorch中保存训练好的模型?我读到过:
Torch.save ()/torch.load()用于保存/加载可序列化对象。 model.state_dict()/model.load_state_dict()用于保存/加载模型状态。
当前回答
PIP安装火炬闪电
确保你的父模型使用pl.LightningModule而不是nn。模块
使用pytorch闪电保存和加载检查点
import pytorch_lightning as pl
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
其他回答
现在所有内容都写在官方教程中: https://pytorch.org/tutorials/beginner/saving_loading_models.html
关于如何保存和保存什么,您有几个选项,所有这些都在本教程中进行了解释。
我用这个方法,希望对大家有用。
num_labels = len(test_label_cols)
robertaclassificationtrain = '/dbfs/FileStore/tables/PM/TC/roberta_model'
robertaclassificationpath = "/dbfs/FileStore/tables/PM/TC/ROBERTACLASSIFICATION"
model = RobertaForSequenceClassification.from_pretrained(robertaclassificationpath,
num_labels=num_labels)
model.cuda()
model.load_state_dict(torch.load(robertaclassificationtrain))
model.eval()
我保存我的火车模型已经在“roberta_model”路径。保存一个火车模型。
torch.save(model.state_dict(), '/dbfs/FileStore/tables/PM/TC/roberta_model')
在他们的github回购上找到了这个页面:
Recommended approach for saving a model There are two main approaches for serializing and restoring a model. The first (recommended) saves and loads only the model parameters: torch.save(the_model.state_dict(), PATH) Then later: the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH)) The second saves and loads the entire model: torch.save(the_model, PATH) Then later: the_model = torch.load(PATH) However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
请参见官方PyTorch教程中的保存和加载模型部分。
PIP安装火炬闪电
确保你的父模型使用pl.LightningModule而不是nn。模块
使用pytorch闪电保存和加载检查点
import pytorch_lightning as pl
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")
如果您想保存模型,并希望稍后恢复训练:
单一的GPU: 拯救策略:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)
负载:
checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
多个GPU: 保存
state = {
'epoch': epoch,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)
负载:
checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
#Don't call DataParallel before loading the model otherwise you will get an error
model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU