simple-squiggle

A restricted subset of Squiggle
Log | Files | Refs | README

dot.js (4764B)


      1 "use strict";
      2 
      3 Object.defineProperty(exports, "__esModule", {
      4   value: true
      5 });
      6 exports.createDot = void 0;
      7 
      8 var _factory = require("../../utils/factory.js");
      9 
     10 var _is = require("../../utils/is.js");
     11 
     12 var name = 'dot';
     13 var dependencies = ['typed', 'addScalar', 'multiplyScalar', 'conj', 'size'];
     14 var createDot = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) {
     15   var typed = _ref.typed,
     16       addScalar = _ref.addScalar,
     17       multiplyScalar = _ref.multiplyScalar,
     18       conj = _ref.conj,
     19       size = _ref.size;
     20 
     21   /**
     22    * Calculate the dot product of two vectors. The dot product of
     23    * `A = [a1, a2, ..., an]` and `B = [b1, b2, ..., bn]` is defined as:
     24    *
     25    *    dot(A, B) = conj(a1) * b1 + conj(a2) * b2 + ... + conj(an) * bn
     26    *
     27    * Syntax:
     28    *
     29    *    math.dot(x, y)
     30    *
     31    * Examples:
     32    *
     33    *    math.dot([2, 4, 1], [2, 2, 3])       // returns number 15
     34    *    math.multiply([2, 4, 1], [2, 2, 3])  // returns number 15
     35    *
     36    * See also:
     37    *
     38    *    multiply, cross
     39    *
     40    * @param  {Array | Matrix} x     First vector
     41    * @param  {Array | Matrix} y     Second vector
     42    * @return {number}               Returns the dot product of `x` and `y`
     43    */
     44   return typed(name, {
     45     'Array | DenseMatrix, Array | DenseMatrix': _denseDot,
     46     'SparseMatrix, SparseMatrix': _sparseDot
     47   });
     48 
     49   function _validateDim(x, y) {
     50     var xSize = _size(x);
     51 
     52     var ySize = _size(y);
     53 
     54     var xLen, yLen;
     55 
     56     if (xSize.length === 1) {
     57       xLen = xSize[0];
     58     } else if (xSize.length === 2 && xSize[1] === 1) {
     59       xLen = xSize[0];
     60     } else {
     61       throw new RangeError('Expected a column vector, instead got a matrix of size (' + xSize.join(', ') + ')');
     62     }
     63 
     64     if (ySize.length === 1) {
     65       yLen = ySize[0];
     66     } else if (ySize.length === 2 && ySize[1] === 1) {
     67       yLen = ySize[0];
     68     } else {
     69       throw new RangeError('Expected a column vector, instead got a matrix of size (' + ySize.join(', ') + ')');
     70     }
     71 
     72     if (xLen !== yLen) throw new RangeError('Vectors must have equal length (' + xLen + ' != ' + yLen + ')');
     73     if (xLen === 0) throw new RangeError('Cannot calculate the dot product of empty vectors');
     74     return xLen;
     75   }
     76 
     77   function _denseDot(a, b) {
     78     var N = _validateDim(a, b);
     79 
     80     var adata = (0, _is.isMatrix)(a) ? a._data : a;
     81     var adt = (0, _is.isMatrix)(a) ? a._datatype : undefined;
     82     var bdata = (0, _is.isMatrix)(b) ? b._data : b;
     83     var bdt = (0, _is.isMatrix)(b) ? b._datatype : undefined; // are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors)
     84 
     85     var aIsColumn = _size(a).length === 2;
     86     var bIsColumn = _size(b).length === 2;
     87     var add = addScalar;
     88     var mul = multiplyScalar; // process data types
     89 
     90     if (adt && bdt && adt === bdt && typeof adt === 'string') {
     91       var dt = adt; // find signatures that matches (dt, dt)
     92 
     93       add = typed.find(addScalar, [dt, dt]);
     94       mul = typed.find(multiplyScalar, [dt, dt]);
     95     } // both vectors 1-dimensional
     96 
     97 
     98     if (!aIsColumn && !bIsColumn) {
     99       var c = mul(conj(adata[0]), bdata[0]);
    100 
    101       for (var i = 1; i < N; i++) {
    102         c = add(c, mul(conj(adata[i]), bdata[i]));
    103       }
    104 
    105       return c;
    106     } // a is 1-dim, b is column
    107 
    108 
    109     if (!aIsColumn && bIsColumn) {
    110       var _c = mul(conj(adata[0]), bdata[0][0]);
    111 
    112       for (var _i = 1; _i < N; _i++) {
    113         _c = add(_c, mul(conj(adata[_i]), bdata[_i][0]));
    114       }
    115 
    116       return _c;
    117     } // a is column, b is 1-dim
    118 
    119 
    120     if (aIsColumn && !bIsColumn) {
    121       var _c2 = mul(conj(adata[0][0]), bdata[0]);
    122 
    123       for (var _i2 = 1; _i2 < N; _i2++) {
    124         _c2 = add(_c2, mul(conj(adata[_i2][0]), bdata[_i2]));
    125       }
    126 
    127       return _c2;
    128     } // both vectors are column
    129 
    130 
    131     if (aIsColumn && bIsColumn) {
    132       var _c3 = mul(conj(adata[0][0]), bdata[0][0]);
    133 
    134       for (var _i3 = 1; _i3 < N; _i3++) {
    135         _c3 = add(_c3, mul(conj(adata[_i3][0]), bdata[_i3][0]));
    136       }
    137 
    138       return _c3;
    139     }
    140   }
    141 
    142   function _sparseDot(x, y) {
    143     _validateDim(x, y);
    144 
    145     var xindex = x._index;
    146     var xvalues = x._values;
    147     var yindex = y._index;
    148     var yvalues = y._values; // TODO optimize add & mul using datatype
    149 
    150     var c = 0;
    151     var add = addScalar;
    152     var mul = multiplyScalar;
    153     var i = 0;
    154     var j = 0;
    155 
    156     while (i < xindex.length && j < yindex.length) {
    157       var I = xindex[i];
    158       var J = yindex[j];
    159 
    160       if (I < J) {
    161         i++;
    162         continue;
    163       }
    164 
    165       if (I > J) {
    166         j++;
    167         continue;
    168       }
    169 
    170       if (I === J) {
    171         c = add(c, mul(xvalues[i], yvalues[j]));
    172         i++;
    173         j++;
    174       }
    175     }
    176 
    177     return c;
    178   } // TODO remove this once #1771 is fixed
    179 
    180 
    181   function _size(x) {
    182     return (0, _is.isMatrix)(x) ? x.size() : size(x);
    183   }
    184 });
    185 exports.createDot = createDot;