From c9db08cb0a48904848baf793a58b9388abddca8e Mon Sep 17 00:00:00 2001 From: Green Sky Date: Mon, 21 Jun 2021 23:07:48 +0200 Subject: [PATCH] srng: impl floating point numbers and make std dist compat --- framework/random/src/mm/random/srng.cpp | 14 ++++++++++++++ framework/random/src/mm/random/srng.hpp | 16 +++++++++++++++- framework/std_utils/src/mm/scalar_range2.hpp | 5 +++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/framework/random/src/mm/random/srng.cpp b/framework/random/src/mm/random/srng.cpp index 1f8035e..dcf7f40 100644 --- a/framework/random/src/mm/random/srng.cpp +++ b/framework/random/src/mm/random/srng.cpp @@ -1,2 +1,16 @@ #include "./srng.hpp" +namespace MM::Random { + +template<> +double SRNG::range(const ScalarRange2& range) { + return zeroToOne() * (range.max() - range.min()) + range.min(); +} + +template<> +float SRNG::range(const ScalarRange2& range) { + return zeroToOne() * (range.max() - range.min()) + range.min(); +} + +} // MM::Random + diff --git a/framework/random/src/mm/random/srng.hpp b/framework/random/src/mm/random/srng.hpp index 9c0d3f6..7899461 100644 --- a/framework/random/src/mm/random/srng.hpp +++ b/framework/random/src/mm/random/srng.hpp @@ -37,7 +37,6 @@ struct SRNG { // more advanced // inclusive - // TODO: test for floats template T range(const ScalarRange2& range) { return (getNext() % ((range.max() - range.min()) + 1)) + range.min(); @@ -51,7 +50,22 @@ struct SRNG { bool operator()(float prob) { return roll(prob); } + + // std:: distributions need those + constexpr static uint32_t min(void) { + return 0; + } + + constexpr static uint32_t max(void) { + return 0xffffffff; + } }; +template<> +double SRNG::range(const ScalarRange2& range); + +template<> +float SRNG::range(const ScalarRange2& range); + } // MM::Random diff --git a/framework/std_utils/src/mm/scalar_range2.hpp b/framework/std_utils/src/mm/scalar_range2.hpp index 45ab7a8..bc230f1 100644 --- a/framework/std_utils/src/mm/scalar_range2.hpp +++ b/framework/std_utils/src/mm/scalar_range2.hpp @@ -9,6 +9,11 @@ struct ScalarRange2 { ScalarRange2(void) = default; + ScalarRange2(const T& both) noexcept { + v_min = both; + v_max = both; + } + ScalarRange2(const T& min, const T& max) noexcept { if (min <= max) { v_min = min;