kldivergence.js (2463B)
1 "use strict"; 2 3 Object.defineProperty(exports, "__esModule", { 4 value: true 5 }); 6 exports.createKldivergence = void 0; 7 8 var _factory = require("../../utils/factory.js"); 9 10 var name = 'kldivergence'; 11 var dependencies = ['typed', 'matrix', 'divide', 'sum', 'multiply', 'dotDivide', 'log', 'isNumeric']; 12 var createKldivergence = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) { 13 var typed = _ref.typed, 14 matrix = _ref.matrix, 15 divide = _ref.divide, 16 sum = _ref.sum, 17 multiply = _ref.multiply, 18 dotDivide = _ref.dotDivide, 19 log = _ref.log, 20 isNumeric = _ref.isNumeric; 21 22 /** 23 * Calculate the Kullback-Leibler (KL) divergence between two distributions 24 * 25 * Syntax: 26 * 27 * math.kldivergence(x, y) 28 * 29 * Examples: 30 * 31 * math.kldivergence([0.7,0.5,0.4], [0.2,0.9,0.5]) //returns 0.24376698773121153 32 * 33 * 34 * @param {Array | Matrix} q First vector 35 * @param {Array | Matrix} p Second vector 36 * @return {number} Returns distance between q and p 37 */ 38 return typed(name, { 39 'Array, Array': function ArrayArray(q, p) { 40 return _kldiv(matrix(q), matrix(p)); 41 }, 42 'Matrix, Array': function MatrixArray(q, p) { 43 return _kldiv(q, matrix(p)); 44 }, 45 'Array, Matrix': function ArrayMatrix(q, p) { 46 return _kldiv(matrix(q), p); 47 }, 48 'Matrix, Matrix': function MatrixMatrix(q, p) { 49 return _kldiv(q, p); 50 } 51 }); 52 53 function _kldiv(q, p) { 54 var plength = p.size().length; 55 var qlength = q.size().length; 56 57 if (plength > 1) { 58 throw new Error('first object must be one dimensional'); 59 } 60 61 if (qlength > 1) { 62 throw new Error('second object must be one dimensional'); 63 } 64 65 if (plength !== qlength) { 66 throw new Error('Length of two vectors must be equal'); 67 } // Before calculation, apply normalization 68 69 70 var sumq = sum(q); 71 72 if (sumq === 0) { 73 throw new Error('Sum of elements in first object must be non zero'); 74 } 75 76 var sump = sum(p); 77 78 if (sump === 0) { 79 throw new Error('Sum of elements in second object must be non zero'); 80 } 81 82 var qnorm = divide(q, sum(q)); 83 var pnorm = divide(p, sum(p)); 84 var result = sum(multiply(qnorm, log(dotDivide(qnorm, pnorm)))); 85 86 if (isNumeric(result)) { 87 return result; 88 } else { 89 return Number.NaN; 90 } 91 } 92 }); 93 exports.createKldivergence = createKldivergence;