thermox: The First Thermodynamic Computing Simulator

Fast and exact Ornstein-Uhlenbeck processes with JAX
computing
research
Author

Sam Duffield and Kaelan Donatella

Published

May 30, 2024

Conventional Computing is Hindering AI Progress

Energy consumption has become the primary bottleneck for AI, with growing concerns around the overall cost, waste heat, and emissions associated with AI infrastructure. This highlights a deep problem due the mismatch between underlying digital hardware and the computations required for AI. Moreover, there is a clear need to democratize AI research, which has become inaccessible to many due to massive energy consumption, cost, and compute requirements. These issues are especially pronounced when one desires AI with high-level reasoning, where probabilistic AI plays a key role but is highly computationally expensive. Overall, this provides motivation for a novel computational paradigm to make AI faster, more capable, and energy-efficient.

Democratizing Thermodynamic Computing

At Normal Computing, our team has been pioneering physics-based hardware utilizing thermodynamics (i.e., a thermodynamic computer) to address the above issues. To accelerate the development of efficient AI with reasoning capabilities, we aim to grow the community around thermodynamic computing. To this end, we have made many of our hardware designs and results publicly available in our publications and github repositories. Advancing thermodynamic computing out of the lab and into large-scale devices will be accelerated by coordination between public, private, and academic entities. This is why we have made publicly available the results necessary to understand how thermodynamic computing devices work, and the knowledge necessary to build them. We are now taking another step towards democratization by open-sourcing a cutting-edge tool developed at Normal: thermox, a fast and exact thermodynamic computing simulator.

At a fundamental level, the thermodynamic computer we are building – coined the Stochastic Processing Unit (SPU) – is a physical device that evolves according to a stochastic differential equation (SDE).

Possibly the simplest SDE and the focus of our first prototype is the Ornstein-Uhlenbeck (OU) process, which takes the form:

\[ dx = - A(x - b) dt + \mathcal{N}(0, D dt), \]

where \(x\) and \(b\) are \(d\)-dimensional vectors, \(A\) and \(D\) are \(d\times d\) matrices1, and \(\mathcal{N}(0, D)\) denotes a normal distribution with covariance \(D\).

Accelerating Linear Algebra with OU Processes

As we describe in our foundational paper, one may perform computations by building a device that simulates that OU process. One may send in the problem inputs to the SPU, wait for some time, and measure the state of the SPU to get a solution to the problem at hand. One reason why this is exciting is that, for a scaled-up system, we can get a speedup over digital computers on a range of linear algebra primitives. Let us take solving a linear system as an example. We wish to solve:

\[ Ax^* = b ,\]

that is, find the \(x^*\) that satisfies this equation. By sending in \(A\) and \(b\) to the device, after some relaxation time, the average value of \(x\) will be \(x^*\), up to some error. It turns out that with this method, we obtain an asymptotic speedup over digital methods, with a runtime that is linear in \(d\)!

Thermodynamic linear system solving is part of a broader range of thermodyamic linear algebra primitives including full inverse \(A^{-1}\) and matrix exponentials.

thermox: Fast and Exact OU Processes

Studying our SPU at the large-scale (think thousands of dimensions) requires solving SDEs of the form described above. This is computationally expensive, so we really needed to have a fast and efficient tool to simulate OU processes, and for it to work on any machine. So as jax-aficionados, we built thermox.

On a digital computer, one needs to discretize time to simulate an OU process (and this is what is typically done in the community). SDEs are especially sensitive to this, and although some second-order solvers are more robust, it is generally a pain to simulate SDEs.

Let us consider the case that we have an initial point \(x_0\) and want to sample a point from the process \(x_T\) at time \(T\). The simplest way to do this is to use the Euler-Maruyama method, which is a first-order method that discretizes the SDE as:

\[ x(t + \delta) = - A(x(t) - b) \delta + \sqrt{\delta} w, \]

with \(w\) a random draw from \(\mathcal{N}(0, D)\). Here \(\delta\) is a stepsize that is required to be small enough to ensure the discretization error is small. The process is repeated until we reach time \(T\).

The time-complexity of running \(n_{\mathrm{steps}}\) steps is \(O(d^2n_{\mathrm{steps}})\) because of matrix-vector multiplications, which doesn’t look too bad at first. But remember there will be some discretization error, which can be kept down by having a large number of steps2. For long OU trajectories where we want samples at many times, this can quickly become intractable. As an example, in molecular dynamics simulations, up to billions of time steps are needed 🫠. Most existing libraries, such as diffrax would typically run OU processes in a similar way, with better solvers but which still suffer from discretization error. Thankfully, it turns out that we don’t need to discretize time if we want to simulate multivariate OU processes.

With thermox, a trajectory is run very differently. For a multivariate OU process, there are analytical expressions for the mean and covariance matrix at all times. The mean reads:

\[ \mu(t) = \langle x(t) \rangle = x_0 \exp{(-At)} + A^{-1}b, \]

and the covariance:

\[\begin{aligned} \Sigma(t) &= \langle [x(t) - \langle x(t)\rangle] [x(t) - \langle x(t)\rangle]^\top \rangle, \\ &= \int_0^{t} \exp{[-A(t-t')]} D D^\top \exp{[-A^\top(t-t')]} dt', \end{aligned}\]

assuming \(\Sigma(0) = 0\). By diagonalizing \(A\) and \(D\), it turns out we can construct these matrices (with a \(O(d^3)\) preprocessing step), and simply sample from \(\mathcal{N}(\mu(t), \Sigma(t))\) for any time \(t\) to obtain a sample. This is convenient as it avoids discretization errors altogether and eliminates the dependence on the number (and sparsity) of time steps from the time complexity. The total thermox complexity is \(O(d^3 + Nd^2)\) to collect \(N\) samples, which can provide a huge improvement over the \(O(Nd^2n_{\mathrm{steps}})\) complexity of discretized methods.3

Comparison to diffrax

diffrax is an awesome library for discretizing generic SDEs, but as shown above we can simulate OU processes exactly. Let’s show that numerically! Here we compare thermox and diffrax to simulate a multivariate OU process. Details of the code can be found here. Here are the results we get:

And there you see it. For \(d=100\) and a large number of time steps, you get up to \(800\times\) speedup and the simulation is exact (no discretization error).

Quick Example

Let’s simulate a 5-dimensional OU process with thermox:

import thermox
import jax
import jax.numpy as jnp

# Set random seed
key = jax.random.PRNGKey(0)

# Timeframe
ts = jnp.arange(0, 1, 0.01)

# System parameters for a 5-dimensional OU process
A = jnp.array([[2.0, 0.5, 0.0, 0.0, 0.0],
               [0.5, 2.0, 0.5, 0.0, 0.0],
               [0.0, 0.5, 2.0, 0.5, 0.0],
               [0.0, 0.0, 0.5, 2.0, 0.5],
               [0.0, 0.0, 0.0, 0.5, 2.0]])

b, x0 = jnp.zeros(5), jnp.zeros(5) # Zero drift displacement vector and initial state

# Diffusion matrix with correlations between x_1 and x_2
D = jnp.array([[2, 1, 0, 0, 0],
               [1, 2, 0, 0, 0],
               [0, 0, 2, 0, 0],
               [0, 0, 0, 2, 0],
               [0, 0, 0, 0, 2]])

# Collect samples
samples = thermox.sample(key, ts, x0, A, b, D)

Ok let’s plot those OU samples:

Code
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 5))
plt.plot(ts, samples, label=[f'Dimension {i+1}' for i in range(5)])
plt.xlabel('Time', fontsize=16)
plt.ylabel('Value', fontsize=16)
plt.title('Trajectories of 5-Dimensional OU Process', fontsize=16)
plt.legend(fontsize=16)
plt.show()

How pretty!

A First Thermodynamic Computing Simulator

As mentioned, thermox is a great tool to run benchmarks for thermodynamic computers. With thermox you can run such thermodynamic simulations in a single line of code (similarly to the widely-used scipy.linalg.solve):

x_s = thermox.linalg.solve(A, b, num_samples=1000)

Here, 1000 samples are collected from an OU process and then averaged to obtain an approximate solution to the linear system \(Ax = b\), with a preset sampling interval. In fact, we believe that any thermodynamic advantage experiment would have to beat thermox, as it provides a baseline for digital computers simulating OU process. We can therefore view thermox as the first open-source thermodynamic computing simulator. This is analogous to how quantum computing simulators (e.g., IBM Qiskit or tensor-network methods) provide a classical baseline for quantum computers to beat. For further details on this, please take a look at this notebook. We hope to see the community also run its own thermodynamic experiments, propose new thermodynamic experiments and run custom applications with thermox!

posteriors and thermox

We recently released posteriors, a Python library for uncertainty quantification which features a few methods that can leverage thermox. For example, consider that you wish to apply a Laplace approximation to a desired application, and sample from it. You can do this with thermox to get an idea of how using a thermodynamic device to collect samples would influence your end result. In the long-term, our vision is to use a real thermodynamic device in conjunction with posteriors for uncertainty quantification in machine learning.

What’s next?

thermox is an open source project, we welcome contributions, and particularly further research into thermodynamic computation and its applications in machine learning and other topics, such as finance and evolutionary biology! We hope that thermox can accelerate research and our journey to practical thermodynamic computation. Come join the thermox fun on GitHub!

Footnotes

  1. With the matrix \(D\) being symmetric positive-definite.↩︎

  2. And therefore small stepsize \(\delta\).↩︎

  3. Of course, thermodynamic devices will ultimately be faster than thermox. More specifically, for solving a linear system, we expect a runtime scaling as \(O(d\kappa^2)\), with \(\kappa\) the condition number of \(A\), as shown here.↩︎

Reuse