leveraging rust's types for linear algebra

17/01/25
using rust's type system to ensure safety of linear-algebraic operations at compile time

type-flexible matrices

In my university course for linear algebra we were tasked with a group project and part of the project was to implement and benchmark various matrix operations. I had initially created a naive implementation in rust, backed by an idea of having the matrix be applicable for arbitrary types. The base struct looked like this:

pub struct Matrix<T> {
    ptr: std::ptr::NonNull<T>,

    m: usize,
    n: usize,
}

It ends up representing an m by n matrix of element type T, however, I wasn't too happy with the usage of std::ptr::NonNull<T>. But, as I was refining the definition further, a new idea came to mind.

const generics

Back in 2017, Rust RFC 2000 proposed the addition of "const generics", allowing compile-time constants to be passed as generic parameters. As an example, consider

type Array<T, const N: usize> = [T; N];

which acts as an extremely thin wrapper to the primitive slice type. But we can also do things like compile-time arithmetic using generic parameters, such as

type NullTerminated<const LEN: usize> = [char; { LEN + 1 }];

representing a C-like null-terminated string by abstracting the extra byte away from the caller.

compile-time matrices

Applying the new concept of const generics makes the powerful nature of compile-time checks applicable to our matrix definition. We can instead use

pub struct Matrix<T, const M: usize, const N: usize>(Box<[[T; N]; M]>);

Firstly, we stripped the struct of the m and n fields, since this information is now embedded within the type signature. Notice that we also defined the two-dimensional array using slices. Since the compiler can now infer the size of the inner data, we don't have to do any weird manual memory management. In the end, we do still wrap it in std::boxed::Box (a pointer type that uniquely owns the heap allocation of type Tdocumentation), as large M * N would result in unwaitable compilation times due to some interesting LLVM internals. Even if that weren't the case, we still wouldn't want to allocate large (M, N) pairs on the stack (due to performance specifics of the stack).

matrix impls

Now that we have all the size information embedded in the type signature, we can make some very powerful impl blocks, which we'll go over in the following sections.

indexing

We'll first go over some components that will make the code in the following impls a bit cleaner. Namely, we'll implement std::ops::Index and subsequently std::ops::IndexMut. We want to set it up in a way that allows us to index a matrix A at row i and column j, as A[(i, j)]. Therefore, we are implementing Index<(usize, usize)> as

impl<T, const M: usize, const N: usize> Index<(usize, usize)>
    for Matrix<T, M, N>
{
    type Output = T;

    fn index(&self, index: (usize, usize)) -> &Self::Output {
        &self.0[index.0][index.1]
    }
}

IndexMut<(usize, usize)> then comes naturally:

impl<T, const M: usize, const N: usize> IndexMut<usize> for Matrix<T, M, N> {
    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
        &mut self.0[index]
    }
}

transpose

It's a simple start, but a solid introduction to the power of using types in this manner.

impl<T, const M: usize, const N: usize> Matrix<T, M, N> {
    pub fn new() -> Self {
        /* ... */
    }

    pub fn transpose(&self) -> Matrix<T, N, M>
    where
        T: Clone,
    {
        let mut m = Matrix::<T, N, M>::new();
        for i in 0..M {
            for j in 0..N {
                m[(j, i)] = self[(i, j)].clone();
            }
        }
        m
    }
}

The impl block above allows us to define standalone implementations for an arbitrary M by N matrix of element type T. Thus, to define a transpose we want to take in our M by N matrix and return the same contents but transposed to an N by M matrix.

This property is also clearly visible from the type signature

fn transpose(&self: Matrix<_, M, N>) -> Matrix<_, N, M>;

although we must also keep in mind that — if we assume the output is to be copied — we must also implement the Clone trait on the type T, otherwise the behaviour of rewriting the data from self to m would be undefined. We'll see this again in future type definitions, since it is a common need, however, I do wonder whether the Copy trait would be more idiomatic.

From here, we can then simply clone every self[(i, j)] into m[(j, i)] as per the definition of transposes. And our transpose function is complete! Onto some more exciting examples...

matrix addition

To define matrix addition, i.e. A+B for two matrices A and B, we'll want to implement the std::ops::Add trait for our type Matrix<T, M, N>.

impl<T, const M: usize, const N: usize> std::ops::Add<Self> for Matrix<T, M, N>
where
    T: Add<T, Output = T> + Clone,
    P: Clone,
{
    type Output = Self;

    fn add(self, rhs: Self) -> Self::Output {
        /* ... */
    }
}

This sort of trait would work for out definition, but I'd like to change the perspective a bit and really embrace the type system of Rust.

We often talk about matrices as "transformations", i.e. a rotation matrix can help us rotate axes in a plane. In addition to this, I'd like us to extend this notion to the types our matrices inhabit, in other words, we'll not only have the values themselves partake in transformation, but also the datatype those internal elements of our matrices represent. Using this, we can also transform a type T ("Type", pretty much the base type by standard notation) by type P ("Parameter", the right-hand side in our definitions) into type R ("Result", the output of our definitions). So, instead of the code block above, we'll consider the following:

impl<T, P, R, const M: usize, const N: usize> std::ops::Add<Matrix<P, M, N>>
    for Matrix<T, M, N>
where
    T: Add<P, Output = R> + Clone,
    P: Clone,
{
    type Output = Matrix<R, M, N>;

    fn add(self, rhs: Matrix<P, M, N>) -> Self::Output {
        let mut m = Self::Output::new();

        for i in 0..M {
            for j in 0..N {
                m[(i, j)] = self[(i, j)].clone() + rhs[(i, j)].clone();
            }
        }

        m
    }
}

Now we have two matrices A and B, both M by N, where A has elements of type T and B has elements of type P, we know by the definition of matrix addition that the resulting matrix C will have the same dimensions M by N and we want to name its element type R. To ensure that addition +: T -> P -> R (borrowing some functional notation) is possible, we also have to apply the constraint that type T can add with some type P and produce a result type R.

It is necessary that we also separate the concepts of generic type parameters and actual types in e.g. trait implementations. If we use T = P = R = i32 in the above example, we require that i32 implements the trait Add<i32, Output = i32>. This may simply be inferred by the fact that the code block

let a: i32 = 22;
let b: i32 = 20;

assert_eq!(a + b, 42i32);

compiles, but we can also verify this by looking at the trait implementations of the primitive type i32, in which we can see that it does infact implement `impl Add for i32` with `type Output = i32`. Therefore, in 90% of cases, the abstraction T: Add<P, Output = R> is essentially equal to T: Add<T, Output = T>, which is, on top of that, usually implemented for numerical types that we may be interested in using matrices for.

matrix multiplication

On to the most exciting one in my opinion; matrix multiplication. It is infamously the more interesting matrix operation, due to its lack of commutativity, most simply proven by the fact that we require an equality in the dimensions of both matrices.

Let A be an m by n matrix and B an n by p matrix. Then the product of these two matrices is C=AB, an m by p matrix. This instantly gives us hints as to what our trait implementation of std::ops::Mul may look like, so we can start building off of what is provided there.

impl<T, const M: usize, const N: usize, const P: usize std::ops::Mul<Matrix<T, N, P>>
    for Matrix<T, M, N>
where
    T: Add<Output = T> + Mul<Output = T> + Clone,
{
    type Output = Matrix<T, M, P>;

    fn mul(self, rhs: Matrix<T, N, P>) -> Self::Output {
        /* ... */
    }
}

In this case we can clearly see that we are multiplying an M by N matrix by an N by P matrix and outputting an M by P matrix, just as we noted above. In this way, the following example can be checked by the compiler and fail at compilation time with a mismatch error.

let a: Matrix<i32, 1, 2> = Matrix::new();
let b: Matrix<i32, 3, 4> = Matrix::new();

a * b

But, just as with addition, we can improve the constraints to work as any arbitrary combination of three types.

impl<T, S, R, const M: usize, const N: usize, const P: usize> std::ops::Mul<Matrix<S, N, P>>
    for Matrix<T, M, N>

However, consider that matrix multiplication multiplies row-column pairs of A and B and then sums up the results. So, we'll need a further constraint on the internal types of T, S and R, similarly to how we constrained T above; T: Add<Output = T> + Mul<Output = T> + Clone.

Since we're computing the products of types T and S, we need to restrict T by addition with some right-hand side S and we'll call the output R, since that is then the common type of the output matrix. The type R is then what we sum up in the end, so our implementation is extended by these where-clauses

impl<T, S, R, const M: usize, const N: usize, const P: usize> std::ops::Mul<Matrix<S, N, P>>
    for Matrix<T, M, N>
where
    T: Add<S, Output = R> + Clone,
    S: Clone,
    R: Add<R, Output = R> + Clone,
{
    type Output = Matrix<R, M, P>;

    fn mul(self, rhs: Matrix<S, N, P>) -> Self::Output {
        let mut m = Self::Output::new();

        for i in 0..M {
            for j in 0..P {
                m[(i, j)] = self
                    .iter_row(i)
                    .zip(rhs.iter_col(j))
                    .map(|(a, b)| a.clone() * b.clone())
                    .reduce(|acc, it| acc + it)
                    .unwrap();
            }
        }

        m
    }
}

The internals of matrix multiplication then follow by iterating over each row in the left-hand side matrix and column in the right-hand side matrix, multipliying the values together and reducing this into a sum. The unwrap there is safe, since we know the iterator is not empty, by the type definitions.

For completeness, the iter_row and iter_col functions are defined as follows:

pub fn iter_row(&self, row: usize) -> impl Iterator<Item = &T> {
    self.0[row].iter()
}

pub fn iter_col(&self, col: usize) -> impl Iterator<Item = &T> {
    self.iter_vals().skip(col).step_by(N)
}

Perhaps the implementation could also define multiplication for borrowed types T and S, such that cloning the internal values wouldn't be necessary for matrix multiplication, but for now we will go with this.

We have established some pretty fun definitions thus far, so to wrap it up, we'll implement a useful real-world linear algebraic function; namely LUP decomposition.

lup decomposition

LU decomposition (without the P) factors a matrix A into a lower triangular matrix L and an upper triangular matrix U, such that A=LU. This is often useful, since triangular matrices are overall easier to work with.

Taking this idea further, we have LUP decomposition (LU factorization with partial pivoting) which factors any square matrix A into the matrices L (lower-triangular) and U (upper-triangular), such that with the addition of a third output P, the equation PA=LU is satisfied, where P is a permutation matrix. By LUP decomposition we are given numerical stability, which means that the errors produced by the calculation aren't too much larger than as they were in the input.

real-world applications

In real-world applications, LUP decomposition allows us to breakdown difficult linear algebraic calculations into more computationally efficient subproblems. For example, computing a determinant of a matrix is costly when using a recursive method, as commonly explained in class. Factorizing the matrix simplifies this greatly by reducing it to the product of the products of the diagonals multiplied by the amount of pivots (row exchanges) performed in the LUP decomposition stage.

Similarly, both solving systems of linear equations and computing inverses of matrices become computationally more efficient using LUP decomposition.

implementation

Finally, returning to the implementation of LUP decomposition. Unlike in previous impl blocks, we want to restrict this impl to square matrices, so instead of specifying two const generics, we'll only use one <T, const N: usize> and specify it in both dimensions for the matrix type: Matrix<_, N, N>. This way, if the caller attemps to LUP decompose a rectangular matrix, the compiler complains and tells them the function is not implemented for it.

impl<T: Clone, const N: usize> Matrix<T, N, N>
where
    T: std::cmp::PartialOrd<T>,
    T: std::ops::DivAssign + std::ops::SubAssign,
    T: std::ops::Mul<T, Output = T> + std::ops::Neg<Output = T>,
    Self: Clone + Identity,
{

For this block, we'll also want our matrix to have constructable identity matrices (i.e. a matrix which has 1s across the diagonal and 0s otherwise). The implementation of this trait is out of the scope of this article, but for clarity, the definition is.

pub trait Identity {
    fn identity() -> Self;

    fn is_identity(&self) -> bool;
}

Additionally, we'll need T to satisfy the std::cmp::PartialOrd trait, so that we can compute absolute values (more on this later). We also need std::ops::{DivAssign, SubAssign, Mul, Neg} for the purposes of the algorithm, though the usages will be visible as we write it out.

The input to our function is the matrix caller itself, so the parameters will simply be (&self), although for the output of our decomposition will be a tuple (A',P,n), in which A'=(L-E)+U, where L and U denote the lower and upper -triangular matrices established earlier, E denotes the matrix of all zeros except for the diagonal matching the diagonal of L (such that L-E ends up being L with the diagonal equal to 0). Finally, A' satisfies PA'=LU. The n in the output denotes the amount of pivots performed during the decomposition.

Thus, our function signature is

    pub fn lup_decompose(&self) -> (Self, Self, usize) {

We'll continue by setting up our return values for mutation throughout the function, that means cloning self since we're borrowing it, initializing our permutation matrix as the identity matrix, which we'll then be pivoting throughout, and initializing the pivot counter as the row/column dimension of self.

        let mut r = self.clone();
        let mut perm = Self::identity();
        let mut pivots = N;

In the algorithm we now want to iterate over the matrix (first using the iterated index for the column, then using it to swap rows and finally as a column index again), which we'll do in a standard for loop and we'll find the index imax of the row with the absolute maximal element for that column. To do this we may want an Abs-like trait, but this sadly doesn't exist in the standard library, but we still want our implementation to work natively for primitive numerical types, so we'll use a sort of "generic" version of this which checks for if a value is greater than the current maximal value, and if not, we check whether the negative of it is larger.

        for i in 0..N {
            let mut imax = i;
            let mut max_val = r[(i, i)].clone();

            for k in i..N {
                let val = r[(k, i)].clone();
                let neg_val = -(r[(k, i)].clone());

                if val > max_val {
                    max_val = val;
                    imax = k;
                } else if neg_val > max_val {
                    max_val = neg_val;
                    imax = k;
                }
            }

Given the index of the maximal row imax, we now want to swap the rows at index i and index imax, which we'll do using an (out of scope of this post) swap_rows function:

            if i != imax {
                perm.swap_rows(i, imax);
                r.swap_rows(i, imax);

                pivots += 1;
            }

Since we pivoted, we also increase the pivot counter. For the last part of the algorithm, we now just have to perform some row operations seen below (where the previous trait restraints are applied).

            for j in i + 1..N {
                let r_ii = r[(i, i)].clone();
                r[(j, i)] /= r_ii;

                for k in i + 1..N {
                    let r_ji = r[(j, i)].clone();
                    let r_ik = r[(i, k)].clone();
                    r[(j, k)] -= r_ji * r_ik;
                }
            }
        }

And finally, we return the tuple and close off our open braces.

        (r, perm, pivots)
    }
}

With this, we have beautifully implemented the LUP decomposition for our Matrix! With this implemented we can now use this function for determinant calculations and solving systems of linear equations.

conclusion

Moving on from here, we can use the basic buildings blocks we have established. The implementation described here exists and is still being worked on under the name mitra, available at this codeberg repository.

Not only is it quite fun to work with Rust in this way, but it's also extremely interesting and kind of brain fart-inducing to think that this form of matrices also allows things such as matrices of strings or matrices of matrices. The one downside here is that it becomes fairly difficult and possibly impossible (using idiomatic and safe Rust) to read in dynamically sized matrices. There may be a way of doing it but with my current knowledge I haven't figured out how, though it'll be a fun problem to try solve as time goes on.