diff --git a/src/internet/model/tcp-l4-protocol.cc b/src/internet/model/tcp-l4-protocol.cc index 6a4a06cdc..c1bf07228 100644 --- a/src/internet/model/tcp-l4-protocol.cc +++ b/src/internet/model/tcp-l4-protocol.cc @@ -42,12 +42,13 @@ #include "ns3/log.h" #include "ns3/node.h" #include "ns3/nstime.h" -#include "ns3/object-vector.h" +#include "ns3/object-map.h" #include "ns3/packet.h" #include "ns3/simulator.h" #include #include +#include #include namespace ns3 @@ -72,30 +73,33 @@ const uint8_t TcpL4Protocol::PROT_NUMBER = 6; TypeId TcpL4Protocol::GetTypeId() { - static TypeId tid = TypeId("ns3::TcpL4Protocol") - .SetParent() - .SetGroupName("Internet") - .AddConstructor() - .AddAttribute("RttEstimatorType", - "Type of RttEstimator objects.", - TypeIdValue(RttMeanDeviation::GetTypeId()), - MakeTypeIdAccessor(&TcpL4Protocol::m_rttTypeId), - MakeTypeIdChecker()) - .AddAttribute("SocketType", - "Socket type of TCP objects.", - TypeIdValue(TcpCubic::GetTypeId()), - MakeTypeIdAccessor(&TcpL4Protocol::m_congestionTypeId), - MakeTypeIdChecker()) - .AddAttribute("RecoveryType", - "Recovery type of TCP objects.", - TypeIdValue(TcpPrrRecovery::GetTypeId()), - MakeTypeIdAccessor(&TcpL4Protocol::m_recoveryTypeId), - MakeTypeIdChecker()) - .AddAttribute("SocketList", - "The list of sockets associated to this protocol.", - ObjectVectorValue(), - MakeObjectVectorAccessor(&TcpL4Protocol::m_sockets), - MakeObjectVectorChecker()); + static TypeId tid = + TypeId("ns3::TcpL4Protocol") + .SetParent() + .SetGroupName("Internet") + .AddConstructor() + .AddAttribute("RttEstimatorType", + "Type of RttEstimator objects.", + TypeIdValue(RttMeanDeviation::GetTypeId()), + MakeTypeIdAccessor(&TcpL4Protocol::m_rttTypeId), + MakeTypeIdChecker()) + .AddAttribute("SocketType", + "Socket type of TCP objects.", + TypeIdValue(TcpCubic::GetTypeId()), + MakeTypeIdAccessor(&TcpL4Protocol::m_congestionTypeId), + MakeTypeIdChecker()) + .AddAttribute("RecoveryType", + "Recovery type of TCP objects.", + TypeIdValue(TcpPrrRecovery::GetTypeId()), + MakeTypeIdAccessor(&TcpL4Protocol::m_recoveryTypeId), + MakeTypeIdChecker()) + .AddAttribute("SocketList", + "A container of sockets associated to this protocol. " + "The underlying type is an unordered map, the attribute name " + "is kept for backward compatibility.", + ObjectMapValue(), + MakeObjectMapAccessor(&TcpL4Protocol::m_sockets), + MakeObjectMapChecker()); return tid; } @@ -213,7 +217,7 @@ TcpL4Protocol::CreateSocket(TypeId congestionTypeId, TypeId recoveryTypeId) socket->SetCongestionControlAlgorithm(algo); socket->SetRecoveryAlgorithm(recovery); - m_sockets.push_back(socket); + m_sockets[m_socketIndex++] = socket; return socket; } @@ -747,36 +751,30 @@ void TcpL4Protocol::AddSocket(Ptr socket) { NS_LOG_FUNCTION(this << socket); - std::vector>::iterator it = m_sockets.begin(); - while (it != m_sockets.end()) + for (auto& socketItem : m_sockets) { - if (*it == socket) + if (socketItem.second == socket) { return; } - - ++it; } - - m_sockets.push_back(socket); + m_sockets[m_socketIndex++] = socket; } bool TcpL4Protocol::RemoveSocket(Ptr socket) { NS_LOG_FUNCTION(this << socket); - std::vector>::iterator it = m_sockets.begin(); - while (it != m_sockets.end()) + for (auto& socketItem : m_sockets) { - if (*it == socket) + if (socketItem.second == socket) { - m_sockets.erase(it); + socketItem.second = nullptr; + m_sockets.erase(socketItem.first); return true; } - - ++it; } return false; diff --git a/src/internet/model/tcp-l4-protocol.h b/src/internet/model/tcp-l4-protocol.h index 7dac8a07b..e68cca596 100644 --- a/src/internet/model/tcp-l4-protocol.h +++ b/src/internet/model/tcp-l4-protocol.h @@ -27,6 +27,7 @@ #include "ns3/sequence-number.h" #include +#include namespace ns3 { @@ -334,13 +335,15 @@ class TcpL4Protocol : public IpL4Protocol const Address& incomingDAddr); private: - Ptr m_node; //!< the node this stack is associated with - Ipv4EndPointDemux* m_endPoints; //!< A list of IPv4 end points. - Ipv6EndPointDemux* m_endPoints6; //!< A list of IPv6 end points. - TypeId m_rttTypeId; //!< The RTT Estimator TypeId - TypeId m_congestionTypeId; //!< The socket TypeId - TypeId m_recoveryTypeId; //!< The recovery TypeId - std::vector> m_sockets; //!< list of sockets + Ptr m_node; //!< the node this stack is associated with + Ipv4EndPointDemux* m_endPoints; //!< A list of IPv4 end points. + Ipv6EndPointDemux* m_endPoints6; //!< A list of IPv6 end points. + TypeId m_rttTypeId; //!< The RTT Estimator TypeId + TypeId m_congestionTypeId; //!< The socket TypeId + TypeId m_recoveryTypeId; //!< The recovery TypeId + std::unordered_map> + m_sockets; //!< Unordered map of socket IDs and corresponding sockets + uint64_t m_socketIndex{0}; //!< index of the next socket to be created IpL4Protocol::DownTargetCallback m_downTarget; //!< Callback to send packets over IPv4 IpL4Protocol::DownTargetCallback6 m_downTarget6; //!< Callback to send packets over IPv6 diff --git a/src/internet/model/tcp-socket-base.cc b/src/internet/model/tcp-socket-base.cc index eaffb3928..49cba7a6e 100644 --- a/src/internet/model/tcp-socket-base.cc +++ b/src/internet/model/tcp-socket-base.cc @@ -2875,6 +2875,11 @@ TcpSocketBase::SendRST() void TcpSocketBase::DeallocateEndPoint() { + // note: it shouldn't be necessary to invalidate the callback and manually call + // TcpL4Protocol::RemoveSocket. Alas, if one relies on the endpoint destruction + // callback, there's a weird memory access to a free'd area. Harmless, but valgrind + // considers it an error. + if (m_endPoint != nullptr) { CancelAllTimers(); diff --git a/src/internet/model/udp-l4-protocol.cc b/src/internet/model/udp-l4-protocol.cc index a34de58a8..9442b530c 100644 --- a/src/internet/model/udp-l4-protocol.cc +++ b/src/internet/model/udp-l4-protocol.cc @@ -35,9 +35,11 @@ #include "ns3/boolean.h" #include "ns3/log.h" #include "ns3/node.h" -#include "ns3/object-vector.h" +#include "ns3/object-map.h" #include "ns3/packet.h" +#include + namespace ns3 { @@ -51,15 +53,18 @@ const uint8_t UdpL4Protocol::PROT_NUMBER = 17; TypeId UdpL4Protocol::GetTypeId() { - static TypeId tid = TypeId("ns3::UdpL4Protocol") - .SetParent() - .SetGroupName("Internet") - .AddConstructor() - .AddAttribute("SocketList", - "The list of sockets associated to this protocol.", - ObjectVectorValue(), - MakeObjectVectorAccessor(&UdpL4Protocol::m_sockets), - MakeObjectVectorChecker()); + static TypeId tid = + TypeId("ns3::UdpL4Protocol") + .SetParent() + .SetGroupName("Internet") + .AddConstructor() + .AddAttribute("SocketList", + "A container of sockets associated to this protocol. " + "The underlying type is an unordered map, the attribute name " + "is kept for backward compatibility.", + ObjectMapValue(), + MakeObjectMapAccessor(&UdpL4Protocol::m_sockets), + MakeObjectMapChecker()); return tid; } @@ -133,9 +138,9 @@ void UdpL4Protocol::DoDispose() { NS_LOG_FUNCTION(this); - for (std::vector>::iterator i = m_sockets.begin(); i != m_sockets.end(); i++) + for (auto i = m_sockets.begin(); i != m_sockets.end(); i++) { - *i = nullptr; + i->second = nullptr; } m_sockets.clear(); @@ -165,7 +170,7 @@ UdpL4Protocol::CreateSocket() Ptr socket = CreateObject(); socket->SetNode(m_node); socket->SetUdp(this); - m_sockets.push_back(socket); + m_sockets[m_socketIndex++] = socket; return socket; } @@ -545,4 +550,21 @@ UdpL4Protocol::GetDownTarget6() const return m_downTarget6; } +bool +UdpL4Protocol::RemoveSocket(Ptr socket) +{ + NS_LOG_FUNCTION(this << socket); + + for (auto& socketItem : m_sockets) + { + if (socketItem.second == socket) + { + socketItem.second = nullptr; + m_sockets.erase(socketItem.first); + return true; + } + } + return false; +} + } // namespace ns3 diff --git a/src/internet/model/udp-l4-protocol.h b/src/internet/model/udp-l4-protocol.h index c5e2e04f0..8c3b24da9 100644 --- a/src/internet/model/udp-l4-protocol.h +++ b/src/internet/model/udp-l4-protocol.h @@ -26,6 +26,7 @@ #include "ns3/ptr.h" #include +#include namespace ns3 { @@ -183,6 +184,14 @@ class UdpL4Protocol : public IpL4Protocol */ void DeAllocate(Ipv6EndPoint* endPoint); + /** + * \brief Remove a socket from the internal list + * + * \param socket socket to remove + * \return true if the socket has been removed + */ + bool RemoveSocket(Ptr socket); + // called by UdpSocket. /** * \brief Send a packet via UDP (IPv4) @@ -283,11 +292,13 @@ class UdpL4Protocol : public IpL4Protocol void NotifyNewAggregate() override; private: - Ptr m_node; //!< the node this stack is associated with + Ptr m_node; //!< The node this stack is associated with Ipv4EndPointDemux* m_endPoints; //!< A list of IPv4 end points. Ipv6EndPointDemux* m_endPoints6; //!< A list of IPv6 end points. - std::vector> m_sockets; //!< list of sockets + std::unordered_map> + m_sockets; //!< Unordered map of socket IDs and corresponding sockets + uint64_t m_socketIndex{0}; //!< Index of the next socket to be created IpL4Protocol::DownTargetCallback m_downTarget; //!< Callback to send packets over IPv4 IpL4Protocol::DownTargetCallback6 m_downTarget6; //!< Callback to send packets over IPv6 }; diff --git a/src/internet/model/udp-socket-impl.cc b/src/internet/model/udp-socket-impl.cc index 8137703f9..0fd3c24d3 100644 --- a/src/internet/model/udp-socket-impl.cc +++ b/src/internet/model/udp-socket-impl.cc @@ -176,6 +176,10 @@ void UdpSocketImpl::Destroy() { NS_LOG_FUNCTION(this); + if (m_udp) + { + m_udp->RemoveSocket(this); + } m_endPoint = nullptr; } @@ -183,6 +187,10 @@ void UdpSocketImpl::Destroy6() { NS_LOG_FUNCTION(this); + if (m_udp) + { + m_udp->RemoveSocket(this); + } m_endPoint6 = nullptr; } @@ -192,13 +200,11 @@ UdpSocketImpl::DeallocateEndPoint() { if (m_endPoint != nullptr) { - m_endPoint->SetDestroyCallback(MakeNullCallback()); m_udp->DeAllocate(m_endPoint); m_endPoint = nullptr; } if (m_endPoint6 != nullptr) { - m_endPoint6->SetDestroyCallback(MakeNullCallback()); m_udp->DeAllocate(m_endPoint6); m_endPoint6 = nullptr; }