Matplotlib: For creating static, animated, and interactive visualizations in Python

Dilip Kumar
14 min readJan 14, 2025

--

Key concepts

Pyplot: A collection of functions that make matplotlib work like MATLAB. It provides a simple interface for creating plots.

Figure: The overall container for a plot. It can contain multiple subplots.

Axes: Represents an individual plot within a figure. It contains the data, labels, ticks, and other graphical elements.

Subplot: A single plot within a figure that shares the same x- or y-axis with other subplots.

State-based approach (usting pyplot)

Following is state based approach to use matplotlib.

import matplotlib.pyplot as plt
x = [1, 2, 3, 4, 5]
y = [2, 4, 1, 5, 3]
plt.plot(x, y)
plt.xlabel("X-axis")
plt.ylabel("Y-axis")
plt.title("Simple Line Plot")
plt.show()

Pros:

  • Simpler for quick, one-off plots.
  • Mimics MATLAB’s style, familiar to some users.

Cons:

  • Can become confusing with multiple plots or complex layouts.
  • Less control and flexibility compared to the OO approach.
  • Potential for unexpected behavior if the “current state” is not managed carefully.

Object-Oriented (using fig and ax)

Following is object oriented approach to use matplotlib.

import matplotlib.pyplot as plt

x = [1, 2, 3, 4, 5]
y = [2, 4, 1, 5, 3]

fig, ax = plt.subplots() # Create a figure containing a single Axes
ax.plot(x, y)
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.set_title("Simple Plot")
plt.show()

Pros:

  • More control over plot appearance and behavior.
  • Easier to manage complex plots with multiple subplots.
  • Better suited for interactive plots or when integrating with other libraries.
  • More flexible and scalable for larger projects.

Cons:

  • Slightly more verbose than the state-based approach.
  • Might have a steeper initial learning curve.

Numpy to generate input data

NumPy is the fundamental package for numerical computation in Python. The most useful feature of the NumPy library is the multidimensional container data structure known as an Ndarray.

An Ndarray is a multidimensional array (also known as a container ) of items that have the same datatype and size.

We can define the size and datatype of the items at the time of creating the Ndarray. Just like other data structures such as lists, we can access the contents of an Ndarray with an index.

list = [1, 2, 3]
x = np.array(list, dtype=np.int16)

Here you are creating an Ndarray from a list. The datatype of the members is a 16-bit integer.

Indexing in Ndarrays

The indexing starts at 0. You can even use a negative index: -1 returns the last element, -2 returns the second last, and so on. The following is an example:

print(x[0]); print(x[1]); print(x[2]); print(x[-1]); 

Note: If you provide any invalid index, then it throws an error.

Indexing in Ndarrays of More Than One Dimension

You can have more than one dimensions for an array as follows:

x1 = np.array([[1, 2, 3], [4, 5, 6]], np.int16)

You can access individual elements as follows:

print(x1[0, 0]); print(x1[0, 1]); print(x1[0, 2]);

You can even access entire rows as follows:

print(x1[0, :])

You can access an entire column as follows:

print(x[:, 0])

You can even have an Ndarray with more than two dimensions. The following is a 3D array:

x2 = np.array([[[1, 2, 3], [4, 5, 6]],[[0, -1, -2], [-3, -4, -5]]], np.int16)
print(x2[0, 0, 0])
print(x2[1, 1, 2])
print(x2[:, 1, 1])

Ndarray Properties

number of dimensions

x2.ndim

shape of the Ndarray

x2.shape

datatype of the members

x2.dtype

Number of elements

x2.size

Bytes required

x2.nbytes

Compute transpose

x2.T

NumPy Constants

np.inf   # infinity
np.NAN # Not a Number
np.NINF #negative infinity
np.NZERO # Negative zero
np.PZERO # Positive zero
np.e # Euler’s number

Slicing Ndarrays

You can extract a part of an Ndarray with slicing using indices as follows:

a1 = np.array([1, 2, 3, 4, 5, 6, 7])
a1[1:5]

This code will display the elements from the second position to the sixth position (you know that the 0 is the starting index) as follows:

array([2, 3, 4, 5])

Note: a[i:j], here j is excluded.

By default it slices the data with a step size of 1. This means you are retrieving the continuous elements in the resultset. You can also change the step size as follows:

a1[1:6:2]

In this example, the size of the step is 2. So, the output will list every second (every other) element. The output is as follows :

array([2, 4, 6])

Ways for Creating Ndarrays

np.empty() returns a new array of a given shape and type, without initializing entries. As the entries corresponding to the members are not initialized, they are arbitrary (random)

x = np.empty([3, 3], np.uint8)

np.eye() returns a 2D matrix with 1s on the diagonal and 0s for other elements. The following is an example:

y = np.eye(3, dtype=np.uint8)
y
# It prints following
array([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]], dtype=uint8)

You can change the position of the index of the diagonal. The default is 0, which refers to the main diagonal. A positive value means an upper diagonal. A negative value means a lower diagonal. The following are examples. Let’s demonstrate the upper diagonal first:

y = np.eye(5, dtype=np.uint8, k=1)
# Output as below
[[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]
[0 0 0 0 0]]

The following is the code to demonstrate the lower diagonal:

y = np.eye(5, dtype=np.uint8, k=-1)
# Output as below
[[0 0 0 0 0]
[1 0 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]]

np.identity() returns an identity matrix (where all the elements at the diagonal are 1) of the specified size, as shown here:

x = np.identity(5, dtype= np.uint8)
# It generates as below
[[1 0 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]]

np.ones() returns the matrix of the given size that has all the elements as 1s.

x = np.ones((2, 3), dtype=np.int16)
# It generates as below
[[1 1 1]
[1 1 1]]

np.zeroes() returns a matrix of a given size with all the element as 0s.

x = np.zeroes((2, 3), dtype=np.int16)
# It generates as below
[[0 0 0 ]
[0 0 0]]

np.full() returns a new array of a given shape and type, filled with the passed argument.

x = np.full((2, 3), dtype=np.int16, fill_value = 5)
# It gnerates as below
[[5 5 5]
[5 5 5]]

np.tri() returns a lower triangular matrix of a given size, as shown here:

x = np.tri(3, 3, k=0, dtype=np.uint16)
# It generates as below
[[1 0 0]
[1 1 0]
[1 1 1]]

You can even change the position of the subdiagonal. All the elements below the subdiagonal will be 0.

x = np.tri(3, 3, k=1, dtype=np.uint16)
# It generates as below
[[1 1 0]
[1 1 1]
[1 1 1]]

Another example with a negative value for the subdiagonal is as follows:

x = np.tri(3, 3, k=-1, dtype=np.uint16)
# It generates as below
[[0 0 0]
[1 0 0]
[1 1 0]]

np.tril() to obtain a lower triangular matrix. It accepts another matrix as an argument. Here’s a demonstration:

x = np.ones((5, 5), dtype=np.uint8)
y = np.tril(x, k=-1)
print(y)

np.arange(start,end,interval) creates evenly spaced values with the given interval.

x = np.arange(6)
# It generates below
[0 1 2 3 4 5]
x = np.arange(2, 6)
# It generates below
[2 3 4 5]
x = np.arange(2, 6, 2)
# It generates below
[2 4]

linspace(start, stop, number_of_elements) returns an array of evenly spaced numbers over a specified interval. You must pass it the starting value, the end value, and the number of values as follows:

x = np.linspace(0, 15, 16)
# This code creates 16 numbers (0 to 15, both inclusive) as follows:
[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.]
x = np.linspace(0, 20, 4)
# This code creates 4 numbers (0 to 20, both inclusive) as follows:
[ 0. 6.66666667 13.33333333 20. ]

np.logspace creates an array of numbers that are evenly spaced on a log scale.

# Syntax
np.logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None)
x = np.logspace(0.1, 2, 4)
# It generates as below
[ 1.25892541 5.41169527 23.26305067 100. ]
N= 16
x = np.linspace(0, 15, N)
y = np.logspace(0.1, 2, N)
plt.plot(x,y, 'o--')
plt.show()

np.geomspace creates an array of numbers that are evenly spaced on a geometric scale.

# Syntax
np.geomspace(start, stop, num=50, endpoint=True, dtype=None)
x = np.geomspace(1, 10, 4)
# it prints following
[ 1. 2.15443469 4.64158883 10. ]
N= 16
x = np.linspace(0, 15, N)
y = np.geomspace(0.1, 2000, N)
plt.plot(x,y, 'o--')
plt.show()

Visualization

Single-line plots

When there is only one visualization in a figure that uses the function plot(), then it is known as a single-line plot.

import matplotlib.pyplot as plt
x = [4, 5, 3, 1, 6, 7]
plt.plot(x)
plt.show()

Note: In this case, the values of the y-axis are assumed.

Let’s visualize the quadratic graph y = f(x) = x3+1. The code is as follows:

x = np.arange(25)
y = x**3 + 1
plt.plot(x,y)
plt.show()

Multiline Plots

It is possible to visualize multiple plots in the same output. Let’s see how to show multiple curves in the same visualization. The following is a simple example:

x = np.arange(7)
plt.plot(x, -x**2,'o--')
plt.plot(x, -x**3, 'o--')
plt.plot(x, -2*x, 'o--')
plt.plot(x, -2**x, 'o--')
plt.show()

Note: Matplotlib automatically assigns colors to the curves separately.

We can also write the same code in a simple way as follows:

plt.plot(x, -x**2, x, -x**3,
x, -2*x, x, -2**x)
plt.show()

Let’s see another example:

x = np.array([[3, 2, 5, 6], [7, 4, 1, 5]])
plt.plot(x)
plt.show()

This plots four different lines because:

x is a 2D array:

  • x has the shape (2, 4), meaning it's a 2D array with two rows and four columns.
  • Each row represents a separate set of y-values.

plt.plot() interprets each row as a separate data series:

  • When you call plt.plot(x), Matplotlib interprets each row of the array as a separate set of y-values.
  • Since x has two rows, it will plot two lines.

Implicit x-axis:

  • When you only provide a single argument to plt.plot(), Matplotlib assumes you want to plot against the default x-axis, which is a sequence of integers starting from 0.

Colors

Color Names: You can specify colors using their names:

plt.plot(x, y, color='red') 
plt.plot(x, y, color='blue')
plt.plot(x, y, color='green')

Hex Codes: Use hexadecimal color codes for more precise control:

plt.plot(x, y, color='#FF0000')  # Red
plt.plot(x, y, color='#00FF00') # Green
plt.plot(x, y, color='#0000FF') # Blue

RGB/RGBA: Specify colors using Red-Green-Blue (and optionally Alpha for transparency) values:

plt.plot(x, y, color=(1, 0, 0))    # Red
plt.plot(x, y, color=(0, 1, 0)) # Green
plt.plot(x, y, color=(0, 0, 1)) # Blue
plt.plot(x, y, color=(0.5, 0.5, 0.5, 0.5)) # Gray with 50% transparency

Line Styles

  • Solid: '-' (default)
  • Dashed: '--'
  • Dotted: ':'
  • Dash-dot: '-.'
plt.plot(x, y, linestyle='-') 
plt.plot(x, y, linestyle='--')
plt.plot(x, y, linestyle=':')
plt.plot(x, y, linestyle='-.')

Markers

  • Circle: 'o'
  • Square: 's'
  • Triangle: '^'
  • Diamond: 'D'
  • Cross: 'x'
  • Plus: '+'
  • Dot: '.'
plt.plot(x, y, marker='o') 
plt.plot(x, y, marker='s')
plt.plot(x, y, marker='^')

Combining Colors, Styles, and Markers

You can combine these options to create various line appearances:

plt.plot(x, y, color='red', linestyle='--', marker='o')

Format String

A format string is a concise way to specify the color, linestyle, and marker in a single string.

Format: [color][marker][linestyle]

  • color: e.g., ‘r’ for red, ‘b’ for blue, ‘g’ for green, ‘k’ for black, etc.
  • marker: e.g., ‘o’ for circle, ‘s’ for square, ‘^’ for triangle, ‘x’ for cross, etc.
  • linestyle: e.g., ‘-’ for solid, ‘ — ‘ for dashed, ‘:’ for dotted, ‘-.’ for dashdot.
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

# Using format string
plt.plot(x, y, 'ro-') # Red circles connected by solid lines

plt.show()

Layouts

In Matplotlib, subplots are used to create multiple plots within a single figure. This is incredibly useful when you want to:

  • Compare different datasets
  • Visualize multiple aspects of the same data
  • Create complex figures with multiple panels

The primary way to create subplots is using the plt.subplots() function:

import matplotlib.pyplot as plt
# Create four axes in 2*2 format
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(8, 6), sharex=True)
  • fig: Represents the overall figure object.
  • axs: An array-like object containing the individual axes objects. You can access individual subplots using their indices within the axs array.
  • Number of Rows and Columns: Adjust nrows and ncols to change the grid dimensions.
  • Figure Size: Control the overall size of the figure using the figsize argument:
  • Sharing Axes: This ensures that all subplots in a row/column have the same axis limits.
  • Tight Layout: Automatically adjusts subplot parameters to minimize overlaps between subplots

Following is code for reference.

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.tan(x)
y4 = x

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(8, 6), sharex=True)
axs[0,0].plot(x, y1)
axs[0,0].set_title("Sine Wave")
axs[0,1].plot(x, y2)
axs[0,1].set_title("Cosine Wave")
axs[1,0].plot(x, y3)
axs[1,0].set_title("Tan Wave")
axs[1,1].plot(x, y4)
axs[1,1].set_title("Line")

plt.xlabel("X-axis")
plt.tight_layout()
plt.show()

Styles

Styles in Matplotlib control the overall appearance of your plots. They define things like:

  • Color palettes: The colors used for lines, markers, and other elements.
  • Line styles: The appearance of lines (solid, dashed, dotted, etc.).
  • Font sizes and styles: The appearance of text elements (labels, titles, etc.).
  • Gridlines: Whether and how gridlines are displayed.

Available Styles: Matplotlib comes with a variety of built-in styles. You can view them using:

print(plt.style.available)

This will print a list of available styles, such as:

['Solarize_Light2', '_classic_test_patch', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'bmh', 'dark_background', 'tableau-colorblind10']

Applying a Style: Use the plt.style.use() function to apply a style to your plots:

plt.style.use('ggplot')

Apply Style Before Creating Subplots

Apply the desired style using plt.style.use() before creating the figure and axes:

import matplotlib.pyplot as plt
import numpy as np

plt.style.use('ggplot') # Apply the 'ggplot' style

x = np.linspace(0, 10, 100)
y = np.sin(x)

fig, ax = plt.subplots()
ax.plot(x, y)
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.set_title("Sine Wave")

plt.show()

Apply Style to Individual Axes (Less Common)

While less common, you can also apply a style to individual axes using the set_prop_cycle() method, although this is generally not recommended for most use cases:

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 10, 100)
y = np.sin(x)

fig, ax = plt.subplots()

# Apply 'ggplot' style to this specific axis
ax.set_prop_cycle(plt.style.library['ggplot']['axes.prop_cycle'])

ax.plot(x, y)
ax.set_xlabel("X-axis")
ax.set_ylabel("Y-axis")
ax.set_title("Sine Wave")

plt.show()

Lines, Bars, and Scatter Plots

Line plot is used to visualize trends and relationships between two variables.

logarithmic scale on the x-axis

plt.semilogx() creates a plot with a logarithmic scale on the x-axis and a linear scale on the y-axis.

# Syntax plt.semilogx(x, y, ...)
x = np.logspace(0, 3, 50) # Generate x-values on a log scale
y = np.exp(-x)

plt.semilogx(x, y)
plt.show()

Key Points:

  • Logarithmic X-axis: This is ideal when the x-values span several orders of magnitude (e.g., from 1 to 1,000,000). Using a logarithmic scale compresses the x-axis, making it easier to visualize data with large variations in the x-values.
  • Linear Y-axis: The y-axis remains on a linear scale.

When to Use semilogx()

  • Analyzing data with widely varying x-values: For example, frequency spectra, time series with exponentially increasing or decreasing values.
  • Visualizing exponential or logarithmic relationships: The logarithmic scale can help to linearize such relationships, making trends easier to identify.

Note: Similarly, we have plt.semilogy() for y-axis and plt.loglog for both axis.

Error Bars

x = np.linspace (0, 2 * np.pi, 100)
y = np.sin(x)
ye = np.random.rand(len(x))/10
plt.errorbar(x, y, yerr = ye)
plt.show()

Similarly, you can show the error data on the x-axis.

xe = np.random.rand(len(x))/10
plt.errorbar(x, y, xerr = xe)
plt.show()

You can show errors on both axes as follows:

plt.errorbar(x, y, xerr = xe, yerr = ye)
plt.show()

Bar graphs

A bar graph is a representation of discrete and categorical data items with bars. You can represent the data with vertical or horizontal bars. The height or length of bars is always in proportion to the magnitude of the data. The following is a simple example of a bar graph:

x = np.arange(4)
y = np.random.rand(4)
plt.bar(x, y)
plt.show()

You can have a combined bar graph as follows:

y = np.random.rand(3, 4)
plt.bar(x + 0.00, y[0], color = 'b', width = 0.25)
plt.bar(x + 0.25, y[1], color = 'g', width = 0.25)
plt.bar(x + 0.50, y[2], color = 'r', width = 0.25)
plt.show()

Similarly, you can have horizontal bar graphs as follows:

x = np.arange(4)
y = np.random.rand(4)
plt.barh(x, y)
plt.show()

Scatter Plot

A scatter plot is a type of graph that uses dots to represent values for two different numeric variables. The position of each dot on the horizontal and vertical axis indicates values for an individual data point.

N = 1000
x = np.random.rand(N)
y = np.random.rand(N)
colors = np.random.rand(N)
size = (20)
plt.scatter(x, y, s=size, c=colors, alpha=1)
plt.show()

The size of the points is fixed in this example. You can also set the size per the place on the graph (which depends on the values of the x and y coordinates). Here is an example:

N = 1000
x = np.random.rand(N)
y = np.random.rand(N)
colors = np.random.rand(N)
size = (50 * x * y)
plt.scatter(x, y, s=size, c=colors, alpha=1)
plt.show()

Histograms, Contours, and Stream Plots

Histograms

A histogram is a graphical representation of the distribution of a numerical dataset. It visually summarizes the frequency of data points within specific intervals or “bins.”

x = [1, 3, 5, 1, 2, 4, 4, 2, 5, 4, 3, 1, 2]
n_bins = 5
plt.hist(x, bins=n_bins)
plt.show()

The histogram of one-dimensional data is a 2D figure. When you want to create a histogram of 2D data, you have to create a 3D figure with the data variables on the x- and y-axes and the histogram on the z-axis.

y = np.random.randn(n_points)
plt.hist2d(x, y, bins=50)
plt.show()

Contour

Contours represent the outline of an object. Contours are continuous (and closed, in many cases) lines highlighting the shape of objects.

In Matplotlib, contour plots are used to visualize 3D data on a 2D plane. They show lines of constant value, connecting points with the same “height” or “level” in the data.

Let’s draw a simple contour. We will create and visualize our own data by creating circular contour as follows:

x = np.arange(-3, 3, 0.005)
y = np.arange(-3, 3, 0.005)
X, Y = np.meshgrid(x, y)
Z = (X**2 + Y**2)
out = plt.contour(X, Y, Z)
plt.clabel(out, inline=True,
fontsize=10)
plt.show()

Visualizing Vectors with Stream Plots

Y, X = np.mgrid[-5:5:200j, -5:5:300j]
U = X**2 + Y**2
V = X + Y

plt.streamplot(X, Y, U, V)
plt.show()

Image and Audio Visualization

TBD

Reference

https://matplotlib.org/stable/tutorials/index.html

--

--

Dilip Kumar
Dilip Kumar

Written by Dilip Kumar

With 18+ years of experience as a software engineer. Enjoy teaching, writing, leading team. Last 4+ years, working at Google as a backend Software Engineer.

No responses yet