fit.js (3647B)
1 /** 2 * @license Apache-2.0 3 * 4 * Copyright (c) 2018 The Stdlib Authors. 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 'use strict'; 20 21 // MODULES // 22 23 var isPositiveInteger = require( '@stdlib/assert/is-positive-integer' ); 24 var randu = require( '@stdlib/random/base/randu' ); 25 var avgMatrix = require( './avg_matrix.js' ); 26 27 28 // MAIN // 29 30 /** 31 * Fit model using collapsed Gibbs sampling. 32 * 33 * @private 34 * @param {PositiveInteger} iter - number of sampling iterations 35 * @param {PositiveInteger} burnin - number of estimates to be thrown away at beginning 36 * @param {PositiveInteger} thin - number of discarded in-between iterations 37 * @throws {TypeError} first argument must be a positive integer 38 * @throws {TypeError} second argument must be a positive integer 39 * @throws {TypeError} third argument must be a positive integer 40 */ 41 function fit( iter, burnin, thin ) { 42 /* eslint-disable no-invalid-this */ 43 var kalpha; 44 var wbeta; 45 var topic; 46 var theta; 47 var prob; 48 var word; 49 var phi; 50 var len; 51 var nt; 52 var d; 53 var i; 54 var j; 55 var u; 56 var w; 57 58 if ( !isPositiveInteger( iter ) ) { 59 throw new TypeError( 'invalid argument. First argument must be a positive integer. Value: `' + iter + '`.' ); 60 } 61 if ( !isPositiveInteger( burnin ) ) { 62 throw new TypeError( 'invalid argument. Second argument must be a positive integer. Value: `' + burnin + '`.' ); 63 } 64 if ( !isPositiveInteger( thin ) ) { 65 throw new TypeError( 'invalid argument. Third argument must be a positive integer. Value: `' + thin + '`.' ); 66 } 67 68 wbeta = this.W * this.beta; 69 kalpha = this.K * this.alpha; 70 71 for ( i = 0; i < iter; i++ ) { 72 for ( d = 0; d < this.D; d++ ) { 73 for ( w = 0; w < this.ndSum[ d ]; w++ ) { 74 word = this.w[ d ][ w ]; 75 topic = this.z[ d ][ w ]; 76 77 this.nw.set( word, topic, this.nw.get( word, topic ) - 1 ); 78 this.nd.set( d, topic, this.nd.get( d, topic ) - 1 ); 79 this.ndSum[ d ] -= 1; 80 this.nwSum[ topic ] -= 1; 81 82 prob = []; 83 for ( j = 0; j < this.K; j++ ) { 84 prob.push( ( this.nw.get( word, j ) + this.beta ) / 85 ( this.nwSum[ j ] + wbeta ) * 86 ( this.nd.get( d, j ) + this.alpha ) / 87 ( this.ndSum[ d ] + kalpha ) ); 88 } 89 for ( j = 1; j < this.K; j++ ) { 90 prob[ j ] += prob[ j - 1 ]; 91 } 92 u = prob[ this.K - 1 ] * randu(); 93 topic = 0; 94 for ( nt = 0; nt < this.K; nt++ ) { 95 if ( prob[ nt ] > u ) { 96 topic = nt; 97 break; 98 } 99 } 100 // Assign new z_i to counts... 101 this.nw.set( word, topic, this.nw.get( word, topic ) + 1 ); 102 this.nd.set( d, topic, this.nd.get( d, topic ) + 1 ); 103 this.nwSum[ topic ] += 1; 104 this.ndSum[ d ] += 1; 105 106 this.z[ d ][ w ] = topic; 107 } 108 } 109 110 if ( i % thin === 0 && i > burnin ) { 111 phi = this.getPhis(); 112 theta = this.getThetas(); 113 114 this.phiList.push( phi ); 115 this.thetaList.push( theta ); 116 117 len = this.phiList.length; 118 if ( len === 1 ) { 119 this.avgPhi = phi; 120 } else { 121 this.avgPhi = avgMatrix( this.avgPhi, phi, len ); 122 } 123 len = this.thetaList.length; 124 if ( len === 1 ) { 125 this.avgTheta = theta; 126 } else { 127 this.avgTheta = avgMatrix( this.avgTheta, theta, len ); 128 } 129 } 130 } 131 } 132 133 134 // EXPORTS // 135 136 module.exports = fit;