dot.js (4501B)
1 import { factory } from '../../utils/factory.js'; 2 import { isMatrix } from '../../utils/is.js'; 3 var name = 'dot'; 4 var dependencies = ['typed', 'addScalar', 'multiplyScalar', 'conj', 'size']; 5 export var createDot = /* #__PURE__ */factory(name, dependencies, _ref => { 6 var { 7 typed, 8 addScalar, 9 multiplyScalar, 10 conj, 11 size 12 } = _ref; 13 14 /** 15 * Calculate the dot product of two vectors. The dot product of 16 * `A = [a1, a2, ..., an]` and `B = [b1, b2, ..., bn]` is defined as: 17 * 18 * dot(A, B) = conj(a1) * b1 + conj(a2) * b2 + ... + conj(an) * bn 19 * 20 * Syntax: 21 * 22 * math.dot(x, y) 23 * 24 * Examples: 25 * 26 * math.dot([2, 4, 1], [2, 2, 3]) // returns number 15 27 * math.multiply([2, 4, 1], [2, 2, 3]) // returns number 15 28 * 29 * See also: 30 * 31 * multiply, cross 32 * 33 * @param {Array | Matrix} x First vector 34 * @param {Array | Matrix} y Second vector 35 * @return {number} Returns the dot product of `x` and `y` 36 */ 37 return typed(name, { 38 'Array | DenseMatrix, Array | DenseMatrix': _denseDot, 39 'SparseMatrix, SparseMatrix': _sparseDot 40 }); 41 42 function _validateDim(x, y) { 43 var xSize = _size(x); 44 45 var ySize = _size(y); 46 47 var xLen, yLen; 48 49 if (xSize.length === 1) { 50 xLen = xSize[0]; 51 } else if (xSize.length === 2 && xSize[1] === 1) { 52 xLen = xSize[0]; 53 } else { 54 throw new RangeError('Expected a column vector, instead got a matrix of size (' + xSize.join(', ') + ')'); 55 } 56 57 if (ySize.length === 1) { 58 yLen = ySize[0]; 59 } else if (ySize.length === 2 && ySize[1] === 1) { 60 yLen = ySize[0]; 61 } else { 62 throw new RangeError('Expected a column vector, instead got a matrix of size (' + ySize.join(', ') + ')'); 63 } 64 65 if (xLen !== yLen) throw new RangeError('Vectors must have equal length (' + xLen + ' != ' + yLen + ')'); 66 if (xLen === 0) throw new RangeError('Cannot calculate the dot product of empty vectors'); 67 return xLen; 68 } 69 70 function _denseDot(a, b) { 71 var N = _validateDim(a, b); 72 73 var adata = isMatrix(a) ? a._data : a; 74 var adt = isMatrix(a) ? a._datatype : undefined; 75 var bdata = isMatrix(b) ? b._data : b; 76 var bdt = isMatrix(b) ? b._datatype : undefined; // are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors) 77 78 var aIsColumn = _size(a).length === 2; 79 var bIsColumn = _size(b).length === 2; 80 var add = addScalar; 81 var mul = multiplyScalar; // process data types 82 83 if (adt && bdt && adt === bdt && typeof adt === 'string') { 84 var dt = adt; // find signatures that matches (dt, dt) 85 86 add = typed.find(addScalar, [dt, dt]); 87 mul = typed.find(multiplyScalar, [dt, dt]); 88 } // both vectors 1-dimensional 89 90 91 if (!aIsColumn && !bIsColumn) { 92 var c = mul(conj(adata[0]), bdata[0]); 93 94 for (var i = 1; i < N; i++) { 95 c = add(c, mul(conj(adata[i]), bdata[i])); 96 } 97 98 return c; 99 } // a is 1-dim, b is column 100 101 102 if (!aIsColumn && bIsColumn) { 103 var _c = mul(conj(adata[0]), bdata[0][0]); 104 105 for (var _i = 1; _i < N; _i++) { 106 _c = add(_c, mul(conj(adata[_i]), bdata[_i][0])); 107 } 108 109 return _c; 110 } // a is column, b is 1-dim 111 112 113 if (aIsColumn && !bIsColumn) { 114 var _c2 = mul(conj(adata[0][0]), bdata[0]); 115 116 for (var _i2 = 1; _i2 < N; _i2++) { 117 _c2 = add(_c2, mul(conj(adata[_i2][0]), bdata[_i2])); 118 } 119 120 return _c2; 121 } // both vectors are column 122 123 124 if (aIsColumn && bIsColumn) { 125 var _c3 = mul(conj(adata[0][0]), bdata[0][0]); 126 127 for (var _i3 = 1; _i3 < N; _i3++) { 128 _c3 = add(_c3, mul(conj(adata[_i3][0]), bdata[_i3][0])); 129 } 130 131 return _c3; 132 } 133 } 134 135 function _sparseDot(x, y) { 136 _validateDim(x, y); 137 138 var xindex = x._index; 139 var xvalues = x._values; 140 var yindex = y._index; 141 var yvalues = y._values; // TODO optimize add & mul using datatype 142 143 var c = 0; 144 var add = addScalar; 145 var mul = multiplyScalar; 146 var i = 0; 147 var j = 0; 148 149 while (i < xindex.length && j < yindex.length) { 150 var I = xindex[i]; 151 var J = yindex[j]; 152 153 if (I < J) { 154 i++; 155 continue; 156 } 157 158 if (I > J) { 159 j++; 160 continue; 161 } 162 163 if (I === J) { 164 c = add(c, mul(xvalues[i], yvalues[j])); 165 i++; 166 j++; 167 } 168 } 169 170 return c; 171 } // TODO remove this once #1771 is fixed 172 173 174 function _size(x) { 175 return isMatrix(x) ? x.size() : size(x); 176 } 177 });