我有一个带有两个y轴的图,使用twinx()。我也给了线条标签,并想用legend()显示它们,但我只成功地获得了图例中一个轴的标签:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(time, Swdown, '-', label = 'Swdown')
ax.plot(time, Rn, '-', label = 'Rn')
ax2 = ax.twinx()
ax2.plot(time, temp, '-r', label = 'temp')
ax.legend(loc=0)
ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()
所以我只得到图例中第一个轴的标签,而不是第二个轴的标签“temp”。如何将第三个标签添加到图例中?
您可以通过添加以下行轻松添加第二个图例:
ax2.legend(loc=0)
你会得到这个:
但是如果你想要所有的标签都在一个图例上,那么你应该这样做:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')
time = np.arange(10)
temp = np.random.random(10)*30
Swdown = np.random.random(10)*100-10
Rn = np.random.random(10)*100-10
fig = plt.figure()
ax = fig.add_subplot(111)
lns1 = ax.plot(time, Swdown, '-', label = 'Swdown')
lns2 = ax.plot(time, Rn, '-', label = 'Rn')
ax2 = ax.twinx()
lns3 = ax2.plot(time, temp, '-r', label = 'temp')
# added these three lines
lns = lns1+lns2+lns3
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc=0)
ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()
它会给你这个:
目前提出的解决方案有一两个不便之处:
在绘图时需要单独收集句柄,例如lns1 = ax。plot(time, Swdown, '-', label = 'Swdown')。在更新代码时,有忘记句柄的风险。
图例是为整个图形绘制的,而不是通过子图绘制的,如果你有多个子图,这可能是不可取的。
这个新的解决方案利用了Axes.get_legend_handles_labels()来收集主轴和双轴的现有句柄和标签。
自动收集手柄和标签
这个numpy操作将扫描所有与ax共享相同subplot区域的轴,包括ax并返回合并的句柄和标签:
hl = np.hstack([axis.get_legend_handles_labels()
for axis in ax.figure.axes
if axis.bbox.bounds == ax.bbox.bounds])
它可以用这样的方式来提供legend()参数:
import numpy as np
import matplotlib.pyplot as plt
t = np.arange(1, 200)
signals = [np.exp(-t/20) * np.cos(t*k) for k in (1, 2)]
fig, axes = plt.subplots(nrows=2, figsize=(10, 3), layout='constrained')
axes = axes.flatten()
for i, (ax, signal) in enumerate(zip(axes, signals)):
# Plot as usual, no change to the code
ax.plot(t, signal, label=f'plotted on axes[{i}]', c='C0', lw=9, alpha=0.3)
ax2 = ax.twinx()
ax2.plot(t, signal, label=f'plotted on axes[{i}].twinx()', c='C1')
# The only specificity of the code is when plotting the legend
h, l = np.hstack([axis.get_legend_handles_labels()
for axis in ax.figure.axes
if axis.bbox.bounds == ax.bbox.bounds]).tolist()
ax2.legend(handles=h, labels=l, loc='upper right')
您可以通过添加以下行轻松添加第二个图例:
ax2.legend(loc=0)
你会得到这个:
但是如果你想要所有的标签都在一个图例上,那么你应该这样做:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')
time = np.arange(10)
temp = np.random.random(10)*30
Swdown = np.random.random(10)*100-10
Rn = np.random.random(10)*100-10
fig = plt.figure()
ax = fig.add_subplot(111)
lns1 = ax.plot(time, Swdown, '-', label = 'Swdown')
lns2 = ax.plot(time, Rn, '-', label = 'Rn')
ax2 = ax.twinx()
lns3 = ax2.plot(time, temp, '-r', label = 'temp')
# added these three lines
lns = lns1+lns2+lns3
labs = [l.get_label() for l in lns]
ax.legend(lns, labs, loc=0)
ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()
它会给你这个:
准备
import numpy as np
from matplotlib import pyplot as plt
fig, ax1 = plt.subplots( figsize=(15,6) )
Y1, Y2 = np.random.random((2,100))
ax2 = ax1.twinx()
内容
我很惊讶它没有显示到目前为止,但最简单的方法是手动收集它们到一个轴objs(躺在彼此的顶部)
l1 = ax1.plot( range(len(Y1)), Y1, label='Label 1' )
l2 = ax2.plot( range(len(Y2)), Y2, label='Label 2', color='orange' )
ax1.legend( handles=l1+l2 )
或者通过fig.legend()将它们自动收集到周围的图形中,并摆弄bbox_to_anchor参数:
ax1.plot( range(len(Y1)), Y1, label='Label 1' )
ax2.plot( range(len(Y2)), Y2, label='Label 2', color='orange' )
fig.legend( bbox_to_anchor=(.97, .97) )
终结
fig.tight_layout()
fig.savefig('stackoverflow.png', bbox_inches='tight')
这里有另一种方法:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('mathtext', default='regular')
fig = plt.figure()
ax = fig.add_subplot(111)
pl_1, = ax.plot(time, Swdown, '-')
label_1 = 'Swdown'
pl_2, = ax.plot(time, Rn, '-')
label_2 = 'Rn'
ax2 = ax.twinx()
pl_3, = ax2.plot(time, temp, '-r')
label_3 = 'temp'
ax.legend([pl[enter image description here][1]_1, pl_2, pl_3], [label_1, label_2, label_3], loc=0)
ax.grid()
ax.set_xlabel("Time (h)")
ax.set_ylabel(r"Radiation ($MJ\,m^{-2}\,d^{-1}$)")
ax2.set_ylabel(r"Temperature ($^\circ$C)")
ax2.set_ylim(0, 35)
ax.set_ylim(-20,100)
plt.show()
在这里输入图像描述