simple-squiggle

A restricted subset of Squiggle
Log | Files | Refs | README

dot.js (4501B)


      1 import { factory } from '../../utils/factory.js';
      2 import { isMatrix } from '../../utils/is.js';
      3 var name = 'dot';
      4 var dependencies = ['typed', 'addScalar', 'multiplyScalar', 'conj', 'size'];
      5 export var createDot = /* #__PURE__ */factory(name, dependencies, _ref => {
      6   var {
      7     typed,
      8     addScalar,
      9     multiplyScalar,
     10     conj,
     11     size
     12   } = _ref;
     13 
     14   /**
     15    * Calculate the dot product of two vectors. The dot product of
     16    * `A = [a1, a2, ..., an]` and `B = [b1, b2, ..., bn]` is defined as:
     17    *
     18    *    dot(A, B) = conj(a1) * b1 + conj(a2) * b2 + ... + conj(an) * bn
     19    *
     20    * Syntax:
     21    *
     22    *    math.dot(x, y)
     23    *
     24    * Examples:
     25    *
     26    *    math.dot([2, 4, 1], [2, 2, 3])       // returns number 15
     27    *    math.multiply([2, 4, 1], [2, 2, 3])  // returns number 15
     28    *
     29    * See also:
     30    *
     31    *    multiply, cross
     32    *
     33    * @param  {Array | Matrix} x     First vector
     34    * @param  {Array | Matrix} y     Second vector
     35    * @return {number}               Returns the dot product of `x` and `y`
     36    */
     37   return typed(name, {
     38     'Array | DenseMatrix, Array | DenseMatrix': _denseDot,
     39     'SparseMatrix, SparseMatrix': _sparseDot
     40   });
     41 
     42   function _validateDim(x, y) {
     43     var xSize = _size(x);
     44 
     45     var ySize = _size(y);
     46 
     47     var xLen, yLen;
     48 
     49     if (xSize.length === 1) {
     50       xLen = xSize[0];
     51     } else if (xSize.length === 2 && xSize[1] === 1) {
     52       xLen = xSize[0];
     53     } else {
     54       throw new RangeError('Expected a column vector, instead got a matrix of size (' + xSize.join(', ') + ')');
     55     }
     56 
     57     if (ySize.length === 1) {
     58       yLen = ySize[0];
     59     } else if (ySize.length === 2 && ySize[1] === 1) {
     60       yLen = ySize[0];
     61     } else {
     62       throw new RangeError('Expected a column vector, instead got a matrix of size (' + ySize.join(', ') + ')');
     63     }
     64 
     65     if (xLen !== yLen) throw new RangeError('Vectors must have equal length (' + xLen + ' != ' + yLen + ')');
     66     if (xLen === 0) throw new RangeError('Cannot calculate the dot product of empty vectors');
     67     return xLen;
     68   }
     69 
     70   function _denseDot(a, b) {
     71     var N = _validateDim(a, b);
     72 
     73     var adata = isMatrix(a) ? a._data : a;
     74     var adt = isMatrix(a) ? a._datatype : undefined;
     75     var bdata = isMatrix(b) ? b._data : b;
     76     var bdt = isMatrix(b) ? b._datatype : undefined; // are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors)
     77 
     78     var aIsColumn = _size(a).length === 2;
     79     var bIsColumn = _size(b).length === 2;
     80     var add = addScalar;
     81     var mul = multiplyScalar; // process data types
     82 
     83     if (adt && bdt && adt === bdt && typeof adt === 'string') {
     84       var dt = adt; // find signatures that matches (dt, dt)
     85 
     86       add = typed.find(addScalar, [dt, dt]);
     87       mul = typed.find(multiplyScalar, [dt, dt]);
     88     } // both vectors 1-dimensional
     89 
     90 
     91     if (!aIsColumn && !bIsColumn) {
     92       var c = mul(conj(adata[0]), bdata[0]);
     93 
     94       for (var i = 1; i < N; i++) {
     95         c = add(c, mul(conj(adata[i]), bdata[i]));
     96       }
     97 
     98       return c;
     99     } // a is 1-dim, b is column
    100 
    101 
    102     if (!aIsColumn && bIsColumn) {
    103       var _c = mul(conj(adata[0]), bdata[0][0]);
    104 
    105       for (var _i = 1; _i < N; _i++) {
    106         _c = add(_c, mul(conj(adata[_i]), bdata[_i][0]));
    107       }
    108 
    109       return _c;
    110     } // a is column, b is 1-dim
    111 
    112 
    113     if (aIsColumn && !bIsColumn) {
    114       var _c2 = mul(conj(adata[0][0]), bdata[0]);
    115 
    116       for (var _i2 = 1; _i2 < N; _i2++) {
    117         _c2 = add(_c2, mul(conj(adata[_i2][0]), bdata[_i2]));
    118       }
    119 
    120       return _c2;
    121     } // both vectors are column
    122 
    123 
    124     if (aIsColumn && bIsColumn) {
    125       var _c3 = mul(conj(adata[0][0]), bdata[0][0]);
    126 
    127       for (var _i3 = 1; _i3 < N; _i3++) {
    128         _c3 = add(_c3, mul(conj(adata[_i3][0]), bdata[_i3][0]));
    129       }
    130 
    131       return _c3;
    132     }
    133   }
    134 
    135   function _sparseDot(x, y) {
    136     _validateDim(x, y);
    137 
    138     var xindex = x._index;
    139     var xvalues = x._values;
    140     var yindex = y._index;
    141     var yvalues = y._values; // TODO optimize add & mul using datatype
    142 
    143     var c = 0;
    144     var add = addScalar;
    145     var mul = multiplyScalar;
    146     var i = 0;
    147     var j = 0;
    148 
    149     while (i < xindex.length && j < yindex.length) {
    150       var I = xindex[i];
    151       var J = yindex[j];
    152 
    153       if (I < J) {
    154         i++;
    155         continue;
    156       }
    157 
    158       if (I > J) {
    159         j++;
    160         continue;
    161       }
    162 
    163       if (I === J) {
    164         c = add(c, mul(xvalues[i], yvalues[j]));
    165         i++;
    166         j++;
    167       }
    168     }
    169 
    170     return c;
    171   } // TODO remove this once #1771 is fixed
    172 
    173 
    174   function _size(x) {
    175     return isMatrix(x) ? x.size() : size(x);
    176   }
    177 });