commit 3950946d6847223dea45c2f43a4bcd04cecae1da
parent 351f4c584de7722136fb7b160c3561818684af0d
Author: NunoSempere <nuno.sempere@protonmail.com>
Date: Sun, 15 Oct 2023 00:33:59 +0100
tweak: move from array to list
Diffstat:
5 files changed, 25 insertions(+), 11 deletions(-)
diff --git a/ocaml/out/samples b/ocaml/out/samples
Binary files differ.
diff --git a/ocaml/out/samples.cmi b/ocaml/out/samples.cmi
Binary files differ.
diff --git a/ocaml/out/samples.cmx b/ocaml/out/samples.cmx
Binary files differ.
diff --git a/ocaml/out/samples.o b/ocaml/out/samples.o
Binary files differ.
diff --git a/ocaml/samples.ml b/ocaml/samples.ml
@@ -2,17 +2,27 @@
let pi = acos (-1.)
let normal_95_ci_length = 1.6448536269514722
-(* Array manipulation helpers *)
-let sumFloats xs = Array.fold_left(fun acc x -> acc +. x) 0.0 xs
+(* List manipulation helpers *)
+let sumFloats xs = List.fold_left(fun acc x -> acc +. x) 0.0 xs
let normalizeXs xs =
let sum_xs = sumFloats xs in
- Array.map(fun x -> x /. sum_xs) xs
+ List.map(fun x -> x /. sum_xs) xs
let cumsumXs xs =
- let _, cum_sum = Array.fold_left(fun (sum, ys) x ->
+ let _, cum_sum = List.fold_left(fun (sum, ys) x ->
let new_sum = sum +. x in
new_sum, ys @ [new_sum]
) (0.0, []) xs in
cum_sum
+let rec nth xs (n: int) =
+ match xs with
+ | [] -> None
+ | y :: ys -> if n = 0 then Some(y) else nth ys (n-1)
+ (*
+ Note that this is O(n) access.
+ That is the cost of using the nice match syntax,
+ which is not possible with OCaml arrays
+ *)
+
let findIndex xs test =
let rec recursiveHelper ys i =
match ys with
@@ -41,18 +51,22 @@ let sampleTo low high =
let logstd = (loghigh -. loglow) /. (2.0 -. normal_95_ci_length ) in
sampleLognormal logmean logstd
-let mixture (samplers: (unit -> float) array) (weights: float array): float option =
- if (Array.length samplers == Array.length weights)
+let mixture (samplers: (unit -> float) list) (weights: float list): float option =
+ if (List.length samplers == List.length weights)
then None
else
let normalized_weights = normalizeXs weights in
let cumsummed_normalized_weights = cumsumXs normalized_weights in
let p = sampleZeroToOne () in
let chosenSamplerIndex = findIndex cumsummed_normalized_weights (fun x -> x < p) in
- let sample = match chosenSamplerIndex with
+ let sampler = match chosenSamplerIndex with
| None -> None
- | Some(i) -> Some (samplers.(i) ())
+ | Some(i) -> nth samplers i
in
+ let sample = match sampler with
+ | None -> None
+ | Some(f) -> Some(f ())
+ in
sample
let () =
@@ -63,9 +77,9 @@ let () =
let p1 = 0.8 in
let p2 = 0.5 in
let p3 = p1 *. p2 in
- let weights = [| 1. -. p3; p3 /. 2.; p3 /. 4.; p3/. 4. |] in
- let sampler () = mixture [| sample0; sample1; sampleFew; sampleMany |] weights in
+ let weights = [ 1. -. p3; p3 /. 2.; p3 /. 4.; p3/. 4. ] in
+ let sampler () = mixture [ sample0; sample1; sampleFew; sampleMany ] weights in
let n = 1_000_000 in
- let samples = Array.init n (fun _ -> sampler ()) in
+ let samples = List.init n (fun _ -> sampler ()) in
(* let mean = sumFloats samples /. n in *)
Printf.printf "Hello world\n"