simple-squiggle

A restricted subset of Squiggle
Log | Files | Refs | README

trace.js (3505B)


      1 "use strict";
      2 
      3 Object.defineProperty(exports, "__esModule", {
      4   value: true
      5 });
      6 exports.createTrace = void 0;
      7 
      8 var _object = require("../../utils/object.js");
      9 
     10 var _string = require("../../utils/string.js");
     11 
     12 var _factory = require("../../utils/factory.js");
     13 
     14 var name = 'trace';
     15 var dependencies = ['typed', 'matrix', 'add'];
     16 var createTrace = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) {
     17   var typed = _ref.typed,
     18       matrix = _ref.matrix,
     19       add = _ref.add;
     20 
     21   /**
     22    * Calculate the trace of a matrix: the sum of the elements on the main
     23    * diagonal of a square matrix.
     24    *
     25    * Syntax:
     26    *
     27    *    math.trace(x)
     28    *
     29    * Examples:
     30    *
     31    *    math.trace([[1, 2], [3, 4]]) // returns 5
     32    *
     33    *    const A = [
     34    *      [1, 2, 3],
     35    *      [-1, 2, 3],
     36    *      [2, 0, 3]
     37    *    ]
     38    *    math.trace(A) // returns 6
     39    *
     40    * See also:
     41    *
     42    *    diag
     43    *
     44    * @param {Array | Matrix} x  A matrix
     45    *
     46    * @return {number} The trace of `x`
     47    */
     48   return typed('trace', {
     49     Array: function _arrayTrace(x) {
     50       // use dense matrix implementation
     51       return _denseTrace(matrix(x));
     52     },
     53     SparseMatrix: _sparseTrace,
     54     DenseMatrix: _denseTrace,
     55     any: _object.clone
     56   });
     57 
     58   function _denseTrace(m) {
     59     // matrix size & data
     60     var size = m._size;
     61     var data = m._data; // process dimensions
     62 
     63     switch (size.length) {
     64       case 1:
     65         // vector
     66         if (size[0] === 1) {
     67           // return data[0]
     68           return (0, _object.clone)(data[0]);
     69         }
     70 
     71         throw new RangeError('Matrix must be square (size: ' + (0, _string.format)(size) + ')');
     72 
     73       case 2:
     74         {
     75           // two dimensional
     76           var rows = size[0];
     77           var cols = size[1];
     78 
     79           if (rows === cols) {
     80             // calulate sum
     81             var sum = 0; // loop diagonal
     82 
     83             for (var i = 0; i < rows; i++) {
     84               sum = add(sum, data[i][i]);
     85             } // return trace
     86 
     87 
     88             return sum;
     89           } else {
     90             throw new RangeError('Matrix must be square (size: ' + (0, _string.format)(size) + ')');
     91           }
     92         }
     93 
     94       default:
     95         // multi dimensional
     96         throw new RangeError('Matrix must be two dimensional (size: ' + (0, _string.format)(size) + ')');
     97     }
     98   }
     99 
    100   function _sparseTrace(m) {
    101     // matrix arrays
    102     var values = m._values;
    103     var index = m._index;
    104     var ptr = m._ptr;
    105     var size = m._size; // check dimensions
    106 
    107     var rows = size[0];
    108     var columns = size[1]; // matrix must be square
    109 
    110     if (rows === columns) {
    111       // calulate sum
    112       var sum = 0; // check we have data (avoid looping columns)
    113 
    114       if (values.length > 0) {
    115         // loop columns
    116         for (var j = 0; j < columns; j++) {
    117           // k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
    118           var k0 = ptr[j];
    119           var k1 = ptr[j + 1]; // loop k within [k0, k1[
    120 
    121           for (var k = k0; k < k1; k++) {
    122             // row index
    123             var i = index[k]; // check row
    124 
    125             if (i === j) {
    126               // accumulate value
    127               sum = add(sum, values[k]); // exit loop
    128 
    129               break;
    130             }
    131 
    132             if (i > j) {
    133               // exit loop, no value on the diagonal for column j
    134               break;
    135             }
    136           }
    137         }
    138       } // return trace
    139 
    140 
    141       return sum;
    142     }
    143 
    144     throw new RangeError('Matrix must be square (size: ' + (0, _string.format)(size) + ')');
    145   }
    146 });
    147 exports.createTrace = createTrace;