LU Decomposition
Introduction
LU decomposition is a numerical method for solving systems of linear equations. It decomposes a square matrix into two triangular matrices, one lower triangular matrix (L) and one upper triangular matrix (U). This decomposition can be used to solve linear equations in the form Ax = b, where A is the matrix, x is the vector of unknowns, and b is the vector of constants.
Implementation
The following implementation of LU decomposition is written in OCaml. It takes a matrix a as input and returns the L and U matrices, as well as a row permutation vector. The _lu_base function performs the decomposition and returns the L and U matrices and the permutation vector. The lu function calls _lu_base and then formats the output into the L and U matrices. The _lu_solve_vec function uses the output of _lu_base to solve a system of linear equations.
(* LU decomposition.
* Input matrix: a[n][n]; return L/U in one matrix, and the row permutation vector.
* Test: https://github.com/scipy/scipy/blob/master/scipy/linalg/tests/test_decomp.py
*)
let _lu_base a =
let k = M.kind a in
let _abs = Owl_base_dense_common._abs_elt k in
let _mul = Owl_base_dense_common._mul_elt k in
let _div = Owl_base_dense_common._div_elt k in
let _sub = Owl_base_dense_common._sub_elt k in
let _flt = Owl_base_dense_common._float_typ_elt k in
let _zero = Owl_const.zero k in
let _one = Owl_const.one k in
let lu = M.copy a in
let n = (M.shape a).(0) in
let m = (M.shape a).(1) in
assert (n = m);
let indx = Array.make n 0 in
(* implicit scaling of each row *)
let vv = Array.make n _zero in
let tiny = _flt 1.0e-40 in
let big = ref _zero in
let temp = ref _zero in
(* flag of row exchange *)
let d = ref 1.0 in
let imax = ref 0 in
(* loop over rows to get the implicit scaling information *)
for i = 0 to n - 1 do
big := _zero;
for j = 0 to n - 1 do
temp := M.get lu [| i; j |] |> _abs;
if !temp > !big then big := !temp
done;
if !big = _zero then raise Owl_exception.SINGULAR;
vv.(i) <- _div _one !big
done;
for k = 0 to n - 1 do
big := _zero;
(* choose suitable pivot *)
for i = k to n - 1 do
temp := _mul (M.get lu [| i; k |] |> _abs) vv.(i);
if !temp > !big
then (
big := !temp;
imax := i)
done;
(* interchange rows *)
if k <> !imax
then (
for j = 0 to n - 1 do
temp := M.get lu [| !imax; j |];
let tmp = M.get lu [| k; j |] in
M.set lu [| !imax; j |] tmp;
M.set lu [| k; j |] !temp
done;
d := !d *. -1.;
vv.(!imax) <- vv.(k));
indx.(k) <- !imax;
if M.get lu [| k; k |] = _zero then M.set lu [| k; k |] tiny;
for i = k + 1 to n - 1 do
let tmp0 = M.get lu [| i; k |] in
let tmp1 = M.get lu [| k; k |] in
temp := _div tmp0 tmp1;
M.set lu [| i; k |] !temp;
for j = k + 1 to n - 1 do
let prev = M.get lu [| i; j |] in
M.set lu [| i; j |] (_sub prev (_mul !temp (M.get lu [| k; j |])))
done
done
done;
lu, indx, !d
(* LU decomposition, return L, U, and permutation vector *)
let lu a =
let k = M.kind a in
let _zero = Owl_const.zero k in
let lu, indx, _ = _lu_base a in
let n = (M.shape lu).(0) in
let m = (M.shape lu).(1) in
assert (n = m && n >= 2);
let l = M.eye k n in
for r = 1 to n - 1 do
for c = 0 to r - 1 do
let v = M.get lu [| r; c |] in
M.set l [| r; c |] v;
M.set lu [| r; c |] _zero
done
done;
l, lu, indx
let _lu_solve_vec a b =
let _k = M.kind a in
let _mul = Owl_base_dense_common._mul_elt _k in
let _div = Owl_base_dense_common._div_elt _k in
let _sub = Owl_base_dense_common._sub_elt _k in
let _zero = Owl_const.zero _k in
assert (Array.length (M.shape b) = 1);
let n = (M.shape a).(0) in
if (M.shape b).(0) <> n then failwith "LUdcmp::solve bad sizes";
let ii = ref 0 in
let sum = ref _zero in
let x = M.copy b in
let lu, indx, _ = _lu_base a in
for i = 0 to n - 1 do
let ip = indx.(i) in
sum := M.get x [| ip |];
M.set x [| ip |] (M.get x [| i |]);
if !ii <> 0
then
for j = !ii - 1 to i - 1 do
sum := _sub !sum (_mul (M.get lu [| i; j |]) (M.get x [| j |]))
done
else if !sum <> _zero
then ii := !ii + 1;
M.set x [| i |] !sum
done;
for i = n - 1 downto 0 do
sum := M.get x [| i |];
for j = i + 1 to n - 1 do
sum := _sub !sum (_mul (M.get lu [| i; j |]) (M.get x [| j |]))
done;
M.set x [| i |] (_div !sum (M.get lu [| i; i |]))
done;
x
Step-by-step Explanation
-
The
_lu_basefunction takes a square matrixaas input and returns the L and U matrices, as well as a row permutation vector. -
The function first makes a copy of the input matrix
ainto a new matrixlu. -
It initializes some variables, including the row permutation vector
indx, the scaling vectorvv, a small valuetiny, and a flagd. -
The function then calculates the implicit scaling of each row by finding the maximum absolute value in each row and storing the reciprocal of that value in the
vvvector. -
It then performs the LU decomposition using Gaussian elimination with partial pivoting. The function loops over each column
kand selects the pivot element as the one with the largest scaled value in the column. If necessary, it exchanges rows to bring the pivot element to the diagonal. It stores the row index of each pivot element in theindxvector. -
During the elimination process, the function stores the multipliers used to eliminate the elements below the pivot element in the
lumatrix. -
The function also keeps track of the sign of the row exchanges in the
dflag. -
After the elimination process is complete, the function returns the
lumatrix, theindxvector, and thedflag. -
The
lufunction calls_lu_baseand then formats the output into the L and U matrices. It first calls_lu_baseto get thelumatrix,indxvector, anddflag. It then creates an identity matrixlwith the same size aslu. -
The function sets the elements of
lto the corresponding elements oflubelow the diagonal. -
The function returns the
landlumatrices, as well as theindxvector. -
The
_lu_solve_vecfunction uses the output of_lu_baseto solve a system of linear equations. It takes a matrixaand a vectorbas input and returns the solution vectorx. -
The function first checks that the dimensions of
aandbare compatible. -
It then calls
_lu_baseto get thelumatrix,indxvector, anddflag. -
The function initializes some variables and loops over each row of the
lumatrix to solve for the intermediate solution vectory. -
The function then loops over each row of the
lumatrix again to solve for the final solution vectorx. -
The function returns the solution vector
x.
Complexity Analysis
The time complexity of LU decomposition is O(n^3), where n is the size of the input matrix. This is because the algorithm involves performing Gaussian elimination on the matrix, which requires O(n^3) operations. The space complexity of the algorithm is also O(n^2), since it requires storing the L and U matrices, as well as the row permutation vector. However, the algorithm is numerically stable and is widely used in practice for solving systems of linear equations.