mtp: Add multithreaded parallel simulation support

This commit is contained in:
F5
2022-10-25 17:01:57 +08:00
parent 6764518fff
commit f818faabcd
11 changed files with 1847 additions and 0 deletions

15
src/mtp/CMakeLists.txt Normal file
View File

@@ -0,0 +1,15 @@
build_lib(
LIBNAME mtp
SOURCE_FILES
model/logical-process.cc
model/mtp-interface.cc
model/multithreaded-simulator-impl.cc
HEADER_FILES
model/logical-process.h
model/mtp-interface.h
model/multithreaded-simulator-impl.h
LIBRARIES_TO_LINK ${libcore}
${libnetwork}
TEST_SOURCES
test/mtp-test-suite.cc
)

98
src/mtp/doc/mtp.rst Normal file
View File

@@ -0,0 +1,98 @@
Example Module Documentation
----------------------------
.. include:: replace.txt
.. highlight:: cpp
.. heading hierarchy:
------------- Chapter
************* Section (#.#)
============= Subsection (#.#.#)
############# Paragraph (no number)
This is a suggested outline for adding new module documentation to |ns3|.
See ``src/click/doc/click.rst`` for an example.
The introductory paragraph is for describing what this code is trying to
model.
For consistency (italicized formatting), please use |ns3| to refer to
ns-3 in the documentation (and likewise, |ns2| for ns-2). These macros
are defined in the file ``replace.txt``.
Model Description
*****************
The source code for the new module lives in the directory ``contrib/mtp``.
Add here a basic description of what is being modeled.
Design
======
Briefly describe the software design of the model and how it fits into
the existing ns-3 architecture.
Scope and Limitations
=====================
What can the model do? What can it not do? Please use this section to
describe the scope and limitations of the model.
References
==========
Add academic citations here, such as if you published a paper on this
model, or if readers should read a particular specification or other work.
Usage
*****
This section is principally concerned with the usage of your model, using
the public API. Focus first on most common usage patterns, then go
into more advanced topics.
Building New Module
===================
Include this subsection only if there are special build instructions or
platform limitations.
Helpers
=======
What helper API will users typically use? Describe it here.
Attributes
==========
What classes hold attributes, and what are the key ones worth mentioning?
Output
======
What kind of data does the model generate? What are the key trace
sources? What kind of logging output can be enabled?
Advanced Usage
==============
Go into further details (such as using the API outside of the helpers)
in additional sections, as needed.
Examples
========
What examples using this new code are available? Describe them here.
Troubleshooting
===============
Add any tips for avoiding pitfalls, etc.
Validation
**********
Describe how the model has been tested/validated. What tests run in the
test suite? How much API and code is covered by the tests? Again,
references to outside published work may help here.

View File

@@ -0,0 +1,10 @@
build_lib_example(
NAME simple-mtp
SOURCE_FILES simple-mtp.cc
LIBRARIES_TO_LINK
${libmtp}
${libpoint-to-point}
${libinternet}
${libnix-vector-routing}
${libapplications}
)

View File

@@ -0,0 +1,244 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation;
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
/**
* \file
* \ingroup mtp
*
* TestDistributed creates a dumbbell topology and logically splits it in
* half. The left half is placed on logical processor 1 and the right half
* is placed on logical processor 2.
*
* ------- -------
* RANK 1 RANK 2
* ------- | -------
* |
* n0 ---------| | |---------- n6
* | | |
* n1 -------\ | | | /------- n7
* n4 ----------|---------- n5
* n2 -------/ | | | \------- n8
* | | |
* n3 ---------| | |---------- n9
*
*
* OnOff clients are placed on each left leaf node. Each right leaf node
* is a packet sink for a left leaf node. As a packet travels from one
* logical processor to another (the link between n4 and n5), MPI messages
* are passed containing the serialized packet. The message is then
* deserialized into a new packet and sent on as normal.
*
* One packet is sent from each left leaf node. The packet sinks on the
* right leaf nodes output logging information when they receive the packet.
*/
#include "ns3/core-module.h"
#include "ns3/network-module.h"
#include "ns3/multithreaded-simulator-impl.h"
#include "ns3/ipv4-global-routing-helper.h"
#include "ns3/point-to-point-helper.h"
#include "ns3/internet-stack-helper.h"
#include "ns3/nix-vector-helper.h"
#include "ns3/ipv4-address-helper.h"
#include "ns3/on-off-helper.h"
#include "ns3/packet-sink-helper.h"
#include "ns3/mtp-interface.h"
#include <iomanip>
using namespace ns3;
NS_LOG_COMPONENT_DEFINE ("SimpleMtp");
int
main (int argc, char *argv[])
{
bool nix = true;
bool tracing = false;
bool verbose = false;
// Parse command line
CommandLine cmd (__FILE__);
cmd.AddValue ("nix", "Enable the use of nix-vector or global routing", nix);
cmd.AddValue ("tracing", "Enable pcap tracing", tracing);
cmd.AddValue ("verbose", "verbose output", verbose);
cmd.Parse (argc, argv);
MtpInterface::Enable (2, 2);
GlobalValue::Bind ("SimulatorImplementationType",
StringValue ("ns3::MultithreadedSimulatorImpl"));
if (verbose)
{
LogComponentEnable ("PacketSink",
(LogLevel) (LOG_LEVEL_INFO | LOG_PREFIX_NODE | LOG_PREFIX_TIME));
}
// Some default values
Config::SetDefault ("ns3::OnOffApplication::PacketSize", UintegerValue (512));
Config::SetDefault ("ns3::OnOffApplication::DataRate", StringValue ("1Mbps"));
Config::SetDefault ("ns3::OnOffApplication::MaxBytes", UintegerValue (512));
// Create leaf nodes on left with system id 1
NodeContainer leftLeafNodes;
leftLeafNodes.Create (4, 1);
// Create router nodes. Left router
// with system id 0, right router with
// system id 1
NodeContainer routerNodes;
Ptr<Node> routerNode1 = CreateObject<Node> (1);
Ptr<Node> routerNode2 = CreateObject<Node> (2);
routerNodes.Add (routerNode1);
routerNodes.Add (routerNode2);
// Create leaf nodes on right with system id 2
NodeContainer rightLeafNodes;
rightLeafNodes.Create (4, 2);
PointToPointHelper routerLink;
routerLink.SetDeviceAttribute ("DataRate", StringValue ("5Mbps"));
routerLink.SetChannelAttribute ("Delay", StringValue ("5ms"));
PointToPointHelper leafLink;
leafLink.SetDeviceAttribute ("DataRate", StringValue ("1Mbps"));
leafLink.SetChannelAttribute ("Delay", StringValue ("2ms"));
// Add link connecting routers
NetDeviceContainer routerDevices;
routerDevices = routerLink.Install (routerNodes);
// Add links for left side leaf nodes to left router
NetDeviceContainer leftRouterDevices;
NetDeviceContainer leftLeafDevices;
for (uint32_t i = 0; i < 4; ++i)
{
NetDeviceContainer temp = leafLink.Install (leftLeafNodes.Get (i), routerNodes.Get (0));
leftLeafDevices.Add (temp.Get (0));
leftRouterDevices.Add (temp.Get (1));
}
// Add links for right side leaf nodes to right router
NetDeviceContainer rightRouterDevices;
NetDeviceContainer rightLeafDevices;
for (uint32_t i = 0; i < 4; ++i)
{
NetDeviceContainer temp = leafLink.Install (rightLeafNodes.Get (i), routerNodes.Get (1));
rightLeafDevices.Add (temp.Get (0));
rightRouterDevices.Add (temp.Get (1));
}
InternetStackHelper stack;
if (nix)
{
Ipv4NixVectorHelper nixRouting;
stack.SetRoutingHelper (nixRouting); // has effect on the next Install ()
}
stack.InstallAll ();
Ipv4InterfaceContainer routerInterfaces;
Ipv4InterfaceContainer leftLeafInterfaces;
Ipv4InterfaceContainer leftRouterInterfaces;
Ipv4InterfaceContainer rightLeafInterfaces;
Ipv4InterfaceContainer rightRouterInterfaces;
Ipv4AddressHelper leftAddress;
leftAddress.SetBase ("10.1.1.0", "255.255.255.0");
Ipv4AddressHelper routerAddress;
routerAddress.SetBase ("10.2.1.0", "255.255.255.0");
Ipv4AddressHelper rightAddress;
rightAddress.SetBase ("10.3.1.0", "255.255.255.0");
// Router-to-Router interfaces
routerInterfaces = routerAddress.Assign (routerDevices);
// Left interfaces
for (uint32_t i = 0; i < 4; ++i)
{
NetDeviceContainer ndc;
ndc.Add (leftLeafDevices.Get (i));
ndc.Add (leftRouterDevices.Get (i));
Ipv4InterfaceContainer ifc = leftAddress.Assign (ndc);
leftLeafInterfaces.Add (ifc.Get (0));
leftRouterInterfaces.Add (ifc.Get (1));
leftAddress.NewNetwork ();
}
// Right interfaces
for (uint32_t i = 0; i < 4; ++i)
{
NetDeviceContainer ndc;
ndc.Add (rightLeafDevices.Get (i));
ndc.Add (rightRouterDevices.Get (i));
Ipv4InterfaceContainer ifc = rightAddress.Assign (ndc);
rightLeafInterfaces.Add (ifc.Get (0));
rightRouterInterfaces.Add (ifc.Get (1));
rightAddress.NewNetwork ();
}
if (!nix)
{
Ipv4GlobalRoutingHelper::PopulateRoutingTables ();
}
if (tracing == true)
{
routerLink.EnablePcap ("router-left", routerDevices, true);
leafLink.EnablePcap ("leaf-left", leftLeafDevices, true);
routerLink.EnablePcap ("router-right", routerDevices, true);
leafLink.EnablePcap ("leaf-right", rightLeafDevices, true);
}
// Create a packet sink on the right leafs to receive packets from left leafs
uint16_t port = 50000;
Address sinkLocalAddress (InetSocketAddress (Ipv4Address::GetAny (), port));
PacketSinkHelper sinkHelper ("ns3::UdpSocketFactory", sinkLocalAddress);
ApplicationContainer sinkApp;
for (uint32_t i = 0; i < 4; ++i)
{
sinkApp.Add (sinkHelper.Install (rightLeafNodes.Get (i)));
}
sinkApp.Start (Seconds (1.0));
sinkApp.Stop (Seconds (5));
// Create the OnOff applications to send
OnOffHelper clientHelper ("ns3::UdpSocketFactory", Address ());
clientHelper.SetAttribute ("OnTime", StringValue ("ns3::ConstantRandomVariable[Constant=1]"));
clientHelper.SetAttribute ("OffTime", StringValue ("ns3::ConstantRandomVariable[Constant=0]"));
ApplicationContainer clientApps;
for (uint32_t i = 0; i < 4; ++i)
{
AddressValue remoteAddress (InetSocketAddress (rightLeafInterfaces.GetAddress (i), port));
clientHelper.SetAttribute ("Remote", remoteAddress);
clientApps.Add (clientHelper.Install (leftLeafNodes.Get (i)));
}
clientApps.Start (Seconds (1.0));
clientApps.Stop (Seconds (5));
Simulator::Stop (Seconds (5));
Simulator::Run ();
Simulator::Destroy ();
// Exit the MPI execution environment
return 0;
}

View File

@@ -0,0 +1,292 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#include "logical-process.h"
#include "mtp-interface.h"
#include "ns3/channel.h"
#include "ns3/node-container.h"
#include "ns3/simulator.h"
#include <algorithm>
namespace ns3 {
NS_LOG_COMPONENT_DEFINE ("LogicalProcess");
LogicalProcess::LogicalProcess ()
{
m_systemId = 0;
m_systemCount = 0;
m_uid = EventId::UID::VALID;
m_stop = false;
m_currentContext = Simulator::NO_CONTEXT;
m_currentUid = 0;
m_currentTs = 0;
m_eventCount = 0;
m_pendingEventCount = 0;
m_events = 0;
m_lookAhead = TimeStep (0);
}
LogicalProcess::~LogicalProcess ()
{
NS_LOG_INFO ("system " << m_systemId << " finished with event count " << m_eventCount);
// if others hold references to event list, do not unref events
if (m_events->GetReferenceCount () == 1)
{
while (!m_events->IsEmpty ())
{
Scheduler::Event next = m_events->RemoveNext ();
next.impl->Unref ();
}
}
}
void
LogicalProcess::Enable (const uint32_t systemId, const uint32_t systemCount)
{
m_systemId = systemId;
m_systemCount = systemCount;
}
void
LogicalProcess::CalculateLookAhead ()
{
NS_LOG_FUNCTION (this);
if (m_systemId == 0)
{
m_lookAhead = TimeStep (0); // No lookahead for public LP
}
else
{
m_lookAhead = Time::Max () / 2 - TimeStep (1);
NodeContainer c = NodeContainer::GetGlobal ();
for (NodeContainer::Iterator iter = c.Begin (); iter != c.End (); ++iter)
{
if ((*iter)->GetSystemId () != m_systemId)
{
continue;
}
for (uint32_t i = 0; i < (*iter)->GetNDevices (); ++i)
{
Ptr<NetDevice> localNetDevice = (*iter)->GetDevice (i);
// only works for p2p links currently
if (!localNetDevice->IsPointToPoint ())
{
continue;
}
Ptr<Channel> channel = localNetDevice->GetChannel ();
if (channel == 0)
{
continue;
}
// grab the adjacent node
Ptr<Node> remoteNode;
if (channel->GetDevice (0) == localNetDevice)
{
remoteNode = (channel->GetDevice (1))->GetNode ();
}
else
{
remoteNode = (channel->GetDevice (0))->GetNode ();
}
// if it's not remote, don't consider it
if (remoteNode->GetSystemId () == m_systemId)
{
continue;
}
// compare delay on the channel with current value of m_lookAhead.
// if delay on channel is smaller, make it the new lookAhead.
TimeValue delay;
channel->GetAttribute ("Delay", delay);
if (delay.Get () < m_lookAhead)
{
m_lookAhead = delay.Get ();
}
// add the neighbour to the mailbox
m_mailbox[remoteNode->GetSystemId ()];
}
}
}
NS_LOG_INFO ("lookahead of system " << m_systemId << " is set to " << m_lookAhead.GetTimeStep ());
}
void
LogicalProcess::ReceiveMessages ()
{
NS_LOG_FUNCTION (this);
m_pendingEventCount = 0;
for (auto &item : m_mailbox)
{
auto &queue = item.second;
std::sort (queue.begin (), queue.end (), std::greater<> ());
while (!queue.empty ())
{
auto &evWithTs = queue.back ();
Scheduler::Event &ev = std::get<3> (evWithTs);
ev.key.m_uid = m_uid++;
m_events->Insert (ev);
queue.pop_back ();
m_pendingEventCount++;
}
}
}
void
LogicalProcess::ProcessOneRound ()
{
NS_LOG_FUNCTION (this);
// set thread context
MtpInterface::SetSystem (m_systemId);
// calculate time window
Time grantedTime =
Min (MtpInterface::GetSmallestTime () + m_lookAhead, MtpInterface::GetNextPublicTime ());
auto start = std::chrono::system_clock::now ();
// process events
while (Next () <= grantedTime)
{
Scheduler::Event next = m_events->RemoveNext ();
m_eventCount++;
NS_LOG_LOGIC ("handle " << next.key.m_ts);
m_currentTs = next.key.m_ts;
m_currentContext = next.key.m_context;
m_currentUid = next.key.m_uid;
next.impl->Invoke ();
next.impl->Unref ();
}
auto end = std::chrono::system_clock::now ();
m_executionTime = std::chrono::duration_cast<std::chrono::nanoseconds> (end - start).count ();
}
EventId
LogicalProcess::Schedule (Time const &delay, EventImpl *event)
{
Scheduler::Event ev;
ev.impl = event;
ev.key.m_ts = m_currentTs + delay.GetTimeStep ();
ev.key.m_context = GetContext ();
ev.key.m_uid = m_uid++;
m_events->Insert (ev);
return EventId (event, ev.key.m_ts, ev.key.m_context, ev.key.m_uid);
}
void
LogicalProcess::ScheduleWithContext (LogicalProcess *remote, uint32_t context, Time const &delay,
EventImpl *event)
{
Scheduler::Event ev;
ev.impl = event;
ev.key.m_ts = delay.GetTimeStep () + m_currentTs;
ev.key.m_context = context;
if (remote == this)
{
ev.key.m_uid = m_uid++;
m_events->Insert (ev);
}
else
{
ev.key.m_uid = EventId::UID::INVALID;
remote->m_mailbox[m_systemId].push_back (
std::make_tuple (m_currentTs, m_systemId, m_uid, ev));
}
}
void
LogicalProcess::InvokeNow (Scheduler::Event const &ev)
{
uint32_t oldSystemId = MtpInterface::GetSystem ()->GetSystemId ();
MtpInterface::SetSystem (m_systemId);
m_eventCount++;
NS_LOG_LOGIC ("handle " << ev.key.m_ts);
m_currentTs = ev.key.m_ts;
m_currentContext = ev.key.m_context;
m_currentUid = ev.key.m_uid;
ev.impl->Invoke ();
ev.impl->Unref ();
// restore previous thread context
MtpInterface::SetSystem (oldSystemId);
}
void
LogicalProcess::Remove (const EventId &id)
{
if (IsExpired (id))
{
return;
}
Scheduler::Event event;
event.impl = id.PeekEventImpl ();
event.key.m_ts = id.GetTs ();
event.key.m_context = id.GetContext ();
event.key.m_uid = id.GetUid ();
m_events->Remove (event);
event.impl->Cancel ();
// whenever we remove an event from the event list, we have to unref it.
event.impl->Unref ();
}
bool
LogicalProcess::IsExpired (const EventId &id) const
{
if (id.PeekEventImpl () == 0 || id.GetTs () < m_currentTs ||
(id.GetTs () == m_currentTs && id.GetUid () <= m_currentUid) ||
id.PeekEventImpl ()->IsCancelled ())
{
return true;
}
else
{
return false;
}
}
void
LogicalProcess::SetScheduler (ObjectFactory schedulerFactory)
{
Ptr<Scheduler> scheduler = schedulerFactory.Create<Scheduler> ();
if (m_events != 0)
{
while (!m_events->IsEmpty ())
{
Scheduler::Event next = m_events->RemoveNext ();
scheduler->Insert (next);
}
}
m_events = scheduler;
}
Time
LogicalProcess::Next () const
{
if (m_stop || m_events->IsEmpty ())
{
return Time::Max ();
}
else
{
Scheduler::Event ev = m_events->PeekNext ();
return TimeStep (ev.key.m_ts);
}
}
} // namespace ns3

View File

@@ -0,0 +1,122 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#ifndef LOGICAL_PROCESS_H
#define LOGICAL_PROCESS_H
#include "ns3/event-id.h"
#include "ns3/event-impl.h"
#include "ns3/nstime.h"
#include "ns3/object-factory.h"
#include "ns3/ptr.h"
#include "ns3/scheduler.h"
#include <atomic>
#include <chrono>
#include <map>
#include <tuple>
#include <vector>
namespace ns3 {
class LogicalProcess
{
public:
LogicalProcess ();
~LogicalProcess ();
void Enable (const uint32_t systemId, const uint32_t systemCount);
void CalculateLookAhead ();
void ReceiveMessages ();
void ProcessOneRound ();
inline uint64_t
GetExecutionTime () const
{
return m_executionTime;
}
inline uint64_t
GetPendingEventCount () const
{
return m_pendingEventCount;
}
inline Ptr<Scheduler>
GetPendingEvents () const
{
return m_events;
}
// mapped from MultithreadedSimulatorImpl
EventId Schedule (Time const &delay, EventImpl *event);
void ScheduleWithContext (LogicalProcess *remote, uint32_t context, Time const &delay,
EventImpl *event);
void InvokeNow (Scheduler::Event const &ev); // cross context immediate invocation
void Remove (const EventId &id);
void Cancel (const EventId &id);
bool IsExpired (const EventId &id) const;
void SetScheduler (ObjectFactory schedulerFactory);
Time Next () const;
inline bool
isLocalFinished () const
{
return m_stop || m_events->IsEmpty ();
}
inline void
Stop ()
{
m_stop = true;
}
inline Time
Now () const
{
return TimeStep (m_currentTs);
}
inline Time
GetDelayLeft (const EventId &id) const
{
return TimeStep (id.GetTs () - m_currentTs);
}
inline uint32_t
GetSystemId (void) const
{
return m_systemId;
}
inline uint32_t
GetContext () const
{
return m_currentContext;
}
inline uint64_t
GetEventCount () const
{
return m_eventCount;
}
private:
uint32_t m_systemId;
uint32_t m_systemCount;
bool m_stop;
uint32_t m_uid;
uint32_t m_currentContext;
uint32_t m_currentUid;
uint64_t m_currentTs;
uint64_t m_eventCount;
uint64_t m_pendingEventCount;
Ptr<Scheduler> m_events;
Time m_lookAhead;
std::map<uint32_t, std::vector<std::tuple<uint64_t, uint32_t, uint32_t, Scheduler::Event>>>
m_mailbox; // event message mail box
std::chrono::nanoseconds::rep m_executionTime;
};
} // namespace ns3
#endif /* LOGICAL_PROCESS_H */

View File

@@ -0,0 +1,368 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#include "mtp-interface.h"
#include "ns3/assert.h"
#include "ns3/log.h"
#include "ns3/string.h"
#include "ns3/uinteger.h"
#include "ns3/config.h"
#include <algorithm>
#include <cmath>
namespace ns3 {
NS_LOG_COMPONENT_DEFINE ("MtpInterface");
void
MtpInterface::Enable ()
{
#ifdef NS3_MPI
GlobalValue::Bind ("SimulatorImplementationType", StringValue ("ns3::HybridSimulatorImpl"));
#else
GlobalValue::Bind ("SimulatorImplementationType",
StringValue ("ns3::MultithreadedSimulatorImpl"));
#endif
g_enabled = true;
}
void
MtpInterface::Enable (const uint32_t threadCount)
{
#ifdef NS3_MPI
Config::SetDefault ("ns3::HybridSimulatorImpl::MaxThreads", UintegerValue (threadCount));
#else
Config::SetDefault ("ns3::MultithreadedSimulatorImpl::MaxThreads", UintegerValue (threadCount));
#endif
MtpInterface::Enable ();
}
void
MtpInterface::Enable (const uint32_t threadCount, const uint32_t systemCount)
{
NS_ASSERT_MSG (threadCount > 0, "There must be at least one thread");
// called by manual partition
if (!g_enabled)
{
GlobalValue::Bind ("SimulatorImplementationType",
StringValue ("ns3::MultithreadedSimulatorImpl"));
}
// set size
g_threadCount = threadCount;
g_systemCount = systemCount;
// allocate systems
g_systems = new LogicalProcess[g_systemCount + 1]; // include the public LP
for (uint32_t i = 0; i <= g_systemCount; i++)
{
g_systems[i].Enable (i, g_systemCount + 1);
}
StringValue s;
g_sortMethod.GetValue (s);
if (s.Get () == "ByExecutionTime")
{
g_sortFunc = SortByExecutionTime;
}
else if (s.Get () == "ByPendingEventCount")
{
g_sortFunc = SortByPendingEventCount;
}
else if (s.Get () == "ByEventCount")
{
g_sortFunc = SortByEventCount;
}
else if (s.Get () == "BySimulationTime")
{
g_sortFunc = SortBySimulationTime;
}
UintegerValue ui;
g_sortPeriod.GetValue (ui);
if (ui.Get () == 0)
{
g_period = std::ceil (std::log2 (g_systemCount) / 4 + 1);
NS_LOG_INFO ("Secheduling period is automatically set to " << g_period);
}
else
{
g_period = ui.Get ();
}
// create a thread local storage key
// so that we can access the currently assigned LP of each thread
pthread_key_create (&g_key, nullptr);
pthread_setspecific (g_key, &g_systems[0]);
}
void
MtpInterface::EnableNew (const uint32_t newSystemCount)
{
const LogicalProcess *oldSystems = g_systems;
g_systems = new LogicalProcess[g_systemCount + newSystemCount + 1];
for (uint32_t i = 0; i <= g_systemCount; i++)
{
g_systems[i] = oldSystems[i];
}
delete[] oldSystems;
g_systemCount += newSystemCount;
for (uint32_t i = 0; i <= g_systemCount; i++)
{
g_systems[i].Enable (i, g_systemCount + 1);
}
}
void
MtpInterface::Disable ()
{
g_threadCount = 0;
g_systemCount = 0;
g_sortFunc = nullptr;
g_globalFinished = false;
delete[] g_systems;
delete[] g_threads;
delete[] g_sortedSystemIndices;
}
void
MtpInterface::Run ()
{
RunBefore ();
while (!g_globalFinished)
{
ProcessOneRound ();
CalculateSmallestTime ();
}
RunAfter ();
}
void
MtpInterface::RunBefore ()
{
CalculateLookAhead ();
// LP index for sorting & holding worker threads
g_sortedSystemIndices = new uint32_t[g_systemCount];
for (uint32_t i = 0; i < g_systemCount; i++)
{
g_sortedSystemIndices[i] = i + 1;
}
g_systemIndex.store (g_systemCount, std::memory_order_release);
// start threads
g_threads = new pthread_t[g_threadCount - 1]; // exclude the main thread
for (uint32_t i = 0; i < g_threadCount - 1; i++)
{
pthread_create (&g_threads[i], nullptr, ThreadFunc, nullptr);
}
}
void
MtpInterface::ProcessOneRound ()
{
// assign logical process to threads
// determine the priority of logical processes
if (g_sortFunc != nullptr && g_round++ % g_period == 0)
{
std::sort (g_sortedSystemIndices, g_sortedSystemIndices + g_systemCount, g_sortFunc);
}
// stage 1: process events
g_recvMsgStage = false;
g_finishedSystemCount.store (0, std::memory_order_relaxed);
g_systemIndex.store (0, std::memory_order_release);
// main thread also needs to process an LP to reduce an extra thread overhead
while (true)
{
uint32_t index = g_systemIndex.fetch_add (1, std::memory_order_acquire);
if (index >= g_systemCount)
{
break;
}
LogicalProcess *system = &g_systems[g_sortedSystemIndices[index]];
system->ProcessOneRound ();
g_finishedSystemCount.fetch_add (1, std::memory_order_release);
}
// logical process barriar synchronization
while (g_finishedSystemCount.load (std::memory_order_acquire) != g_systemCount)
;
// stage 2: process the public LP
g_systems[0].ProcessOneRound ();
// stage 3: receive messages
g_recvMsgStage = true;
g_finishedSystemCount.store (0, std::memory_order_relaxed);
g_systemIndex.store (0, std::memory_order_release);
while (true)
{
uint32_t index = g_systemIndex.fetch_add (1, std::memory_order_acquire);
if (index >= g_systemCount)
{
break;
}
LogicalProcess *system = &g_systems[g_sortedSystemIndices[index]];
system->ReceiveMessages ();
g_finishedSystemCount.fetch_add (1, std::memory_order_release);
}
// logical process barriar synchronization
while (g_finishedSystemCount.load (std::memory_order_acquire) != g_systemCount)
;
}
void
MtpInterface::CalculateSmallestTime ()
{
// update smallest time
g_smallestTime = Time::Max () / 2;
for (uint32_t i = 0; i <= g_systemCount; i++)
{
Time nextTime = g_systems[i].Next ();
if (nextTime < g_smallestTime)
{
g_smallestTime = nextTime;
}
}
g_nextPublicTime = g_systems[0].Next ();
// test if global finished
bool globalFinished = true;
for (uint32_t i = 0; i <= g_systemCount; i++)
{
globalFinished &= g_systems[i].isLocalFinished ();
}
g_globalFinished = globalFinished;
}
void
MtpInterface::RunAfter ()
{
// global finished, terminate threads
g_systemIndex.store (0, std::memory_order_release);
for (uint32_t i = 0; i < g_threadCount - 1; i++)
{
pthread_join (g_threads[i], nullptr);
}
}
bool
MtpInterface::isEnabled ()
{
return g_enabled;
}
bool
MtpInterface::isPartitioned ()
{
return g_threadCount != 0;
}
void
MtpInterface::CalculateLookAhead ()
{
for (uint32_t i = 1; i <= g_systemCount; i++)
{
g_systems[i].CalculateLookAhead ();
}
}
void *
MtpInterface::ThreadFunc (void *arg)
{
while (!g_globalFinished)
{
uint32_t index = g_systemIndex.fetch_add (1, std::memory_order_acquire);
if (index >= g_systemCount)
{
while (g_systemIndex.load (std::memory_order_acquire) >= g_systemCount)
;
continue;
}
LogicalProcess *system = &g_systems[g_sortedSystemIndices[index]];
if (g_recvMsgStage)
{
system->ReceiveMessages ();
}
else
{
system->ProcessOneRound ();
}
g_finishedSystemCount.fetch_add (1, std::memory_order_release);
}
return nullptr;
}
bool
MtpInterface::SortByExecutionTime (const uint32_t &i, const uint32_t &j)
{
return g_systems[i].GetExecutionTime () > g_systems[j].GetExecutionTime ();
}
bool
MtpInterface::SortByEventCount (const uint32_t &i, const uint32_t &j)
{
return g_systems[i].GetEventCount () > g_systems[j].GetEventCount ();
}
bool
MtpInterface::SortByPendingEventCount (const uint32_t &i, const uint32_t &j)
{
return g_systems[i].GetPendingEventCount () > g_systems[j].GetPendingEventCount ();
}
bool
MtpInterface::SortBySimulationTime (const uint32_t &i, const uint32_t &j)
{
return g_systems[i].Now () > g_systems[j].Now ();
}
bool (*MtpInterface::g_sortFunc) (const uint32_t &, const uint32_t &) = nullptr;
GlobalValue MtpInterface::g_sortMethod = GlobalValue (
"PartitionSchedulingMethod", "The scheduling method to determine which partition runs first",
StringValue ("ByExecutionTime"), MakeStringChecker ());
GlobalValue MtpInterface::g_sortPeriod =
GlobalValue ("PartitionSchedulingPeriod", "The scheduling period of partitions",
UintegerValue (0), MakeUintegerChecker<uint32_t> (0));
uint32_t MtpInterface::g_period = 0;
pthread_t *MtpInterface::g_threads = nullptr;
LogicalProcess *MtpInterface::g_systems = nullptr;
uint32_t MtpInterface::g_threadCount = 0;
uint32_t MtpInterface::g_systemCount = 0;
uint32_t *MtpInterface::g_sortedSystemIndices = nullptr;
std::atomic<uint32_t> MtpInterface::g_systemIndex;
std::atomic<uint32_t> MtpInterface::g_finishedSystemCount;
uint32_t MtpInterface::g_round = 0;
Time MtpInterface::g_smallestTime = TimeStep (0);
Time MtpInterface::g_nextPublicTime = TimeStep (0);
bool MtpInterface::g_recvMsgStage = false;
bool MtpInterface::g_globalFinished = false;
bool MtpInterface::g_enabled = false;
pthread_key_t MtpInterface::g_key;
std::atomic<bool> MtpInterface::g_inCriticalSection (false);
} // namespace ns3

View File

@@ -0,0 +1,160 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#ifndef MTP_INTERFACE_H
#define MTP_INTERFACE_H
#include "logical-process.h"
#include "ns3/atomic-counter.h"
#include "ns3/global-value.h"
#include "ns3/nstime.h"
#include "ns3/simulator.h"
#include <pthread.h>
namespace ns3 {
class MtpInterface
{
public:
class CriticalSection
{
public:
inline CriticalSection ()
{
while (g_inCriticalSection.exchange (true, std::memory_order_acquire))
;
}
inline ~CriticalSection ()
{
g_inCriticalSection.store (false, std::memory_order_release);
}
};
static void Enable (); // auto topology partition
static void Enable (const uint32_t threadCount); // auto partition, specify thread count
static void Enable (const uint32_t threadCount, const uint32_t systemCount); // manual partition
static void EnableNew (const uint32_t newSystemCount); // add LPs for dynamic added node
static void Disable ();
static void Run ();
static void RunBefore ();
static void ProcessOneRound ();
static void CalculateSmallestTime ();
static void RunAfter ();
static bool isEnabled ();
static bool isPartitioned ();
static void CalculateLookAhead ();
// get current thread's executing logical process
inline static LogicalProcess *
GetSystem ()
{
return static_cast<LogicalProcess *> (pthread_getspecific (g_key));
}
inline static LogicalProcess *
GetSystem (const uint32_t systemId)
{
return &g_systems[systemId];
}
// set current thread's executing logical process
inline static void
SetSystem (const uint32_t systemId)
{
pthread_setspecific (g_key, &g_systems[systemId]);
}
inline static uint32_t
GetSize ()
{
return g_systemCount + 1;
}
inline static uint32_t
GetRound ()
{
return g_round;
}
inline static Time
GetSmallestTime ()
{
return g_smallestTime;
}
inline static void
SetSmallestTime (const Time smallestTime)
{
g_smallestTime = smallestTime;
}
inline static Time
GetNextPublicTime ()
{
return g_nextPublicTime;
}
inline static bool
isFinished ()
{
return g_globalFinished;
}
template <typename FUNC,
typename std::enable_if<!std::is_convertible<FUNC, Ptr<EventImpl>>::value, int>::type,
typename std::enable_if<
!std::is_function<typename std::remove_pointer<FUNC>::type>::value, int>::type,
typename... Ts>
inline static void
ScheduleGlobal (FUNC f, Ts &&...args)
{
CriticalSection cs;
g_systems[0].ScheduleAt (Simulator::NO_CONTEXT, Min (g_smallestTime, g_nextPublicTime),
MakeEvent (f, std::forward<Ts> (args)...));
}
template <typename... Us, typename... Ts>
inline static void
ScheduleGlobal (void (*f) (Us...), Ts &&...args)
{
CriticalSection cs;
g_systems[0].ScheduleAt (Simulator::NO_CONTEXT, Min (g_smallestTime, g_nextPublicTime),
MakeEvent (f, std::forward<Ts> (args)...));
}
private:
static void *ThreadFunc (void *arg);
// determine logical process priority
static bool SortByExecutionTime (const uint32_t &i, const uint32_t &j);
static bool SortByEventCount (const uint32_t &i, const uint32_t &j);
static bool SortByPendingEventCount (const uint32_t &i, const uint32_t &j);
static bool SortBySimulationTime (const uint32_t &i, const uint32_t &j);
static bool (*g_sortFunc) (const uint32_t &, const uint32_t &);
static GlobalValue g_sortMethod;
static GlobalValue g_sortPeriod;
static uint32_t g_period;
static pthread_t *g_threads;
static LogicalProcess *g_systems;
static uint32_t g_threadCount;
static uint32_t g_systemCount;
static uint32_t *g_sortedSystemIndices;
static std::atomic<uint32_t> g_systemIndex;
static std::atomic<uint32_t> g_finishedSystemCount;
static uint32_t g_round;
static Time g_smallestTime;
static Time g_nextPublicTime;
static bool g_recvMsgStage;
static bool g_globalFinished;
static bool g_enabled;
static pthread_key_t g_key;
static std::atomic<bool> g_inCriticalSection;
};
} // namespace ns3
#endif /* MTP_INTERFACE_H */

View File

@@ -0,0 +1,409 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#include "multithreaded-simulator-impl.h"
#include "mtp-interface.h"
#include "ns3/type-id.h"
#include "ns3/channel.h"
#include "ns3/simulator.h"
#include "ns3/node.h"
#include "ns3/node-container.h"
#include "ns3/node-list.h"
#include "ns3/uinteger.h"
#include <algorithm>
#include <queue>
#include <thread>
namespace ns3 {
NS_LOG_COMPONENT_DEFINE ("MultithreadedSimulatorImpl");
NS_OBJECT_ENSURE_REGISTERED (MultithreadedSimulatorImpl);
MultithreadedSimulatorImpl::MultithreadedSimulatorImpl ()
{
NS_LOG_FUNCTION (this);
if (!MtpInterface::isPartitioned ())
{
MtpInterface::Enable (1, 0);
m_partition = true;
}
else
{
m_partition = false;
}
}
MultithreadedSimulatorImpl::~MultithreadedSimulatorImpl ()
{
NS_LOG_FUNCTION (this);
}
TypeId
MultithreadedSimulatorImpl::GetTypeId ()
{
static TypeId tid =
TypeId ("ns3::MultithreadedSimulatorImpl")
.SetParent<SimulatorImpl> ()
.SetGroupName ("Mtp")
.AddConstructor<MultithreadedSimulatorImpl> ()
.AddAttribute ("MaxThreads", "The maximum threads used in simulation",
UintegerValue (std::thread::hardware_concurrency ()),
MakeUintegerAccessor (&MultithreadedSimulatorImpl::m_maxThreads),
MakeUintegerChecker<uint32_t> (1))
.AddAttribute ("MinLookahead", "The minimum lookahead in a partition",
TimeValue (TimeStep (0)),
MakeTimeAccessor (&MultithreadedSimulatorImpl::m_minLookahead),
MakeTimeChecker (TimeStep (0)));
return tid;
}
void
MultithreadedSimulatorImpl::Destroy ()
{
while (!m_destroyEvents.empty ())
{
Ptr<EventImpl> ev = m_destroyEvents.front ().PeekEventImpl ();
m_destroyEvents.pop_front ();
NS_LOG_LOGIC ("handle destroy " << ev);
if (!ev->IsCancelled ())
{
ev->Invoke ();
}
}
MtpInterface::Disable ();
}
bool
MultithreadedSimulatorImpl::IsFinished () const
{
return MtpInterface::isFinished ();
}
void
MultithreadedSimulatorImpl::Stop ()
{
NS_LOG_FUNCTION (this);
for (uint32_t i = 0; i < MtpInterface::GetSize (); i++)
{
MtpInterface::GetSystem (i)->Stop ();
}
}
void
MultithreadedSimulatorImpl::Stop (Time const &delay)
{
NS_LOG_FUNCTION (this << delay.GetTimeStep ());
Simulator::Schedule (delay, &Simulator::Stop);
}
EventId
MultithreadedSimulatorImpl::Schedule (Time const &delay, EventImpl *event)
{
NS_LOG_FUNCTION (this << delay.GetTimeStep () << event);
return MtpInterface::GetSystem ()->Schedule (delay, event);
}
void
MultithreadedSimulatorImpl::ScheduleWithContext (uint32_t context, Time const &delay,
EventImpl *event)
{
NS_LOG_FUNCTION (this << context << delay.GetTimeStep () << event);
LogicalProcess *remote = MtpInterface::GetSystem (NodeList::GetNode (context)->GetSystemId ());
MtpInterface::GetSystem ()->ScheduleWithContext (remote, context, delay, event);
}
EventId
MultithreadedSimulatorImpl::ScheduleNow (EventImpl *event)
{
return Schedule (TimeStep (0), event);
}
EventId
MultithreadedSimulatorImpl::ScheduleDestroy (EventImpl *event)
{
EventId id (Ptr<EventImpl> (event, false), GetMaximumSimulationTime ().GetTimeStep (), 0xffffffff,
EventId::DESTROY);
MtpInterface::CriticalSection cs;
m_destroyEvents.push_back (id);
return id;
}
void
MultithreadedSimulatorImpl::Remove (const EventId &id)
{
if (id.GetUid () == EventId::DESTROY)
{
// destroy events.
for (std::list<EventId>::iterator i = m_destroyEvents.begin (); i != m_destroyEvents.end ();
i++)
{
if (*i == id)
{
m_destroyEvents.erase (i);
break;
}
}
}
else
{
MtpInterface::GetSystem ()->Remove (id);
}
}
void
MultithreadedSimulatorImpl::Cancel (const EventId &id)
{
if (!IsExpired (id))
{
id.PeekEventImpl ()->Cancel ();
}
}
bool
MultithreadedSimulatorImpl::IsExpired (const EventId &id) const
{
if (id.GetUid () == EventId::DESTROY)
{
// destroy events.
if (id.PeekEventImpl () == nullptr || id.PeekEventImpl ()->IsCancelled ())
{
return true;
}
for (std::list<EventId>::const_iterator i = m_destroyEvents.begin ();
i != m_destroyEvents.end (); i++)
{
if (*i == id)
{
return false;
}
}
return true;
}
else
{
return MtpInterface::GetSystem ()->IsExpired (id);
}
}
void
MultithreadedSimulatorImpl::Run ()
{
NS_LOG_FUNCTION (this);
// auto partition
if (m_partition)
{
Partition ();
}
MtpInterface::Run ();
}
Time
MultithreadedSimulatorImpl::Now (void) const
{
// Do not add function logging here, to avoid stack overflow
return MtpInterface::GetSystem ()->Now ();
}
Time
MultithreadedSimulatorImpl::GetDelayLeft (const EventId &id) const
{
if (IsExpired (id))
{
return TimeStep (0);
}
else
{
return MtpInterface::GetSystem ()->GetDelayLeft (id);
}
}
Time
MultithreadedSimulatorImpl::GetMaximumSimulationTime (void) const
{
return Time::Max () / 2;
}
void
MultithreadedSimulatorImpl::SetScheduler (ObjectFactory schedulerFactory)
{
NS_LOG_FUNCTION (this << schedulerFactory);
for (uint32_t i = 0; i < MtpInterface::GetSize (); i++)
{
MtpInterface::GetSystem (i)->SetScheduler (schedulerFactory);
}
m_schedulerTypeId = schedulerFactory.GetTypeId ();
}
uint32_t
MultithreadedSimulatorImpl::GetSystemId (void) const
{
return MtpInterface::GetSystem ()->GetSystemId ();
}
uint32_t
MultithreadedSimulatorImpl::GetContext () const
{
return MtpInterface::GetSystem ()->GetContext ();
}
uint64_t
MultithreadedSimulatorImpl::GetEventCount (void) const
{
uint64_t eventCount = 0;
for (uint32_t i = 0; i < MtpInterface::GetSize (); i++)
{
eventCount += MtpInterface::GetSystem (i)->GetEventCount ();
}
return eventCount;
}
void
MultithreadedSimulatorImpl::DoDispose ()
{
SimulatorImpl::DoDispose ();
}
void
MultithreadedSimulatorImpl::Partition ()
{
NS_LOG_FUNCTION (this);
uint32_t systemId = 0;
const NodeContainer nodes = NodeContainer::GetGlobal ();
bool *visited = new bool[nodes.GetN ()]{false};
std::queue<Ptr<Node>> q;
// if m_minLookahead is not set, calculate the median of delay for every link
if (m_minLookahead == TimeStep (0))
{
std::vector<Time> delays;
for (NodeContainer::Iterator it = nodes.Begin (); it != nodes.End (); it++)
{
Ptr<Node> node = *it;
for (uint32_t i = 0; i < node->GetNDevices (); i++)
{
Ptr<NetDevice> localNetDevice = node->GetDevice (i);
Ptr<Channel> channel = localNetDevice->GetChannel ();
if (channel == 0)
{
continue;
}
// cut-off p2p links for partition
if (localNetDevice->IsPointToPoint ())
{
TimeValue delay;
channel->GetAttribute ("Delay", delay);
delays.push_back (delay.Get ());
}
}
}
std::sort (delays.begin (), delays.end ());
if (delays.size () % 2 == 1)
{
m_minLookahead = delays[delays.size () / 2];
}
else
{
m_minLookahead = (delays[delays.size () / 2 - 1] + delays[delays.size () / 2]) / 2;
}
NS_LOG_INFO ("Min lookahead is set to " << m_minLookahead);
}
// perform a BFS on the whole network topo to assign each node a systemId
for (NodeContainer::Iterator it = nodes.Begin (); it != nodes.End (); it++)
{
Ptr<Node> node = *it;
if (!visited[node->GetId ()])
{
q.push (node);
systemId++;
while (!q.empty ())
{
// pop from BFS queue
node = q.front ();
q.pop ();
visited[node->GetId ()] = true;
// assign this node the current systemId
node->SetSystemId (systemId);
NS_LOG_INFO ("node " << node->GetId () << " is set to system " << systemId);
for (uint32_t i = 0; i < node->GetNDevices (); i++)
{
Ptr<NetDevice> localNetDevice = node->GetDevice (i);
Ptr<Channel> channel = localNetDevice->GetChannel ();
if (channel == 0)
{
continue;
}
// cut-off p2p links for partition
if (localNetDevice->IsPointToPoint ())
{
TimeValue delay;
channel->GetAttribute ("Delay", delay);
// if delay is below threshold, do not cut-off
if (delay.Get () >= m_minLookahead)
{
continue;
}
}
// grab the adjacent nodes
for (uint32_t j = 0; j < channel->GetNDevices (); j++)
{
Ptr<Node> remote = channel->GetDevice (j)->GetNode ();
// if it's not visited, add it to the current partition
if (!visited[remote->GetId ()])
{
q.push (remote);
}
}
}
}
}
}
delete[] visited;
// after the partition, we finally know the system count (# of LPs)
const uint32_t systemCount = systemId;
const uint32_t threadCount = std::min (m_maxThreads, systemCount);
NS_LOG_INFO ("Partition done! " << systemCount << " systems share " << threadCount << " threads");
// create new LPs
const Ptr<Scheduler> events = MtpInterface::GetSystem ()->GetPendingEvents ();
MtpInterface::Disable ();
MtpInterface::Enable (threadCount, systemCount);
// set scheduler
ObjectFactory schedulerFactory;
schedulerFactory.SetTypeId (m_schedulerTypeId);
for (uint32_t i = 0; i <= systemCount; i++)
{
MtpInterface::GetSystem (i)->SetScheduler (schedulerFactory);
}
// transfer events to new LPs
while (!events->IsEmpty ())
{
Scheduler::Event ev = events->RemoveNext ();
// invoke initialization events (at time 0) by their insertion order
// since changing the execution order of these events may cause error,
// they have to be invoked now rather than parallelly executed
if (ev.key.m_ts == 0)
{
MtpInterface::GetSystem (ev.key.m_context == Simulator::NO_CONTEXT
? 0
: NodeList::GetNode (ev.key.m_context)->GetSystemId ())
->InvokeNow (ev);
}
else if (ev.key.m_context == Simulator::NO_CONTEXT)
{
Schedule (TimeStep (ev.key.m_ts), ev.impl);
}
else
{
ScheduleWithContext (ev.key.m_context, TimeStep (ev.key.m_ts), ev.impl);
}
}
}
} // namespace ns3

View File

@@ -0,0 +1,61 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#ifndef MULTITHREADED_SIMULATOR_IMPL_H
#define MULTITHREADED_SIMULATOR_IMPL_H
#include "ns3/event-id.h"
#include "ns3/event-impl.h"
#include "ns3/nstime.h"
#include "ns3/object-factory.h"
#include "ns3/simulator-impl.h"
#include <list>
namespace ns3 {
class MultithreadedSimulatorImpl : public SimulatorImpl
{
public:
static TypeId GetTypeId ();
/** Default constructor. */
MultithreadedSimulatorImpl ();
/** Destructor. */
~MultithreadedSimulatorImpl ();
// virtual from SimulatorImpl
virtual void Destroy ();
virtual bool IsFinished () const;
virtual void Stop ();
virtual void Stop (Time const &delay);
virtual EventId Schedule (Time const &delay, EventImpl *event);
virtual void ScheduleWithContext (uint32_t context, Time const &delay, EventImpl *event);
virtual EventId ScheduleNow (EventImpl *event);
virtual EventId ScheduleDestroy (EventImpl *event);
virtual void Remove (const EventId &id);
virtual void Cancel (const EventId &id);
virtual bool IsExpired (const EventId &id) const;
virtual void Run ();
virtual Time Now () const;
virtual Time GetDelayLeft (const EventId &id) const;
virtual Time GetMaximumSimulationTime () const;
virtual void SetScheduler (ObjectFactory schedulerFactory);
virtual uint32_t GetSystemId () const;
virtual uint32_t GetContext () const;
virtual uint64_t GetEventCount () const;
private:
// Inherited from Object
virtual void DoDispose ();
void Partition ();
bool m_partition;
uint32_t m_maxThreads;
Time m_minLookahead;
TypeId m_schedulerTypeId;
std::list<EventId> m_destroyEvents;
};
} // namespace ns3
#endif /* MULTITHREADED_SIMULATOR_IMPL_H */

View File

@@ -0,0 +1,68 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
// Include a header file from your module to test.
#include "ns3/mtp.h"
// An essential include is test.h
#include "ns3/test.h"
// Do not put your test classes in namespace ns3. You may find it useful
// to use the using directive to access the ns3 namespace directly
using namespace ns3;
// This is an example TestCase.
class MtpTestCase1 : public TestCase
{
public:
MtpTestCase1 ();
virtual ~MtpTestCase1 ();
private:
virtual void DoRun (void);
};
// Add some help text to this case to describe what it is intended to test
MtpTestCase1::MtpTestCase1 ()
: TestCase ("Mtp test case (does nothing)")
{
}
// This destructor does nothing but we include it as a reminder that
// the test case should clean up after itself
MtpTestCase1::~MtpTestCase1 ()
{
}
//
// This method is the pure virtual method from class TestCase that every
// TestCase must implement
//
void
MtpTestCase1::DoRun (void)
{
// A wide variety of test macros are available in src/core/test.h
NS_TEST_ASSERT_MSG_EQ (true, true, "true doesn't equal true for some reason");
// Use this one for floating point comparisons
NS_TEST_ASSERT_MSG_EQ_TOL (0.01, 0.01, 0.001, "Numbers are not equal within tolerance");
}
// The TestSuite class names the TestSuite, identifies what type of TestSuite,
// and enables the TestCases to be run. Typically, only the constructor for
// this class must be defined
//
class MtpTestSuite : public TestSuite
{
public:
MtpTestSuite ();
};
MtpTestSuite::MtpTestSuite ()
: TestSuite ("mtp", UNIT)
{
// TestDuration for TestCase can be QUICK, EXTENSIVE or TAKES_FOREVER
AddTestCase (new MtpTestCase1, TestCase::QUICK);
}
// Do not forget to allocate an instance of this TestSuite
static MtpTestSuite smtpTestSuite;