simple-squiggle

A restricted subset of Squiggle
Log | Files | Refs | README

kldivergence.js (2170B)


      1 import { factory } from '../../utils/factory.js';
      2 var name = 'kldivergence';
      3 var dependencies = ['typed', 'matrix', 'divide', 'sum', 'multiply', 'dotDivide', 'log', 'isNumeric'];
      4 export var createKldivergence = /* #__PURE__ */factory(name, dependencies, _ref => {
      5   var {
      6     typed,
      7     matrix,
      8     divide,
      9     sum,
     10     multiply,
     11     dotDivide,
     12     log,
     13     isNumeric
     14   } = _ref;
     15 
     16   /**
     17      * Calculate the Kullback-Leibler (KL) divergence  between two distributions
     18      *
     19      * Syntax:
     20      *
     21      *     math.kldivergence(x, y)
     22      *
     23      * Examples:
     24      *
     25      *     math.kldivergence([0.7,0.5,0.4], [0.2,0.9,0.5])   //returns 0.24376698773121153
     26      *
     27      *
     28      * @param  {Array | Matrix} q    First vector
     29      * @param  {Array | Matrix} p    Second vector
     30      * @return {number}              Returns distance between q and p
     31      */
     32   return typed(name, {
     33     'Array, Array': function ArrayArray(q, p) {
     34       return _kldiv(matrix(q), matrix(p));
     35     },
     36     'Matrix, Array': function MatrixArray(q, p) {
     37       return _kldiv(q, matrix(p));
     38     },
     39     'Array, Matrix': function ArrayMatrix(q, p) {
     40       return _kldiv(matrix(q), p);
     41     },
     42     'Matrix, Matrix': function MatrixMatrix(q, p) {
     43       return _kldiv(q, p);
     44     }
     45   });
     46 
     47   function _kldiv(q, p) {
     48     var plength = p.size().length;
     49     var qlength = q.size().length;
     50 
     51     if (plength > 1) {
     52       throw new Error('first object must be one dimensional');
     53     }
     54 
     55     if (qlength > 1) {
     56       throw new Error('second object must be one dimensional');
     57     }
     58 
     59     if (plength !== qlength) {
     60       throw new Error('Length of two vectors must be equal');
     61     } // Before calculation, apply normalization
     62 
     63 
     64     var sumq = sum(q);
     65 
     66     if (sumq === 0) {
     67       throw new Error('Sum of elements in first object must be non zero');
     68     }
     69 
     70     var sump = sum(p);
     71 
     72     if (sump === 0) {
     73       throw new Error('Sum of elements in second object must be non zero');
     74     }
     75 
     76     var qnorm = divide(q, sum(q));
     77     var pnorm = divide(p, sum(p));
     78     var result = sum(multiply(qnorm, log(dotDivide(qnorm, pnorm))));
     79 
     80     if (isNumeric(result)) {
     81       return result;
     82     } else {
     83       return Number.NaN;
     84     }
     85   }
     86 });