helpers.js (7868B)
1 import { BaseDist } from "../../dist/BaseDist.js"; 2 import { PointMass } from "../../dist/SymbolicDist.js"; 3 import { REDistributionError, REOperationError, REOther, } from "../../errors/messages.js"; 4 import { SampleMapNeedsNtoNFunction } from "../../operationError.js"; 5 import * as Result from "../../utility/result.js"; 6 import { vArray, vBool, vDist, vNumber, vString, } from "../../value/index.js"; 7 import { makeDefinition } from "./fnDefinition.js"; 8 import { frBool, frDist, frDistOrNumber, frNumber, frString, } from "./frTypes.js"; 9 import * as SampleSetDist from "../../dist/SampleSetDist/index.js"; 10 import { OtherOperationError } from "../../operationError.js"; 11 export class FnFactory { 12 constructor(opts) { 13 this.nameSpace = opts.nameSpace; 14 this.requiresNamespace = opts.requiresNamespace; 15 } 16 make(args) { 17 return { 18 nameSpace: this.nameSpace, 19 requiresNamespace: this.requiresNamespace, 20 ...args, 21 }; 22 } 23 n2n({ fn, ...args }) { 24 return this.make({ 25 ...args, 26 output: "Number", 27 definitions: [makeDefinition([frNumber], ([x]) => vNumber(fn(x)))], 28 }); 29 } 30 nn2n({ fn, ...args }) { 31 return this.make({ 32 ...args, 33 output: "Number", 34 definitions: [ 35 makeDefinition([frNumber, frNumber], ([x, y]) => vNumber(fn(x, y))), 36 ], 37 }); 38 } 39 nn2b({ fn, ...args }) { 40 return this.make({ 41 ...args, 42 output: "Bool", 43 definitions: [ 44 makeDefinition([frNumber, frNumber], ([x, y]) => vBool(fn(x, y))), 45 ], 46 }); 47 } 48 bb2b({ fn, ...args }) { 49 return this.make({ 50 ...args, 51 output: "Bool", 52 definitions: [ 53 makeDefinition([frBool, frBool], ([x, y]) => vBool(fn(x, y))), 54 ], 55 }); 56 } 57 ss2b({ fn, ...args }) { 58 return this.make({ 59 ...args, 60 output: "Bool", 61 definitions: [ 62 makeDefinition([frString, frString], ([x, y]) => vBool(fn(x, y))), 63 ], 64 }); 65 } 66 ss2s({ fn, ...args }) { 67 return this.make({ 68 ...args, 69 output: "String", 70 definitions: [ 71 makeDefinition([frString, frString], ([x, y]) => vString(fn(x, y))), 72 ], 73 }); 74 } 75 d2s({ fn, ...args }) { 76 return this.make({ 77 ...args, 78 output: "String", 79 definitions: [ 80 makeDefinition([frDist], ([dist], { environment }) => vString(fn(dist, environment))), 81 ], 82 }); 83 } 84 dn2s({ fn, ...args }) { 85 return this.make({ 86 ...args, 87 output: "String", 88 definitions: [ 89 makeDefinition([frDist, frNumber], ([dist, n], { environment }) => vString(fn(dist, n, environment))), 90 ], 91 }); 92 } 93 d2n({ fn, ...args }) { 94 return this.make({ 95 ...args, 96 output: "Number", 97 definitions: [ 98 makeDefinition([frDist], ([x], { environment }) => vNumber(fn(x, environment))), 99 ], 100 }); 101 } 102 d2b({ fn, ...args }) { 103 return this.make({ 104 ...args, 105 output: "Bool", 106 definitions: [ 107 makeDefinition([frDist], ([x], { environment }) => vBool(fn(x, environment))), 108 ], 109 }); 110 } 111 d2d({ fn, ...args }) { 112 return this.make({ 113 ...args, 114 output: "Dist", 115 definitions: [ 116 makeDefinition([frDist], ([dist], { environment }) => vDist(fn(dist, environment))), 117 ], 118 }); 119 } 120 dn2d({ fn, ...args }) { 121 return this.make({ 122 ...args, 123 output: "Dist", 124 definitions: [ 125 makeDefinition([frDist, frNumber], ([dist, n], { environment }) => vDist(fn(dist, n, environment))), 126 ], 127 }); 128 } 129 dn2n({ fn, ...args }) { 130 return this.make({ 131 ...args, 132 output: "Number", 133 definitions: [ 134 makeDefinition([frDist, frNumber], ([dist, n], { environment }) => vNumber(fn(dist, n, environment))), 135 ], 136 }); 137 } 138 fromDefinition(name, def) { 139 return this.make({ 140 name, 141 definitions: [def], 142 }); 143 } 144 } 145 export function unpackDistResult(result) { 146 if (!result.ok) { 147 throw new REDistributionError(result.value); 148 } 149 return result.value; 150 } 151 export function repackDistResult(result) { 152 const dist = unpackDistResult(result); 153 return vDist(dist); 154 } 155 export function doNumberLambdaCall(lambda, args, context) { 156 const value = lambda.call(args, context); 157 if (value.type === "Number") { 158 return value.value; 159 } 160 throw new REOperationError(new SampleMapNeedsNtoNFunction()); 161 } 162 export function doBinaryLambdaCall(args, lambda, context) { 163 const value = lambda.call(args, context); 164 if (value.type === "Bool") { 165 return value.value; 166 } 167 throw new REOther("Expected function to return a boolean value"); 168 } 169 export const parseDistFromDistOrNumber = (d) => typeof d == "number" ? Result.getExt(PointMass.make(d)) : d; 170 export function distResultToValue(result) { 171 if (!result.ok) { 172 throw new REDistributionError(result.value); 173 } 174 return vDist(result.value); 175 } 176 export function distsResultToValue(result) { 177 if (!result.ok) { 178 throw new REDistributionError(result.value); 179 } 180 return vArray(result.value.map((r) => vDist(r))); 181 } 182 export function makeSampleSet(d, env) { 183 const result = SampleSetDist.SampleSetDist.fromDist(d, env); 184 if (!result.ok) { 185 throw new REDistributionError(result.value); 186 } 187 return result.value; 188 } 189 export function twoVarSample(v1, v2, env, fn) { 190 const sampleFn = (a, b) => Result.fmap2(fn(a, b), (d) => d.sample(), (e) => new OtherOperationError(e)); 191 if (v1 instanceof BaseDist && v2 instanceof BaseDist) { 192 const s1 = makeSampleSet(v1, env); 193 const s2 = makeSampleSet(v2, env); 194 return distResultToValue(SampleSetDist.map2({ 195 fn: sampleFn, 196 t1: s1, 197 t2: s2, 198 })); 199 } 200 else if (v1 instanceof BaseDist && typeof v2 === "number") { 201 const s1 = makeSampleSet(v1, env); 202 return distResultToValue(s1.samplesMap((a) => sampleFn(a, v2))); 203 } 204 else if (typeof v1 === "number" && v2 instanceof BaseDist) { 205 const s2 = makeSampleSet(v2, env); 206 return distResultToValue(s2.samplesMap((a) => sampleFn(v1, a))); 207 } 208 else if (typeof v1 === "number" && typeof v2 === "number") { 209 const result = fn(v1, v2); 210 if (!result.ok) { 211 throw new REOther(result.value); 212 } 213 return vDist(makeSampleSet(result.value, env)); 214 } 215 throw new REOther("Impossible branch"); 216 } 217 export function makeTwoArgsDist(fn) { 218 return makeDefinition([frDistOrNumber, frDistOrNumber], ([v1, v2], { environment }) => twoVarSample(v1, v2, environment, fn)); 219 } 220 export function makeOneArgDist(fn) { 221 return makeDefinition([frDistOrNumber], ([v], { environment }) => { 222 const sampleFn = (a) => Result.fmap2(fn(a), (d) => d.sample(), (e) => new OtherOperationError(e)); 223 if (v instanceof BaseDist) { 224 const s = makeSampleSet(v, environment); 225 return distResultToValue(s.samplesMap(sampleFn)); 226 } 227 else if (typeof v === "number") { 228 const result = fn(v); 229 if (!result.ok) { 230 throw new REOther(result.value); 231 } 232 return vDist(makeSampleSet(result.value, environment)); 233 } 234 throw new REOther("Impossible branch"); 235 }); 236 } 237 //# sourceMappingURL=helpers.js.map