如何在PyTorch中保存训练好的模型?我读到过:

Torch.save ()/torch.load()用于保存/加载可序列化对象。 model.state_dict()/model.load_state_dict()用于保存/加载模型状态。


当前回答

这取决于你想做什么。

案例# 1:保存模型以便自己使用它进行推理:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为你通常有BatchNorm和Dropout图层,默认情况下在构造时处于火车模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例2:保存模型以稍后恢复训练:如果您需要继续训练即将保存的模型,那么您需要保存的不仅仅是模型。您还需要保存优化器的状态、epoch、分数等。你会这样做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

为了恢复训练,你可以这样做:state = torch.load(filepath),然后,恢复每个单独对象的状态,就像这样:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由于您正在恢复训练,因此在加载时恢复状态后不要调用model.eval()。

案例# 3:模型被其他人使用,而不能访问您的代码: 在Tensorflow中,你可以创建一个.pb文件来定义模型的架构和权重。这非常方便,特别是在使用Tensorflow服务时。在Pytorch中这样做的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

这种方式仍然不是防弹的,因为pytorch仍然在经历很多变化,我不推荐它。

其他回答

我用这个方法,希望对大家有用。

num_labels = len(test_label_cols)
robertaclassificationtrain = '/dbfs/FileStore/tables/PM/TC/roberta_model'
robertaclassificationpath = "/dbfs/FileStore/tables/PM/TC/ROBERTACLASSIFICATION"

model = RobertaForSequenceClassification.from_pretrained(robertaclassificationpath, 
num_labels=num_labels)
model.cuda()

model.load_state_dict(torch.load(robertaclassificationtrain))
model.eval()

我保存我的火车模型已经在“roberta_model”路径。保存一个火车模型。

torch.save(model.state_dict(), '/dbfs/FileStore/tables/PM/TC/roberta_model')

在他们的github回购上找到了这个页面:

Recommended approach for saving a model There are two main approaches for serializing and restoring a model. The first (recommended) saves and loads only the model parameters: torch.save(the_model.state_dict(), PATH) Then later: the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH)) The second saves and loads the entire model: torch.save(the_model, PATH) Then later: the_model = torch.load(PATH) However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.


请参见官方PyTorch教程中的保存和加载模型部分。

PIP安装火炬闪电

确保你的父模型使用pl.LightningModule而不是nn。模块

使用pytorch闪电保存和加载检查点

import pytorch_lightning as pl

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

这取决于你想做什么。

案例# 1:保存模型以便自己使用它进行推理:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为你通常有BatchNorm和Dropout图层,默认情况下在构造时处于火车模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例2:保存模型以稍后恢复训练:如果您需要继续训练即将保存的模型,那么您需要保存的不仅仅是模型。您还需要保存优化器的状态、epoch、分数等。你会这样做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

为了恢复训练,你可以这样做:state = torch.load(filepath),然后,恢复每个单独对象的状态,就像这样:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由于您正在恢复训练,因此在加载时恢复状态后不要调用model.eval()。

案例# 3:模型被其他人使用,而不能访问您的代码: 在Tensorflow中,你可以创建一个.pb文件来定义模型的架构和权重。这非常方便,特别是在使用Tensorflow服务时。在Pytorch中这样做的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

这种方式仍然不是防弹的,因为pytorch仍然在经历很多变化,我不推荐它。

如果您想保存模型,并希望稍后恢复训练:

单一的GPU: 拯救策略:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

负载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

多个GPU: 保存

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

负载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU