mtp: Add multithreaded parallel simulation support

This commit is contained in:
F5
2022-10-25 17:01:57 +08:00
parent 6764518fff
commit f818faabcd
11 changed files with 1847 additions and 0 deletions

View File

@@ -0,0 +1,292 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#include "logical-process.h"
#include "mtp-interface.h"
#include "ns3/channel.h"
#include "ns3/node-container.h"
#include "ns3/simulator.h"
#include <algorithm>
namespace ns3 {
NS_LOG_COMPONENT_DEFINE ("LogicalProcess");
LogicalProcess::LogicalProcess ()
{
m_systemId = 0;
m_systemCount = 0;
m_uid = EventId::UID::VALID;
m_stop = false;
m_currentContext = Simulator::NO_CONTEXT;
m_currentUid = 0;
m_currentTs = 0;
m_eventCount = 0;
m_pendingEventCount = 0;
m_events = 0;
m_lookAhead = TimeStep (0);
}
LogicalProcess::~LogicalProcess ()
{
NS_LOG_INFO ("system " << m_systemId << " finished with event count " << m_eventCount);
// if others hold references to event list, do not unref events
if (m_events->GetReferenceCount () == 1)
{
while (!m_events->IsEmpty ())
{
Scheduler::Event next = m_events->RemoveNext ();
next.impl->Unref ();
}
}
}
void
LogicalProcess::Enable (const uint32_t systemId, const uint32_t systemCount)
{
m_systemId = systemId;
m_systemCount = systemCount;
}
void
LogicalProcess::CalculateLookAhead ()
{
NS_LOG_FUNCTION (this);
if (m_systemId == 0)
{
m_lookAhead = TimeStep (0); // No lookahead for public LP
}
else
{
m_lookAhead = Time::Max () / 2 - TimeStep (1);
NodeContainer c = NodeContainer::GetGlobal ();
for (NodeContainer::Iterator iter = c.Begin (); iter != c.End (); ++iter)
{
if ((*iter)->GetSystemId () != m_systemId)
{
continue;
}
for (uint32_t i = 0; i < (*iter)->GetNDevices (); ++i)
{
Ptr<NetDevice> localNetDevice = (*iter)->GetDevice (i);
// only works for p2p links currently
if (!localNetDevice->IsPointToPoint ())
{
continue;
}
Ptr<Channel> channel = localNetDevice->GetChannel ();
if (channel == 0)
{
continue;
}
// grab the adjacent node
Ptr<Node> remoteNode;
if (channel->GetDevice (0) == localNetDevice)
{
remoteNode = (channel->GetDevice (1))->GetNode ();
}
else
{
remoteNode = (channel->GetDevice (0))->GetNode ();
}
// if it's not remote, don't consider it
if (remoteNode->GetSystemId () == m_systemId)
{
continue;
}
// compare delay on the channel with current value of m_lookAhead.
// if delay on channel is smaller, make it the new lookAhead.
TimeValue delay;
channel->GetAttribute ("Delay", delay);
if (delay.Get () < m_lookAhead)
{
m_lookAhead = delay.Get ();
}
// add the neighbour to the mailbox
m_mailbox[remoteNode->GetSystemId ()];
}
}
}
NS_LOG_INFO ("lookahead of system " << m_systemId << " is set to " << m_lookAhead.GetTimeStep ());
}
void
LogicalProcess::ReceiveMessages ()
{
NS_LOG_FUNCTION (this);
m_pendingEventCount = 0;
for (auto &item : m_mailbox)
{
auto &queue = item.second;
std::sort (queue.begin (), queue.end (), std::greater<> ());
while (!queue.empty ())
{
auto &evWithTs = queue.back ();
Scheduler::Event &ev = std::get<3> (evWithTs);
ev.key.m_uid = m_uid++;
m_events->Insert (ev);
queue.pop_back ();
m_pendingEventCount++;
}
}
}
void
LogicalProcess::ProcessOneRound ()
{
NS_LOG_FUNCTION (this);
// set thread context
MtpInterface::SetSystem (m_systemId);
// calculate time window
Time grantedTime =
Min (MtpInterface::GetSmallestTime () + m_lookAhead, MtpInterface::GetNextPublicTime ());
auto start = std::chrono::system_clock::now ();
// process events
while (Next () <= grantedTime)
{
Scheduler::Event next = m_events->RemoveNext ();
m_eventCount++;
NS_LOG_LOGIC ("handle " << next.key.m_ts);
m_currentTs = next.key.m_ts;
m_currentContext = next.key.m_context;
m_currentUid = next.key.m_uid;
next.impl->Invoke ();
next.impl->Unref ();
}
auto end = std::chrono::system_clock::now ();
m_executionTime = std::chrono::duration_cast<std::chrono::nanoseconds> (end - start).count ();
}
EventId
LogicalProcess::Schedule (Time const &delay, EventImpl *event)
{
Scheduler::Event ev;
ev.impl = event;
ev.key.m_ts = m_currentTs + delay.GetTimeStep ();
ev.key.m_context = GetContext ();
ev.key.m_uid = m_uid++;
m_events->Insert (ev);
return EventId (event, ev.key.m_ts, ev.key.m_context, ev.key.m_uid);
}
void
LogicalProcess::ScheduleWithContext (LogicalProcess *remote, uint32_t context, Time const &delay,
EventImpl *event)
{
Scheduler::Event ev;
ev.impl = event;
ev.key.m_ts = delay.GetTimeStep () + m_currentTs;
ev.key.m_context = context;
if (remote == this)
{
ev.key.m_uid = m_uid++;
m_events->Insert (ev);
}
else
{
ev.key.m_uid = EventId::UID::INVALID;
remote->m_mailbox[m_systemId].push_back (
std::make_tuple (m_currentTs, m_systemId, m_uid, ev));
}
}
void
LogicalProcess::InvokeNow (Scheduler::Event const &ev)
{
uint32_t oldSystemId = MtpInterface::GetSystem ()->GetSystemId ();
MtpInterface::SetSystem (m_systemId);
m_eventCount++;
NS_LOG_LOGIC ("handle " << ev.key.m_ts);
m_currentTs = ev.key.m_ts;
m_currentContext = ev.key.m_context;
m_currentUid = ev.key.m_uid;
ev.impl->Invoke ();
ev.impl->Unref ();
// restore previous thread context
MtpInterface::SetSystem (oldSystemId);
}
void
LogicalProcess::Remove (const EventId &id)
{
if (IsExpired (id))
{
return;
}
Scheduler::Event event;
event.impl = id.PeekEventImpl ();
event.key.m_ts = id.GetTs ();
event.key.m_context = id.GetContext ();
event.key.m_uid = id.GetUid ();
m_events->Remove (event);
event.impl->Cancel ();
// whenever we remove an event from the event list, we have to unref it.
event.impl->Unref ();
}
bool
LogicalProcess::IsExpired (const EventId &id) const
{
if (id.PeekEventImpl () == 0 || id.GetTs () < m_currentTs ||
(id.GetTs () == m_currentTs && id.GetUid () <= m_currentUid) ||
id.PeekEventImpl ()->IsCancelled ())
{
return true;
}
else
{
return false;
}
}
void
LogicalProcess::SetScheduler (ObjectFactory schedulerFactory)
{
Ptr<Scheduler> scheduler = schedulerFactory.Create<Scheduler> ();
if (m_events != 0)
{
while (!m_events->IsEmpty ())
{
Scheduler::Event next = m_events->RemoveNext ();
scheduler->Insert (next);
}
}
m_events = scheduler;
}
Time
LogicalProcess::Next () const
{
if (m_stop || m_events->IsEmpty ())
{
return Time::Max ();
}
else
{
Scheduler::Event ev = m_events->PeekNext ();
return TimeStep (ev.key.m_ts);
}
}
} // namespace ns3

View File

@@ -0,0 +1,122 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#ifndef LOGICAL_PROCESS_H
#define LOGICAL_PROCESS_H
#include "ns3/event-id.h"
#include "ns3/event-impl.h"
#include "ns3/nstime.h"
#include "ns3/object-factory.h"
#include "ns3/ptr.h"
#include "ns3/scheduler.h"
#include <atomic>
#include <chrono>
#include <map>
#include <tuple>
#include <vector>
namespace ns3 {
class LogicalProcess
{
public:
LogicalProcess ();
~LogicalProcess ();
void Enable (const uint32_t systemId, const uint32_t systemCount);
void CalculateLookAhead ();
void ReceiveMessages ();
void ProcessOneRound ();
inline uint64_t
GetExecutionTime () const
{
return m_executionTime;
}
inline uint64_t
GetPendingEventCount () const
{
return m_pendingEventCount;
}
inline Ptr<Scheduler>
GetPendingEvents () const
{
return m_events;
}
// mapped from MultithreadedSimulatorImpl
EventId Schedule (Time const &delay, EventImpl *event);
void ScheduleWithContext (LogicalProcess *remote, uint32_t context, Time const &delay,
EventImpl *event);
void InvokeNow (Scheduler::Event const &ev); // cross context immediate invocation
void Remove (const EventId &id);
void Cancel (const EventId &id);
bool IsExpired (const EventId &id) const;
void SetScheduler (ObjectFactory schedulerFactory);
Time Next () const;
inline bool
isLocalFinished () const
{
return m_stop || m_events->IsEmpty ();
}
inline void
Stop ()
{
m_stop = true;
}
inline Time
Now () const
{
return TimeStep (m_currentTs);
}
inline Time
GetDelayLeft (const EventId &id) const
{
return TimeStep (id.GetTs () - m_currentTs);
}
inline uint32_t
GetSystemId (void) const
{
return m_systemId;
}
inline uint32_t
GetContext () const
{
return m_currentContext;
}
inline uint64_t
GetEventCount () const
{
return m_eventCount;
}
private:
uint32_t m_systemId;
uint32_t m_systemCount;
bool m_stop;
uint32_t m_uid;
uint32_t m_currentContext;
uint32_t m_currentUid;
uint64_t m_currentTs;
uint64_t m_eventCount;
uint64_t m_pendingEventCount;
Ptr<Scheduler> m_events;
Time m_lookAhead;
std::map<uint32_t, std::vector<std::tuple<uint64_t, uint32_t, uint32_t, Scheduler::Event>>>
m_mailbox; // event message mail box
std::chrono::nanoseconds::rep m_executionTime;
};
} // namespace ns3
#endif /* LOGICAL_PROCESS_H */

View File

@@ -0,0 +1,368 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#include "mtp-interface.h"
#include "ns3/assert.h"
#include "ns3/log.h"
#include "ns3/string.h"
#include "ns3/uinteger.h"
#include "ns3/config.h"
#include <algorithm>
#include <cmath>
namespace ns3 {
NS_LOG_COMPONENT_DEFINE ("MtpInterface");
void
MtpInterface::Enable ()
{
#ifdef NS3_MPI
GlobalValue::Bind ("SimulatorImplementationType", StringValue ("ns3::HybridSimulatorImpl"));
#else
GlobalValue::Bind ("SimulatorImplementationType",
StringValue ("ns3::MultithreadedSimulatorImpl"));
#endif
g_enabled = true;
}
void
MtpInterface::Enable (const uint32_t threadCount)
{
#ifdef NS3_MPI
Config::SetDefault ("ns3::HybridSimulatorImpl::MaxThreads", UintegerValue (threadCount));
#else
Config::SetDefault ("ns3::MultithreadedSimulatorImpl::MaxThreads", UintegerValue (threadCount));
#endif
MtpInterface::Enable ();
}
void
MtpInterface::Enable (const uint32_t threadCount, const uint32_t systemCount)
{
NS_ASSERT_MSG (threadCount > 0, "There must be at least one thread");
// called by manual partition
if (!g_enabled)
{
GlobalValue::Bind ("SimulatorImplementationType",
StringValue ("ns3::MultithreadedSimulatorImpl"));
}
// set size
g_threadCount = threadCount;
g_systemCount = systemCount;
// allocate systems
g_systems = new LogicalProcess[g_systemCount + 1]; // include the public LP
for (uint32_t i = 0; i <= g_systemCount; i++)
{
g_systems[i].Enable (i, g_systemCount + 1);
}
StringValue s;
g_sortMethod.GetValue (s);
if (s.Get () == "ByExecutionTime")
{
g_sortFunc = SortByExecutionTime;
}
else if (s.Get () == "ByPendingEventCount")
{
g_sortFunc = SortByPendingEventCount;
}
else if (s.Get () == "ByEventCount")
{
g_sortFunc = SortByEventCount;
}
else if (s.Get () == "BySimulationTime")
{
g_sortFunc = SortBySimulationTime;
}
UintegerValue ui;
g_sortPeriod.GetValue (ui);
if (ui.Get () == 0)
{
g_period = std::ceil (std::log2 (g_systemCount) / 4 + 1);
NS_LOG_INFO ("Secheduling period is automatically set to " << g_period);
}
else
{
g_period = ui.Get ();
}
// create a thread local storage key
// so that we can access the currently assigned LP of each thread
pthread_key_create (&g_key, nullptr);
pthread_setspecific (g_key, &g_systems[0]);
}
void
MtpInterface::EnableNew (const uint32_t newSystemCount)
{
const LogicalProcess *oldSystems = g_systems;
g_systems = new LogicalProcess[g_systemCount + newSystemCount + 1];
for (uint32_t i = 0; i <= g_systemCount; i++)
{
g_systems[i] = oldSystems[i];
}
delete[] oldSystems;
g_systemCount += newSystemCount;
for (uint32_t i = 0; i <= g_systemCount; i++)
{
g_systems[i].Enable (i, g_systemCount + 1);
}
}
void
MtpInterface::Disable ()
{
g_threadCount = 0;
g_systemCount = 0;
g_sortFunc = nullptr;
g_globalFinished = false;
delete[] g_systems;
delete[] g_threads;
delete[] g_sortedSystemIndices;
}
void
MtpInterface::Run ()
{
RunBefore ();
while (!g_globalFinished)
{
ProcessOneRound ();
CalculateSmallestTime ();
}
RunAfter ();
}
void
MtpInterface::RunBefore ()
{
CalculateLookAhead ();
// LP index for sorting & holding worker threads
g_sortedSystemIndices = new uint32_t[g_systemCount];
for (uint32_t i = 0; i < g_systemCount; i++)
{
g_sortedSystemIndices[i] = i + 1;
}
g_systemIndex.store (g_systemCount, std::memory_order_release);
// start threads
g_threads = new pthread_t[g_threadCount - 1]; // exclude the main thread
for (uint32_t i = 0; i < g_threadCount - 1; i++)
{
pthread_create (&g_threads[i], nullptr, ThreadFunc, nullptr);
}
}
void
MtpInterface::ProcessOneRound ()
{
// assign logical process to threads
// determine the priority of logical processes
if (g_sortFunc != nullptr && g_round++ % g_period == 0)
{
std::sort (g_sortedSystemIndices, g_sortedSystemIndices + g_systemCount, g_sortFunc);
}
// stage 1: process events
g_recvMsgStage = false;
g_finishedSystemCount.store (0, std::memory_order_relaxed);
g_systemIndex.store (0, std::memory_order_release);
// main thread also needs to process an LP to reduce an extra thread overhead
while (true)
{
uint32_t index = g_systemIndex.fetch_add (1, std::memory_order_acquire);
if (index >= g_systemCount)
{
break;
}
LogicalProcess *system = &g_systems[g_sortedSystemIndices[index]];
system->ProcessOneRound ();
g_finishedSystemCount.fetch_add (1, std::memory_order_release);
}
// logical process barriar synchronization
while (g_finishedSystemCount.load (std::memory_order_acquire) != g_systemCount)
;
// stage 2: process the public LP
g_systems[0].ProcessOneRound ();
// stage 3: receive messages
g_recvMsgStage = true;
g_finishedSystemCount.store (0, std::memory_order_relaxed);
g_systemIndex.store (0, std::memory_order_release);
while (true)
{
uint32_t index = g_systemIndex.fetch_add (1, std::memory_order_acquire);
if (index >= g_systemCount)
{
break;
}
LogicalProcess *system = &g_systems[g_sortedSystemIndices[index]];
system->ReceiveMessages ();
g_finishedSystemCount.fetch_add (1, std::memory_order_release);
}
// logical process barriar synchronization
while (g_finishedSystemCount.load (std::memory_order_acquire) != g_systemCount)
;
}
void
MtpInterface::CalculateSmallestTime ()
{
// update smallest time
g_smallestTime = Time::Max () / 2;
for (uint32_t i = 0; i <= g_systemCount; i++)
{
Time nextTime = g_systems[i].Next ();
if (nextTime < g_smallestTime)
{
g_smallestTime = nextTime;
}
}
g_nextPublicTime = g_systems[0].Next ();
// test if global finished
bool globalFinished = true;
for (uint32_t i = 0; i <= g_systemCount; i++)
{
globalFinished &= g_systems[i].isLocalFinished ();
}
g_globalFinished = globalFinished;
}
void
MtpInterface::RunAfter ()
{
// global finished, terminate threads
g_systemIndex.store (0, std::memory_order_release);
for (uint32_t i = 0; i < g_threadCount - 1; i++)
{
pthread_join (g_threads[i], nullptr);
}
}
bool
MtpInterface::isEnabled ()
{
return g_enabled;
}
bool
MtpInterface::isPartitioned ()
{
return g_threadCount != 0;
}
void
MtpInterface::CalculateLookAhead ()
{
for (uint32_t i = 1; i <= g_systemCount; i++)
{
g_systems[i].CalculateLookAhead ();
}
}
void *
MtpInterface::ThreadFunc (void *arg)
{
while (!g_globalFinished)
{
uint32_t index = g_systemIndex.fetch_add (1, std::memory_order_acquire);
if (index >= g_systemCount)
{
while (g_systemIndex.load (std::memory_order_acquire) >= g_systemCount)
;
continue;
}
LogicalProcess *system = &g_systems[g_sortedSystemIndices[index]];
if (g_recvMsgStage)
{
system->ReceiveMessages ();
}
else
{
system->ProcessOneRound ();
}
g_finishedSystemCount.fetch_add (1, std::memory_order_release);
}
return nullptr;
}
bool
MtpInterface::SortByExecutionTime (const uint32_t &i, const uint32_t &j)
{
return g_systems[i].GetExecutionTime () > g_systems[j].GetExecutionTime ();
}
bool
MtpInterface::SortByEventCount (const uint32_t &i, const uint32_t &j)
{
return g_systems[i].GetEventCount () > g_systems[j].GetEventCount ();
}
bool
MtpInterface::SortByPendingEventCount (const uint32_t &i, const uint32_t &j)
{
return g_systems[i].GetPendingEventCount () > g_systems[j].GetPendingEventCount ();
}
bool
MtpInterface::SortBySimulationTime (const uint32_t &i, const uint32_t &j)
{
return g_systems[i].Now () > g_systems[j].Now ();
}
bool (*MtpInterface::g_sortFunc) (const uint32_t &, const uint32_t &) = nullptr;
GlobalValue MtpInterface::g_sortMethod = GlobalValue (
"PartitionSchedulingMethod", "The scheduling method to determine which partition runs first",
StringValue ("ByExecutionTime"), MakeStringChecker ());
GlobalValue MtpInterface::g_sortPeriod =
GlobalValue ("PartitionSchedulingPeriod", "The scheduling period of partitions",
UintegerValue (0), MakeUintegerChecker<uint32_t> (0));
uint32_t MtpInterface::g_period = 0;
pthread_t *MtpInterface::g_threads = nullptr;
LogicalProcess *MtpInterface::g_systems = nullptr;
uint32_t MtpInterface::g_threadCount = 0;
uint32_t MtpInterface::g_systemCount = 0;
uint32_t *MtpInterface::g_sortedSystemIndices = nullptr;
std::atomic<uint32_t> MtpInterface::g_systemIndex;
std::atomic<uint32_t> MtpInterface::g_finishedSystemCount;
uint32_t MtpInterface::g_round = 0;
Time MtpInterface::g_smallestTime = TimeStep (0);
Time MtpInterface::g_nextPublicTime = TimeStep (0);
bool MtpInterface::g_recvMsgStage = false;
bool MtpInterface::g_globalFinished = false;
bool MtpInterface::g_enabled = false;
pthread_key_t MtpInterface::g_key;
std::atomic<bool> MtpInterface::g_inCriticalSection (false);
} // namespace ns3

View File

@@ -0,0 +1,160 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#ifndef MTP_INTERFACE_H
#define MTP_INTERFACE_H
#include "logical-process.h"
#include "ns3/atomic-counter.h"
#include "ns3/global-value.h"
#include "ns3/nstime.h"
#include "ns3/simulator.h"
#include <pthread.h>
namespace ns3 {
class MtpInterface
{
public:
class CriticalSection
{
public:
inline CriticalSection ()
{
while (g_inCriticalSection.exchange (true, std::memory_order_acquire))
;
}
inline ~CriticalSection ()
{
g_inCriticalSection.store (false, std::memory_order_release);
}
};
static void Enable (); // auto topology partition
static void Enable (const uint32_t threadCount); // auto partition, specify thread count
static void Enable (const uint32_t threadCount, const uint32_t systemCount); // manual partition
static void EnableNew (const uint32_t newSystemCount); // add LPs for dynamic added node
static void Disable ();
static void Run ();
static void RunBefore ();
static void ProcessOneRound ();
static void CalculateSmallestTime ();
static void RunAfter ();
static bool isEnabled ();
static bool isPartitioned ();
static void CalculateLookAhead ();
// get current thread's executing logical process
inline static LogicalProcess *
GetSystem ()
{
return static_cast<LogicalProcess *> (pthread_getspecific (g_key));
}
inline static LogicalProcess *
GetSystem (const uint32_t systemId)
{
return &g_systems[systemId];
}
// set current thread's executing logical process
inline static void
SetSystem (const uint32_t systemId)
{
pthread_setspecific (g_key, &g_systems[systemId]);
}
inline static uint32_t
GetSize ()
{
return g_systemCount + 1;
}
inline static uint32_t
GetRound ()
{
return g_round;
}
inline static Time
GetSmallestTime ()
{
return g_smallestTime;
}
inline static void
SetSmallestTime (const Time smallestTime)
{
g_smallestTime = smallestTime;
}
inline static Time
GetNextPublicTime ()
{
return g_nextPublicTime;
}
inline static bool
isFinished ()
{
return g_globalFinished;
}
template <typename FUNC,
typename std::enable_if<!std::is_convertible<FUNC, Ptr<EventImpl>>::value, int>::type,
typename std::enable_if<
!std::is_function<typename std::remove_pointer<FUNC>::type>::value, int>::type,
typename... Ts>
inline static void
ScheduleGlobal (FUNC f, Ts &&...args)
{
CriticalSection cs;
g_systems[0].ScheduleAt (Simulator::NO_CONTEXT, Min (g_smallestTime, g_nextPublicTime),
MakeEvent (f, std::forward<Ts> (args)...));
}
template <typename... Us, typename... Ts>
inline static void
ScheduleGlobal (void (*f) (Us...), Ts &&...args)
{
CriticalSection cs;
g_systems[0].ScheduleAt (Simulator::NO_CONTEXT, Min (g_smallestTime, g_nextPublicTime),
MakeEvent (f, std::forward<Ts> (args)...));
}
private:
static void *ThreadFunc (void *arg);
// determine logical process priority
static bool SortByExecutionTime (const uint32_t &i, const uint32_t &j);
static bool SortByEventCount (const uint32_t &i, const uint32_t &j);
static bool SortByPendingEventCount (const uint32_t &i, const uint32_t &j);
static bool SortBySimulationTime (const uint32_t &i, const uint32_t &j);
static bool (*g_sortFunc) (const uint32_t &, const uint32_t &);
static GlobalValue g_sortMethod;
static GlobalValue g_sortPeriod;
static uint32_t g_period;
static pthread_t *g_threads;
static LogicalProcess *g_systems;
static uint32_t g_threadCount;
static uint32_t g_systemCount;
static uint32_t *g_sortedSystemIndices;
static std::atomic<uint32_t> g_systemIndex;
static std::atomic<uint32_t> g_finishedSystemCount;
static uint32_t g_round;
static Time g_smallestTime;
static Time g_nextPublicTime;
static bool g_recvMsgStage;
static bool g_globalFinished;
static bool g_enabled;
static pthread_key_t g_key;
static std::atomic<bool> g_inCriticalSection;
};
} // namespace ns3
#endif /* MTP_INTERFACE_H */

View File

@@ -0,0 +1,409 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#include "multithreaded-simulator-impl.h"
#include "mtp-interface.h"
#include "ns3/type-id.h"
#include "ns3/channel.h"
#include "ns3/simulator.h"
#include "ns3/node.h"
#include "ns3/node-container.h"
#include "ns3/node-list.h"
#include "ns3/uinteger.h"
#include <algorithm>
#include <queue>
#include <thread>
namespace ns3 {
NS_LOG_COMPONENT_DEFINE ("MultithreadedSimulatorImpl");
NS_OBJECT_ENSURE_REGISTERED (MultithreadedSimulatorImpl);
MultithreadedSimulatorImpl::MultithreadedSimulatorImpl ()
{
NS_LOG_FUNCTION (this);
if (!MtpInterface::isPartitioned ())
{
MtpInterface::Enable (1, 0);
m_partition = true;
}
else
{
m_partition = false;
}
}
MultithreadedSimulatorImpl::~MultithreadedSimulatorImpl ()
{
NS_LOG_FUNCTION (this);
}
TypeId
MultithreadedSimulatorImpl::GetTypeId ()
{
static TypeId tid =
TypeId ("ns3::MultithreadedSimulatorImpl")
.SetParent<SimulatorImpl> ()
.SetGroupName ("Mtp")
.AddConstructor<MultithreadedSimulatorImpl> ()
.AddAttribute ("MaxThreads", "The maximum threads used in simulation",
UintegerValue (std::thread::hardware_concurrency ()),
MakeUintegerAccessor (&MultithreadedSimulatorImpl::m_maxThreads),
MakeUintegerChecker<uint32_t> (1))
.AddAttribute ("MinLookahead", "The minimum lookahead in a partition",
TimeValue (TimeStep (0)),
MakeTimeAccessor (&MultithreadedSimulatorImpl::m_minLookahead),
MakeTimeChecker (TimeStep (0)));
return tid;
}
void
MultithreadedSimulatorImpl::Destroy ()
{
while (!m_destroyEvents.empty ())
{
Ptr<EventImpl> ev = m_destroyEvents.front ().PeekEventImpl ();
m_destroyEvents.pop_front ();
NS_LOG_LOGIC ("handle destroy " << ev);
if (!ev->IsCancelled ())
{
ev->Invoke ();
}
}
MtpInterface::Disable ();
}
bool
MultithreadedSimulatorImpl::IsFinished () const
{
return MtpInterface::isFinished ();
}
void
MultithreadedSimulatorImpl::Stop ()
{
NS_LOG_FUNCTION (this);
for (uint32_t i = 0; i < MtpInterface::GetSize (); i++)
{
MtpInterface::GetSystem (i)->Stop ();
}
}
void
MultithreadedSimulatorImpl::Stop (Time const &delay)
{
NS_LOG_FUNCTION (this << delay.GetTimeStep ());
Simulator::Schedule (delay, &Simulator::Stop);
}
EventId
MultithreadedSimulatorImpl::Schedule (Time const &delay, EventImpl *event)
{
NS_LOG_FUNCTION (this << delay.GetTimeStep () << event);
return MtpInterface::GetSystem ()->Schedule (delay, event);
}
void
MultithreadedSimulatorImpl::ScheduleWithContext (uint32_t context, Time const &delay,
EventImpl *event)
{
NS_LOG_FUNCTION (this << context << delay.GetTimeStep () << event);
LogicalProcess *remote = MtpInterface::GetSystem (NodeList::GetNode (context)->GetSystemId ());
MtpInterface::GetSystem ()->ScheduleWithContext (remote, context, delay, event);
}
EventId
MultithreadedSimulatorImpl::ScheduleNow (EventImpl *event)
{
return Schedule (TimeStep (0), event);
}
EventId
MultithreadedSimulatorImpl::ScheduleDestroy (EventImpl *event)
{
EventId id (Ptr<EventImpl> (event, false), GetMaximumSimulationTime ().GetTimeStep (), 0xffffffff,
EventId::DESTROY);
MtpInterface::CriticalSection cs;
m_destroyEvents.push_back (id);
return id;
}
void
MultithreadedSimulatorImpl::Remove (const EventId &id)
{
if (id.GetUid () == EventId::DESTROY)
{
// destroy events.
for (std::list<EventId>::iterator i = m_destroyEvents.begin (); i != m_destroyEvents.end ();
i++)
{
if (*i == id)
{
m_destroyEvents.erase (i);
break;
}
}
}
else
{
MtpInterface::GetSystem ()->Remove (id);
}
}
void
MultithreadedSimulatorImpl::Cancel (const EventId &id)
{
if (!IsExpired (id))
{
id.PeekEventImpl ()->Cancel ();
}
}
bool
MultithreadedSimulatorImpl::IsExpired (const EventId &id) const
{
if (id.GetUid () == EventId::DESTROY)
{
// destroy events.
if (id.PeekEventImpl () == nullptr || id.PeekEventImpl ()->IsCancelled ())
{
return true;
}
for (std::list<EventId>::const_iterator i = m_destroyEvents.begin ();
i != m_destroyEvents.end (); i++)
{
if (*i == id)
{
return false;
}
}
return true;
}
else
{
return MtpInterface::GetSystem ()->IsExpired (id);
}
}
void
MultithreadedSimulatorImpl::Run ()
{
NS_LOG_FUNCTION (this);
// auto partition
if (m_partition)
{
Partition ();
}
MtpInterface::Run ();
}
Time
MultithreadedSimulatorImpl::Now (void) const
{
// Do not add function logging here, to avoid stack overflow
return MtpInterface::GetSystem ()->Now ();
}
Time
MultithreadedSimulatorImpl::GetDelayLeft (const EventId &id) const
{
if (IsExpired (id))
{
return TimeStep (0);
}
else
{
return MtpInterface::GetSystem ()->GetDelayLeft (id);
}
}
Time
MultithreadedSimulatorImpl::GetMaximumSimulationTime (void) const
{
return Time::Max () / 2;
}
void
MultithreadedSimulatorImpl::SetScheduler (ObjectFactory schedulerFactory)
{
NS_LOG_FUNCTION (this << schedulerFactory);
for (uint32_t i = 0; i < MtpInterface::GetSize (); i++)
{
MtpInterface::GetSystem (i)->SetScheduler (schedulerFactory);
}
m_schedulerTypeId = schedulerFactory.GetTypeId ();
}
uint32_t
MultithreadedSimulatorImpl::GetSystemId (void) const
{
return MtpInterface::GetSystem ()->GetSystemId ();
}
uint32_t
MultithreadedSimulatorImpl::GetContext () const
{
return MtpInterface::GetSystem ()->GetContext ();
}
uint64_t
MultithreadedSimulatorImpl::GetEventCount (void) const
{
uint64_t eventCount = 0;
for (uint32_t i = 0; i < MtpInterface::GetSize (); i++)
{
eventCount += MtpInterface::GetSystem (i)->GetEventCount ();
}
return eventCount;
}
void
MultithreadedSimulatorImpl::DoDispose ()
{
SimulatorImpl::DoDispose ();
}
void
MultithreadedSimulatorImpl::Partition ()
{
NS_LOG_FUNCTION (this);
uint32_t systemId = 0;
const NodeContainer nodes = NodeContainer::GetGlobal ();
bool *visited = new bool[nodes.GetN ()]{false};
std::queue<Ptr<Node>> q;
// if m_minLookahead is not set, calculate the median of delay for every link
if (m_minLookahead == TimeStep (0))
{
std::vector<Time> delays;
for (NodeContainer::Iterator it = nodes.Begin (); it != nodes.End (); it++)
{
Ptr<Node> node = *it;
for (uint32_t i = 0; i < node->GetNDevices (); i++)
{
Ptr<NetDevice> localNetDevice = node->GetDevice (i);
Ptr<Channel> channel = localNetDevice->GetChannel ();
if (channel == 0)
{
continue;
}
// cut-off p2p links for partition
if (localNetDevice->IsPointToPoint ())
{
TimeValue delay;
channel->GetAttribute ("Delay", delay);
delays.push_back (delay.Get ());
}
}
}
std::sort (delays.begin (), delays.end ());
if (delays.size () % 2 == 1)
{
m_minLookahead = delays[delays.size () / 2];
}
else
{
m_minLookahead = (delays[delays.size () / 2 - 1] + delays[delays.size () / 2]) / 2;
}
NS_LOG_INFO ("Min lookahead is set to " << m_minLookahead);
}
// perform a BFS on the whole network topo to assign each node a systemId
for (NodeContainer::Iterator it = nodes.Begin (); it != nodes.End (); it++)
{
Ptr<Node> node = *it;
if (!visited[node->GetId ()])
{
q.push (node);
systemId++;
while (!q.empty ())
{
// pop from BFS queue
node = q.front ();
q.pop ();
visited[node->GetId ()] = true;
// assign this node the current systemId
node->SetSystemId (systemId);
NS_LOG_INFO ("node " << node->GetId () << " is set to system " << systemId);
for (uint32_t i = 0; i < node->GetNDevices (); i++)
{
Ptr<NetDevice> localNetDevice = node->GetDevice (i);
Ptr<Channel> channel = localNetDevice->GetChannel ();
if (channel == 0)
{
continue;
}
// cut-off p2p links for partition
if (localNetDevice->IsPointToPoint ())
{
TimeValue delay;
channel->GetAttribute ("Delay", delay);
// if delay is below threshold, do not cut-off
if (delay.Get () >= m_minLookahead)
{
continue;
}
}
// grab the adjacent nodes
for (uint32_t j = 0; j < channel->GetNDevices (); j++)
{
Ptr<Node> remote = channel->GetDevice (j)->GetNode ();
// if it's not visited, add it to the current partition
if (!visited[remote->GetId ()])
{
q.push (remote);
}
}
}
}
}
}
delete[] visited;
// after the partition, we finally know the system count (# of LPs)
const uint32_t systemCount = systemId;
const uint32_t threadCount = std::min (m_maxThreads, systemCount);
NS_LOG_INFO ("Partition done! " << systemCount << " systems share " << threadCount << " threads");
// create new LPs
const Ptr<Scheduler> events = MtpInterface::GetSystem ()->GetPendingEvents ();
MtpInterface::Disable ();
MtpInterface::Enable (threadCount, systemCount);
// set scheduler
ObjectFactory schedulerFactory;
schedulerFactory.SetTypeId (m_schedulerTypeId);
for (uint32_t i = 0; i <= systemCount; i++)
{
MtpInterface::GetSystem (i)->SetScheduler (schedulerFactory);
}
// transfer events to new LPs
while (!events->IsEmpty ())
{
Scheduler::Event ev = events->RemoveNext ();
// invoke initialization events (at time 0) by their insertion order
// since changing the execution order of these events may cause error,
// they have to be invoked now rather than parallelly executed
if (ev.key.m_ts == 0)
{
MtpInterface::GetSystem (ev.key.m_context == Simulator::NO_CONTEXT
? 0
: NodeList::GetNode (ev.key.m_context)->GetSystemId ())
->InvokeNow (ev);
}
else if (ev.key.m_context == Simulator::NO_CONTEXT)
{
Schedule (TimeStep (ev.key.m_ts), ev.impl);
}
else
{
ScheduleWithContext (ev.key.m_context, TimeStep (ev.key.m_ts), ev.impl);
}
}
}
} // namespace ns3

View File

@@ -0,0 +1,61 @@
/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
#ifndef MULTITHREADED_SIMULATOR_IMPL_H
#define MULTITHREADED_SIMULATOR_IMPL_H
#include "ns3/event-id.h"
#include "ns3/event-impl.h"
#include "ns3/nstime.h"
#include "ns3/object-factory.h"
#include "ns3/simulator-impl.h"
#include <list>
namespace ns3 {
class MultithreadedSimulatorImpl : public SimulatorImpl
{
public:
static TypeId GetTypeId ();
/** Default constructor. */
MultithreadedSimulatorImpl ();
/** Destructor. */
~MultithreadedSimulatorImpl ();
// virtual from SimulatorImpl
virtual void Destroy ();
virtual bool IsFinished () const;
virtual void Stop ();
virtual void Stop (Time const &delay);
virtual EventId Schedule (Time const &delay, EventImpl *event);
virtual void ScheduleWithContext (uint32_t context, Time const &delay, EventImpl *event);
virtual EventId ScheduleNow (EventImpl *event);
virtual EventId ScheduleDestroy (EventImpl *event);
virtual void Remove (const EventId &id);
virtual void Cancel (const EventId &id);
virtual bool IsExpired (const EventId &id) const;
virtual void Run ();
virtual Time Now () const;
virtual Time GetDelayLeft (const EventId &id) const;
virtual Time GetMaximumSimulationTime () const;
virtual void SetScheduler (ObjectFactory schedulerFactory);
virtual uint32_t GetSystemId () const;
virtual uint32_t GetContext () const;
virtual uint64_t GetEventCount () const;
private:
// Inherited from Object
virtual void DoDispose ();
void Partition ();
bool m_partition;
uint32_t m_maxThreads;
Time m_minLookahead;
TypeId m_schedulerTypeId;
std::list<EventId> m_destroyEvents;
};
} // namespace ns3
#endif /* MULTITHREADED_SIMULATOR_IMPL_H */