trace.js (3242B)
1 import { clone } from '../../utils/object.js'; 2 import { format } from '../../utils/string.js'; 3 import { factory } from '../../utils/factory.js'; 4 var name = 'trace'; 5 var dependencies = ['typed', 'matrix', 'add']; 6 export var createTrace = /* #__PURE__ */factory(name, dependencies, _ref => { 7 var { 8 typed, 9 matrix, 10 add 11 } = _ref; 12 13 /** 14 * Calculate the trace of a matrix: the sum of the elements on the main 15 * diagonal of a square matrix. 16 * 17 * Syntax: 18 * 19 * math.trace(x) 20 * 21 * Examples: 22 * 23 * math.trace([[1, 2], [3, 4]]) // returns 5 24 * 25 * const A = [ 26 * [1, 2, 3], 27 * [-1, 2, 3], 28 * [2, 0, 3] 29 * ] 30 * math.trace(A) // returns 6 31 * 32 * See also: 33 * 34 * diag 35 * 36 * @param {Array | Matrix} x A matrix 37 * 38 * @return {number} The trace of `x` 39 */ 40 return typed('trace', { 41 Array: function _arrayTrace(x) { 42 // use dense matrix implementation 43 return _denseTrace(matrix(x)); 44 }, 45 SparseMatrix: _sparseTrace, 46 DenseMatrix: _denseTrace, 47 any: clone 48 }); 49 50 function _denseTrace(m) { 51 // matrix size & data 52 var size = m._size; 53 var data = m._data; // process dimensions 54 55 switch (size.length) { 56 case 1: 57 // vector 58 if (size[0] === 1) { 59 // return data[0] 60 return clone(data[0]); 61 } 62 63 throw new RangeError('Matrix must be square (size: ' + format(size) + ')'); 64 65 case 2: 66 { 67 // two dimensional 68 var rows = size[0]; 69 var cols = size[1]; 70 71 if (rows === cols) { 72 // calulate sum 73 var sum = 0; // loop diagonal 74 75 for (var i = 0; i < rows; i++) { 76 sum = add(sum, data[i][i]); 77 } // return trace 78 79 80 return sum; 81 } else { 82 throw new RangeError('Matrix must be square (size: ' + format(size) + ')'); 83 } 84 } 85 86 default: 87 // multi dimensional 88 throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')'); 89 } 90 } 91 92 function _sparseTrace(m) { 93 // matrix arrays 94 var values = m._values; 95 var index = m._index; 96 var ptr = m._ptr; 97 var size = m._size; // check dimensions 98 99 var rows = size[0]; 100 var columns = size[1]; // matrix must be square 101 102 if (rows === columns) { 103 // calulate sum 104 var sum = 0; // check we have data (avoid looping columns) 105 106 if (values.length > 0) { 107 // loop columns 108 for (var j = 0; j < columns; j++) { 109 // k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1] 110 var k0 = ptr[j]; 111 var k1 = ptr[j + 1]; // loop k within [k0, k1[ 112 113 for (var k = k0; k < k1; k++) { 114 // row index 115 var i = index[k]; // check row 116 117 if (i === j) { 118 // accumulate value 119 sum = add(sum, values[k]); // exit loop 120 121 break; 122 } 123 124 if (i > j) { 125 // exit loop, no value on the diagonal for column j 126 break; 127 } 128 } 129 } 130 } // return trace 131 132 133 return sum; 134 } 135 136 throw new RangeError('Matrix must be square (size: ' + format(size) + ')'); 137 } 138 });