simple-squiggle

A restricted subset of Squiggle
Log | Files | Refs | README

kldivergence.js (2463B)


      1 "use strict";
      2 
      3 Object.defineProperty(exports, "__esModule", {
      4   value: true
      5 });
      6 exports.createKldivergence = void 0;
      7 
      8 var _factory = require("../../utils/factory.js");
      9 
     10 var name = 'kldivergence';
     11 var dependencies = ['typed', 'matrix', 'divide', 'sum', 'multiply', 'dotDivide', 'log', 'isNumeric'];
     12 var createKldivergence = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) {
     13   var typed = _ref.typed,
     14       matrix = _ref.matrix,
     15       divide = _ref.divide,
     16       sum = _ref.sum,
     17       multiply = _ref.multiply,
     18       dotDivide = _ref.dotDivide,
     19       log = _ref.log,
     20       isNumeric = _ref.isNumeric;
     21 
     22   /**
     23      * Calculate the Kullback-Leibler (KL) divergence  between two distributions
     24      *
     25      * Syntax:
     26      *
     27      *     math.kldivergence(x, y)
     28      *
     29      * Examples:
     30      *
     31      *     math.kldivergence([0.7,0.5,0.4], [0.2,0.9,0.5])   //returns 0.24376698773121153
     32      *
     33      *
     34      * @param  {Array | Matrix} q    First vector
     35      * @param  {Array | Matrix} p    Second vector
     36      * @return {number}              Returns distance between q and p
     37      */
     38   return typed(name, {
     39     'Array, Array': function ArrayArray(q, p) {
     40       return _kldiv(matrix(q), matrix(p));
     41     },
     42     'Matrix, Array': function MatrixArray(q, p) {
     43       return _kldiv(q, matrix(p));
     44     },
     45     'Array, Matrix': function ArrayMatrix(q, p) {
     46       return _kldiv(matrix(q), p);
     47     },
     48     'Matrix, Matrix': function MatrixMatrix(q, p) {
     49       return _kldiv(q, p);
     50     }
     51   });
     52 
     53   function _kldiv(q, p) {
     54     var plength = p.size().length;
     55     var qlength = q.size().length;
     56 
     57     if (plength > 1) {
     58       throw new Error('first object must be one dimensional');
     59     }
     60 
     61     if (qlength > 1) {
     62       throw new Error('second object must be one dimensional');
     63     }
     64 
     65     if (plength !== qlength) {
     66       throw new Error('Length of two vectors must be equal');
     67     } // Before calculation, apply normalization
     68 
     69 
     70     var sumq = sum(q);
     71 
     72     if (sumq === 0) {
     73       throw new Error('Sum of elements in first object must be non zero');
     74     }
     75 
     76     var sump = sum(p);
     77 
     78     if (sump === 0) {
     79       throw new Error('Sum of elements in second object must be non zero');
     80     }
     81 
     82     var qnorm = divide(q, sum(q));
     83     var pnorm = divide(p, sum(p));
     84     var result = sum(multiply(qnorm, log(dotDivide(qnorm, pnorm))));
     85 
     86     if (isNumeric(result)) {
     87       return result;
     88     } else {
     89       return Number.NaN;
     90     }
     91   }
     92 });
     93 exports.createKldivergence = createKldivergence;