samples.lua (1957B)
1 -- Consts and prep 2 PI = 3.14159265358979323846; 3 NORMAL95CONFIDENCE = 1.6448536269514722; 4 math.randomseed(1234) 5 6 -- Random distribution functions 7 function sample_normal_0_1() 8 local u1 = math.random() 9 local u2 = math.random() 10 local result = math.sqrt(-2 * math.log(u1)) * math.sin(2 * PI * u2) 11 return result 12 end 13 14 function sample_normal(mean, sigma) 15 return mean + (sigma * sample_normal_0_1()) 16 end 17 18 function sample_uniform(min, max) 19 return math.random() * (max - min) + min 20 end 21 22 function sample_lognormal(logmean, logsigma) 23 return math.exp(sample_normal(logmean, logsigma)) 24 end 25 26 function sample_to(low, high) 27 local loglow = math.log(low); 28 local loghigh = math.log(high); 29 local logmean = (loglow + loghigh) / 2; 30 local logsigma = (loghigh - loglow) / (2.0 * NORMAL95CONFIDENCE); 31 return sample_lognormal(logmean, logsigma, seed); 32 end 33 34 -- Mixture 35 function mixture(samplers, weights, n) 36 assert(#samplers == #weights) 37 local l = #weights 38 local sum_weights = 0 39 for i = 1, l, 1 do 40 -- ^ arrays start at 1 41 sum_weights = sum_weights + weights[i] 42 end 43 local cumsummed_normalized_weights = {} 44 table.insert(cumsummed_normalized_weights, weights[1]/sum_weights) 45 for i = 2, l, 1 do 46 table.insert(cumsummed_normalized_weights, cumsummed_normalized_weights[i-1] + weights[i]/sum_weights) 47 end 48 49 local result = {} 50 for i = 1, n, 1 do 51 r = math.random() 52 local i = 1 53 while r > cumsummed_normalized_weights[i] do 54 i = i+1 55 end 56 table.insert(result, samplers[i]()) 57 end 58 return result 59 end 60 61 62 -- Main 63 p_a = 0.8 64 p_b = 0.5 65 p_c = p_a * p_b 66 67 function sample_0() return 0 end 68 function sample_1() return 1 end 69 function sample_few() return sample_to(1, 3) end 70 function sample_many() return sample_to(2, 10) end 71 72 samplers = {sample_0, sample_1, sample_few, sample_many} 73 weights = { (1 - p_c), p_c/2, p_c/4, p_c/4 } 74 75 n = 1000000 76 result = mixture(samplers, weights, n) 77 sum = 0 78 for i = 1, n, 1 do 79 sum = sum + result[i] 80 -- print(result[i]) 81 end 82 print(sum/n)