
LU decomposition is a numerical method for solving systems of linear equations. It decomposes a square matrix into two triangular matrices, one lower triangular matrix (L) and one upper triangular matrix (U). This decomposition can be used to solve linear equations in the form Ax = b, where A is the matrix, x is the vector of unknowns, and b is the vector of constants.


The following implementation of LU decomposition is written in OCaml. It takes a matrix a as input and returns the L and U matrices, as well as a row permutation vector. The _lu_base function performs the decomposition and returns the L and U matrices and the permutation vector. The lu function calls _lu_base and then formats the output into the L and U matrices. The _lu_solve_vec function uses the output of _lu_base to solve a system of linear equations.

(* LU decomposition.
 * Input matrix: a[n][n]; return L/U in one matrix, and the row permutation vector.
 * Test:
let _lu_base a =
  let k = M.kind a in
  let _abs = Owl_base_dense_common._abs_elt k in
  let _mul = Owl_base_dense_common._mul_elt k in
  let _div = Owl_base_dense_common._div_elt k in
  let _sub = Owl_base_dense_common._sub_elt k in
  let _flt = Owl_base_dense_common._float_typ_elt k in
  let _zero = k in
  let _one = k in
  let lu = M.copy a in
  let n = (M.shape a).(0) in
  let m = (M.shape a).(1) in
  assert (n = m);
  let indx = Array.make n 0 in
  (* implicit scaling of each row *)
  let vv = Array.make n _zero in
  let tiny = _flt 1.0e-40 in
  let big = ref _zero in
  let temp = ref _zero in
  (* flag of row exchange *)
  let d = ref 1.0 in
  let imax = ref 0 in
  (* loop over rows to get the implicit scaling information *)
  for i = 0 to n - 1 do
    big := _zero;
    for j = 0 to n - 1 do
      temp := M.get lu [| i; j |] |> _abs;
      if !temp > !big then big := !temp
    if !big = _zero then raise Owl_exception.SINGULAR;
    vv.(i) <- _div _one !big
  for k = 0 to n - 1 do
    big := _zero;
    (* choose suitable pivot *)
    for i = k to n - 1 do
      temp := _mul (M.get lu [| i; k |] |> _abs) vv.(i);
      if !temp > !big
      then (
        big := !temp;
        imax := i)
    (* interchange rows *)
    if k <> !imax
    then (
      for j = 0 to n - 1 do
        temp := M.get lu [| !imax; j |];
        let tmp = M.get lu [| k; j |] in
        M.set lu [| !imax; j |] tmp;
        M.set lu [| k; j |] !temp
      d := !d *. -1.;
      vv.(!imax) <- vv.(k));
    indx.(k) <- !imax;
    if M.get lu [| k; k |] = _zero then M.set lu [| k; k |] tiny;
    for i = k + 1 to n - 1 do
      let tmp0 = M.get lu [| i; k |] in
      let tmp1 = M.get lu [| k; k |] in
      temp := _div tmp0 tmp1;
      M.set lu [| i; k |] !temp;
      for j = k + 1 to n - 1 do
        let prev = M.get lu [| i; j |] in
        M.set lu [| i; j |] (_sub prev (_mul !temp (M.get lu [| k; j |])))
  lu, indx, !d

(* LU decomposition, return L, U, and permutation vector *)
let lu a =
  let k = M.kind a in
  let _zero = k in
  let lu, indx, _ = _lu_base a in
  let n = (M.shape lu).(0) in
  let m = (M.shape lu).(1) in
  assert (n = m && n >= 2);
  let l = M.eye k n in
  for r = 1 to n - 1 do
    for c = 0 to r - 1 do
      let v = M.get lu [| r; c |] in
      M.set l [| r; c |] v;
      M.set lu [| r; c |] _zero
  l, lu, indx

let _lu_solve_vec a b =
  let _k = M.kind a in
  let _mul = Owl_base_dense_common._mul_elt _k in
  let _div = Owl_base_dense_common._div_elt _k in
  let _sub = Owl_base_dense_common._sub_elt _k in
  let _zero = _k in
  assert (Array.length (M.shape b) = 1);
  let n = (M.shape a).(0) in
  if (M.shape b).(0) <> n then failwith "LUdcmp::solve bad sizes";
  let ii = ref 0 in
  let sum = ref _zero in
  let x = M.copy b in
  let lu, indx, _ = _lu_base a in
  for i = 0 to n - 1 do
    let ip = indx.(i) in
    sum := M.get x [| ip |];
    M.set x [| ip |] (M.get x [| i |]);
    if !ii <> 0
      for j = !ii - 1 to i - 1 do
        sum := _sub !sum (_mul (M.get lu [| i; j |]) (M.get x [| j |]))
    else if !sum <> _zero
    then ii := !ii + 1;
    M.set x [| i |] !sum
  for i = n - 1 downto 0 do
    sum := M.get x [| i |];
    for j = i + 1 to n - 1 do
      sum := _sub !sum (_mul (M.get lu [| i; j |]) (M.get x [| j |]))
    M.set x [| i |] (_div !sum (M.get lu [| i; i |]))

Step-by-step Explanation

  1. The _lu_base function takes a square matrix a as input and returns the L and U matrices, as well as a row permutation vector.

  2. The function first makes a copy of the input matrix a into a new matrix lu.

  3. It initializes some variables, including the row permutation vector indx, the scaling vector vv, a small value tiny, and a flag d.

  4. The function then calculates the implicit scaling of each row by finding the maximum absolute value in each row and storing the reciprocal of that value in the vv vector.

  5. It then performs the LU decomposition using Gaussian elimination with partial pivoting. The function loops over each column k and selects the pivot element as the one with the largest scaled value in the column. If necessary, it exchanges rows to bring the pivot element to the diagonal. It stores the row index of each pivot element in the indx vector.

  6. During the elimination process, the function stores the multipliers used to eliminate the elements below the pivot element in the lu matrix.

  7. The function also keeps track of the sign of the row exchanges in the d flag.

  8. After the elimination process is complete, the function returns the lu matrix, the indx vector, and the d flag.

  9. The lu function calls _lu_base and then formats the output into the L and U matrices. It first calls _lu_base to get the lu matrix, indx vector, and d flag. It then creates an identity matrix l with the same size as lu.

  10. The function sets the elements of l to the corresponding elements of lu below the diagonal.

  11. The function returns the l and lu matrices, as well as the indx vector.

  12. The _lu_solve_vec function uses the output of _lu_base to solve a system of linear equations. It takes a matrix a and a vector b as input and returns the solution vector x.

  13. The function first checks that the dimensions of a and b are compatible.

  14. It then calls _lu_base to get the lu matrix, indx vector, and d flag.

  15. The function initializes some variables and loops over each row of the lu matrix to solve for the intermediate solution vector y.

  16. The function then loops over each row of the lu matrix again to solve for the final solution vector x.

  17. The function returns the solution vector x.

Complexity Analysis

The time complexity of LU decomposition is O(n^3), where n is the size of the input matrix. This is because the algorithm involves performing Gaussian elimination on the matrix, which requires O(n^3) operations. The space complexity of the algorithm is also O(n^2), since it requires storing the L and U matrices, as well as the row permutation vector. However, the algorithm is numerically stable and is widely used in practice for solving systems of linear equations.