time-to-botec

Benchmark sampling in different programming languages
Log | Files | Refs | README

samples.ml (3626B)


      1 (* Constants *)
      2 let pi = acos (-1.)
      3 let normal_95_ci_length = 1.6448536269514722
      4 
      5 (* List manipulation helpers *)
      6 let sumFloats xs = List.fold_left(fun acc x -> acc +. x) 0.0 xs
      7 
      8 let normalizeXs xs = 
      9   let sum_xs = sumFloats xs in 
     10   List.map(fun x -> x /. sum_xs) xs
     11 
     12 let cumsumXs xs = 
     13   let _, cum_sum = List.fold_left(fun (sum, ys) x -> 
     14     let new_sum = sum +. x in
     15     new_sum, ys @ [new_sum]
     16   ) (0.0, []) xs in
     17   cum_sum
     18 
     19 let rec nth xs (n: int) = 
     20   match xs with
     21   | [] -> Error "nth function finds no match"
     22   | y :: ys -> if n = 0 then Ok(y) else nth ys (n-1)
     23   (* 
     24      Note that this is O(n) access.
     25      That is the cost of using the nice match syntax,
     26      which is not possible with OCaml arrays
     27   *)
     28 
     29 let findIndex xs test = 
     30   let rec recursiveHelper ys i = 
     31     match ys with
     32       | [] -> Error "findIndex doesn't find an index"
     33       | z :: zs -> if test z then Ok(i) else recursiveHelper zs (i+1)
     34     in
     35   recursiveHelper xs 0
     36 
     37 let unwind xs = 
     38   let rec tailRecursiveHelper ys acc =
     39     match ys with
     40     | [] -> Ok(acc)
     41     | Error e :: _ -> Error e
     42     | Ok(y) :: ys -> tailRecursiveHelper ys (y :: acc) 
     43   in 
     44   tailRecursiveHelper xs []
     45   
     46 let unwindSum xs = 
     47   let rec tailRecursiveHelper ys sum =
     48     match ys with
     49     | [] -> Ok(sum)
     50     | Error e :: _ -> Error e
     51     | Ok(y) :: ys -> tailRecursiveHelper ys (y +. sum) 
     52   in 
     53   tailRecursiveHelper xs 0.0
     54 
     55 (* Array helpers *)
     56 let unwindSumArray xs = 
     57   Array.fold_left(fun acc x -> 
     58     (
     59       match acc, x with
     60       | Error e, _ -> Error e
     61       | _, Error e -> Error e
     62       | Ok(sum), Ok(y) -> Ok(sum +. y)
     63     )
     64   ) (Ok 0.0) xs
     65 
     66 let sumFloats xs = List.fold_left(fun acc x -> acc +. x) 0.0 xs
     67 
     68 (* Basic samplers *)
     69 let sampleZeroToOne () : float = Random.float 1.0
     70 
     71 let sampleStandardNormal (): float = 
     72   let u1 = sampleZeroToOne () in 
     73   let u2 = sampleZeroToOne () in 
     74   let z = sqrt(-2.0 *. log(u1)) *. sin(2.0 *. pi *. u2) in 
     75   z
     76 
     77 let sampleNormal mean std = mean +. std *. (sampleStandardNormal ())
     78 
     79 let sampleLognormal logmean logstd = exp(sampleNormal logmean logstd)
     80 
     81 let sampleTo low high = 
     82   let loglow = log(low) in
     83   let loghigh = log(high) in
     84   let logmean = (loglow +. loghigh) /. 2.0 in
     85   let logstd = (loghigh -. loglow) /. (2.0 *. normal_95_ci_length ) in
     86   sampleLognormal logmean logstd
     87 
     88 let mixture (samplers: (unit -> float) list) (weights: float list): (float, string) result = 
     89    if (List.length samplers <> List.length weights) 
     90     then Error "in mixture function, List.length samplers != List.length weights"
     91   else
     92     let normalized_weights = normalizeXs weights in
     93     let cumsummed_normalized_weights = cumsumXs normalized_weights in
     94     let p = sampleZeroToOne () in
     95     let chosenSamplerIndex = findIndex cumsummed_normalized_weights (fun x -> p < x) in
     96     let sampler = match chosenSamplerIndex with
     97       | Error e -> Error e
     98       | Ok(i) -> nth samplers i
     99     in
    100     let sample = match sampler with
    101       | Error e -> Error e
    102       | Ok(f) -> Ok(f ())
    103     in 
    104     sample
    105 
    106 let () =
    107   let sample0 () = 0. in 
    108   let sample1 () = 1. in
    109   let sampleFew () = sampleTo 1. 3. in
    110   let sampleMany () = sampleTo 2. 10. in 
    111   let p1 = 0.8 in 
    112   let p2 = 0.5 in 
    113   let p3 = p1 *. p2 in 
    114   let weights = [ 1. -. p3; p3 /. 2.; p3 /. 4.; p3/. 4. ] in
    115   let sampler () = mixture [ sample0; sample1; sampleFew; sampleMany ] weights in
    116   let n = 1_000_000 in 
    117   let samples = Array.init n (fun _ -> sampler ()) in 
    118   match unwindSumArray samples with
    119     | Error err -> Printf.printf "Error %s\n" err
    120     | Ok(sum) -> (
    121         let mean = sum /. float_of_int(n) in 
    122         Printf.printf "Mean: %f\n" mean
    123       )