samples.ml (3626B)
1 (* Constants *) 2 let pi = acos (-1.) 3 let normal_95_ci_length = 1.6448536269514722 4 5 (* List manipulation helpers *) 6 let sumFloats xs = List.fold_left(fun acc x -> acc +. x) 0.0 xs 7 8 let normalizeXs xs = 9 let sum_xs = sumFloats xs in 10 List.map(fun x -> x /. sum_xs) xs 11 12 let cumsumXs xs = 13 let _, cum_sum = List.fold_left(fun (sum, ys) x -> 14 let new_sum = sum +. x in 15 new_sum, ys @ [new_sum] 16 ) (0.0, []) xs in 17 cum_sum 18 19 let rec nth xs (n: int) = 20 match xs with 21 | [] -> Error "nth function finds no match" 22 | y :: ys -> if n = 0 then Ok(y) else nth ys (n-1) 23 (* 24 Note that this is O(n) access. 25 That is the cost of using the nice match syntax, 26 which is not possible with OCaml arrays 27 *) 28 29 let findIndex xs test = 30 let rec recursiveHelper ys i = 31 match ys with 32 | [] -> Error "findIndex doesn't find an index" 33 | z :: zs -> if test z then Ok(i) else recursiveHelper zs (i+1) 34 in 35 recursiveHelper xs 0 36 37 let unwind xs = 38 let rec tailRecursiveHelper ys acc = 39 match ys with 40 | [] -> Ok(acc) 41 | Error e :: _ -> Error e 42 | Ok(y) :: ys -> tailRecursiveHelper ys (y :: acc) 43 in 44 tailRecursiveHelper xs [] 45 46 let unwindSum xs = 47 let rec tailRecursiveHelper ys sum = 48 match ys with 49 | [] -> Ok(sum) 50 | Error e :: _ -> Error e 51 | Ok(y) :: ys -> tailRecursiveHelper ys (y +. sum) 52 in 53 tailRecursiveHelper xs 0.0 54 55 (* Array helpers *) 56 let unwindSumArray xs = 57 Array.fold_left(fun acc x -> 58 ( 59 match acc, x with 60 | Error e, _ -> Error e 61 | _, Error e -> Error e 62 | Ok(sum), Ok(y) -> Ok(sum +. y) 63 ) 64 ) (Ok 0.0) xs 65 66 let sumFloats xs = List.fold_left(fun acc x -> acc +. x) 0.0 xs 67 68 (* Basic samplers *) 69 let sampleZeroToOne () : float = Random.float 1.0 70 71 let sampleStandardNormal (): float = 72 let u1 = sampleZeroToOne () in 73 let u2 = sampleZeroToOne () in 74 let z = sqrt(-2.0 *. log(u1)) *. sin(2.0 *. pi *. u2) in 75 z 76 77 let sampleNormal mean std = mean +. std *. (sampleStandardNormal ()) 78 79 let sampleLognormal logmean logstd = exp(sampleNormal logmean logstd) 80 81 let sampleTo low high = 82 let loglow = log(low) in 83 let loghigh = log(high) in 84 let logmean = (loglow +. loghigh) /. 2.0 in 85 let logstd = (loghigh -. loglow) /. (2.0 *. normal_95_ci_length ) in 86 sampleLognormal logmean logstd 87 88 let mixture (samplers: (unit -> float) list) (weights: float list): (float, string) result = 89 if (List.length samplers <> List.length weights) 90 then Error "in mixture function, List.length samplers != List.length weights" 91 else 92 let normalized_weights = normalizeXs weights in 93 let cumsummed_normalized_weights = cumsumXs normalized_weights in 94 let p = sampleZeroToOne () in 95 let chosenSamplerIndex = findIndex cumsummed_normalized_weights (fun x -> p < x) in 96 let sampler = match chosenSamplerIndex with 97 | Error e -> Error e 98 | Ok(i) -> nth samplers i 99 in 100 let sample = match sampler with 101 | Error e -> Error e 102 | Ok(f) -> Ok(f ()) 103 in 104 sample 105 106 let () = 107 let sample0 () = 0. in 108 let sample1 () = 1. in 109 let sampleFew () = sampleTo 1. 3. in 110 let sampleMany () = sampleTo 2. 10. in 111 let p1 = 0.8 in 112 let p2 = 0.5 in 113 let p3 = p1 *. p2 in 114 let weights = [ 1. -. p3; p3 /. 2.; p3 /. 4.; p3/. 4. ] in 115 let sampler () = mixture [ sample0; sample1; sampleFew; sampleMany ] weights in 116 let n = 1_000_000 in 117 let samples = Array.init n (fun _ -> sampler ()) in 118 match unwindSumArray samples with 119 | Error err -> Printf.printf "Error %s\n" err 120 | Ok(sum) -> ( 121 let mean = sum /. float_of_int(n) in 122 Printf.printf "Mean: %f\n" mean 123 )