time-to-botec

Benchmark sampling in different programming languages
Log | Files | Refs | README

lda.js (5661B)


      1 /**
      2 * @license Apache-2.0
      3 *
      4 * Copyright (c) 2018 The Stdlib Authors.
      5 *
      6 * Licensed under the Apache License, Version 2.0 (the "License");
      7 * you may not use this file except in compliance with the License.
      8 * You may obtain a copy of the License at
      9 *
     10 *    http://www.apache.org/licenses/LICENSE-2.0
     11 *
     12 * Unless required by applicable law or agreed to in writing, software
     13 * distributed under the License is distributed on an "AS IS" BASIS,
     14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     15 * See the License for the specific language governing permissions and
     16 * limitations under the License.
     17 */
     18 
     19 'use strict';
     20 
     21 // MODULES //
     22 
     23 var isNonNegativeInteger = require( '@stdlib/assert/is-nonnegative-integer' );
     24 var isPositiveInteger = require( '@stdlib/assert/is-positive-integer' );
     25 var isStringArray = require( '@stdlib/assert/is-string-array' );
     26 var setReadOnly = require( '@stdlib/utils/define-read-only-property' );
     27 var contains = require( '@stdlib/assert/contains' );
     28 var tokenize = require( './../../tokenize' );
     29 var Int32Array = require( '@stdlib/array/int32' );
     30 var matrix = require( './matrix.js' );
     31 var getThetas = require( './get_thetas.js' );
     32 var validate = require( './validate.js' );
     33 var getPhis = require( './get_phis.js' );
     34 var init = require( './init.js' );
     35 var fit = require( './fit.js' );
     36 
     37 
     38 // FUNCTIONS //
     39 
     40 /**
     41 * Find index of the value in vocabulary equal to the supplied search value.
     42 *
     43 * @private
     44 * @param {Array} vocab - vocabulary
     45 * @param {string} searchVal - search value
     46 * @returns {integer} index in vocab if search value is found, -1 otherwise
     47 */
     48 function findIndex( vocab, searchVal ) {
     49 	var i;
     50 	for ( i = 0; i < vocab.length; i++ ) {
     51 		if ( vocab[ i ] === searchVal ) {
     52 			return i;
     53 		}
     54 	}
     55 	return -1;
     56 }
     57 
     58 
     59 // MAIN //
     60 
     61 /**
     62 * Latent Dirichlet Allocation via collapsed Gibbs sampling.
     63 *
     64 * @param {StringArray} documents - document corpus
     65 * @param {PositiveInteger} K - number of topics
     66 * @param {Options} [options] - options object
     67 * @param {PositiveNumber} [options.alpha=50/K] - Dirichlet hyper-parameter of topic vector theta:
     68 * @param {PositiveNumber} [options.beta=0.1] - Dirichlet hyper-parameter for word vector phi
     69 * @throws {TypeError} first argument must be an array of strings
     70 * @throws {TypeError} second argument must be a positive integer
     71 * @throws {TypeError} must provide valid options
     72 * @returns {Object} model object
     73 */
     74 function lda( documents, K, options ) {
     75 	var target;
     76 	var vocab;
     77 	var model;
     78 	var alpha;
     79 	var beta;
     80 	var opts;
     81 	var err;
     82 	var pos;
     83 	var nd;
     84 	var it;
     85 	var wd;
     86 	var D;
     87 	var d;
     88 	var i;
     89 	var j;
     90 	var W;
     91 	var w;
     92 
     93 	if ( !isStringArray( documents ) ) {
     94 		throw new TypeError( 'invalid argument. First argument must be a string array. Value: `' + documents + '`.' );
     95 	}
     96 	if ( !isPositiveInteger( K ) ) {
     97 		throw new TypeError( 'invalid argument. Number of topics `K` must be a positive integer. Value: `' + K + '`.' );
     98 	}
     99 	opts = {};
    100 	if ( arguments.length > 2 ) {
    101 		err = validate( opts, options );
    102 		if ( err ) {
    103 			throw err;
    104 		}
    105 	}
    106 
    107 	// Number of documents:
    108 	D = documents.length;
    109 
    110 	// Hyper-parameter for Dirichlet distribution of topic vector theta:
    111 	alpha = opts.alpha || 50 / K;
    112 
    113 	// Hyper-parameter of Dirichlet distribution of phi:
    114 	beta = opts.beta || 0.1;
    115 
    116 	// Extract words & construct vocabulary:s
    117 	vocab = [];
    118 	w = [];
    119 	pos = 0;
    120 	for ( d = 0; d < D; d++ ) {
    121 		w.push( [] );
    122 		wd = tokenize( documents[ d ] );
    123 		nd = wd.length;
    124 		for ( i = 0; i < nd; i++ ) {
    125 			target = wd[ i ];
    126 			it = findIndex( vocab, target );
    127 			if ( it === -1 ) {
    128 				vocab.push( target );
    129 				w[ d ].push( pos );
    130 				pos += 1;
    131 			} else {
    132 				w[ d ].push( it );
    133 			}
    134 		}
    135 	}
    136 	// Size of vocabulary:
    137 	W = vocab.length;
    138 
    139 	model = {};
    140 
    141 	// Attach read-only properties:
    142 	setReadOnly( model, 'K', K );
    143 	setReadOnly( model, 'D', D );
    144 	setReadOnly( model, 'W', W );
    145 	setReadOnly( model, 'alpha', alpha );
    146 	setReadOnly( model, 'beta', beta );
    147 
    148 	// Attach methods:
    149 	setReadOnly( model, 'init', init );
    150 	setReadOnly( model, 'fit', fit );
    151 	setReadOnly( model, 'getPhis', getPhis );
    152 	setReadOnly( model, 'getThetas', getThetas );
    153 	setReadOnly( model, 'getTerms', getTerms );
    154 
    155 	model.nwSum = new Int32Array( K );
    156 	model.ndSum = new Int32Array( D );
    157 	model.nw = matrix( [ W, K ], 'int32' );
    158 	model.nd = matrix( [ D, K ], 'int32' );
    159 
    160 	model.phiList = [];
    161 	model.thetaList = [];
    162 
    163 	model.w = w;
    164 	model.init();
    165 
    166 	return model;
    167 
    168 	/**
    169 	* Get top terms for the specified topic.
    170 	*
    171 	* @private
    172 	* @param {NonNegativeInteger} k - topic
    173 	* @param {PositiveInteger} [no=10] - number of terms
    174 	* @throws {TypeError} first argument must be a nonnegative integer smaller than the total number of topics
    175 	* @throws {TypeError} second argument must be a positive integer
    176 	* @returns {Array} word probability array
    177 	*/
    178 	function getTerms( k, no ) {
    179 		/* eslint-disable no-invalid-this */
    180 		var skip;
    181 		var phi;
    182 		var ret;
    183 		var max;
    184 		var mid;
    185 		var i;
    186 
    187 		if ( !isNonNegativeInteger( k ) || k >= K ) {
    188 			throw new TypeError( 'invalid argument. First argument must be a nonnegative integer smaller than the total number of topics. Value: `' + k + '`.' );
    189 		}
    190 		if ( no ) {
    191 			if ( !isPositiveInteger( no ) ) {
    192 				throw new TypeError( 'invalid argument. Second argument must be a positive integer. Value: `' + no + '`.' );
    193 			}
    194 		} else {
    195 			no = 10;
    196 		}
    197 
    198 		ret = [];
    199 		skip = [];
    200 		for ( i = 0; i < no; i++ ) {
    201 			max = 0;
    202 			for ( j = 0; j < this.W; j++ ) {
    203 				phi = this.avgPhi.get( k, j );
    204 				if ( phi > max && !contains( skip, j ) ) {
    205 					max = phi;
    206 					mid = j;
    207 				}
    208 			}
    209 			skip.push( mid );
    210 			ret.push({
    211 				'word': vocab[ mid ],
    212 				'prob': max
    213 			});
    214 		}
    215 		return ret;
    216 	}
    217 }
    218 
    219 
    220 // EXPORTS //
    221 
    222 module.exports = lda;