\$\begingroup\$
\$\endgroup\$
1
I try to plot the correlation matrix of a Pandas DataFrame. As the diagonal elements are always ones, and the matrix is symmetrical, so I can get rid of most than a half of the squares without loosing any useful information. I also moved the zero to the white color by default. Is there any other way to improve the function?
import seaborn as sns
import numpy as np
def plot_correlation_matrix(df, decimal=2, **kwargs):
kwargs["cmap"] = kwargs.get("cmap", "vlag")
kwargs["center"] = kwargs.get("center", 0) # 0 will be white if you use vlag cmap
kwargs["fmt"] = kwargs.get("fmt", f".{decimal}f")
corr = df.corr()
corr_reduced = corr.iloc[1:,:-1] # No need for first row, last column
mask = np.triu(np.ones_like(corr_reduced, dtype=bool), k=1) # No need for upper triangle
sns.heatmap(corr_reduced, annot=True, mask=mask, **kwargs)
# Using the function
df_ir = sns.load_dataset("iris")
df_ir.petal_length = - df_ir.petal_length # creating negative correlation
plot_correlation_matrix(df_ir)
asked Dec 22, 2022 at 19:20
-
\$\begingroup\$ Improve the looks of the plot or improve looks/effiency of the code? \$\endgroup\$noah1400– noah14002022年12月22日 22:58:21 +00:00Commented Dec 22, 2022 at 22:58
1 Answer 1
\$\begingroup\$
\$\endgroup\$
There are few things you can improve or add to your function:
- You can add a title to the plot using the title parameter in the heatmap function. This will make it easier to interpret the plot.
- Specify to range of values that the color map can have by using the
vmin
andvmax
parameters in theheatmap
function. - You can disable the color bar by using the
cbar
parameter in the heatmap function. If you are displaying a lot if matrices, you should save some space - using the
square
parameter to force the plot to be a square. If you have a big number of columns the plot would be very wide. - To improve the visual appeal you can use
linecolor
andlinewidth
to specify the color and width of the lines between the cells. - you can use the
xticklabels
andyticklabels
parameters to disable the x and y-axis. Again if you have a big number of columns, you may want to save space. - You can also use
annot_kws
to use additional formatting options e.g.fontsize
- Using the numpy function
triu_indices
to remove the need of the mask parameter in theheatmap
function
import seaborn as sns
import numpy as np
def plot_correlation_matrix(df, decimal=2, **kwargs):
kwargs["cmap"] = kwargs.get("cmap", "vlag")
kwargs["center"] = kwargs.get("center", 0) # 0 will be white if you use vlag cmap
kwargs["fmt"] = kwargs.get("fmt", f".{decimal}f")
kwargs["vmin"] = kwargs.get("vmin", -1) # specify the minimum value for the color map
kwargs["vmax"] = kwargs.get("vmax", 1) # specify the maximum value for the color map
kwargs["cbar"] = kwargs.get("cbar", False) # specify whether or not to display a color bar
kwargs["square"] = kwargs.get("square", True) # specify whether or not to force the plot to be a square
kwargs["linecolor"] = kwargs.get("linecolor", "gray") # specify the color of the lines separating the cells
kwargs["linewidth"] = kwargs.get("linewidth", 0.5) # specify the width of the lines separating the cells
kwargs["xticklabels"] = kwargs.get("xticklabels", True) # specify whether or not to display x-axis labels
kwargs["yticklabels"] = kwargs.get("yticklabels", True) # specify whether or not to display y-axis labels
kwargs["annot_kws"] = kwargs.get("annot_kws", {"fontsize": 8}) # specify additional formatting options for the annotation text
corr = df.corr()
corr_reduced = corr.iloc[1:,:-1] # No need for first row, last column
mask = np.triu_indices(corr_reduced.shape[0], k=1)
corr_reduced[mask] = 0
ax = sns.heatmap(corr_reduced, annot=True, **kwargs)
ax.set_title("Correlation Matrix") # add a title to the plot
# Using the function
df_ir = sns.load_dataset("iris")
df_ir.petal_length = - df_ir.petal_length # creating negative correlation
plot_correlation_matrix(df_ir)
AJNeufeld
35.2k5 gold badges41 silver badges103 bronze badges
answered Dec 22, 2022 at 23:14
lang-py