什么时候应该使用.eval()?我知道它应该允许我“评估我的模型”。我如何在训练时关闭它?

使用.eval()的示例训练代码。


当前回答

model.eval()是一种针对模型中某些特定层/部分的开关,这些层/部分在训练和推断(评估)时间中表现不同。例如,Dropouts层,BatchNorm层等。您需要在模型计算期间关闭它们,而.eval()将为您做这件事。此外,评估/验证的常用做法是使用torch.no_grad()和model.eval()来关闭梯度计算:

# evaluate model:
model.eval()

with torch.no_grad():
    ...
    out_data = model(data)
    ...

但是,不要忘记在评估步骤后回到训练模式:

# training step
...
model.train()
...

其他回答

model.eval()是一种针对模型中某些特定层/部分的开关,这些层/部分在训练和推断(评估)时间中表现不同。例如,Dropouts层,BatchNorm层等。您需要在模型计算期间关闭它们,而.eval()将为您做这件事。此外,评估/验证的常用做法是使用torch.no_grad()和model.eval()来关闭梯度计算:

# evaluate model:
model.eval()

with torch.no_grad():
    ...
    out_data = model(data)
    ...

但是,不要忘记在评估步骤后回到训练模式:

# training step
...
model.train()
...

除了上述答案之外,还有一个额外的答案:

我最近开始使用Pytorch-lightning,它在训练-验证-测试管道中封装了许多样板文件。

此外,它允许使用train_step和validation_step回调来包装eval和train,从而使model.eval()和model.train()近乎多余,因此您永远不会忘记。

模型。eval是torch.nn的一个方法。

eval () 将模块设置为评估模式。 这只对某些模块有影响。具体模块在训练/评估模式下的行为,如Dropout, BatchNorm等,请参见具体模块的文档。 这相当于self.train(False)。

相反的方法是模型。Umang Gupta很好地解释了火车。

model.train() model.eval()
Sets model in training mode:

• normalisation layers1 use per-batch statistics
• activates Dropout layers2
Sets model in evaluation (inference) mode:

• normalisation layers use running statistics
• de-activates Dropout layers
Equivalent to model.train(False).

你可以通过运行model.train()来关闭评估模式。你应该在运行你的模型时使用它作为推理引擎——即在测试、验证和预测时使用(尽管实际上,如果你的模型不包括任何行为不同的层,它也没有什么区别)。


例如:BatchNorm, InstanceNorm 这包括RNN模块的子模块等。