如何像在Keras中使用model.summary()那样在PyTorch中打印模型的摘要呢?

Model Summary:
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 1, 15, 27)     0                                            
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 8, 15, 27)     872         input_1[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D)    (None, 8, 7, 27)      0           convolution2d_1[0][0]            
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 1512)          0           maxpooling2d_1[0][0]             
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 1)             1513        flatten_1[0][0]                  
====================================================================================================
Total params: 2,385
Trainable params: 2,385
Non-trainable params: 0

当前回答

你可以使用

from torchsummary import summary

你可以指定设备

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

您可以创建一个网络,如果您正在使用MNIST数据集,那么以下命令将工作并显示摘要

model = Network().to(device)
summary(model,(1,28,28))

其他回答

最容易记住(不如Keras漂亮):

print(model)

这也是可行的:

repr(model)

如果你只想知道参数的个数:

sum([param.nelement() for param in model.parameters()])

是否有类似于model.summary()和keras的pytorch函数?(forum.PyTorch.org)

在为模型类定义对象后,只需打印模型

class RNN(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
    def forward():
        ...

model = RNN(input_dim, embedding_dim, hidden_dim, output_dim)
print(model)

Summary (my_model, (3, 224, 224), device = 'cpu')可以解决这个问题。

这将显示模型的权重和参数(但不输出形状)。

from torch.nn.modules.module import _addindent
import torch
import numpy as np
def torch_summarize(model, show_weights=True, show_parameters=True):
    """Summarizes torch model by showing trainable parameters and weights."""
    tmpstr = model.__class__.__name__ + ' (\n'
    for key, module in model._modules.items():
        # if it contains layers let call it recursively to get params and weights
        if type(module) in [
            torch.nn.modules.container.Container,
            torch.nn.modules.container.Sequential
        ]:
            modstr = torch_summarize(module)
        else:
            modstr = module.__repr__()
        modstr = _addindent(modstr, 2)

        params = sum([np.prod(p.size()) for p in module.parameters()])
        weights = tuple([tuple(p.size()) for p in module.parameters()])

        tmpstr += '  (' + key + '): ' + modstr 
        if show_weights:
            tmpstr += ', weights={}'.format(weights)
        if show_parameters:
            tmpstr +=  ', parameters={}'.format(params)
        tmpstr += '\n'   

    tmpstr = tmpstr + ')'
    return tmpstr

# Test
import torchvision.models as models
model = models.alexnet()
print(torch_summarize(model))

# # Output
# AlexNet (
#   (features): Sequential (
#     (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)), weights=((64, 3, 11, 11), (64,)), parameters=23296
#     (1): ReLU (inplace), weights=(), parameters=0
#     (2): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
#     (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)), weights=((192, 64, 5, 5), (192,)), parameters=307392
#     (4): ReLU (inplace), weights=(), parameters=0
#     (5): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
#     (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((384, 192, 3, 3), (384,)), parameters=663936
#     (7): ReLU (inplace), weights=(), parameters=0
#     (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((256, 384, 3, 3), (256,)), parameters=884992
#     (9): ReLU (inplace), weights=(), parameters=0
#     (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((256, 256, 3, 3), (256,)), parameters=590080
#     (11): ReLU (inplace), weights=(), parameters=0
#     (12): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
#   ), weights=((64, 3, 11, 11), (64,), (192, 64, 5, 5), (192,), (384, 192, 3, 3), (384,), (256, 384, 3, 3), (256,), (256, 256, 3, 3), (256,)), parameters=2469696
#   (classifier): Sequential (
#     (0): Dropout (p = 0.5), weights=(), parameters=0
#     (1): Linear (9216 -> 4096), weights=((4096, 9216), (4096,)), parameters=37752832
#     (2): ReLU (inplace), weights=(), parameters=0
#     (3): Dropout (p = 0.5), weights=(), parameters=0
#     (4): Linear (4096 -> 4096), weights=((4096, 4096), (4096,)), parameters=16781312
#     (5): ReLU (inplace), weights=(), parameters=0
#     (6): Linear (4096 -> 1000), weights=((1000, 4096), (1000,)), parameters=4097000
#   ), weights=((4096, 9216), (4096,), (4096, 4096), (4096,), (1000, 4096), (1000,)), parameters=58631144
# )

编辑:isaykatsman有一个pytorch PR来添加一个完全像keras https://github.com/pytorch/pytorch/pull/3043/files的model.summary()

我更喜欢这个简单的片段——

net = model
modules = [module for module in net.modules()]
params = [param.shape for param in net.parameters()]

# Print Model Summary
print(modules[0])
total_params=0
for i in range(1,len(modules)):
   j = 2*i
   param = (params[j-2][1]*params[j-2][0])+params[j-1][0]
   total_params += param
   print("Layer",i,"->\t",end="")
   print("Weights:", params[j-2][0],"x",params[j-2][1],
         "\tBias: ",params[j-1][0], "\tParameters: ", param)
print("\nTotal Params: ", total_params)

这能打印出我需要的一切

Net(
  (hLayer1): Linear(in_features=1024, out_features=256, bias=True)
  (hLayer2): Linear(in_features=256, out_features=128, bias=True)
  (hLayer3): Linear(in_features=128, out_features=64, bias=True)
  (outLayer): Linear(in_features=64, out_features=10, bias=True)
)
Layer 1 ->  Weights: 256 x 1024     Bias:  256  Parameters:  262400
Layer 2 ->  Weights: 128 x 256      Bias:  128  Parameters:  32896
Layer 3 ->  Weights: 64 x 128       Bias:  64   Parameters:  8256
Layer 4 ->  Weights: 10 x 64        Bias:  10   Parameters:  650

Total Parameters:  304202

对于复杂的模型或更深入的模型统计

安装火炬

pip install torchstat

获得统计数据

from torchstat import stat
import torchvision.models as models

model = models.vgg19()
stat(model, (3, 224, 224))

输出-

        module name  input shape output shape       params memory(MB)              MAdd             Flops   MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0        features.0    3 224 224   64 224 224       1792.0      12.25     173,408,256.0      89,915,392.0     609280.0   12845056.0      21.56%   13454336.0
1        features.1   64 224 224   64 224 224          0.0      12.25       3,211,264.0       3,211,264.0   12845056.0   12845056.0       0.92%   25690112.0
2        features.2   64 224 224   64 224 224      36928.0      12.25   3,699,376,128.0   1,852,899,328.0   12992768.0   12845056.0       4.74%   25837824.0
3        features.3   64 224 224   64 224 224          0.0      12.25       3,211,264.0       3,211,264.0   12845056.0   12845056.0       0.92%   25690112.0
4        features.4   64 224 224   64 112 112          0.0       3.06       2,408,448.0       3,211,264.0   12845056.0    3211264.0       1.22%   16056320.0
5        features.5   64 112 112  128 112 112      73856.0       6.12   1,849,688,064.0     926,449,664.0    3506688.0    6422528.0       4.71%    9929216.0
6        features.6  128 112 112  128 112 112          0.0       6.12       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.94%   12845056.0
7        features.7  128 112 112  128 112 112     147584.0       6.12   3,699,376,128.0   1,851,293,696.0    7012864.0    6422528.0       4.36%   13435392.0
8        features.8  128 112 112  128 112 112          0.0       6.12       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.91%   12845056.0
9        features.9  128 112 112  128  56  56          0.0       1.53       1,204,224.0       1,605,632.0    6422528.0    1605632.0       1.51%    8028160.0
10      features.10  128  56  56  256  56  56     295168.0       3.06   1,849,688,064.0     925,646,848.0    2786304.0    3211264.0       3.57%    5997568.0
11      features.11  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.90%    6422528.0
12      features.12  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       4.30%    8782848.0
13      features.13  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.90%    6422528.0
14      features.14  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       4.38%    8782848.0
15      features.15  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.94%    6422528.0
16      features.16  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       4.33%    8782848.0
17      features.17  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.90%    6422528.0
18      features.18  256  56  56  256  28  28          0.0       0.77         602,112.0         802,816.0    3211264.0     802816.0       1.44%    4014080.0
19      features.19  256  28  28  512  28  28    1180160.0       1.53   1,849,688,064.0     925,245,440.0    5523456.0    1605632.0       3.60%    7129088.0
20      features.20  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.92%    3211264.0
21      features.21  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       4.45%   12650496.0
22      features.22  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.94%    3211264.0
23      features.23  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       4.39%   12650496.0
24      features.24  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.90%    3211264.0
25      features.25  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       4.34%   12650496.0
26      features.26  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.90%    3211264.0
27      features.27  512  28  28  512  14  14          0.0       0.38         301,056.0         401,408.0    1605632.0     401408.0       0.96%    2007040.0
28      features.28  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       0.99%   10242048.0
29      features.29  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
30      features.30  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       0.11%   10242048.0
31      features.31  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
32      features.32  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       0.11%   10242048.0
33      features.33  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
34      features.34  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       0.11%   10242048.0
35      features.35  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
36      features.36  512  14  14  512   7   7          0.0       0.10          75,264.0         100,352.0     401408.0     100352.0       0.01%     501760.0
37          avgpool  512   7   7  512   7   7          0.0       0.10               0.0               0.0          0.0          0.0       0.49%          0.0
38     classifier.0        25088         4096  102764544.0       0.02     205,516,800.0     102,760,448.0  411158528.0      16384.0      11.27%  411174912.0
39     classifier.1         4096         4096          0.0       0.02           4,096.0           4,096.0      16384.0      16384.0       0.00%      32768.0
40     classifier.2         4096         4096          0.0       0.02               0.0               0.0          0.0          0.0       0.01%          0.0
41     classifier.3         4096         4096   16781312.0       0.02      33,550,336.0      16,777,216.0   67141632.0      16384.0       1.08%   67158016.0
42     classifier.4         4096         4096          0.0       0.02           4,096.0           4,096.0      16384.0      16384.0       0.00%      32768.0
43     classifier.5         4096         4096          0.0       0.02               0.0               0.0          0.0          0.0       0.00%          0.0
44     classifier.6         4096         1000    4097000.0       0.00       8,191,000.0       4,096,000.0   16404384.0       4000.0       0.93%   16408384.0
total                                          143667240.0     119.34  39,283,567,128.0  19,667,896,320.0   16404384.0       4000.0     100.00%  825282624.0
============================================================================================================================================================
Total params: 143,667,240
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 119.34 MB
Total MAdd: 39.28 GMAdd
Total Flops: 19.67 GFlops
Total MemR+W: 787.05 MB