time-to-botec

Benchmark sampling in different programming languages
Log | Files | Refs | README

fit.js (3647B)


      1 /**
      2 * @license Apache-2.0
      3 *
      4 * Copyright (c) 2018 The Stdlib Authors.
      5 *
      6 * Licensed under the Apache License, Version 2.0 (the "License");
      7 * you may not use this file except in compliance with the License.
      8 * You may obtain a copy of the License at
      9 *
     10 *    http://www.apache.org/licenses/LICENSE-2.0
     11 *
     12 * Unless required by applicable law or agreed to in writing, software
     13 * distributed under the License is distributed on an "AS IS" BASIS,
     14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     15 * See the License for the specific language governing permissions and
     16 * limitations under the License.
     17 */
     18 
     19 'use strict';
     20 
     21 // MODULES //
     22 
     23 var isPositiveInteger = require( '@stdlib/assert/is-positive-integer' );
     24 var randu = require( '@stdlib/random/base/randu' );
     25 var avgMatrix = require( './avg_matrix.js' );
     26 
     27 
     28 // MAIN //
     29 
     30 /**
     31 * Fit model using collapsed Gibbs sampling.
     32 *
     33 * @private
     34 * @param {PositiveInteger} iter - number of sampling iterations
     35 * @param {PositiveInteger} burnin - number of estimates to be thrown away at beginning
     36 * @param {PositiveInteger} thin - number of discarded in-between iterations
     37 * @throws {TypeError} first argument must be a positive integer
     38 * @throws {TypeError} second argument must be a positive integer
     39 * @throws {TypeError} third argument must be a positive integer
     40 */
     41 function fit( iter, burnin, thin ) {
     42 	/* eslint-disable no-invalid-this */
     43 	var kalpha;
     44 	var wbeta;
     45 	var topic;
     46 	var theta;
     47 	var prob;
     48 	var word;
     49 	var phi;
     50 	var len;
     51 	var nt;
     52 	var d;
     53 	var i;
     54 	var j;
     55 	var u;
     56 	var w;
     57 
     58 	if ( !isPositiveInteger( iter ) ) {
     59 		throw new TypeError( 'invalid argument. First argument must be a positive integer. Value: `' + iter + '`.' );
     60 	}
     61 	if ( !isPositiveInteger( burnin ) ) {
     62 		throw new TypeError( 'invalid argument. Second argument must be a positive integer. Value: `' + burnin + '`.' );
     63 	}
     64 	if ( !isPositiveInteger( thin ) ) {
     65 		throw new TypeError( 'invalid argument. Third argument must be a positive integer. Value: `' + thin + '`.' );
     66 	}
     67 
     68 	wbeta = this.W * this.beta;
     69 	kalpha = this.K * this.alpha;
     70 
     71 	for ( i = 0; i < iter; i++ ) {
     72 		for ( d = 0; d < this.D; d++ ) {
     73 			for ( w = 0; w < this.ndSum[ d ]; w++ ) {
     74 				word = this.w[ d ][ w ];
     75 				topic = this.z[ d ][ w ];
     76 
     77 				this.nw.set( word, topic, this.nw.get( word, topic ) - 1 );
     78 				this.nd.set( d, topic, this.nd.get( d, topic ) - 1 );
     79 				this.ndSum[ d ] -= 1;
     80 				this.nwSum[ topic ] -= 1;
     81 
     82 				prob = [];
     83 				for ( j = 0; j < this.K; j++ ) {
     84 					prob.push( ( this.nw.get( word, j ) + this.beta ) /
     85 						( this.nwSum[ j ] + wbeta ) *
     86 						( this.nd.get( d, j ) + this.alpha ) /
     87 						( this.ndSum[ d ] + kalpha ) );
     88 				}
     89 				for ( j = 1; j < this.K; j++ ) {
     90 					prob[ j ] += prob[ j - 1 ];
     91 				}
     92 				u = prob[ this.K - 1 ] * randu();
     93 				topic = 0;
     94 				for ( nt = 0; nt < this.K; nt++ ) {
     95 					if ( prob[ nt ] > u ) {
     96 						topic = nt;
     97 						break;
     98 					}
     99 				}
    100 				// Assign new z_i to counts...
    101 				this.nw.set( word, topic, this.nw.get( word, topic ) + 1 );
    102 				this.nd.set( d, topic, this.nd.get( d, topic ) + 1 );
    103 				this.nwSum[ topic ] += 1;
    104 				this.ndSum[ d ] += 1;
    105 
    106 				this.z[ d ][ w ] = topic;
    107 			}
    108 		}
    109 
    110 		if ( i % thin === 0 && i > burnin ) {
    111 			phi = this.getPhis();
    112 			theta = this.getThetas();
    113 
    114 			this.phiList.push( phi );
    115 			this.thetaList.push( theta );
    116 
    117 			len = this.phiList.length;
    118 			if ( len === 1 ) {
    119 				this.avgPhi = phi;
    120 			} else {
    121 				this.avgPhi = avgMatrix( this.avgPhi, phi, len );
    122 			}
    123 			len = this.thetaList.length;
    124 			if ( len === 1 ) {
    125 				this.avgTheta = theta;
    126 			} else {
    127 				this.avgTheta = avgMatrix( this.avgTheta, theta, len );
    128 			}
    129 		}
    130 	}
    131 }
    132 
    133 
    134 // EXPORTS //
    135 
    136 module.exports = fit;