Learning to use matplotlib

Back in the early 1990s I occasionally used MATLAB™ as part of my work on image processing. More recently I've used hand-crafted SVG images, mostly to represent geometry but sometimes to plot graphs of data. I'd also heard of matplotlib, with capabilities apparently comparable to those of MATLAB™, plus the ability to emit results to SVG (which did not exist in the early 1990s; for all I know, MATLAB™ may also support it now – but its website was, when I checked (Easter 2024), too busy trying to sell me it with buzzwords to make it easy to find out).

Then in the spring of 2024 I was working out how to parameterise a particular representation of the two-torus, T(2) and, on finding what I thought might make a viable approach, wanted to actually visualise what my formulæ produced. So I finally decided to give matplotlib a go.

Rather than following the package's own pip-based installation instructions, I'd just installed the Debian package for it. I went to the web-site and found an example to try out, but it did not work out of the box – failing with an import error for tkinter until I installed the packages Debian merely Recommends.

Then, trying out an example that had me call matplotlib.pyplot.subplots(subplot_kw={"projection": "3d"}), I hit an exception because the 3d projection type wasn't known. Which probably means that's something some other package is meant to do, but it's not installed so didn't. By luck I stumbled on another example that included the magic incantation from mpl_toolkits.mplot3d import axes3d that was needed to make this work. I still got an exception when I tried to matplotlib.plot.style.use("_mpl-gallery") but I was fairly sure that wouldn't be a huge problem.

So then I was finally able to do something that looked like it might have some chance of showing me my data. (Note: see below – there are mistakes here.)

>>> v = np.linspace(−np.pi/2, np.pi/2, 101)
>>> u = np.linspace(−np.pi/2, np.pi/2, 101)
>>> u, v = np.meshgrid(u, v)
>>> x = np.tan(u/2)/(1 +np.tan(u/2)**2 * (1 −np.cos(v))/2)
>>> y = np.tan(v/2)/(1 +np.tan(v/2)**2 * (1 −np.cos(u))/2)
>>> z = np.cos(v) −np.cos(u)
>>> fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
>>> ax.plot_surface(x, y, z)
<mpl_toolkits.mplot3d.art3d.Poly3DCollection object at 0x7f7162275630>
>>> plt.show()

Unfortunately the results did not look like what I was expecting. After a bit of playing around and visualising some slices through the data, I concluded that I needed to learn more and went in search of tutorials. It also occurred to me to actually document my experience, so here's this page.

So I started on the quick start guide. After showing one trivial example, it gave an overview of the various high-level features, described at a glossed-over level of informativeness. That prepared the way for an introduction to the different ways of using the system – a low-leel one and the matplotlib.plotlib convenience API.

Around this point I noticed I shouldn't have halved the π ends of the ranges above, so I fixed that and it showed the axes and no sign of my data, presumably thanks to the axes having multiples of 1e16 as their visible tick-mark values – because, indeed, x and y values that big do now appear in the data. That prompted me to try limiting the range of u and v to ±7π/8, which at least gave me something I could see – which was manifestly wrong. So I studied my formulæ a bit more closely and noticed a couple more errors: an unimportant factor of two on the x and y co-ordinates and a crucial error in the signs of the cosines. So I fixed those and tried again.

>>> x = 2 * np.tan(u/2)/(1 +np.tan(u/2)**2 * (1 +np.cos(v))/2)
>>> y = 2 * np.tan(v/2)/(1 +np.tan(v/2)**2 * (1 +np.cos(u))/2)

A pair of diagonal lines showed up in the z = 0 plane. That's sort of expected. The scale still reached out to the O(1e16) maximal values of x and y, so all the fine detail near the origin is invisible. torus minus two strips So I limited the range of u and v to ±7.π/8 again, producing results that look, at least somewhat (in matplotlib's viewer – its axes seems to have lost in saving to SVG), like they might be sensible if they weren't missing the neighbourhoods of ±π in {u, v} (skipped so as to avoid the infinite neighbourhood of u = π = v). Maybe all I need to do is to tell ax to only show the region where x, y are less than about ten in magnitude.

I'm not entirely convinced my formulæ are doing what I intended – if I could work out how to view the scene from various directions, that would give me a better understanding. Then my attempt to use the UI to zoom in (to a rectangle; I've no clue how to select it so I try click-and-drag, imagining this selects corners) leads me to discover I can rotate it to different views – yay ! I can even zoom in, holding the the right mouse button down and moving. Unfortunately, when I try that with the u and v ranges restored to ±π, I discover that z is scaled by the same factor as x and y, but the graph has been plotted at a scale where 1e16 is easy to see, so the thickness of its lines is much bigger than the range of z values, making the whole exercise futile. So now I just need to work out how to limit the region displayed.

So I resume the quick-start guild and rapidly realise it won't do me any good; nor does the Matplotlib tutorial. Everyone believes the area to be plotted can just be the range of values, no problem, without considering cases where all the displayed variables are computed from some non-displayed ones and, to see all of the interesting region of the displayed data, one needs the range of underlying data to be wide enough to deliver parts of the derived data that we don't want to see. I find the documentation of the Axes type, which offers me a plethora of ways to pretify the results but no clue how to control the range plotted. But at least I've managed to get far enough to know what question I need to answer.

So, time for the reference manual. It turns out I can control autoscaling of the axes, which I need to turn off, but it's not clear how I can go about actually setting the ranges of the axes. I resort to introspection in my python session and ask dir(ax) what it has to offer; the word bound catches my eye and finally I know what to search for in the reference manual. Sure enough, ax.set_xbound(−10, 10) and equivalent for y, improve the situation; that makes things actually look vaguely like I expect. However, it hasn't clipped the data, only the axes, which makes the UI rather unwieldy because, if I'm looking at it all anywhere but horizontally, the large expanse of asymptoting the z = 0 plane away from the origin dominates what I get to see. It would seem I need to clip the data, rather than the axes (they can go back to autoscaling if the data's clipped). So I need to look at the numpy docs, which are enormous and impenetrable (for someone who doesn't already know all about how numpy works).

I manage to discover there is such a thing as a MaskedArray which sounds like it might just do the trick, and I'm able to construct one that represents the part of my [u, v] space I want to mask out. I eventually found how to apply one to an array, xm = ma.array(x, mask=mask.mask) worked; but applying similar to y and z then passing the resulting masked arrays to ax.plot_surface() just ignored the masking and gave me the same unhelpful result as using x, y and z directly. I could use MaskedArray.filled() to set the masked out values to zero, but this just lead to mesh connecting the origin to the mesh-adjacent points of the surface, messing up the diagram with a filled-in z = 0 plane.

Then I found ma.clip() and tried using it to at least map the points outside the desired range onto its boundary. That produces a lip at the edge of the shape, where points outside the clip-box have different z-values from those at its boundary, but it's less bad than what I've managed before. It's barely noticeable when clipping to within ten of the origin, or even four, but really messes up the view when clipped to within two of the origin. In particular, the latter should have made it possible to see through the two holes in the middle of the shape, but the lips block out that view.

So it looks like masked arrays don't really help, although they come with a tool that lets me construct a less-unhelpful array to display. I guess I need to actively delete the parts of the [u, v] grid that I need to avoid, rather than trying to mask them out or clip them to the boundary.

Unfortunately, the amount I need to read in order to do something simple appears to be enormous, which inclines me towards giving up. I have only finite time for these explorations.

Still, combining fragments learned above, with x, y, z computed correctly now, I managed to display them vaguely sensibly with:

>>> fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
>>> ax.autoscale(False)
>>> ax.plot_surface(ma.clip(x, −10, 10), ma.clip(y, −10, 10), z, vmin=−3, cmap=cm.Blues)
<mpl_toolkits.mplot3d.art3d.Poly3DCollection object at 0x7f7161e9bd30>
>>> ax.set_zbound(−10, 10)
>>> ax.set_xbound(−10, 10)
>>> ax.set_ybound(−10, 10)
>>> plt.show()

producing the following images (but the colour variation was smoother in the original display) – click to view at full size:

Finally, here's the full code for a dumb script to reproduce that (using my rearranged form of the formulæ):

import numpy as np
from matplotlib import pyplot, cm
# Side-effect - causes pyplot to know about 3d projection:
from mpl_toolkits.mplot3d import axes3d

def angles(steps, count = 2):
    while count:
        count −= 1
        yield np.linspace(−np.pi, np.pi, steps)

def chart(u, v, clip):
    den = 1 −(1 −np.cos(u)) * (1 −np.cos(v)) / 4
    x, y = np.sin(u) / den, np.sin(v) / den
    z = np.cos(v) −np.cos(u)
    return np.ma.clip(x, −clip, clip), np.ma.clip(y, −clip, clip), z

def display(u, v, clip, **kw):
    fig, ax = pyplot.subplots(subplot_kw={"projection": "3d"})
    ax.autoscale(False)
    ax.plot_surface(*chart(*np.meshgrid(u, v), clip), **kw)
    ax.set_xbound(−clip, clip)
    ax.set_ybound(−clip, clip)
    ax.set_zbound(−clip, clip)
    pyplot.show()

if __name__ == '__main__':
    import sys
    def params(*defaults):
        for i, v in enumerate(defaults, 1):
            try:
                yield int(sys.argv[i])
            except IndexError:
                yield v
    steps, clip, top, bot = params(21, 6, 6, 9)
    display(*angles(steps), clip, vmin=−bot, vmax=top, cmap=cm.Blues)

Valid CSSValid HTML 5 Written by Eddy.