From 34c7ac027796bf866c785f2f106fe14c404aba8b Mon Sep 17 00:00:00 2001 From: "Gustavo J. A. M. Carneiro" Date: Tue, 3 Mar 2009 10:57:42 +0000 Subject: [PATCH] Bug 485: implement deserialization of normal/gaussian random variables --- src/core/random-variable.cc | 93 +++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 4 deletions(-) diff --git a/src/core/random-variable.cc b/src/core/random-variable.cc index a622fed5c..402068679 100644 --- a/src/core/random-variable.cc +++ b/src/core/random-variable.cc @@ -763,6 +763,10 @@ public: virtual double GetValue(); virtual RandomVariableBase* Copy(void) const; + double GetMean (void) const; + double GetVariance (void) const; + double GetBound (void) const; + private: double m_mean; // Mean value of RV double m_variance; // Mean value of RV @@ -835,6 +839,24 @@ RandomVariableBase* NormalVariableImpl::Copy() const return new NormalVariableImpl(*this); } +double +NormalVariableImpl::GetMean (void) const +{ + return m_mean; +} + +double +NormalVariableImpl::GetVariance (void) const +{ + return m_variance; +} + +double +NormalVariableImpl::GetBound (void) const +{ + return m_bound; +} + NormalVariable::NormalVariable() : RandomVariable (NormalVariableImpl ()) {} @@ -1280,6 +1302,17 @@ std::ostream &operator << (std::ostream &os, const RandomVariable &var) os << "Uniform:" << uniform->GetMin () << ":" << uniform->GetMax (); return os; } + NormalVariableImpl *normal = dynamic_cast (base); + if (normal != 0) + { + os << "Normal:" << normal->GetMean () << ":" << normal->GetVariance (); + double bound = normal->GetBound (); + if (bound != NormalVariableImpl::INFINITE_VALUE) + { + os << ":" << bound; + } + return os; + } // XXX: support other distributions os.setstate (std::ios_base::badbit); return os; @@ -1325,6 +1358,44 @@ std::istream &operator >> (std::istream &is, RandomVariable &var) var = UniformVariable (a, b); } } + else if (type == "Normal") + { + if (value.size () == 0) + { + var = NormalVariable (); + } + else + { + tmp = value.find (":"); + if (tmp == value.npos) + { + NS_FATAL_ERROR ("bad Normal value: " << value); + } + std::string::size_type tmp2; + std::string sub = value.substr (tmp + 1, value.npos); + tmp2 = sub.find (":"); + if (tmp2 == value.npos) + { + istringstream issA (value.substr (0, tmp)); + istringstream issB (sub); + double a, b; + issA >> a; + issB >> b; + var = NormalVariable (a, b); + } + else + { + istringstream issA (value.substr (0, tmp)); + istringstream issB (sub.substr (0, tmp2)); + istringstream issC (sub.substr (tmp2 + 1, value.npos)); + double a, b, c; + issA >> a; + issB >> b; + issC >> c; + var = NormalVariable (a, b, c); + } + } + } else { NS_FATAL_ERROR ("RandomVariable deserialization not implemented for " << type); @@ -1396,10 +1467,24 @@ public: // Test attribute serialization { - RandomVariableValue val; - val.DeserializeFromString ("Uniform:0.1:0.2", MakeRandomVariableChecker ()); - RandomVariable rng = val.Get (); - NS_TEST_ASSERT_EQUAL (val.SerializeToString (MakeRandomVariableChecker ()), "Uniform:0.1:0.2"); + { + RandomVariableValue val; + val.DeserializeFromString ("Uniform:0.1:0.2", MakeRandomVariableChecker ()); + RandomVariable rng = val.Get (); + NS_TEST_ASSERT_EQUAL (val.SerializeToString (MakeRandomVariableChecker ()), "Uniform:0.1:0.2"); + } + { + RandomVariableValue val; + val.DeserializeFromString ("Normal:0.1:0.2", MakeRandomVariableChecker ()); + RandomVariable rng = val.Get (); + NS_TEST_ASSERT_EQUAL (val.SerializeToString (MakeRandomVariableChecker ()), "Normal:0.1:0.2"); + } + { + RandomVariableValue val; + val.DeserializeFromString ("Normal:0.1:0.2:0.15", MakeRandomVariableChecker ()); + RandomVariable rng = val.Get (); + NS_TEST_ASSERT_EQUAL (val.SerializeToString (MakeRandomVariableChecker ()), "Normal:0.1:0.2:0.15"); + } } return result;