commit 3e4360f9304bf848ef1173ba95c2cf7101d1c5df
parent dc3f7eed4d9fd279cbeeef591f4ab29f4b1c025e
Author: NunoSempere <nuno.sempere@protonmail.com>
Date: Wed, 29 Nov 2023 21:49:22 +0000
move quickselect to squiggle_more.c
Diffstat:
5 files changed, 55 insertions(+), 80 deletions(-)
diff --git a/scratchpad/quickselect/makefile b/scratchpad/quickselect/makefile
@@ -1,11 +0,0 @@
-build:
- gcc quickselect.c -lm -o quickselect
-
-run:
- ./quickselect
-
-## Formatter
-STYLE_BLUEPRINT=webkit
-FORMATTER=clang-format -i -style=$(STYLE_BLUEPRINT)
-format:
- $(FORMATTER) quickselect.c
diff --git a/scratchpad/quickselect/quickselect b/scratchpad/quickselect/quickselect
Binary files differ.
diff --git a/scratchpad/quickselect/quickselect.c b/scratchpad/quickselect/quickselect.c
@@ -1,69 +0,0 @@
-#include <math.h>
-#include <stdio.h>
-#include <stdlib.h>
-
-void swp(int i, int j, double xs[])
-{
- double tmp = xs[i];
- xs[i] = xs[j];
- xs[j] = tmp;
-}
-
-void array_print(double xs[], int n)
-{
- printf("[");
- for (int i = 0; i < n; i++) {
- printf("%f, ", xs[i]);
- }
- printf("]\n");
-}
-
-int partition(int low, int high, double xs[], int length)
-{
- // To understand this function:
- // - see the note after gt variable definition
- // - go to commit 578bfa27 and the scratchpad/ folder in it, which has printfs sprinkled throughout
- int pivot = low + floor((high - low) / 2);
- double pivot_value = xs[pivot];
- swp(pivot, high, xs);
- int gt = low; /* This pointer will iterate until finding an element which is greater than the pivot. Then it will move elements that are smaller before it--more specifically, it will move elements to its position and then increment. As a result all elements between gt and i will be greater than the pivot. */
- for (int i = low; i < high; i++) {
- if (xs[i] < pivot_value) {
- swp(gt, i, xs);
- gt++;
- }
- }
- swp(high, gt, xs);
- return gt;
-}
-
-double quickselect(int k, double xs[], int length)
-{
- int low = 0;
- int high = length - 1;
- for (;;) {
- if (low == high) {
- return xs[low];
- }
- int pivot = partition(low, high, xs, length);
- if (pivot == k) {
- return xs[pivot];
- } else if (k < pivot) {
- high = pivot - 1;
- } else {
- low = pivot + 1;
- }
- }
-}
-
-int main()
-{
- double xs[] = { 2.1, 1.0, 6.0, 4.0, 7.0, -1.0, 2.0, 10.0 };
- int length = 8;
- int k = 2;
- array_print(xs, 8);
- double result = quickselect(k, xs, length);
- printf("The item in pos #%d is: %f\n", k, result);
- array_print(xs, 8);
- return 0;
-}
diff --git a/squiggle.c b/squiggle.c
@@ -196,6 +196,16 @@ double array_std(double* array, int length)
return std;
}
+void array_print(double xs[], int n)
+{
+ printf("[");
+ for (int i = 0; i < n - 1; i++) {
+ printf("%f, ", xs[i]);
+ }
+ printf("%f", xs[n - 1]);
+ printf("]\n");
+}
+
// Mixture function
double sample_mixture(double (*samplers[])(uint64_t*), double* weights, int n_dists, uint64_t* seed)
{
diff --git a/squiggle_more.c b/squiggle_more.c
@@ -30,6 +30,51 @@ typedef struct ci_searcher_t {
int remaining;
} ci_searcher;
+void swp(int i, int j, double xs[])
+{
+ double tmp = xs[i];
+ xs[i] = xs[j];
+ xs[j] = tmp;
+}
+
+int partition(int low, int high, double xs[], int length)
+{
+ // To understand this function:
+ // - see the note after gt variable definition
+ // - go to commit 578bfa27 and the scratchpad/ folder in it, which has printfs sprinkled throughout
+ int pivot = low + floor((high - low) / 2);
+ double pivot_value = xs[pivot];
+ swp(pivot, high, xs);
+ int gt = low; /* This pointer will iterate until finding an element which is greater than the pivot. Then it will move elements that are smaller before it--more specifically, it will move elements to its position and then increment. As a result all elements between gt and i will be greater than the pivot. */
+ for (int i = low; i < high; i++) {
+ if (xs[i] < pivot_value) {
+ swp(gt, i, xs);
+ gt++;
+ }
+ }
+ swp(high, gt, xs);
+ return gt;
+}
+
+double quickselect(int k, double xs[], int length)
+{
+ int low = 0;
+ int high = length - 1;
+ for (;;) {
+ if (low == high) {
+ return xs[low];
+ }
+ int pivot = partition(low, high, xs, length);
+ if (pivot == k) {
+ return xs[pivot];
+ } else if (k < pivot) {
+ high = pivot - 1;
+ } else {
+ low = pivot + 1;
+ }
+ }
+}
+
ci get_90_confidence_interval(double (*sampler)(uint64_t*), uint64_t* seed)
{
int n = 100 * 1000;