Interactive Coronavirus Map With Jupyter Notebook and Plotly

This article updates each day automatically with the latest data

In [1]:
from datetime import datetime, timezone
f"Last updated: {datetime.now(tz=timezone.utc):%d %B %Y %H:%M:%S %Z}"
Out[1]:
'Last updated: 11 May 2020 20:13:28 UTC'

Early on during what has since become the first global pandemic of my career, I started to comes across some really good looking charts plotting the spread of coronavirus from its origins in Wuhan, China, to almost every other country in the world. Having recently started using Jupyter Notebooks myself, it seemed like a good opportunity to increase my familiarity with Jupyter by seeing what I could do with the wealth of data that this pandemic has produced.

I have previously done most of my plotting with Matplotlib, but I have since stumbled across Plotly and I noticed that it seems to have very good support for map-based charts out of the box. As such, Plotly seemed to be the go-to library for the charts that I wanted to produce.

Detailed below is the process of taking the raw time-series data from the widely used John Hopkins repo, processing it and then using Plotly to graphically show the spread of worldwide spread of coronavirus over time.

You can find my original Github repo here.

(Word of warning: whilst the maps render nicely on larger screens, mobile users' mileage may vary.)

In [2]:
import re
from datetime import datetime

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from IPython.display import display
from plotly.subplots import make_subplots

pd.options.display.max_columns = 12
In [3]:
date_pattern = re.compile(r"\d{1,2}/\d{1,2}/\d{2}")
def reformat_dates(col_name: str) -> str:
    # for columns which are dates, I'd much rather they were in day/month/year format
    try:
        return date_pattern.sub(datetime.strptime(col_name, "%m/%d/%y").strftime("%d/%m/%Y"), col_name, count=1)
    except ValueError:
        return col_name
In [4]:
# this github repo contains timeseries data for all coronavirus cases: https://github.com/CSSEGISandData/COVID-19
confirmed_cases_url = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data" \
                      "/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv"
deaths_url = "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data" \
             "/csse_covid_19_time_series/time_series_covid19_deaths_global.csv"

Chart 1 - A Choropleth Chart

In [5]:
renamed_columns_map = {
    "Country/Region": "country",
    "Province/State": "location",
    "Lat": "latitude",
    "Long": "longitude"
}

cols_to_drop = ["location", "latitude", "longitude"]

confirmed_cases_df = (
    pd.read_csv(confirmed_cases_url)
    .rename(columns=renamed_columns_map)
    .rename(columns=reformat_dates)
    .drop(columns=cols_to_drop)
)
deaths_df = (
    pd.read_csv(deaths_url)
    .rename(columns=renamed_columns_map)
    .rename(columns=reformat_dates)
    .drop(columns=cols_to_drop)
)

display(confirmed_cases_df.head())
display(deaths_df.head())
country 22/01/2020 23/01/2020 24/01/2020 25/01/2020 26/01/2020 ... 05/05/2020 06/05/2020 07/05/2020 08/05/2020 09/05/2020 10/05/2020
0 Afghanistan 0 0 0 0 0 ... 3224 3392 3563 3778 4033 4402
1 Albania 0 0 0 0 0 ... 820 832 842 850 856 868
2 Algeria 0 0 0 0 0 ... 4838 4997 5182 5369 5558 5723
3 Andorra 0 0 0 0 0 ... 751 751 752 752 754 755
4 Angola 0 0 0 0 0 ... 36 36 36 43 43 45

5 rows × 111 columns

country 22/01/2020 23/01/2020 24/01/2020 25/01/2020 26/01/2020 ... 05/05/2020 06/05/2020 07/05/2020 08/05/2020 09/05/2020 10/05/2020
0 Afghanistan 0 0 0 0 0 ... 95 104 106 109 115 120
1 Albania 0 0 0 0 0 ... 31 31 31 31 31 31
2 Algeria 0 0 0 0 0 ... 470 476 483 488 494 502
3 Andorra 0 0 0 0 0 ... 46 46 47 47 48 48
4 Angola 0 0 0 0 0 ... 2 2 2 2 2 2

5 rows × 111 columns

In [6]:
# extract out just the relevant geographical data and join it to another .csv which has the country codes.
# The country codes are required for the plotting function to identify countries on the map
geo_data_df = confirmed_cases_df[["country"]].drop_duplicates()
country_codes_df = (
    pd.read_csv(
        "country_code_mapping.csv",
        usecols=["country", "alpha-3_code"],
        index_col="country")
)
geo_data_df = geo_data_df.join(country_codes_df, how="left", on="country").set_index("country")
In [7]:
# my .csv file of country codes and the COVID-19 data source disagree on the names of some countries. This 
# dataframe should be empty, otherwise it means I need to edit the country name in the .csv to match
geo_data_df[(pd.isnull(geo_data_df["alpha-3_code"])) & (~geo_data_df.index.isin(
    ["Diamond Princess", "MS Zaandam", "West Bank and Gaza"]
))]
Out[7]:
alpha-3_code
country
In [8]:
dates_list = (
    deaths_df.filter(regex=r"(\d{2}/\d{2}/\d{4})", axis=1)
    .columns
    .to_list()
)

# create a mapping of date -> dataframe, where each df holds the daily counts of cases and deaths per country
cases_by_date = {}
for date in dates_list:
    confirmed_cases_day_df = (
        confirmed_cases_df
        .filter(like=date, axis=1)
        .rename(columns=lambda col: "confirmed_cases")
    )
    deaths_day_df = deaths_df.filter(like=date, axis=1).rename(columns=lambda col: "deaths")
    cases_df = confirmed_cases_day_df.join(deaths_day_df).set_index(confirmed_cases_df["country"])

    date_df = (
        geo_data_df.join(cases_df)
        .groupby("country")
        .agg({"confirmed_cases": "sum", "deaths": "sum", "alpha-3_code": "first"})
    )
    date_df = date_df[date_df["confirmed_cases"] > 0].reset_index()
    
    cases_by_date[date] = date_df
    
# the dataframe for each day looks something like this:
cases_by_date[dates_list[-1]].head()
Out[8]:
country confirmed_cases deaths alpha-3_code
0 Afghanistan 4402 120 AFG
1 Albania 868 31 ALB
2 Algeria 5723 502 DZA
3 Andorra 755 48 AND
4 Angola 45 2 AGO
In [9]:
# helper function for when we produce the frames for the map animation
def frame_args(duration):
    return {
        "frame": {"duration": duration},
        "mode": "immediate",
        "fromcurrent": True,
        "transition": {"duration": duration, "easing": "linear"},
    }
In [10]:
fig = make_subplots(rows=2, cols=1, specs=[[{"type": "scattergeo"}], [{"type": "xy"}]], row_heights=[0.8, 0.2])

# set up the geo data, the slider, the play and pause buttons, and the title
fig.layout.geo = {"showcountries": True}
fig.layout.sliders = [{"active": 0, "steps": []}]
fig.layout.updatemenus = [
    {
        "type": "buttons",
        "buttons": [
            {
                "label": "▶",  # play symbol
                "method": "animate",
                "args": [None, frame_args(100)],
            },
            {
                "label": "◼",
                "method": "animate",  # stop symbol
                "args": [[None], frame_args(0)],
            },
        ],
        "showactive": False,
        "direction": "left",
    }
]
fig.layout.title = {"text": "Covid-19 Global Case Tracker", "x": 0.5}
In [11]:
frames = []
steps = []
# set up colourbar tick values, ranging from 1 to the highest num. of confirmed cases for any country thus far
max_country_confirmed_cases = cases_by_date[dates_list[-1]]["confirmed_cases"].max()

# to account for the significant variance in number of cases, we want the scale to be logarithmic...
high_tick = np.log1p(max_country_confirmed_cases)
low_tick = np.log1p(1)
log_tick_values = np.geomspace(low_tick, high_tick, num=6)

# ...however, we want the /labels/ on the scale to be the actual number of cases (i.e. not log(n_cases))
visual_tick_values = np.expm1(log_tick_values).astype(int)
# explicitly set max cbar value, otherwise it might be max - 1 due to a rounding error
visual_tick_values[-1] = max_country_confirmed_cases  
visual_tick_values = [f"{val:,}" for val in visual_tick_values]

# generate line chart data
# list of tuples: [(confirmed_cases, deaths), ...]
cases_deaths_totals = [(df.filter(like="confirmed_cases").astype("uint32").agg("sum")[0], 
                        df.filter(like="deaths").astype("uint32").agg("sum")[0]) 
                          for df in cases_by_date.values()]

confirmed_cases_totals = [daily_total[0] for daily_total in cases_deaths_totals]
deaths_totals =[daily_total[1] for daily_total in cases_deaths_totals]


# this loop generates the data for each frame
for i, (date, data) in enumerate(cases_by_date.items(), start=1):
    df = data

    # the z-scale (for calculating the colour for each country) needs to be logarithmic
    df["confirmed_cases_log"] = np.log1p(df["confirmed_cases"])

    df["text"] = (
        date
        + "<br>"
        + df["country"]
        + "<br>Confirmed cases: "
        + df["confirmed_cases"].apply(lambda x: "{:,}".format(x))
        + "<br>Deaths: "
        + df["deaths"].apply(lambda x: "{:,}".format(x))
    )

    # create the choropleth chart
    choro_trace = go.Choropleth(
        **{
            "locations": df["alpha-3_code"],
            "z": df["confirmed_cases_log"],
            "zmax": high_tick,
            "zmin": low_tick,
            "colorscale": "reds",
            "colorbar": {
                "ticks": "outside",
                "ticktext": visual_tick_values,
                "tickmode": "array",
                "tickvals": log_tick_values,
                "title": {"text": "<b>Confirmed Cases</b>"},
                "len": 0.8,
                "y": 1,
                "yanchor": "top"
            },
            "hovertemplate": df["text"],
            "name": "",
            "showlegend": False
        }
    )
    
    # create the confirmed cases trace
    confirmed_cases_trace = go.Scatter(
        x=dates_list,
        y=confirmed_cases_totals[:i],
        mode="markers" if i == 1 else "lines",
        name="Total Confirmed Cases",
        line={"color": "Red"},
        hovertemplate="%{x}<br>Total confirmed cases: %{y:,}<extra></extra>"
    )
        
    # create the deaths trace
    deaths_trace = go.Scatter(
        x=dates_list,
        y=deaths_totals[:i],
        mode="markers" if i == 1 else "lines",
        name="Total Deaths",
        line={"color": "Black"},
        hovertemplate="%{x}<br>Total deaths: %{y:,}<extra></extra>"
    )

    if i == 1:
        # the first frame is what the figure initially shows...
        fig.add_trace(choro_trace, row=1, col=1)
        fig.add_traces([confirmed_cases_trace, deaths_trace], rows=[2, 2], cols=[1, 1])
    # ...and all the other frames are appended to the `frames` list and slider
    frames.append({"data": [choro_trace, confirmed_cases_trace, deaths_trace], "name": date})

    steps.append(
        {"args": [[date], frame_args(50)], "label": date, "method": "animate",}
    )

# tidy up the axes and finalise the chart ready for display
fig.update_xaxes(range=[0, len(dates_list)-1], visible=False)
fig.update_yaxes(range=[0, max(confirmed_cases_totals)])
fig.frames = frames
fig.layout.sliders[0].steps = steps
fig.layout.geo.domain = {"x": [0,1], "y": [0.2, 1]}
fig.update_layout(
    height=650, 
    legend={"x": 0.05, "y": 0.175, "yanchor": "top", "bgcolor": "rgba(0, 0, 0, 0)"})
fig