Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 50d7eb3

Browse files
Merge pull request #605 from alexpvpmindustry/master
API for adding labels: `mpf.make_addplot(..., label="myLabel")`
2 parents 46dcc89 + cbda0af commit 50d7eb3

File tree

4 files changed

+626
-9
lines changed

4 files changed

+626
-9
lines changed

‎examples/addplot_legends.ipynb

Lines changed: 576 additions & 0 deletions
Large diffs are not rendered by default.

‎src/mplfinance/_arg_validators.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import matplotlib as mpl
88
import warnings
99

10+
1011
def _check_and_prepare_data(data, config):
1112
'''
1213
Check and Prepare the data input:
@@ -94,6 +95,19 @@ def _check_and_prepare_data(data, config):
9495

9596
return dates, opens, highs, lows, closes, volumes
9697

98+
99+
def _label_validator(label_value):
100+
''' Validates the input of [legend] label for added plots.
101+
label_value may be a str or a sequence of str.
102+
'''
103+
if isinstance(label_value,str):
104+
return True
105+
if isinstance(label_value,(list,tuple,np.ndarray)):
106+
if all([isinstance(v,str) for v in label_value]):
107+
return True
108+
return False
109+
110+
97111
def _get_valid_plot_types(plottype=None):
98112

99113
_alias_types = {

‎src/mplfinance/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version_info = (0, 12, 9, 'beta', 9)
1+
version_info = (0, 12, 10, 'beta', 0)
22

33
_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}
44

‎src/mplfinance/plotting.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from mplfinance import _styles
3434

35-
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator
35+
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator, _label_validator
3636
from mplfinance._arg_validators import _get_valid_plot_types, _fill_between_validator
3737
from mplfinance._arg_validators import _process_kwargs, _validate_vkwargs_dict
3838
from mplfinance._arg_validators import _kwarg_not_implemented, _bypass_kwarg_validation
@@ -765,6 +765,8 @@ def plot( data, **kwargs ):
765765

766766
elif not _list_of_dict(addplot):
767767
raise TypeError('addplot must be `dict`, or `list of dict`, NOT '+str(type(addplot)))
768+
769+
contains_legend_label=[] # a list of axes that contains legend labels
768770

769771
for apdict in addplot:
770772

@@ -788,10 +790,28 @@ def plot( data, **kwargs ):
788790
else:
789791
havedf = False # must be a single series or array
790792
apdata = [apdata,] # make it iterable
793+
if havedf and apdict['label']:
794+
if not isinstance(apdict['label'],(list,tuple,np.ndarray)):
795+
nlabels = 1
796+
else:
797+
nlabels = len(apdict['label'])
798+
ncolumns = len(apdata.columns)
799+
#print('nlabels=',nlabels,'ncolumns=',ncolumns)
800+
if nlabels < ncolumns:
801+
warnings.warn('\n =======================================\n'+
802+
' addplot MISMATCH between data and labels:\n'+
803+
' have '+str(ncolumns)+' columns to plot \n'+
804+
' BUT '+str(nlabels)+' labels for them.\n')
805+
colcount = 0
791806
for column in apdata:
792807
ydata = apdata.loc[:,column] if havedf else column
793-
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config)
808+
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount)
794809
_addplot_apply_supplements(ax,apdict,xdates)
810+
colcount += 1
811+
if apdict['label']: # not supported for aptype == 'ohlc' or 'candle'
812+
contains_legend_label.append(ax)
813+
for ax in set(contains_legend_label): # there might be duplicates
814+
ax.legend()
795815

796816
# fill_between is NOT supported for external_axes_mode
797817
# (caller can easily call ax.fill_between() themselves).
@@ -1079,7 +1099,7 @@ def _addplot_collections(panid,panels,apdict,xdates,config):
10791099
ax.autoscale_view()
10801100
return ax
10811101

1082-
def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
1102+
def _addplot_columns(panid,panels,ydata,apdict,xdates,config,colcount):
10831103
external_axes_mode = apdict['ax'] is not None
10841104
if not external_axes_mode:
10851105
secondary_y = False
@@ -1101,6 +1121,10 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
11011121
ax = apdict['ax']
11021122

11031123
aptype = apdict['type']
1124+
if isinstance(apdict['label'],(list,tuple,np.ndarray)):
1125+
label = apdict['label'][colcount]
1126+
else: # isinstance(...,str)
1127+
label = apdict['label']
11041128
if aptype == 'scatter':
11051129
size = apdict['markersize']
11061130
mark = apdict['marker']
@@ -1111,27 +1135,27 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
11111135

11121136
if isinstance(mark,(list,tuple,np.ndarray)):
11131137
_mscatter(xdates, ydata, ax=ax, m=mark, s=size, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
1114-
else:
1115-
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
1138+
else:
1139+
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths,label=label)
11161140
elif aptype == 'bar':
11171141
width = 0.8 if apdict['width'] is None else apdict['width']
11181142
bottom = apdict['bottom']
11191143
color = apdict['color']
11201144
alpha = apdict['alpha']
1121-
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
1145+
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha,label=label)
11221146
elif aptype == 'line':
11231147
ls = apdict['linestyle']
11241148
color = apdict['color']
11251149
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
11261150
alpha = apdict['alpha']
1127-
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
1151+
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
11281152
elif aptype == 'step':
11291153
stepwhere = apdict['stepwhere']
11301154
ls = apdict['linestyle']
11311155
color = apdict['color']
11321156
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
11331157
alpha = apdict['alpha']
1134-
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha)
1158+
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
11351159
else:
11361160
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')
11371161

@@ -1384,6 +1408,9 @@ def _valid_addplot_kwargs():
13841408
'fill_between': { 'Default' : None, # added by Wen
13851409
'Description' : " fill region",
13861410
'Validator' : _fill_between_validator },
1411+
'label' : { 'Default' : None,
1412+
'Description' : 'Label for the added plot. One per added plot.',
1413+
'Validator' : _label_validator },
13871414

13881415
}
13891416

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /