lda.js (5661B)
1 /** 2 * @license Apache-2.0 3 * 4 * Copyright (c) 2018 The Stdlib Authors. 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 'use strict'; 20 21 // MODULES // 22 23 var isNonNegativeInteger = require( '@stdlib/assert/is-nonnegative-integer' ); 24 var isPositiveInteger = require( '@stdlib/assert/is-positive-integer' ); 25 var isStringArray = require( '@stdlib/assert/is-string-array' ); 26 var setReadOnly = require( '@stdlib/utils/define-read-only-property' ); 27 var contains = require( '@stdlib/assert/contains' ); 28 var tokenize = require( './../../tokenize' ); 29 var Int32Array = require( '@stdlib/array/int32' ); 30 var matrix = require( './matrix.js' ); 31 var getThetas = require( './get_thetas.js' ); 32 var validate = require( './validate.js' ); 33 var getPhis = require( './get_phis.js' ); 34 var init = require( './init.js' ); 35 var fit = require( './fit.js' ); 36 37 38 // FUNCTIONS // 39 40 /** 41 * Find index of the value in vocabulary equal to the supplied search value. 42 * 43 * @private 44 * @param {Array} vocab - vocabulary 45 * @param {string} searchVal - search value 46 * @returns {integer} index in vocab if search value is found, -1 otherwise 47 */ 48 function findIndex( vocab, searchVal ) { 49 var i; 50 for ( i = 0; i < vocab.length; i++ ) { 51 if ( vocab[ i ] === searchVal ) { 52 return i; 53 } 54 } 55 return -1; 56 } 57 58 59 // MAIN // 60 61 /** 62 * Latent Dirichlet Allocation via collapsed Gibbs sampling. 63 * 64 * @param {StringArray} documents - document corpus 65 * @param {PositiveInteger} K - number of topics 66 * @param {Options} [options] - options object 67 * @param {PositiveNumber} [options.alpha=50/K] - Dirichlet hyper-parameter of topic vector theta: 68 * @param {PositiveNumber} [options.beta=0.1] - Dirichlet hyper-parameter for word vector phi 69 * @throws {TypeError} first argument must be an array of strings 70 * @throws {TypeError} second argument must be a positive integer 71 * @throws {TypeError} must provide valid options 72 * @returns {Object} model object 73 */ 74 function lda( documents, K, options ) { 75 var target; 76 var vocab; 77 var model; 78 var alpha; 79 var beta; 80 var opts; 81 var err; 82 var pos; 83 var nd; 84 var it; 85 var wd; 86 var D; 87 var d; 88 var i; 89 var j; 90 var W; 91 var w; 92 93 if ( !isStringArray( documents ) ) { 94 throw new TypeError( 'invalid argument. First argument must be a string array. Value: `' + documents + '`.' ); 95 } 96 if ( !isPositiveInteger( K ) ) { 97 throw new TypeError( 'invalid argument. Number of topics `K` must be a positive integer. Value: `' + K + '`.' ); 98 } 99 opts = {}; 100 if ( arguments.length > 2 ) { 101 err = validate( opts, options ); 102 if ( err ) { 103 throw err; 104 } 105 } 106 107 // Number of documents: 108 D = documents.length; 109 110 // Hyper-parameter for Dirichlet distribution of topic vector theta: 111 alpha = opts.alpha || 50 / K; 112 113 // Hyper-parameter of Dirichlet distribution of phi: 114 beta = opts.beta || 0.1; 115 116 // Extract words & construct vocabulary:s 117 vocab = []; 118 w = []; 119 pos = 0; 120 for ( d = 0; d < D; d++ ) { 121 w.push( [] ); 122 wd = tokenize( documents[ d ] ); 123 nd = wd.length; 124 for ( i = 0; i < nd; i++ ) { 125 target = wd[ i ]; 126 it = findIndex( vocab, target ); 127 if ( it === -1 ) { 128 vocab.push( target ); 129 w[ d ].push( pos ); 130 pos += 1; 131 } else { 132 w[ d ].push( it ); 133 } 134 } 135 } 136 // Size of vocabulary: 137 W = vocab.length; 138 139 model = {}; 140 141 // Attach read-only properties: 142 setReadOnly( model, 'K', K ); 143 setReadOnly( model, 'D', D ); 144 setReadOnly( model, 'W', W ); 145 setReadOnly( model, 'alpha', alpha ); 146 setReadOnly( model, 'beta', beta ); 147 148 // Attach methods: 149 setReadOnly( model, 'init', init ); 150 setReadOnly( model, 'fit', fit ); 151 setReadOnly( model, 'getPhis', getPhis ); 152 setReadOnly( model, 'getThetas', getThetas ); 153 setReadOnly( model, 'getTerms', getTerms ); 154 155 model.nwSum = new Int32Array( K ); 156 model.ndSum = new Int32Array( D ); 157 model.nw = matrix( [ W, K ], 'int32' ); 158 model.nd = matrix( [ D, K ], 'int32' ); 159 160 model.phiList = []; 161 model.thetaList = []; 162 163 model.w = w; 164 model.init(); 165 166 return model; 167 168 /** 169 * Get top terms for the specified topic. 170 * 171 * @private 172 * @param {NonNegativeInteger} k - topic 173 * @param {PositiveInteger} [no=10] - number of terms 174 * @throws {TypeError} first argument must be a nonnegative integer smaller than the total number of topics 175 * @throws {TypeError} second argument must be a positive integer 176 * @returns {Array} word probability array 177 */ 178 function getTerms( k, no ) { 179 /* eslint-disable no-invalid-this */ 180 var skip; 181 var phi; 182 var ret; 183 var max; 184 var mid; 185 var i; 186 187 if ( !isNonNegativeInteger( k ) || k >= K ) { 188 throw new TypeError( 'invalid argument. First argument must be a nonnegative integer smaller than the total number of topics. Value: `' + k + '`.' ); 189 } 190 if ( no ) { 191 if ( !isPositiveInteger( no ) ) { 192 throw new TypeError( 'invalid argument. Second argument must be a positive integer. Value: `' + no + '`.' ); 193 } 194 } else { 195 no = 10; 196 } 197 198 ret = []; 199 skip = []; 200 for ( i = 0; i < no; i++ ) { 201 max = 0; 202 for ( j = 0; j < this.W; j++ ) { 203 phi = this.avgPhi.get( k, j ); 204 if ( phi > max && !contains( skip, j ) ) { 205 max = phi; 206 mid = j; 207 } 208 } 209 skip.push( mid ); 210 ret.push({ 211 'word': vocab[ mid ], 212 'prob': max 213 }); 214 } 215 return ret; 216 } 217 } 218 219 220 // EXPORTS // 221 222 module.exports = lda;