simple-squiggle

A restricted subset of Squiggle
Log | Files | Refs | README

trace.js (3242B)


      1 import { clone } from '../../utils/object.js';
      2 import { format } from '../../utils/string.js';
      3 import { factory } from '../../utils/factory.js';
      4 var name = 'trace';
      5 var dependencies = ['typed', 'matrix', 'add'];
      6 export var createTrace = /* #__PURE__ */factory(name, dependencies, _ref => {
      7   var {
      8     typed,
      9     matrix,
     10     add
     11   } = _ref;
     12 
     13   /**
     14    * Calculate the trace of a matrix: the sum of the elements on the main
     15    * diagonal of a square matrix.
     16    *
     17    * Syntax:
     18    *
     19    *    math.trace(x)
     20    *
     21    * Examples:
     22    *
     23    *    math.trace([[1, 2], [3, 4]]) // returns 5
     24    *
     25    *    const A = [
     26    *      [1, 2, 3],
     27    *      [-1, 2, 3],
     28    *      [2, 0, 3]
     29    *    ]
     30    *    math.trace(A) // returns 6
     31    *
     32    * See also:
     33    *
     34    *    diag
     35    *
     36    * @param {Array | Matrix} x  A matrix
     37    *
     38    * @return {number} The trace of `x`
     39    */
     40   return typed('trace', {
     41     Array: function _arrayTrace(x) {
     42       // use dense matrix implementation
     43       return _denseTrace(matrix(x));
     44     },
     45     SparseMatrix: _sparseTrace,
     46     DenseMatrix: _denseTrace,
     47     any: clone
     48   });
     49 
     50   function _denseTrace(m) {
     51     // matrix size & data
     52     var size = m._size;
     53     var data = m._data; // process dimensions
     54 
     55     switch (size.length) {
     56       case 1:
     57         // vector
     58         if (size[0] === 1) {
     59           // return data[0]
     60           return clone(data[0]);
     61         }
     62 
     63         throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
     64 
     65       case 2:
     66         {
     67           // two dimensional
     68           var rows = size[0];
     69           var cols = size[1];
     70 
     71           if (rows === cols) {
     72             // calulate sum
     73             var sum = 0; // loop diagonal
     74 
     75             for (var i = 0; i < rows; i++) {
     76               sum = add(sum, data[i][i]);
     77             } // return trace
     78 
     79 
     80             return sum;
     81           } else {
     82             throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
     83           }
     84         }
     85 
     86       default:
     87         // multi dimensional
     88         throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')');
     89     }
     90   }
     91 
     92   function _sparseTrace(m) {
     93     // matrix arrays
     94     var values = m._values;
     95     var index = m._index;
     96     var ptr = m._ptr;
     97     var size = m._size; // check dimensions
     98 
     99     var rows = size[0];
    100     var columns = size[1]; // matrix must be square
    101 
    102     if (rows === columns) {
    103       // calulate sum
    104       var sum = 0; // check we have data (avoid looping columns)
    105 
    106       if (values.length > 0) {
    107         // loop columns
    108         for (var j = 0; j < columns; j++) {
    109           // k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
    110           var k0 = ptr[j];
    111           var k1 = ptr[j + 1]; // loop k within [k0, k1[
    112 
    113           for (var k = k0; k < k1; k++) {
    114             // row index
    115             var i = index[k]; // check row
    116 
    117             if (i === j) {
    118               // accumulate value
    119               sum = add(sum, values[k]); // exit loop
    120 
    121               break;
    122             }
    123 
    124             if (i > j) {
    125               // exit loop, no value on the diagonal for column j
    126               break;
    127             }
    128           }
    129         }
    130       } // return trace
    131 
    132 
    133       return sum;
    134     }
    135 
    136     throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
    137   }
    138 });