我有点困惑这段代码是如何工作的:

fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()

在这种情况下,无花果轴是如何工作的?它能做什么?

还有,为什么这不能做同样的事情:

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

阅读文档:matplotlib.pyplot.subplots

Pyplot.subplots()返回一个元组图ax,它使用表示法在两个变量中解包

fig, axes = plt.subplots(nrows=2, ncols=2)

代码:

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

不能工作,因为subplots()是pyplot中的函数,而不是对象图的成员。


有几种方法可以做到这一点。subplots方法创建图形和子图,然后存储在ax数组中。例如:

import matplotlib.pyplot as plt

x = range(10)
y = range(10)

fig, ax = plt.subplots(nrows=2, ncols=2)

for row in ax:
    for col in row:
        col.plot(x, y)

plt.show()

然而,像这样的东西也可以工作,虽然它不是那么“干净”,因为你创建了一个带有子图的图形,然后在它们上面添加:

fig = plt.figure()

plt.subplot(2, 2, 1)
plt.plot(x, y)

plt.subplot(2, 2, 2)
plt.plot(x, y)

plt.subplot(2, 2, 3)
plt.plot(x, y)

plt.subplot(2, 2, 4)
plt.plot(x, y)

plt.show()


您可能会对这样一个事实感兴趣,即在matplotlib 2.1版本中,问题中的第二个代码也可以正常工作。

从更改日志中:

图形类现在有subplots方法 Figure类现在有一个subplots()方法,其行为与pyplot.subplots()相同,但是是在一个现有的图形上。

例子:

import matplotlib.pyplot as plt

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)

plt.show()

import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 2)

ax[0, 0].plot(range(10), 'r') #row=0, col=0
ax[1, 0].plot(range(10), 'b') #row=1, col=0
ax[0, 1].plot(range(10), 'g') #row=0, col=1
ax[1, 1].plot(range(10), 'k') #row=1, col=1
plt.show()


还可以在subplots调用中解包坐标轴 并设置是否要在子图之间共享x轴和y轴

是这样的:

import matplotlib.pyplot as plt
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1, ax2, ax3, ax4 = axes.flatten()

ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()


如果你真的想使用循环,请执行以下操作:

def plot(data):
    fig = plt.figure(figsize=(100, 100))
    for idx, k in enumerate(data.keys(), 1):
        x, y = data[k].keys(), data[k].values
        plt.subplot(63, 10, idx)
        plt.bar(x, y)  
    plt.show()

依次遍历所有子图:

fig, axes = plt.subplots(nrows, ncols)

for ax in axes.flatten():
    ax.plot(x,y)

访问特定索引:

for row in range(nrows):
    for col in range(ncols):
        axes[row,col].plot(x[row], y[col])

你可以使用以下语句:

import numpy as np
import matplotlib.pyplot as plt

fig, _ = plt.subplots(nrows=2, ncols=2)

for i, ax in enumerate(fig.axes):
  ax.plot(np.sin(np.linspace(0,2*np.pi,100) + np.pi/2*i))

或者用第二个变量plt。次要情节的回报:

fig, ax_mat = plt.subplots(nrows=2, ncols=2)
for i, ax in enumerate(ax_mat.flatten()):
    ...

Ax_mat是一个轴的矩阵。它的形状是nrows x ncols。


熊猫的次要情节

这个答案适用于使用pandas的子图,它使用matplotlib作为默认的绘图后端。 这里有四种选择,可以创建以熊猫为开头的子情节。DataFrame 实现1。和2。用于宽格式的数据,为每列创建子图。 实现3。和4。用于长格式的数据,为列中的每个惟一值创建子图。 在python 3.8.11, pandas 1.3.2, matplotlib 3.4.3, seaborn 0.11.2中测试

导入和数据

import seaborn as sns  # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]

   orbital_period   mass  distance
0         269.300   7.10     77.40
1         874.774   2.21     56.95
2         763.000   2.60     19.84
3         326.030  19.40    110.62
4         516.220  10.50    119.47

# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()

         variable    value
0  orbital_period  269.300
1  orbital_period  874.774
2  orbital_period  763.000
3  orbital_period  326.030
4  orbital_period  516.220

1. subplots=True和layout,为每列

使用pandas.DataFrame.plot中的参数subplots=True和layout=(rows, cols) 本例使用kind='density',但kind有不同的选项,这适用于所有类型。如果不指定类型,则默认为线形图。 ax是pandas.DataFrame.plot返回的AxesSubplot数组 如果需要,请参阅如何获取Figure对象。 如何拯救熊猫支线

axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))

# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure() 

# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
    ax.set_title(title)
fig.tight_layout()
plt.show()

2. plt。每个列的子图

Create an array of Axes with matplotlib.pyplot.subplots and then pass axes[i, j] or axes[n] to the ax parameter. This option uses pandas.DataFrame.plot, but can use other axes level plot calls as a substitute (e.g. sns.kdeplot, plt.plot, etc.) It's easiest to collapse the subplot array of Axes into one dimension with .ravel or .flatten. See .ravel vs .flatten. Any variables applying to each axes, that need to be iterate through, are combined with .zip (e.g. cols, axes, colors, palette, etc.). Each object must be the same length.

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))  # define the figure and subplots
axes = axes.ravel()  # array to 1D
cols = df.columns  # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green']  # list of colors for each subplot, otherwise all subplots will be one color

for col, color, ax in zip(cols, colors, axes):
    df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
    ax.legend()
    
fig.delaxes(axes[3])  # delete the empty subplot
fig.tight_layout()
plt.show()

结果是1。和2。

3.plt。子图,用于.groupby中的每个组

这和2类似。,除了它压缩颜色和轴到一个.groupby对象。

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6))  # define the figure and subplots
axes = axes.ravel()  # array to 1D
dfg = dfm.groupby('variable')  # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green']  # list of colors for each subplot, otherwise all subplots will be one color

for (group, data), color, ax in zip(dfg, colors, axes):
    data.plot(kind='density', ax=ax, color=color, title=group, legend=False)

fig.delaxes(axes[3])  # delete the empty subplot
fig.tight_layout()
plt.show()

4. Seaborn数字级图

使用海运数字级绘图,并使用col或row参数。seaborn是matplotlib的高级API。参见seaborn: API参考

p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
                facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))

将坐标轴数组转换为1D

Generating subplots with plt.subplots(nrows, ncols), where both nrows and ncols is greater than 1, returns a nested array of <AxesSubplot:> objects. It’s not necessary to flatten axes in cases where either nrows=1 or ncols=1, because axes will already be 1 dimensional, which is a result of the default parameter squeeze=True The easiest way to access the objects, is to convert the array to 1 dimension with .ravel(), .flatten(), or .flat. .ravel vs. .flatten flatten always returns a copy. ravel returns a view of the original array whenever possible. Once the array of axes is converted to 1-d, there are a number of ways to plot. This answer is relevant to seaborn axes-level plots, which have the ax= parameter (e.g. sns.barplot(…, ax=ax[0]). seaborn is a high-level API for matplotlib. See Figure-level vs. axes-level functions and seaborn is not plotting within defined subplots

import matplotlib.pyplot as plt
import numpy as np  # sample data only

# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]

# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)

# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
       [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)

# convert the array to 1 dimension
axes = axes.ravel()

# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
      dtype=object)

遍历扁平数组 如果子图比数据多,这将导致IndexError: list索引超出范围 尝试选择3。或者选择坐标轴的子集(例如坐标轴[:-2])

for i, ax in enumerate(axes):
    ax.plot(x_data[i], y_data[i])

按索引访问每个轴

axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])

索引数据和坐标轴

for i in range(len(x_data)):
    axes[i].plot(x_data[i], y_data[i])

压缩轴和数据,然后遍历元组列表。

for ax, x, y in zip(axes, x_data, y_data):
    ax.plot(x, y)

Ouput


An option is to assign each axes to a variable, fig, (ax1, ax2, ax3) = plt.subplots(1, 3). However, as written, this only works in cases with either nrows=1 or ncols=1. This is based on the shape of the array returned by plt.subplots, and quickly becomes cumbersome. fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) for a 2 x 2 array. This option is most useful for two subplots (e.g.: fig, (ax1, ax2) = plt.subplots(1, 2) or fig, (ax1, ax2) = plt.subplots(2, 1)). For more subplots, it's more efficient to flatten and iterate through the array of axes.


这里有一个简单的解决办法

fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=False)
for sp in fig.axes:
    sp.plot(range(10))


另一个简洁的解决方案是:

// set up structure of plots
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,10))

// for plot 1
ax1.set_title('Title A')
ax1.plot(x, y)

// for plot 2
ax2.set_title('Title B')
ax2.plot(x, y)

// for plot 3
ax3.set_title('Title C')
ax3.plot(x,y)