commit b0f48286d5e230a6b6ed9041f89ce9aa7c33a5ae
parent 6cf0b3f9d5f1a9869347de3912ac52c2516d565d
Author: NunoSempere <nuno.semperelh@protonmail.com>
Date: Mon, 10 Jun 2024 01:12:02 +0200
add code to multiply beta distributions
Diffstat:
2 files changed, 37 insertions(+), 8 deletions(-)
diff --git a/f.go b/f.go
@@ -42,7 +42,9 @@ type Lognormal struct {
func (ln Lognormal) Samples() []float64 {
sampler := func(r sample.Src) float64 { return sample.Sample_to(ln.low, ln.high, r) }
- return sample.Sample_parallel(sampler, N_SAMPLES)
+ // return sample.Sample_parallel(sampler, N_SAMPLES)
+ // Can't do parallel because then I'd have to await throughout the code
+ return sample.Sample_serially(sampler, N_SAMPLES)
}
type Beta struct {
@@ -52,7 +54,8 @@ type Beta struct {
func (beta Beta) Samples() []float64 {
sampler := func(r sample.Src) float64 { return sample.Sample_beta(beta.a, beta.b, r) }
- return sample.Sample_parallel(sampler, N_SAMPLES)
+ // return sample.Sample_parallel(sampler, N_SAMPLES)
+ return sample.Sample_serially(sampler, N_SAMPLES)
}
type FilledSamples struct {
@@ -156,12 +159,12 @@ func multiplyBetaDists(beta1 Beta, beta2 Beta) Beta {
func multiplyAsSamples(dist1 Dist, dist2 Dist) Dist {
// dist2 = Beta{a: 1, b: 2}
- fmt.Printf("dist1: %v\n", dist1)
- fmt.Printf("dist2: %v\n", dist2)
+ // fmt.Printf("dist1: %v\n", dist1)
+ // fmt.Printf("dist2: %v\n", dist2)
xs := dist1.Samples()
ys := dist2.Samples()
- fmt.Printf("xs: %v\n", xs)
- fmt.Printf("ys: %v\n", ys)
+ // fmt.Printf("xs: %v\n", xs)
+ // fmt.Printf("ys: %v\n", ys)
zs := make([]float64, N_SAMPLES)
for i := 0; i < N_SAMPLES; i++ {
@@ -188,6 +191,9 @@ func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) {
}
case Scalar:
{
+ if o.p == 1 {
+ return new_dist, nil
+ }
switch n := new_dist.(type) {
case Lognormal:
return multiplyLogDists(Lognormal{low: o.p, high: o.p}, n), nil
@@ -197,8 +203,16 @@ func multiplyDists(old_dist Dist, new_dist Dist) (Dist, error) {
return multiplyAsSamples(o, n), nil
}
}
+ case Beta:
+ switch n := new_dist.(type) {
+ case Beta:
+ return multiplyBetaDists(o, n), nil
+ default:
+ return multiplyAsSamples(o, n), nil
+ }
default:
- return nil, errors.New("Can't multiply dists")
+ return multiplyAsSamples(old_dist, new_dist), nil
+ // return nil, errors.New("Can't multiply dists")
}
}
@@ -227,7 +241,6 @@ func joinDists(old_dist Dist, new_dist Dist, op string) (Dist, error) {
/* Pretty print distributions */
func prettyPrint90CI(low float64, high float64) {
// fmt.Printf("=> %.1f %.1f\n", low, high)
- fmt.Printf("=> ")
switch {
case math.Abs(low) >= 1_000_000_000_000:
fmt.Printf("%.1fT", low/1_000_000_000_000)
@@ -264,6 +277,7 @@ func prettyPrint90CI(low float64, high float64) {
func prettyPrintDist(dist Dist) {
switch v := dist.(type) {
case Lognormal:
+ fmt.Printf("=> ")
prettyPrint90CI(v.low, v.high)
case FilledSamples:
tmp_xs := make([]float64, N_SAMPLES)
@@ -276,6 +290,9 @@ func prettyPrintDist(dist Dist) {
high_int := N_SAMPLES * 19 / 20
high := tmp_xs[high_int]
prettyPrint90CI(low, high)
+ case Beta:
+ fmt.Printf("=> beta ")
+ prettyPrint90CI(v.a, v.b)
default:
fmt.Printf("%v", v)
}
@@ -342,6 +359,9 @@ EventForLoop:
joint_dist, err := joinDists(old_dist, new_dist, op)
if err != nil {
+ fmt.Printf("%v\n", err)
+ fmt.Printf("Dist on stack: ")
+ prettyPrintDist(old_dist)
continue EventForLoop
}
old_dist = joint_dist
diff --git a/sample/sample.go b/sample/sample.go
@@ -137,6 +137,15 @@ func Sample_mixture(fs []func64, weights []float64, r Src) float64 {
}
+func Sample_serially(f func64, n_samples int) []float64 {
+ var r = rand.New(rand.NewPCG(uint64(1), uint64(2)))
+ xs := make([]float64, n_samples)
+ for i := 0; i < n_samples; i++ {
+ xs[i] = f(r)
+ }
+ return xs
+}
+
func Sample_parallel(f func64, n_samples int) []float64 {
var num_threads = 16
var xs = make([]float64, n_samples)