如何在PyTorch中保存训练好的模型?我读到过:
Torch.save ()/torch.load()用于保存/加载可序列化对象。 model.state_dict()/model.load_state_dict()用于保存/加载模型状态。
如何在PyTorch中保存训练好的模型?我读到过:
Torch.save ()/torch.load()用于保存/加载可序列化对象。 model.state_dict()/model.load_state_dict()用于保存/加载模型状态。
当前回答
这取决于你想做什么。
案例# 1:保存模型以便自己使用它进行推理:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为你通常有BatchNorm和Dropout图层,默认情况下在构造时处于火车模式:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
案例2:保存模型以稍后恢复训练:如果您需要继续训练即将保存的模型,那么您需要保存的不仅仅是模型。您还需要保存优化器的状态、epoch、分数等。你会这样做:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
为了恢复训练,你可以这样做:state = torch.load(filepath),然后,恢复每个单独对象的状态,就像这样:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
由于您正在恢复训练,因此在加载时恢复状态后不要调用model.eval()。
案例# 3:模型被其他人使用,而不能访问您的代码: 在Tensorflow中,你可以创建一个.pb文件来定义模型的架构和权重。这非常方便,特别是在使用Tensorflow服务时。在Pytorch中这样做的等效方法是:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
这种方式仍然不是防弹的,因为pytorch仍然在经历很多变化,我不推荐它。
其他回答
我总是喜欢使用Torch7 (.t7)或Pickle (.pth, .pt)来保存pytorch模型的权重。
如果您想保存模型,并希望稍后恢复训练:
单一的GPU: 拯救策略:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)
负载:
checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
多个GPU: 保存
state = {
'epoch': epoch,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)
负载:
checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
#Don't call DataParallel before loading the model otherwise you will get an error
model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
这取决于你想做什么。
案例# 1:保存模型以便自己使用它进行推理:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为你通常有BatchNorm和Dropout图层,默认情况下在构造时处于火车模式:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
案例2:保存模型以稍后恢复训练:如果您需要继续训练即将保存的模型,那么您需要保存的不仅仅是模型。您还需要保存优化器的状态、epoch、分数等。你会这样做:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
为了恢复训练,你可以这样做:state = torch.load(filepath),然后,恢复每个单独对象的状态,就像这样:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
由于您正在恢复训练,因此在加载时恢复状态后不要调用model.eval()。
案例# 3:模型被其他人使用,而不能访问您的代码: 在Tensorflow中,你可以创建一个.pb文件来定义模型的架构和权重。这非常方便,特别是在使用Tensorflow服务时。在Pytorch中这样做的等效方法是:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
这种方式仍然不是防弹的,因为pytorch仍然在经历很多变化,我不推荐它。
我用这个方法,希望对大家有用。
num_labels = len(test_label_cols)
robertaclassificationtrain = '/dbfs/FileStore/tables/PM/TC/roberta_model'
robertaclassificationpath = "/dbfs/FileStore/tables/PM/TC/ROBERTACLASSIFICATION"
model = RobertaForSequenceClassification.from_pretrained(robertaclassificationpath,
num_labels=num_labels)
model.cuda()
model.load_state_dict(torch.load(robertaclassificationtrain))
model.eval()
我保存我的火车模型已经在“roberta_model”路径。保存一个火车模型。
torch.save(model.state_dict(), '/dbfs/FileStore/tables/PM/TC/roberta_model')
保存在本地
如何保存模型取决于将来如何访问它。如果你可以调用模型类的一个新实例,那么你所需要做的就是用model.state_dict()保存/加载模型的权重:
# Save:
torch.save(old_model.state_dict(), PATH)
# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))
如果你因为任何原因(或者更喜欢简单的语法)不能,那么你可以使用torch.save()保存整个模型(实际上是对定义模型的文件的引用,以及它的state_dict):
# Save:
torch.save(old_model, PATH)
# Load:
new_model = torch.load(PATH)
但是由于这是对定义模型类的文件位置的引用,所以这段代码是不可移植的,除非这些文件也移植到相同的目录结构中。
保存到云- TorchHub
如果您希望您的模型是可移植的,您可以轻松地使用torch.hub导入它。如果你在github repo中添加了一个适当定义的hubconf.py文件,这可以很容易地在PyTorch中调用,使用户能够加载你的模型,带/不带权重:
hubconf.py (github.com/repo_owner/repo_name)
dependencies = ['torch']
from my_module import mymodel as _mymodel
def mymodel(pretrained=False, **kwargs):
return _mymodel(pretrained=pretrained, **kwargs)
加载模型:
new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)