Algorithm Implementation/Linear Algebra/Tridiagonal matrix algorithm

From Wikibooks, open books for an open world
Jump to navigation Jump to search

All the provided implementations of the tridiagonal matrix algorithm assume that the three diagonals, a (below), b (main), and c (above), are passed as arguments.

Haskell[edit]

thomas :: Fractional g => [g] -> [g] -> [g] -> [g] -> [g]
thomas as bs cs ds = xs
  where
    n = length bs
    bs' = b(0) : [b(i) - a(i)/b'(i-1) * c(i-1) | i <- [1..n-1]]
    ds' = d(0) : [d(i) - a(i)/b'(i-1) * d'(i-1) | i <- [1..n-1]]
    xs = reverse $ d'(n-1) / b'(n-1) : [(d'(i) - c(i) * x(i+1)) / b'(i) | i <- [n-2, n-3..0]]
    
    -- convenience accessors (because otherwise it's hard to read)
    a i = as !! (i-1) --  because the list's first item is equivalent to a_1
    b i = bs !! i
    c i = cs !! i
    d i = ds !! i
    x i = xs !! i
    b' i = bs' !! i
    d' i = ds' !! i

note that the function takes in a subdiagonal list (as) whose first element (index 0) is on row 1, which is why we shift the index in the convenience accessors so we can call it "a(1)" anyway.

C[edit]

The following is an example of the implementation of this algorithm in the C programming language.

void solve_tridiagonal_in_place_destructive(float * restrict const x, const size_t X, const float * restrict const a, const float * restrict const b, float * restrict const c) {
    /*
     solves Ax = v where A is a tridiagonal matrix consisting of vectors a, b, c
     x - initially contains the input vector v, and returns the solution x. indexed from 0 to X - 1 inclusive
     X - number of equations (length of vector x)
     a - subdiagonal (means it is the diagonal below the main diagonal), indexed from 1 to X - 1 inclusive
     b - the main diagonal, indexed from 0 to X - 1 inclusive
     c - superdiagonal (means it is the diagonal above the main diagonal), indexed from 0 to X - 2 inclusive
     
     Note: contents of input vector c will be modified, making this a one-time-use function (scratch space can be allocated instead for this purpose to make it reusable)
     Note 2: We don't check for diagonal dominance, etc.; this is not guaranteed stable
     */

    c[0] = c[0] / b[0];
    x[0] = x[0] / b[0];
    
    /* loop from 1 to X - 1 inclusive, performing the forward sweep */
    for (size_t ix = 1; ix < X; ix++) {
        const float m = 1.0f / (b[ix] - a[ix] * c[ix - 1]);
        c[ix] = c[ix] * m;
        x[ix] = (x[ix] - a[ix] * x[ix - 1]) * m;
    }
    
    /* loop from X - 2 to 0 inclusive (safely testing loop condition for an unsigned integer), to perform the back substitution */
    for (size_t ix = X - 2; ; ix--)
        x[ix] -= c[ix] * x[ix + 1];
}

The following variant preserves the system of equations for reuse on other inputs. Note the necessity of library calls to allocate and free scratch space - a more efficient implementation for solving the same tridiagonal system on many inputs would rely on the calling function to provide a pointer to the scratch space.

void solve_tridiagonal_in_place_reusable(float * restrict const x, const size_t X, const float * restrict const a, const float * restrict const b, const float * restrict const c) {
    /* Allocate scratch space. */
    float * restrict const cprime = malloc(sizeof(float) * X);
    
    if (!cprime) {
        /* do something to handle error */
    }
    
    cprime[0] = c[0] / b[0];
    x[0] = x[0] / b[0];
    
    /* loop from 1 to X - 1 inclusive */
    for (size_t ix = 1; ix < X; ix++) {
        const float m = 1.0f / (b[ix] - a[ix] * cprime[ix - 1]);
        cprime[ix] = c[ix] * m;
        x[ix] = (x[ix] - a[ix] * x[ix - 1]) * m;
    }
    
    /* loop from X - 2 to 0 inclusive, safely testing loop end condition */
    for (size_t ix = X - 2; ; ix--)
        x[ix] -= cprime[ix] * x[ix + 1];
    
    /* free scratch space */
    free(cprime);
}

C++[edit]

The following C++ function will solve a general tridiagonal system (though it will destroy the input vector c and d in the process). Note that the index here is zero based, in other words where is the number of unknowns. Note that, save for the printing of text in the main() function, this code is valid C as well as C++.

#include <iostream>

using namespace std;

void solve(double* a, double* b, double* c, double* d, int n) {
    /*
    // n is the number of unknowns

    |b0 c0 0 ||x0| |d0|
    |a1 b1 c1||x1|=|d1|
    |0  a2 b2||x2| |d2|

    1st iteration: b0x0 + c0x1 = d0 -> x0 + (c0/b0)x1 = d0/b0 ->

        x0 + g0x1 = r0               where g0 = c0/b0        , r0 = d0/b0

    2nd iteration:     | a1x0 + b1x1   + c1x2 = d1
        from 1st it.: -| a1x0 + a1g0x1        = a1r0
                    -----------------------------
                          (b1 - a1g0)x1 + c1x2 = d1 - a1r0

        x1 + g1x2 = r1               where g1=c1/(b1 - a1g0) , r1 = (d1 - a1r0)/(b1 - a1g0)

    3rd iteration:      | a2x1 + b2x2   = d2
        from 2nd it. : -| a2x1 + a2g1x2 = a2r2
                       -----------------------
                       (b2 - a2g1)x2 = d2 - a2r2
        x2 = r2                      where                     r2 = (d2 - a2r2)/(b2 - a2g1)
    Finally we have a triangular matrix:
    |1  g0 0 ||x0| |r0|
    |0  1  g1||x1|=|r1|
    |0  0  1 ||x2| |r2|

    Condition: ||bi|| > ||ai|| + ||ci||

    in this version the c matrix reused instead of g
    and             the d matrix reused instead of r and x matrices to report results
    Written by Keivan Moradi, 2014
    */
    n--; // since we start from x0 (not x1)
    c[0] /= b[0];
    d[0] /= b[0];

    for (int i = 1; i < n; i++) {
        c[i] /= b[i] - a[i]*c[i-1];
        d[i] = (d[i] - a[i]*d[i-1]) / (b[i] - a[i]*c[i-1]);
    }

    d[n] = (d[n] - a[n]*d[n-1]) / (b[n] - a[n]*c[n-1]);

    for (int i = n; i-- > 0;) {
        d[i] -= c[i]*d[i+1];
    }
}

int main() {
	int  n = 4;
	double a[4] = { 0, -1, -1, -1 };
	double b[4] = { 4,  4,  4,  4 };
	double c[4] = {-1, -1, -1,  0 };
	double d[4] = { 5,  5, 10, 23 };
	// results    { 2,  3,  5, 7  }
	solve(a,b,c,d,n);
	for (int i = 0; i < n; i++) {
		cout << d[i] << endl;
	}
	cout << endl << "n= " << n << endl << "n is not changed hooray !!";
	return 0;
}

Python[edit]

Note that the index here is zero-based, in other words where is the number of unknowns.

try:
   import numpypy as np    # for compatibility with numpy in pypy
except:
   import numpy as np      # if using numpy in cpython

def TDMASolve(a, b, c, d):
    n = len(a)
    ac, bc, cc, dc = map(np.array, (a, b, c, d))
    xc = []
    for j in range(2, n):
        if(bc[j - 1] == 0):
            ier = 1
            return
        ac[j] = ac[j]/bc[j-1]
        bc[j] = bc[j] - ac[j]*cc[j-1]
    if(b[n-1] == 0):
        ier = 1
        return
    for j in range(2, n):
        dc[j] = dc[j] - ac[j]*dc[j-1]
    dc[n-1] = dc[n-1]/bc[n-1]
    for j in range(n-2, -1, -1):
        dc[j] = (dc[j] - cc[j]*dc[j+1])/bc[j]
    return dc

Another way

def TDMA(a,b,c,f):
    a, b, c, f = map(lambda k_list: map(float, k_list), (a, b, c, f))

    alpha = [0]
    beta = [0]
    n = len(f)
    x = [0] * n

    for i in range(n-1):
        alpha.append(-b[i]/(a[i]*alpha[i] + c[i]))
        beta.append((f[i] - a[i]*beta[i])/(a[i]*alpha[i] + c[i]))
            
    x[n-1] = (f[n-1] - a[n-2]*beta[n-1])/(c[n-1] + a[n-2]*alpha[n-1])

    for i in reversed(range(n-1)):
        x[i] = alpha[i+1]*x[i+1] + beta[i+1]
    
    return x

# small test. x = (1,2,3)
if __name__ == '__main__':
    a = [3,3]
    b = [2,1]
    c = [6,5,8]
    f = [10,16,30]
    print(TDMA(a,b,c,f))

MATLAB[edit]

function x = TDMAsolver(a,b,c,d)
%a, b, c are the column vectors for the compressed tridiagonal matrix, d is the right vector
n = length(d); % n is the number of rows
 
% Modify the first-row coefficients
c(1) = c(1) / b(1);    % Division by zero risk.
d(1) = d(1) / b(1);   
 
for i = 2:n-1
    temp = b(i) - a(i) * c(i-1);
    c(i) = c(i) / temp;
    d(i) = (d(i) - a(i) * d(i-1))/temp;
end
 
d(n) = (d(n) - a(n) * d(n-1))/( b(n) - a(n) * c(n-1));
 
% Now back substitute.
x(n) = d(n);
for i = n-1:-1:1
    x(i) = d(i) - c(i) * x(i + 1);
end

Fortran 90[edit]

Note that the index here is one based, in other words where is the number of unknowns.

Sometimes it is undesirable to have the solver routine overwrite the tridiagonal coefficients (e.g. for solving multiple systems of equations where only the right side of the system changes), so this implementation gives an example of a relatively inexpensive method of preserving the coefficients.

      subroutine solve_tridiag(a,b,c,d,x,n)
      implicit none
!	 a - sub-diagonal (means it is the diagonal below the main diagonal)
!	 b - the main diagonal
!	 c - sup-diagonal (means it is the diagonal above the main diagonal)
!	 d - right part
!	 x - the answer
!	 n - number of equations

        integer,parameter :: r8 = kind(1.d0)

        integer,intent(in) :: n
        real(r8),dimension(n),intent(in) :: a,b,c,d
        real(r8),dimension(n),intent(out) :: x
        real(r8),dimension(n) :: cp,dp
        real(r8) :: m
        integer i

! initialize c-prime and d-prime
        cp(1) = c(1)/b(1)
        dp(1) = d(1)/b(1)
! solve for vectors c-prime and d-prime
         do i = 2,n
           m = b(i)-cp(i-1)*a(i)
           cp(i) = c(i)/m
           dp(i) = (d(i)-dp(i-1)*a(i))/m
         enddo
! initialize x
         x(n) = dp(n)
! solve for x from the vectors c-prime and d-prime
        do i = n-1, 1, -1
          x(i) = dp(i)-cp(i)*x(i+1)
        end do

    end subroutine solve_tridiag

This subroutine offers an option of overwriting d or not.[1]

      subroutine tdma(n,a,b,c,d,x)
	  implicit none
      integer, intent(in) :: n
      real, intent(in) :: a(n), c(n)
      real, intent(inout), dimension(n) :: b, d
	  real, intent(out) :: x(n)
	  !  --- Local variables ---
	  integer :: i
	  real :: q
      !  --- Elimination ---
      do i = 2,n
         q = a(i)/b(i - 1)
         b(i) = b(i) - c(i - 1)*q
         d(i) = d(i) - d(i - 1)*q
      end do
      ! --- Backsubstitution ---
      q = d(n)/b(n)
      x(n) = q
      do i = n - 1,1,-1
         q = (d(i) - c(i)*q)/b(i)
         x(i) = q
      end do
      return
      end

References[edit]