commit 5e39c386f7e73efca69df35fb36ddbc91618ed22
parent 4dad518d3f39bb3787d6486d0999b7a79feccd34
Author: NunoSempere <nuno.sempere@protonmail.com>
Date: Sun, 23 Jul 2023 09:29:00 +0200
fix dumb beta sampling bug
Diffstat:
| M | squiggle.c | | | 92 | +++++++++++++++++++++++++++++++++++++++++-------------------------------------- |
1 file changed, 48 insertions(+), 44 deletions(-)
diff --git a/squiggle.c b/squiggle.c
@@ -74,46 +74,48 @@ float sample_to(float low, float high, uint32_t* seed)
return sample_lognormal(logmean, logsigma, seed);
}
-float sample_gamma(float alpha, uint32_t* seed){
-
- // A Simple Method for Generating Gamma Variables, Marsaglia and Wan Tsang, 2001
- // https://dl.acm.org/doi/pdf/10.1145/358407.358414
- // see also the references/ folder
- if(alpha >=1){
- float d, c, x, v, u;
- d = alpha - 1.0/3.0;
- c = 1.0/sqrt(9.0 * d);
- while(1){
-
- do {
- x = sample_unit_normal(seed);
- v = 1.0 + c * x;
- } while(v <= 0.0);
-
- v = pow(v, 3);
- u = sample_unit_uniform(seed);
- if( u < 1.0 - 0.0331 * pow(x, 4)){ // Condition 1
- // the 0.0331 doesn't inspire much confidence
- // however, this isn't the whole story
- // by knowing that Condition 1 implies condition 2
- // we realize that this is just a way of making the algorithm faster
- // i.e., of not using the logarithms
- return d*v;
- }
- if(log(u) < 0.5*pow(x,2) + d*(1.0 - v + log(v))){ // Condition 2
- return d*v;
- }
- }
- }else{
- return sample_gamma(1 + alpha, seed) * pow(sample_unit_uniform(seed), 1/alpha);
- // see note in p. 371 of https://dl.acm.org/doi/pdf/10.1145/358407.358414
- }
+float sample_gamma(float alpha, uint32_t* seed)
+{
+
+ // A Simple Method for Generating Gamma Variables, Marsaglia and Wan Tsang, 2001
+ // https://dl.acm.org/doi/pdf/10.1145/358407.358414
+ // see also the references/ folder
+ if (alpha >= 1) {
+ float d, c, x, v, u;
+ d = alpha - 1.0 / 3.0;
+ c = 1.0 / sqrt(9.0 * d);
+ while (1) {
+
+ do {
+ x = sample_unit_normal(seed);
+ v = 1.0 + c * x;
+ } while (v <= 0.0);
+
+ v = pow(v, 3);
+ u = sample_unit_uniform(seed);
+ if (u < 1.0 - 0.0331 * pow(x, 4)) { // Condition 1
+ // the 0.0331 doesn't inspire much confidence
+ // however, this isn't the whole story
+ // by knowing that Condition 1 implies condition 2
+ // we realize that this is just a way of making the algorithm faster
+ // i.e., of not using the logarithms
+ return d * v;
+ }
+ if (log(u) < 0.5 * pow(x, 2) + d * (1.0 - v + log(v))) { // Condition 2
+ return d * v;
+ }
+ }
+ } else {
+ return sample_gamma(1 + alpha, seed) * pow(sample_unit_uniform(seed), 1 / alpha);
+ // see note in p. 371 of https://dl.acm.org/doi/pdf/10.1145/358407.358414
+ }
}
-float sample_beta(float a, float b, uint32_t* seed){
- float gamma_a = sample_gamma(a, seed);
- float gamma_b = sample_gamma(b, seed);
- return a / (a + b);
+float sample_beta(float a, float b, uint32_t* seed)
+{
+ float gamma_a = sample_gamma(a, seed);
+ float gamma_b = sample_gamma(b, seed);
+ return gamma_a / (gamma_a + gamma_b);
}
// Array helpers
@@ -134,18 +136,20 @@ void array_cumsum(float* array_to_sum, float* array_cumsummed, int length)
}
}
-float array_mean(float* array, int length){
- float sum = array_sum(array, length);
- return sum / length;
+float array_mean(float* array, int length)
+{
+ float sum = array_sum(array, length);
+ return sum / length;
}
-float array_std(float* array, int length){
- float mean = array_mean(array, length);
+float array_std(float* array, int length)
+{
+ float mean = array_mean(array, length);
float std = 0.0;
for (int i = 0; i < length; i++) {
std += pow(array[i] - mean, 2.0);
}
- std=sqrt(std/length);
+ std = sqrt(std / length);
return std;
}