1 /++
2 Simple mir-random wrappers for convenience
3  +/
4 module numir.random;
5 
6 import std.algorithm : fold;
7 import mir.random : unpredictableSeed, Random;
8 import mir.random.variable : UniformVariable, NormalVariable;
9 import mir.ndslice: slicedField, slice;
10 
11 
12 ///
13 class RNG
14 {
15     private static this() {}
16     private __gshared Random* _rng = null;
17 
18     /// 
19     @property static ref get()
20     {
21         if (!_rng)
22         {
23             synchronized(RNG.classinfo)
24             {
25                 _rng = new Random(unpredictableSeed!size_t);
26             }
27         }
28         return *_rng;
29     }
30 
31     ///
32     static void setSeed(uint seed)
33     {
34         _rng = new Random(seed);
35     }
36 }
37 
38 
39 /* 
40 // FIXME: this test won't finish
41 unittest
42 {
43     import std.parallelism;
44     import std.range;
45     import std.stdio;
46 
47     auto pool = new TaskPool();
48 
49     RNG.setSeed(1);
50     foreach (i, p; iota(4).parallel)
51     {
52         uniform(3).writeln;
53     }
54 }
55 */
56 
57 /// general function for random slice generation with global RNG
58 auto generate(V, size_t N)(V var, size_t[N] length...)
59 {
60     import mir.random.algorithm: randomSlice;
61     return RNG.get.randomSlice(var, length);
62 }
63 
64 ///
65 auto normal(E=double, size_t N)(size_t[N] length...)
66 {
67     return NormalVariable!E(0, 1).generate(length);
68 }
69 
70 ///
71 auto uniform(E=double, size_t N)(size_t[N] length...)
72 {
73     return UniformVariable!E(0, 1).generate(length);
74 }
75 
76 ///
77 unittest
78 {
79     import mir.ndslice : all;
80     import std.algorithm : sum;
81     import mir.random.variable : BernoulliVariable;
82     auto bs = BernoulliVariable!double(0.25).generate(100).sum;
83     assert(0 < bs && bs < 50, "maybe fail");
84 
85     // pre-defined random variables (normal and uniform)
86     RNG.setSeed(1);
87     auto r0 = normal(3, 4).slice;
88     assert(r0.shape == [3, 4]);
89     RNG.setSeed(0);
90     auto r1 = normal(3, 4).slice;
91     assert(r0 != r1);
92 
93     RNG.setSeed(0);
94     auto r2 = normal(3, 4).slice;
95     assert(r1 == r2);
96 
97     auto u = uniform(3, 4).slice;
98     assert(u.shape == [3, 4]);
99     assert(u.all!(a => (0 <= a && a < 1)));
100 }
101 
102 
103 
104 /// generate a sequence as same as numir.core.arange but shuffled
105 auto permutation(T...)(T t) {
106     import numir.core : arange;
107     import mir.ndslice : slice;
108     import mir.random.algorithm : shuffle;
109     auto a = arange(t).slice;
110     shuffle(RNG.get, a);
111     return a;
112 }
113 
114 
115 ///
116 unittest {
117     import numir : arange;
118     import mir.ndslice.sorting : sort;
119     import numir.testing : approxEqual;
120     import std.stdio;
121     auto ps1 = permutation(100);
122     auto ps2 = permutation(100);
123     assert(ps1 != ps2, "maybe fail at 1%");
124     assert(ps1.sort() == arange(100));
125 
126     auto ps3 = permutation(1, 10, 0.1);
127     auto ps4 = permutation(1, 10, 0.1);
128     assert(ps3 != ps4);
129     assert(ps4.sort.approxEqual(arange(1, 10, 0.1)));
130 }