Introduction to the Bass Diffusion Model#

What is the Bass Model?#

The Bass diffusion model, developed by Frank Bass in 1969, is a mathematical model that describes how new products get adopted in a population over time. It’s widely used in marketing to forecast sales of new products, especially when historical data is limited or non-existent.

The model captures the entire lifecycle of product adoption, from introduction to saturation, making it a powerful tool for product planning and marketing strategy development.

The Motivation Behind the Bass Model#

Before the Bass model, companies struggled to predict the adoption patterns of new products. Traditional forecasting methods often failed because they couldn’t account for the social dynamics that drive product adoption. Frank Bass recognized that product adoption follows a distinct pattern:

  • Initial slow growth: When a product first launches, adoption starts slowly

  • Rapid acceleration: As more people adopt, word-of-mouth spreads and adoption accelerates

  • Eventual saturation: Eventually, the market becomes saturated and adoption slows down

The Bass model provides a mathematical framework to capture these patterns, enabling businesses to make more informed decisions about production planning, inventory management, and marketing resource allocation.

Mathematical Formulation#

The Bass model is based on a differential equation that describes the rate of adoption:

\[\frac{f(t)}{1-F(t)} = p + q F(t)\]

Where:

  • \(F(t)\) is the installed base fraction (cumulative proportion of adopters)

  • \(f(t)\) is the rate of change of the installed base fraction (\(f(t) = F'(t)\))

  • \(p\) is the coefficient of innovation or external influence

  • \(q\) is the coefficient of imitation or internal influence

The solution to this equation gives the adoption curve:

\[F(t) = \frac{1 - e^{-(p+q)t}}{1 + (\frac{q}{p})e^{-(p+q)t}}\]

The adoption rate at time \(t\) is given by:

\[f(t) = (p + q F(t))(1 - F(t))\]

Alternatively, this can be written as:

\[f(t) = \frac{(p+q)^2 \cdot e^{-(p+q)t}}{p \cdot (1+\frac{q}{p}e^{-(p+q)t})^2}\]

Key Components of the Bass Model Implementation#

The Bass model implementation in PyMC-Marketing consists of several key components:

  1. Adopters - The number of new adoptions at time \(t\):

\[\text{adopters}(t) = m \cdot f(p, q, t)\]
  1. Innovators - Adoptions driven by external influence (advertising, etc.):

\[\text{innovators}(t) = m \cdot p \cdot (1 - F(p, q, t))\]
  1. Imitators - Adoptions driven by internal influence (word-of-mouth):

\[\text{imitators}(t) = m \cdot q \cdot F(p, q, t) \cdot (1 - F(p, q, t))\]
  1. Peak Adoption Time - When the adoption rate reaches its maximum:

\[\text{peak} = \frac{\ln(q) - \ln(p)}{p + q}\]

The total number of adopters over time is the sum of innovators and imitators, which equals \(\text{adopters}(t)\). All of these components are directly implemented in the PyMC model, allowing us to analyze each aspect of the diffusion process separately.

Understanding the Relationship Between Components#

A key insight of the Bass model is how it decomposes adoption into two sources:

\[\text{adopters}(t) = \text{innovators}(t) + \text{imitators}(t)\]

At each time point:

  • Innovators (\(m \cdot p \cdot (1 - F(t))\)) represents new adoptions coming from people who are influenced by external factors like advertising

  • Imitators (\(m \cdot q \cdot F(t) \cdot (1 - F(t))\)) represents new adoptions coming from people who are influenced by previous adopters

As time progresses:

  • Initially, innovators dominate the adoption process when few people have adopted (\(F(t)\) is small)

  • Later, imitators become the primary source of new adoptions as the word-of-mouth effect grows

  • Eventually, both decrease as the market approaches saturation (\(F(t)\) approaches 1)

The cumulative adoption at any time point is:

\[\text{Cumulative Adoption}(t) = m \cdot F(t)\]

This means that as \(t \to \infty\), the cumulative adoption approaches the total market potential \(m\):

\[\lim_{t \to \infty} \text{Cumulative Adoption}(t) = m\]

Therefore, the Bass model provides a complete accounting of the market:

  • At each time point, new adopters are either innovators or imitators

  • Over the entire product lifecycle, all potential adopters (m) eventually adopt the product

  • The model tracks both the adoption rate (new adopters per time period) and the cumulative adoption (total adopters to date)

This structure enables marketers to understand not just how many people will adopt over time, but also the driving forces behind adoption at different stages of the product lifecycle.

Understanding the Key Parameters#

The model has three main parameters:

  • Market potential (m): Total number of eventual adopters (the ultimate market size)

  • Innovation coefficient (p): Measures external influence like advertising and media - typically \(0.01-0.03\)

  • Imitation coefficient (q): Measures internal influence like word-of-mouth - typically \(0.3-0.5\)

Parameter Interpretation#

  • A higher p value indicates stronger external influence (advertising, marketing)

  • A higher q value indicates stronger internal influence (word-of-mouth, social interactions)

  • The ratio q/p indicates the relative strength of internal vs. external influences

  • The peak of adoption occurs at time

\[t^* = \frac{\ln(q / p)}{p + q}\]

Innovators vs. Imitators#

The Bass model distinguishes between two types of adopters:

  1. Innovators: People who adopt independently of others’ decisions, influenced mainly by mass media and external communications

    • Mathematically represented as: \(\text{innovators}(t) = m \cdot p \cdot (1 - F(p, q, t))\)

  2. Imitators: People who adopt because of social influence and word-of-mouth from previous adopters

    • Mathematically represented as: \(\text{imitators}(t) = m \cdot q \cdot F(p, q, t) \cdot (1 - F(p, q, t))\)

Real-World Applications#

The Bass model has been successfully applied to forecast the adoption of various products and technologies:

  • Consumer durables: TVs, refrigerators, washing machines

  • Technology products: Smartphones, computers, software

  • Pharmaceutical products: New drugs and treatments

  • Entertainment products: Movies, games, streaming services

  • Services and subscriptions: Banking services, subscription plans

Business Value: Why the Bass Model Matters to Executives and Marketers#

From a business perspective, the Bass diffusion model provides substantial competitive advantages and ROI benefits:

1. Resource Optimization and Cash Flow Management#

  • Production Planning: Avoid costly overproduction or stockouts by accurately forecasting demand curves

  • Marketing Budget Allocation: Optimize spending across the product lifecycle, investing more during key inflection points

  • Supply Chain Efficiency: Coordinate with suppliers and distributors based on predicted adoption rates

  • Cash Flow Optimization: Better predict revenue streams, improving financial planning and investor relations

2. Strategic Decision Making#

  • Launch Timing: Determine the optimal time to enter a market based on diffusion patterns

  • Pricing Strategy: Implement dynamic pricing strategies aligned with the adoption curve

  • Competitive Analysis: Compare your product’s adoption parameters with competitors to identify strengths and weaknesses

  • Product Portfolio Management: Make informed decisions about when to phase out older products and introduce new ones

3. Risk Mitigation#

  • Scenario Planning: Test different market assumptions and external factors through parameter variations

  • Early Warning System: Identify deviations from expected adoption curves early, enabling faster intervention

  • Investment Justification: Provide data-driven forecasts to justify R&D and marketing investments to stakeholders

4. Performance Measurement#

  • Marketing Effectiveness: Measure the impact of marketing campaigns on the innovation coefficient (p)

  • Word-of-Mouth Strength: Quantify the power of your brand’s social influence through the imitation coefficient (q)

  • Total Market Potential: Validate or adjust your total addressable market estimates (m)

In today’s data-driven business environment, companies that effectively utilize models like Bass diffusion gain a significant competitive edge through more precise forecasting, better resource allocation, and strategic market timing.

Bayesian Extensions#

In this notebook, we show how to generate simulated data from the Bass model and fit a Bayesian model to it. The Bayesian formulation offers several advantages:

  • Uncertainty quantification through prior distributions on parameters

  • Hierarchical modeling for multiple products or markets

  • Incorporation of expert knowledge through informative priors

  • Full probability distributions for future adoption forecasts

What we’ll do in this notebook#

In this notebook, we’ll:

  1. Set up parameters for a Bass model simulation

  2. Generate simulated adoption data for multiple products

  3. Fit the Bass model to our simulated data with the BassModel class

  4. Visualize the adoption curves with the built-in plotting methods

  5. Forecast adoption beyond the observed window

  6. Save and load the fitted model

  7. Track the workflow with MLflow

Prepare Notebook#

from typing import Any

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import xarray as xr
from pymc_extras.prior import Prior, Scaled

from pymc_marketing.bass import BassModel, create_bass_model
from pymc_marketing.plot import plot_curve

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
seed: int = sum(map(ord, "bass"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

Setting Up Simulation Parameters#

First, we’ll set up the parameters for our simulation. This includes:

  • The time period for our simulation (in weeks)

  • The number of products to simulate

  • Start dates for the simulation period

def setup_simulation_parameters(
    n_weeks: int = 52,
    n_products: int = 9,
    start_date: str = "2023-01-01",
    cutoff_start_date: str = "2023-12-01",
) -> tuple[
    npt.NDArray[np.int_],
    pd.DatetimeIndex,
    pd.DatetimeIndex,
    list[str],
    pd.Series,
    dict[str, Any],
]:
    """Set up initial parameters for the Bass diffusion model simulation.

    Parameters
    ----------
    n_weeks : int
        Number of weeks to simulate
    n_products : int
        Number of products to include in the simulation
    start_date : str
        Starting date for the simulation period
    cutoff_start_date : str
        Latest possible start date for products

    Returns
    -------
    T : numpy.ndarray
        Time array (weeks)
    possible_dates : pandas.DatetimeIndex
        All dates in the simulation period
    possible_start_dates : pandas.DatetimeIndex
        Possible start dates for products
    products : list
        List of product names
    product_start : pandas.Series
        Start date for each product
    coords : dict
        Coordinates for PyMC model
    """
    # Set a seed for reproducibility
    seed = sum(map(ord, "bass"))
    rng = np.random.default_rng(seed)

    # Create time array and date range
    T = np.arange(n_weeks)
    possible_dates = pd.date_range(start_date, freq="W-MON", periods=n_weeks)
    cutoff_start_date = pd.to_datetime(cutoff_start_date)
    cutoff_start_date = cutoff_start_date + pd.DateOffset(weeks=1)
    possible_start_dates = possible_dates[possible_dates < cutoff_start_date]

    # Generate product names and random start dates
    products = [f"P{i}" for i in range(n_products)]
    product_start = pd.Series(
        rng.choice(possible_start_dates, size=len(products)),
        index=pd.Index(products, name="product"),
    )

    coords = {"T": T, "product": products}
    return T, possible_dates, possible_start_dates, products, product_start, coords

Creating Prior Distributions#

For our Bayesian Bass model, we need to specify prior distributions for the key parameters:

  • m (market potential): How many units can potentially be sold in total

  • p (innovation coefficient): Rate of adoption from external influences

  • q (imitation coefficient): Rate of adoption from internal/social influences

  • likelihood: The probability distribution that models the observed adoption data

For the market potential m we use a scaling trick to specify a scale-free prior and then add a global factor:

def create_bass_priors(factor: float) -> dict[str, Prior | Scaled]:
    """Define prior distributions for the Bass model parameters.

    Returns
    -------
    dict
        Dictionary of prior distributions for m, p, q, and likelihood

    Notes
    -----
    - m: Market potential (scaled Gamma distribution)
    - p: Innovation coefficient (Beta distribution)
    - q: Imitation coefficient (Beta distribution)
    - likelihood: Observation model (Negative Binomial)
    """
    return {
        # We use a scaled Gamma distribution for the market potential.
        "m": Scaled(Prior("Gamma", mu=1, sigma=0.1, dims="product"), factor=factor),
        "p": Prior("Beta", mu=0.03, dims="product").constrain(lower=0.01, upper=0.03),
        "q": Prior("Beta", dims="product").constrain(lower=0.3, upper=0.5),
        "likelihood": Prior("NegativeBinomial", n=1.5, dims="product"),
    }

Let’s generate and visualize the priors.

FACTOR = 50_000
priors = create_bass_priors(factor=FACTOR)
/home/pablo/micromamba/envs/pymc-marketing-dev/lib/python3.12/site-packages/pymc_extras/prior.py:1205: UserWarning: 
The requested mass is 0.95,
but the computed one is 0.528
  new_parameters = maxent(self.preliz, lower, upper, mass, **kwargs).params_dict
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(15, 12))

priors["p"].preliz.plot_pdf(ax=ax[0])
ax[0].set(title="Innovation Coefficient (p)")
priors["q"].preliz.plot_pdf(ax=ax[1])
ax[1].set(title="Imitation Coefficient (q)")
fig.suptitle(
    "Prior Distributions for Bass Model Parameters",
    fontsize=18,
    fontweight="bold",
    y=0.95,
);
/home/pablo/micromamba/envs/pymc-marketing-dev/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:214: UserWarning: Numba will use object mode to run XlogY0's perform method. Set `pytensor.config.compiler_verbose = True` to see more details.
  warnings.warn(
../../_images/56a89b2c5c700f6e485a88582fb061f479175f7d91df1f2133b55050b06f8717.png

Observe we have chosen the priors within the usual ranges of empirical studies:

  • Innovation coefficient (p): Measures external influence like advertising and media - typically \(0.01-0.03\)

  • Imitation coefficient (q): Measures internal influence like word-of-mouth - typically \(0.3-0.5\)

Generate Synthetic Data#

With the generative Bass model, we can generate a synthetic dataset by sampling from the prior and choosing one particular sample to use as observed data. For this purpose we define two auxiliary functions.

Here we use the lower-level create_bass_model() function, which returns a raw pm.Model. It is the right tool when you need direct access to the model object, for example to sample from the prior. For fitting we will use the higher-level BassModel class below.

def sample_prior_bass_data(model: pm.Model) -> xr.DataArray:
    """Generate a sample from the prior predictive distribution of the Bass model.

    Parameters
    ----------
    model : pymc.Model
        The PyMC model to sample from

    Returns
    -------
    xarray.DataArray
        Simulated adoption data
    """
    with model:
        idata = pm.sample_prior_predictive(random_seed=rng)
    return idata["prior"]["y"].sel(chain=0, draw=0)


def transform_to_actual_dates(bass_data, product_start, possible_dates) -> pd.DataFrame:
    """Transform simulation data from time index to calendar dates.

    Parameters
    ----------
    bass_data : xarray.DataArray
        Simulated bass model data
    product_start : pandas.Series
        Start date for each product
    possible_dates : pandas.DatetimeIndex
        All dates in the simulation period

    Returns
    -------
    pandas.DataFrame
        Adoption data with actual calendar dates
    """
    bass_data = bass_data.to_dataset()
    bass_data["product_start"] = product_start.to_xarray()

    df_bass_data = (
        bass_data.to_dataframe().drop(columns=["chain", "draw"]).reset_index()
    )
    df_bass_data["actual_date"] = df_bass_data["product_start"] + pd.to_timedelta(
        7 * df_bass_data["T"], unit="days"
    )

    return (
        df_bass_data.set_index(["actual_date", "product"])
        .y.unstack(fill_value=0)
        .reindex(possible_dates, fill_value=0)
    )

Now we can generate the observed data:

# Setup simulation parameters
T, possible_dates, _, products, product_start, coords = setup_simulation_parameters()

# Create and configure the Bass model
generative_model = create_bass_model(t=T, coords=coords, observed=None, priors=priors)

# Sample and select one "observed" dataset.
bass_data = sample_prior_bass_data(generative_model)
actual_data = transform_to_actual_dates(bass_data, product_start, possible_dates)
Sampling: [m_unscaled, p, q, y]

The actual_data data frame has the typical format of a real dataset.

actual_data
product P0 P1 P2 P3 P4 P5 P6 P7 P8
2023-01-02 0 0 1735 0 0 0 0 0 0
2023-01-09 0 0 1034 0 0 0 0 0 0
2023-01-16 0 0 1332 0 0 0 0 0 0
2023-01-23 0 0 4222 0 0 0 0 0 0
2023-01-30 0 1790 3034 0 0 766 0 0 0
2023-02-06 0 611 424 0 0 379 0 0 0
2023-02-13 0 1728 8364 0 0 2582 0 0 0
2023-02-20 0 1604 8897 0 0 822 0 0 0
2023-02-27 0 1282 3142 4054 0 3920 0 0 0
2023-03-06 0 7988 1275 211 0 4626 0 0 0
2023-03-13 538 2445 1657 1714 0 8210 0 0 0
2023-03-20 2369 194 2038 924 0 4338 0 0 0
2023-03-27 421 1399 1052 724 0 7798 0 0 0
2023-04-03 3491 4473 242 3921 0 647 0 0 0
2023-04-10 5214 2798 169 14889 0 3357 0 0 0
2023-04-17 11602 894 252 2810 0 1830 0 0 0
2023-04-24 15316 21 156 5408 0 3235 0 0 0
2023-05-01 2299 115 12 1609 0 1414 0 0 0
2023-05-08 1449 95 26 4436 0 976 0 0 0
2023-05-15 1618 96 104 2320 0 335 0 0 0
2023-05-22 4392 221 6 534 0 528 0 0 0
2023-05-29 102 119 15 802 0 103 0 0 0
2023-06-05 328 35 13 239 0 37 0 0 0
2023-06-12 558 20 3 698 0 17 0 0 0
2023-06-19 873 44 7 280 0 43 0 0 0
2023-06-26 247 1 15 379 0 49 0 0 0
2023-07-03 170 6 5 38 0 0 0 0 0
2023-07-10 114 2 4 270 0 19 0 0 0
2023-07-17 134 3 0 52 0 3 0 0 0
2023-07-24 45 1 1 81 0 6 0 0 652
2023-07-31 38 3 2 15 0 5 0 0 2270
2023-08-07 5 0 0 2 0 2 0 0 1432
2023-08-14 27 0 0 6 0 1 0 0 2882
2023-08-21 2 0 0 4 0 0 0 0 6372
2023-08-28 3 0 0 5 0 0 0 0 430
2023-09-04 2 1 0 4 0 0 0 880 6659
2023-09-11 3 0 0 3 0 0 875 5882 8570
2023-09-18 12 0 0 0 0 0 9348 2024 3969
2023-09-25 1 0 0 1 0 1 4281 14057 9505
2023-10-02 1 0 0 0 0 0 5266 7695 2724
2023-10-09 0 0 0 0 0 0 3594 5308 6818
2023-10-16 0 0 0 1 0 0 3454 4711 1658
2023-10-23 0 0 0 0 0 0 4153 2200 4602
2023-10-30 0 0 0 0 248 0 1879 4318 310
2023-11-06 0 0 0 1 156 0 7918 607 1224
2023-11-13 0 0 0 0 3344 0 536 3806 1914
2023-11-20 0 0 0 0 2997 0 144 3937 223
2023-11-27 0 0 0 0 1626 0 1311 2942 160
2023-12-04 0 0 0 0 2689 0 571 40 218
2023-12-11 0 0 0 0 8190 0 997 64 5
2023-12-18 0 0 0 0 5116 0 1329 863 270
2023-12-25 0 0 0 0 3313 0 117 442 97

On the other hand, the bass_data has the same data as arrays indexed by time (relative) and product.

Let’s visualize both.

fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(15, 12), sharex=False, sharey=True, layout="constrained"
)

# Plot raw simulated data (by time step)
bass_data.to_series().unstack().plot(ax=ax[0])
ax[0].legend(
    title="Product", title_fontsize=14, loc="center left", bbox_to_anchor=(1, 0.5)
)
ax[0].set(
    title="Simulated Weekly Adoption by Product (Time Steps)",
    xlabel="Time Step (Weeks)",
    ylabel="Number of Adoptions",
)

# Plot data with actual calendar dates
actual_data.plot(ax=ax[1])
ax[1].legend(
    title="Product", title_fontsize=14, loc="center left", bbox_to_anchor=(1, 0.5)
)
ax[1].set(
    title="Simulated Weekly Adoption by Product (Calendar Dates)",
    xlabel="Date",
    ylabel="Number of Adoptions",
)

fig.suptitle(
    "Bass Diffusion Model - Simulated Product Adoption", fontsize=18, fontweight="bold"
);

Fit the Model#

We are now ready to fit the model. We use the BassModel class, a ModelBuilder subclass that wraps create_bass_model behind the standard .fit(), .save() and .load() workflow.

The priors go into model_config and the sampler settings into sampler_config. The fit method accepts the data as xr.Dataset (with an observed variable), a wide pd.DataFrame (one column per product), a pd.Series, or a np.ndarray; see to_bass_dataset() for the conversion rules. Here we pass the simulated data as a xr.Dataset.

observed_ds = bass_data.drop_vars(["chain", "draw"]).to_dataset(name="observed")

model = BassModel(
    model_config=priors,
    sampler_config={
        "tune": 1_500,
        "draws": 2_000,
        "chains": 4,
        "nuts_sampler": "nutpie",
        "compile_kwargs": {"mode": "NUMBA"},
    },
)

idata = model.fit(data=observed_ds, random_seed=rng)

model.sample_posterior_predictive(X=observed_ds, extend_idata=True, random_seed=rng);

Sampler Progress

Total Chains: 4

Active Chains: 0

Finished Chains: 4

Sampling for now

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
3500 0 0.27 31
3500 0 0.27 15
3500 0 0.25 15
3500 0 0.29 15

Sampling: [y]

We do not have any divergences. Let’s look at the summary of the parameters.

az.summary(data=idata, var_names=["p", "q", "m"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
p[P0] 0.032 0.006 0.021 0.043 0.000 0.000 8377.0 6209.0 1.0
p[P1] 0.033 0.006 0.022 0.045 0.000 0.000 8450.0 5775.0 1.0
p[P2] 0.033 0.006 0.022 0.044 0.000 0.000 8127.0 6307.0 1.0
p[P3] 0.030 0.005 0.021 0.041 0.000 0.000 7935.0 6544.0 1.0
p[P4] 0.025 0.005 0.016 0.034 0.000 0.000 8260.0 6129.0 1.0
p[P5] 0.026 0.005 0.016 0.035 0.000 0.000 7156.0 5251.0 1.0
p[P6] 0.035 0.006 0.025 0.046 0.000 0.000 7932.0 5844.0 1.0
p[P7] 0.033 0.005 0.023 0.043 0.000 0.000 8255.0 6270.0 1.0
p[P8] 0.026 0.005 0.018 0.035 0.000 0.000 7408.0 5910.0 1.0
q[P0] 0.414 0.020 0.377 0.450 0.000 0.000 8116.0 5775.0 1.0
q[P1] 0.453 0.022 0.411 0.495 0.000 0.000 8550.0 5235.0 1.0
q[P2] 0.409 0.020 0.374 0.447 0.000 0.000 7650.0 6387.0 1.0
q[P3] 0.392 0.019 0.356 0.429 0.000 0.000 8109.0 6002.0 1.0
q[P4] 0.442 0.020 0.404 0.481 0.000 0.000 8095.0 6166.0 1.0
q[P5] 0.426 0.020 0.387 0.464 0.000 0.000 7126.0 5467.0 1.0
q[P6] 0.433 0.021 0.392 0.471 0.000 0.000 7830.0 5370.0 1.0
q[P7] 0.417 0.020 0.381 0.455 0.000 0.000 8191.0 5963.0 1.0
q[P8] 0.328 0.016 0.298 0.357 0.000 0.000 7482.0 5521.0 1.0
m[P0] 48788.337 4381.141 40455.521 56795.585 37.466 54.011 13696.0 5815.0 1.0
m[P1] 45561.569 4345.081 37422.823 53613.919 37.898 56.582 13254.0 5723.0 1.0
m[P2] 46857.727 4347.919 38679.781 54966.294 37.238 53.173 13482.0 5688.0 1.0
m[P3] 49533.535 4502.664 41033.697 57907.223 40.052 54.934 12923.0 6452.0 1.0
m[P4] 48841.898 4531.126 40127.342 57132.759 41.452 54.893 11854.0 5540.0 1.0
m[P5] 48889.856 4514.144 40302.619 57149.439 39.818 57.523 13092.0 6218.0 1.0
m[P6] 50495.324 4525.490 41939.392 58898.973 40.074 55.601 12744.0 6178.0 1.0
m[P7] 52494.504 4535.057 43701.878 60797.826 39.628 53.102 13069.0 5958.0 1.0
m[P8] 52249.953 4483.539 43657.380 60565.393 41.467 58.969 11882.0 5366.0 1.0
_ = az.plot_trace(
    data=idata,
    var_names=["p", "q", "m"],
    compact=True,
    backend_kwargs={"figsize": (12, 7), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);

Overall, the diagnostics and trace look good.

Next, we look into the posterior distributions of the parameters.

ax, *_ = az.plot_forest(idata["posterior"]["p"], combined=True)
ax.axvline(x=priors["p"].parameters["mu"], color="gray", linestyle="--")
ax.get_figure().suptitle("Innovation Coefficient (p)", fontsize=18, fontweight="bold")
Text(0.5, 0.98, 'Innovation Coefficient (p)')
../../_images/774cddb974d619c9f37b696b1d121ce5501bd6616e905cdca515de1f0998f989.png
ax, *_ = az.plot_forest(idata["posterior"]["q"], combined=True)
ax.axvline(x=priors["q"].preliz.mean(), color="gray", linestyle="--")
ax.get_figure().suptitle("Imitation Coefficient (q)", fontsize=18, fontweight="bold")
Text(0.5, 0.98, 'Imitation Coefficient (q)')
../../_images/c14d5b7ba12d79b409fbbc0bf7465ad3d7935b69837d225952a03555d4870f80.png

We do see some heterogeneity in the parameters, but overall they are centered around the true values (from the generative model).

Examining Posterior Predictions for Specific Products#

Let’s look at the posterior predictive distributions to see how well our model captures the simulated data.

fig, axes = plt.subplots(
    nrows=3, ncols=3, figsize=(15, 12), sharex=True, sharey=True, layout="constrained"
)

idata["posterior_predictive"]["y"].pipe(plot_curve, {"T"}, axes=axes)

for i, ax in enumerate(axes.flatten()):
    ax.plot(T, bass_data[:, i], color="black")

fig.suptitle("Posterior Predictive vs Observed Data", fontsize=18, fontweight="bold");
fig, axes = plt.subplots(
    nrows=3, ncols=3, figsize=(15, 12), sharex=True, sharey=True, layout="constrained"
)

idata["posterior_predictive"]["y"].cumsum(dim="T").pipe(plot_curve, {"T"}, axes=axes)

for i, ax in enumerate(axes.flatten()):
    ax.plot(T, bass_data[:, i].cumsum(), color="black")

fig.suptitle(
    "Cumulative Posterior Predictive vs Cumulative Observed Data",
    fontsize=18,
    fontweight="bold",
);
observed_cumulative = bass_data.cumsum(dim="T").isel(T=-1).to_series()

ref_val = {
    "m": [
        {"product": name, "ref_val": value}
        for name, value in observed_cumulative.items()
    ]
}

az.plot_posterior(
    idata.posterior,
    var_names=["m"],
    backend_kwargs=dict(sharex=True, layout="constrained", figsize=(15, 12)),
    ref_val=ref_val,
)

max_T = bass_data.coords["T"].max().item()
fig = plt.gcf()
fig.suptitle(
    f"Estimated Market Cap (m) vs Observed Cumulative at T={max_T}",
    fontsize=18,
    fontweight="bold",
);

Overall, the model does a good job of capturing the data.

Next, we look into the adopters, which represent the expected value of the likelihood.

fig, axes = model.plot_adoption_curve(subplot_kwargs={"ncols": 3, "figsize": (15, 12)})

fig.suptitle("Adopters vs Observed Data", fontsize=18, fontweight="bold");

This show the fit is indeed quite reasonable.

We can also evaluate the model goodness by looking into the cumulative data:

Note

Remember that the adopters is the mean of the distribution so we see some cumulative curves above and some below.

Look at the idata["posterior_predictive"]["y"] for the observed data.

fig, axes = model.plot_cumulative(subplot_kwargs={"ncols": 3, "figsize": (15, 12)})

fig.suptitle("Adopters Cumulative vs Observed Data", fontsize=18, fontweight="bold");

We can enhance this view by looking into the components of the model: innovators and imitators (in orange and green, respectively). The per-period components go on the left y-axis and the cumulative adoption on a twin right y-axis, since they live on very different scales.

fig, axes = model.plot_decomposition(subplot_kwargs={"ncols": 3, "figsize": (15, 12)})

fig.suptitle("Innovators vs Imitators", fontsize=18, fontweight="bold");

Finally, we can inspect the peak of the adoption curve.

fig, axes = model.plot_peak(figsize=(15, 12))

fig.suptitle("Peak", fontsize=18, fontweight="bold");

This fits the observed data quite well. Let’s see for example the product P4.

fig, ax = plt.subplots()

product_id = 4

bass_data[:, product_id].plot(ax=ax, color="black")

idata["posterior"]["adopters"].sel(product=f"P{product_id}").pipe(
    plot_curve, {"T"}, axes=ax
)

peak_hdi = az.hdi(idata["posterior"]["peak"].sel(product=f"P{product_id}"))["peak"]
ax.axvspan(
    peak_hdi.sel(hdi="lower").item(),
    peak_hdi.sel(hdi="higher").item(),
    color="C1",
    alpha=0.4,
)

ax.set_title(f"Peak Product {products[product_id]}", fontsize=18, fontweight="bold");

Forecasting Beyond the Observed Window#

The sample_posterior_predictive method accepts new data with a different time range than the one used for fitting. Passing an extended T coordinate produces an out-of-sample forecast: the posterior of \(m\), \(p\) and \(q\) stays fixed and the adoption curve is evaluated over the new time points.

T_extended = np.arange(int(T.max()) + 26)

forecast = model.sample_posterior_predictive(
    X=xr.Dataset(coords={"T": T_extended}),
    extend_idata=False,
    random_seed=rng,
)

fig, ax = plt.subplots()

forecast["y"].sel(product="P4").pipe(plot_curve, {"T"}, axes=ax, legend=False)
(observed_line,) = ax.plot(
    T, bass_data.sel(product="P4"), color="black", label="observed"
)
cutoff_line = ax.axvline(
    int(T.max()), color="gray", linestyle="--", label="end of observed data"
)
ax.legend(handles=[observed_line, cutoff_line], loc="upper right")
ax.set_title("26-week forecast for product P4", fontsize=18, fontweight="bold");
Sampling: [y]

../../_images/37d9cb54d03c11dbaa1cc70e5fb1a0ea2eb5a425123b338e778c2e6664a26e64.png

Save and Load the Model#

The fitted model can be stored as a single NetCDF file and restored later. The file contains the posterior together with the model configuration (the priors) and the data used for fitting, so the loaded model is ready for posterior analysis, plotting, and posterior predictive sampling.

model.save("bass_model.nc")

loaded_model = BassModel.load("bass_model.nc")

az.summary(loaded_model.idata, var_names=["p", "q"]).head()
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
p[P0] 0.032 0.006 0.021 0.043 0.0 0.0 8377.0 6209.0 1.0
p[P1] 0.033 0.006 0.022 0.045 0.0 0.0 8450.0 5775.0 1.0
p[P2] 0.033 0.006 0.022 0.044 0.0 0.0 8127.0 6307.0 1.0
p[P3] 0.030 0.005 0.021 0.041 0.0 0.0 7935.0 6544.0 1.0
p[P4] 0.025 0.005 0.016 0.034 0.0 0.0 8260.0 6129.0 1.0

MLflow Integration#

The MLflow autologging from pymc_marketing.mlflow supports the Bass model through the log_bass flag. Enabling it patches BassModel.fit so every fit inside an MLflow run logs the model configuration (the priors for m, p, q and the likelihood), the sampler diagnostics, the model graph, and the resulting InferenceData as artifacts. Figures from the plotting methods can be logged with mlflow.log_figure.

We refit the model with a lighter sampler configuration to keep the demo fast.

import mlflow

import pymc_marketing.mlflow

pymc_marketing.mlflow.autolog(log_bass=True)

mlflow.set_experiment("bass-model")

mlflow_model = BassModel(
    model_config=priors,
    sampler_config={
        "tune": 1_000,
        "draws": 1_000,
        "chains": 2,
        "nuts_sampler": "nutpie",
        "compile_kwargs": {"mode": "NUMBA"},
    },
)

with mlflow.start_run():
    mlflow_model.fit(data=observed_ds, random_seed=rng)

    fig, _ = mlflow_model.plot_adoption_curve()
    mlflow.log_figure(fig, "adoption_curve.png")

    fig, _ = mlflow_model.plot_decomposition()
    mlflow.log_figure(fig, "decomposition.png")

Sampler Progress

Total Chains: 2

Active Chains: 0

Finished Chains: 2

Sampling for now

Estimated Time to Completion: now

Progress Draws Divergences Step Size Gradients/Draw
2000 0 0.26 31
2000 0 0.25 31

/home/pablo/micromamba/envs/pymc-marketing-dev/lib/python3.12/site-packages/mlflow/tracking/client.py:3042: UserWarning: constrained_layout not applied because axes sizes collapsed to zero.  Try making figure larger or Axes decorations smaller.
  figure.savefig(tmp_path, **save_kwargs)
/home/pablo/micromamba/envs/pymc-marketing-dev/lib/python3.12/site-packages/IPython/core/events.py:100: UserWarning: constrained_layout not applied because axes sizes collapsed to zero.  Try making figure larger or Axes decorations smaller.
  func(*args, **kwargs)
../../_images/56dd9082ca3115c48a6d50b10a3577027737215139783228f15180185d18a67a.png
/home/pablo/micromamba/envs/pymc-marketing-dev/lib/python3.12/site-packages/IPython/core/pylabtools.py:170: UserWarning: constrained_layout not applied because axes sizes collapsed to zero.  Try making figure larger or Axes decorations smaller.
  fig.canvas.print_figure(bytes_io, **kw)
../../_images/0a0343e15d0daeabdb1e09569ca4d64e83610be75d7937ca7f1e8d99171e9e74.png
%load_ext watermark
%watermark -n -u -v -iv -w -p nutpie,pymc_marketing,pytensor
Last updated: Fri, 12 Jun 2026

Python implementation: CPython
Python version       : 3.12.13
IPython version      : 9.13.0

nutpie        : 0.16.8
pymc_marketing: 0.19.4
pytensor      : 2.38.3

arviz         : 0.23.4
matplotlib    : 3.10.9
mlflow        : 3.12.0
numpy         : 2.4.3
pandas        : 2.3.3
pymc          : 5.28.5
pymc_extras   : 0.10.0
pymc_marketing: 0.19.4
xarray        : 2026.4.0

Watermark: 2.6.0