How to set the size of a figure in matplotlib and seaborn

TL;DR

  • if you're using plot() on a pandas Series or Dataframe, use the figsize keyword
  • if you're using matplotlib directly, use matplotlib.pyplot.figure with the figsize keyword
  • if you're using a seaborn function that draws a single plot, use matplotlib.pyplot.figure with the figsize keyword
  • if you're using a seaborn function that draws multiple plots, use the height and aspect keyword arguments

Introduction

Setting figure sizes, like rotating axis tick labels, is one of those things that feels like it should be very straightforward. However, it still manages to show up on the first page of stackoverflow questions for both matplotlib and seaborn. Part of the confusion arises because there are so many ways to do the same thing - this highly upvoted question has six suggested solutions:

  • manually create an Axes object with the desired size
  • pass some configuration paramteters to seaborn so that the size you want is the default
  • call a method on the figure once it's been created
  • pass hight and aspect keywords to the seaborn plotting function
  • use the matplotlib.pyplot interface and call the figure() function
  • use the matplotlib.pyplot interface to get the current figure then set its size using a method

each of which will work in some circumstances but not others!

Drawing a figure using pandas

Let's jump in. As an example we'll use the olympic medal dataset, which we can load directly from a URL::

import pandas as pd
pd.options.display.max_rows = 10
pd.options.display.max_columns = 6
data = pd.read_csv("https://raw.githubusercontent.com/mojones/binders/master/olympics.csv", sep="\t")
data
City Year Sport ... Medal Country Int Olympic Committee code
0 Athens 1896 Aquatics ... Gold Hungary HUN
1 Athens 1896 Aquatics ... Silver Austria AUT
2 Athens 1896 Aquatics ... Bronze Greece GRE
3 Athens 1896 Aquatics ... Gold Greece GRE
4 Athens 1896 Aquatics ... Silver Greece GRE
... ... ... ... ... ... ... ...
29211 Beijing 2008 Wrestling ... Silver Germany GER
29212 Beijing 2008 Wrestling ... Bronze Lithuania LTU
29213 Beijing 2008 Wrestling ... Bronze Armenia ARM
29214 Beijing 2008 Wrestling ... Gold Cuba CUB
29215 Beijing 2008 Wrestling ... Silver Russia RUS

29216 rows × 12 columns

For our first figure, we'll count how many medals have been won in total by each country, then take the top thirty:

data['Country'].value_counts().head(30)
United States     4335
Soviet Union      2049
United Kingdom    1594
France            1314
Italy             1228
                  ... 
Spain              377
Switzerland        376
Brazil             372
Bulgaria           331
Czechoslovakia     329
Name: Country, Length: 30, dtype: int64

And turn it into a bar chart:

data['Country'].value_counts().head(30).plot(kind='barh')

png

Ignoring other asthetic aspects of the plot, it's obvious that we need to change the size - or rather the shape. Part of the confusion over sizes in plotting is that sometimes we need to just make the chart bigger or smaller, and sometimes we need to make it thinner or fatter. If we just scaled up this plot so that it was big enough to read the names on the vertical axis, then it would also be very wide. We can set the size by adding a figsize keyword argument to our pandas plot() function. The value has to be a tuple of sizes - it's actually the horizontal and vertical size in inches, but for most purposes we can think of them as arbirary units.

Here's what happens if we make the plot bigger, but keep the original shape:

import matplotlib.pyplot as plt
data['Country'].value_counts().head(30).plot(kind='barh', figsize=(20,10))

png

And here's a version that keeps the large vertical size but shrinks the chart horizontally so it doesn't take up so much space:

import matplotlib.pyplot as plt
data['Country'].value_counts().head(30).plot(kind='barh', figsize=(6,10))

png

Drawing a figure using matplotlib

OK, but what if we aren't using pandas' convenient plot() method but drawing the chart using matplotlib directly? Let's look at the number of medals awarded in each year:

plt.plot(data['Year'].value_counts().sort_index())

png

This time, we'll say that we want to make the plot longer in the horizontal direction, to better see the pattern over time. If we search the documentation for the matplotlib plot() funtion, we won't find any mention of size or shape. This actually makes sense in the design of matplotlib - plots don't really have a size, figures do. So to change it we have to call the figure() function:

plt.figure(figsize=(15,4))
plt.plot(data['Year'].value_counts().sort_index())

png

Notice that with the figure() function we have to call it before we make the call to plot(), otherwise it won't take effect:

plt.plot(data['Year'].value_counts().sort_index())

# no effect, the plot has already been drawn
plt.figure(figsize=(15,4))

png

Drawing a figure with seaborn

OK, now what if we're using seaborn rather than matplotlib? Well, happily the same technique will work. We know from our first plot which countries have won the most medals overall, but now let's look at how this varies by year. We'll create a summary table to show the number of medals per year for all countries that have won at least 500 medals total.

(ignore this panda stuff if it seems confusing, and just look at the final table)

summary = (
    data
    .groupby('Country')
    .filter(lambda x : len(x) > 500)
    .groupby(['Country', 'Year'])
    .size()
    .to_frame('medal count')
    .reset_index()
)

# wrap long country names
summary['Country'] = summary['Country'].str.replace(' ', '\n')
summary
Country Year medal count
0 Australia 1896 2
1 Australia 1900 5
2 Australia 1920 6
3 Australia 1924 10
4 Australia 1928 4
... ... ... ...
309 United\nStates 1992 224
310 United\nStates 1996 260
311 United\nStates 2000 248
312 United\nStates 2004 264
313 United\nStates 2008 315

314 rows × 3 columns

Now we can do a box plot to show the distribution of yearly medal totals for each country:

import seaborn as sns
sns.boxplot(
    data=summary,
    x='Country',
    y='medal count',
    color='red')

png

This is hard to read because of all the names, so let's space them out a bit:

plt.figure(figsize=(20,5))
sns.boxplot(
    data=summary,
    x='Country',
    y='medal count',
    color='red')

png

Now we come to the final complication; let's say we want to look at the distributions of the different medal types separately. We'll make a new summary table - again, ignore the pandas stuff if it's confusing, and just look at the final table:

summary_by_medal = (
    data
    .groupby('Country')
    .filter(lambda x : len(x) > 500)
    .groupby(['Country', 'Year', 'Medal'])
    .size()
    .to_frame('medal count')
    .reset_index()
)
summary_by_medal['Country'] = summary_by_medal['Country'].str.replace(' ', '\n')
summary_by_medal
Country Year Medal medal count
0 Australia 1896 Gold 2
1 Australia 1900 Bronze 3
2 Australia 1900 Gold 2
3 Australia 1920 Bronze 1
4 Australia 1920 Silver 5
... ... ... ... ...
881 United\nStates 2004 Gold 116
882 United\nStates 2004 Silver 75
883 United\nStates 2008 Bronze 81
884 United\nStates 2008 Gold 125
885 United\nStates 2008 Silver 109

886 rows × 4 columns

Now we will switch from boxplot() to the higher level catplot(), as this makes it easy to switch between different plot types. But notice that now our call to plt.figure() gets ignored:

plt.figure(figsize=(20,5))

sns.catplot(
    data=summary_by_medal,
    x='Country',
    y='medal count',
    hue='Medal',
    kind='box')

png

The reason for this is that the higher level plotting functions in seaborn (what the documentation calls Figure-level interfaces) have a different way of managing size, largely due to the fact that the often produce multiple subplots. To set the size when using catplot() or relplot() (also pairplot(), lmplot() and jointplot()), use the height keyword to control the size and the aspect keyword to control the shape:

sns.catplot(
    data=summary_by_medal,
    x='Country',
    y='medal count',
    hue='Medal',
    kind='box',
    height=5, # make the plot 5 units high
    aspect=3) # height should be three times width

png

Because we often end up drawing small multiples with catplot() and relplot(), being able to control the shape separately from the size is very convenient. The height and aspect keywords apply to each subplot separately, not to the figure as a whole. So if we put each medal on a separate row rather than using hue, we'll end up with three subplots, so we'll want to set the height to be smaller, but the aspect ratio to be bigger:

sns.catplot(
    data=summary_by_medal,
    x='Country',
    y='medal count',
    row='Medal',
    kind='box',
    height=3, 
    aspect=4,
    color='blue')

png

Printing a figure

Finally, a word about printing. If the reason that you need to change the size of a plot, rather than the shape, is because you need to print it, then don't worry about the size - get the shape that you want, then use savefig() to make the plot in SVG format:

plt.savefig('medals.svg')

This will give you a plot in Scalable Vector Graphics format, which stores the actual lines and shapes of the chart so that you can print it at any size - even a giant poster - and it will look sharp. As a nice bonus, you can also edit individual bits of the chart using a graphical SVG editor (Inkscape is free and powerful, though takes a bit of effort to learn).