diff --git a/src/common/buffer.cc b/src/common/buffer.cc index 91b35b9f8..d00141f13 100644 --- a/src/common/buffer.cc +++ b/src/common/buffer.cc @@ -1128,6 +1128,12 @@ Buffer::Iterator::CalculateIpChecksum(uint16_t size, uint32_t initialChecksum) return ~sum; } +uint32_t +Buffer::Iterator::GetSize (void) const +{ + return m_dataEnd - m_dataStart; +} + } // namespace ns3 diff --git a/src/common/buffer.h b/src/common/buffer.h index 245f31ea8..4c65518de 100644 --- a/src/common/buffer.h +++ b/src/common/buffer.h @@ -358,6 +358,11 @@ public: */ uint16_t CalculateIpChecksum(uint16_t size, uint32_t initialChecksum); + /** + * \returns the size of the underlying buffer we are iterating + */ + uint32_t GetSize (void) const; + private: friend class Buffer; Iterator (Buffer const*buffer); diff --git a/src/internet-stack/tcp-header.cc b/src/internet-stack/tcp-header.cc index e23c9e281..93c2428ca 100644 --- a/src/internet-stack/tcp-header.cc +++ b/src/internet-stack/tcp-header.cc @@ -23,6 +23,7 @@ #include "tcp-socket-impl.h" #include "tcp-header.h" #include "ns3/buffer.h" +#include "ns3/address-utils.h" namespace ns3 { @@ -37,8 +38,6 @@ TcpHeader::TcpHeader () m_flags (0), m_windowSize (0xffff), m_urgentPointer (0), - m_initialChecksum(0), - m_checksum (0), m_calcChecksum(false), m_goodChecksum(true) {} @@ -52,11 +51,6 @@ TcpHeader::EnableChecksums (void) m_calcChecksum = true; } -void TcpHeader::SetPayloadSize(uint16_t payloadSize) -{ - m_payloadSize = payloadSize; -} - void TcpHeader::SetSourcePort (uint16_t port) { m_sourcePort = port; @@ -85,10 +79,6 @@ void TcpHeader::SetWindowSize (uint16_t windowSize) { m_windowSize = windowSize; } -void TcpHeader::SetChecksum (uint16_t checksum) -{ - m_checksum = checksum; -} void TcpHeader::SetUrgentPointer (uint16_t urgentPointer) { m_urgentPointer = urgentPointer; @@ -122,10 +112,6 @@ uint16_t TcpHeader::GetWindowSize () const { return m_windowSize; } -uint16_t TcpHeader::GetChecksum () const -{ - return m_checksum; -} uint16_t TcpHeader::GetUrgentPointer () const { return m_urgentPointer; @@ -133,29 +119,31 @@ uint16_t TcpHeader::GetUrgentPointer () const void TcpHeader::InitializeChecksum (Ipv4Address source, - Ipv4Address destination, - uint8_t protocol) + Ipv4Address destination, + uint8_t protocol) { - Buffer buf = Buffer(12); - uint8_t tmp[4]; - Buffer::Iterator it; - uint16_t tcpLength = m_payloadSize + GetSerializedSize(); + m_source = source; + m_destination = destination; + m_protocol = protocol; +} - buf.AddAtStart(12); - it = buf.Begin(); +uint16_t +TcpHeader::CalculateHeaderChecksum (uint16_t size) const +{ + Buffer buf = Buffer (12); + buf.AddAtStart (12); + Buffer::Iterator it = buf.Begin (); - source.Serialize(tmp); - it.Write(tmp, 4); /* source IP address */ - destination.Serialize(tmp); - it.Write(tmp, 4); /* destination IP address */ - it.WriteU8(0); /* protocol */ - it.WriteU8(protocol); /* protocol */ - it.WriteU8(tcpLength >> 8); /* length */ - it.WriteU8(tcpLength & 0xff); /* length */ + WriteTo (it, m_source); + WriteTo (it, m_destination); + it.WriteU8 (0); /* protocol */ + it.WriteU8 (m_protocol); /* protocol */ + it.WriteU8 (size >> 8); /* length */ + it.WriteU8 (size & 0xff); /* length */ - it = buf.Begin(); + it = buf.Begin (); /* we don't CompleteChecksum ( ~ ) now */ - m_initialChecksum = ~(it.CalculateIpChecksum(12)); + return ~(it.CalculateIpChecksum (12)); } bool @@ -219,7 +207,6 @@ uint32_t TcpHeader::GetSerializedSize (void) const void TcpHeader::Serialize (Buffer::Iterator start) const { Buffer::Iterator i = start; - uint16_t tcpLength = m_payloadSize + GetSerializedSize(); i.WriteHtonU16 (m_sourcePort); i.WriteHtonU16 (m_destinationPort); i.WriteHtonU32 (m_sequenceNumber); @@ -231,8 +218,9 @@ void TcpHeader::Serialize (Buffer::Iterator start) const if(m_calcChecksum) { + uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ()); i = start; - uint16_t checksum = i.CalculateIpChecksum(tcpLength, m_initialChecksum); + uint16_t checksum = i.CalculateIpChecksum(start.GetSize (), headerChecksum); i = start; i.Next(16); @@ -250,16 +238,16 @@ uint32_t TcpHeader::Deserialize (Buffer::Iterator start) m_flags = field & 0x3F; m_length = field>>12; m_windowSize = i.ReadNtohU16 (); - m_checksum = i.ReadU16 (); + i.Next (2); m_urgentPointer = i.ReadNtohU16 (); if(m_calcChecksum) - { + { + uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ()); i = start; - uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum); - + uint16_t checksum = i.CalculateIpChecksum(start.GetSize (), headerChecksum); m_goodChecksum = (checksum == 0); - } + } return GetSerializedSize (); } diff --git a/src/internet-stack/tcp-header.h b/src/internet-stack/tcp-header.h index 99d91747a..bfcbc81d8 100644 --- a/src/internet-stack/tcp-header.h +++ b/src/internet-stack/tcp-header.h @@ -77,10 +77,6 @@ public: * \param windowSize the window size for this TcpHeader */ void SetWindowSize (uint16_t windowSize); - /** - * \param checksum the checksum for this TcpHeader - */ - void SetChecksum (uint16_t checksum); /** * \param urgentPointer the urgent pointer for this TcpHeader */ @@ -116,10 +112,6 @@ public: * \return the window size for this TcpHeader */ uint16_t GetWindowSize () const; - /** - * \return the checksum for this TcpHeader - */ - uint16_t GetChecksum () const; /** * \return the urgent pointer for this TcpHeader */ @@ -150,11 +142,6 @@ public: virtual void Serialize (Buffer::Iterator start) const; virtual uint32_t Deserialize (Buffer::Iterator start); - /** - * \param size The payload size in bytes - */ - void SetPayloadSize (uint16_t size); - /** * \brief Is the TCP checksum correct ? * \returns true if the checksum is correct, false otherwise. @@ -162,6 +149,7 @@ public: bool IsChecksumOk (void) const; private: + uint16_t CalculateHeaderChecksum (uint16_t size) const; uint16_t m_sourcePort; uint16_t m_destinationPort; uint32_t m_sequenceNumber; @@ -170,13 +158,14 @@ private: uint8_t m_flags; // really a uint6_t uint16_t m_windowSize; uint16_t m_urgentPointer; - uint16_t m_payloadSize; - uint16_t m_initialChecksum; - uint16_t m_checksum; + Ipv4Address m_source; + Ipv4Address m_destination; + uint8_t m_protocol; + + uint16_t m_initialChecksum; bool m_calcChecksum; bool m_goodChecksum; - }; }; // namespace ns3 diff --git a/src/internet-stack/tcp-l4-protocol.cc b/src/internet-stack/tcp-l4-protocol.cc index f0e9c0b9b..3f385cf5a 100644 --- a/src/internet-stack/tcp-l4-protocol.cc +++ b/src/internet-stack/tcp-l4-protocol.cc @@ -448,6 +448,7 @@ TcpL4Protocol::Receive (Ptr packet, if(m_calcChecksum) { tcpHeader.EnableChecksums(); + tcpHeader.InitializeChecksum (source, destination, PROT_NUMBER); } packet->PeekHeader (tcpHeader); @@ -495,7 +496,6 @@ TcpL4Protocol::Send (Ptr packet, TcpHeader tcpHeader; tcpHeader.SetDestinationPort (dport); tcpHeader.SetSourcePort (sport); - tcpHeader.SetPayloadSize(packet->GetSize()); if(m_calcChecksum) { tcpHeader.EnableChecksums(); @@ -529,7 +529,6 @@ TcpL4Protocol::SendPacket (Ptr packet, TcpHeader outgoingHeader, // XXX outgoingHeader cannot be logged outgoingHeader.SetLength (5); //header length in units of 32bit words - outgoingHeader.SetPayloadSize(packet->GetSize()); /* outgoingHeader.SetUrgentPointer (0); //XXX */ if(m_calcChecksum) { diff --git a/src/internet-stack/udp-header.cc b/src/internet-stack/udp-header.cc index 64a83710f..ce0af0766 100644 --- a/src/internet-stack/udp-header.cc +++ b/src/internet-stack/udp-header.cc @@ -19,6 +19,7 @@ */ #include "udp-header.h" +#include "ns3/address-utils.h" namespace ns3 { @@ -32,8 +33,6 @@ UdpHeader::UdpHeader () : m_sourcePort (0xfffd), m_destinationPort (0xfffd), m_payloadSize (0xfffd), - m_initialChecksum (0), - m_checksum(0), m_calcChecksum(false), m_goodChecksum(true) {} @@ -71,35 +70,31 @@ UdpHeader::GetDestinationPort (void) const return m_destinationPort; } void -UdpHeader::SetPayloadSize (uint16_t size) -{ - m_payloadSize = size; -} -void UdpHeader::InitializeChecksum (Ipv4Address source, Ipv4Address destination, uint8_t protocol) { - Buffer buf = Buffer(12); - uint8_t tmp[4]; - Buffer::Iterator it; - uint16_t udpLength = m_payloadSize + GetSerializedSize(); + m_source = source; + m_destination = destination; + m_protocol = protocol; +} +uint16_t +UdpHeader::CalculateHeaderChecksum (uint16_t size) const +{ + Buffer buf = Buffer (12); + buf.AddAtStart (12); + Buffer::Iterator it = buf.Begin (); - buf.AddAtStart(12); - it = buf.Begin(); + WriteTo (it, m_source); + WriteTo (it, m_destination); + it.WriteU8 (0); /* protocol */ + it.WriteU8 (m_protocol); /* protocol */ + it.WriteU8 (size >> 8); /* length */ + it.WriteU8 (size & 0xff); /* length */ - source.Serialize(tmp); - it.Write(tmp, 4); /* source IP address */ - destination.Serialize(tmp); - it.Write(tmp, 4); /* destination IP address */ - it.WriteU8(0); /* protocol */ - it.WriteU8(protocol); /* protocol */ - it.WriteU8(udpLength >> 8); /* length */ - it.WriteU8(udpLength & 0xff); /* length */ - - it = buf.Begin(); + it = buf.Begin (); /* we don't CompleteChecksum ( ~ ) now */ - m_initialChecksum = ~(it.CalculateIpChecksum(12)); + return ~(it.CalculateIpChecksum (12)); } bool @@ -142,17 +137,17 @@ void UdpHeader::Serialize (Buffer::Iterator start) const { Buffer::Iterator i = start; - uint16_t udpLength = m_payloadSize + GetSerializedSize(); i.WriteHtonU16 (m_sourcePort); i.WriteHtonU16 (m_destinationPort); - i.WriteHtonU16 (udpLength); + i.WriteHtonU16 (start.GetSize ()); i.WriteU16 (0); if (m_calcChecksum) { + uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ()); i = start; - uint16_t checksum = i.CalculateIpChecksum(udpLength, m_initialChecksum); + uint16_t checksum = i.CalculateIpChecksum (start.GetSize (), headerChecksum); i = start; i.Next(6); @@ -166,12 +161,13 @@ UdpHeader::Deserialize (Buffer::Iterator start) m_sourcePort = i.ReadNtohU16 (); m_destinationPort = i.ReadNtohU16 (); m_payloadSize = i.ReadNtohU16 () - GetSerializedSize (); - m_checksum = i.ReadU16(); + i.Next (2); if(m_calcChecksum) { + uint16_t headerChecksum = CalculateHeaderChecksum (start.GetSize ()); i = start; - uint16_t checksum = i.CalculateIpChecksum(m_payloadSize + GetSerializedSize(), m_initialChecksum); + uint16_t checksum = i.CalculateIpChecksum (start.GetSize (), headerChecksum); m_goodChecksum = (checksum == 0); } diff --git a/src/internet-stack/udp-header.h b/src/internet-stack/udp-header.h index d3908e64f..23aaa1450 100644 --- a/src/internet-stack/udp-header.h +++ b/src/internet-stack/udp-header.h @@ -66,10 +66,6 @@ public: * \return the destination port for this UdpHeader */ uint16_t GetDestinationPort (void) const; - /** - * \param size The payload size in bytes - */ - void SetPayloadSize (uint16_t size); /** * \param source the ip source to use in the underlying @@ -100,12 +96,14 @@ public: bool IsChecksumOk (void) const; private: + uint16_t CalculateHeaderChecksum (uint16_t size) const; uint16_t m_sourcePort; uint16_t m_destinationPort; uint16_t m_payloadSize; - uint16_t m_initialChecksum; - uint16_t m_checksum; + Ipv4Address m_source; + Ipv4Address m_destination; + uint8_t m_protocol; bool m_calcChecksum; bool m_goodChecksum; }; diff --git a/src/internet-stack/udp-l4-protocol.cc b/src/internet-stack/udp-l4-protocol.cc index 8def85407..75b6e756d 100644 --- a/src/internet-stack/udp-l4-protocol.cc +++ b/src/internet-stack/udp-l4-protocol.cc @@ -163,7 +163,6 @@ UdpL4Protocol::Receive(Ptr packet, udpHeader.EnableChecksums(); } - udpHeader.SetPayloadSize (packet->GetSize () - udpHeader.GetSerializedSize ()); udpHeader.InitializeChecksum (source, destination, PROT_NUMBER); packet->RemoveHeader (udpHeader); @@ -195,13 +194,12 @@ UdpL4Protocol::Send (Ptr packet, if(m_calcChecksum) { udpHeader.EnableChecksums(); + udpHeader.InitializeChecksum (saddr, + daddr, + PROT_NUMBER); } udpHeader.SetDestinationPort (dport); udpHeader.SetSourcePort (sport); - udpHeader.SetPayloadSize (packet->GetSize ()); - udpHeader.InitializeChecksum (saddr, - daddr, - PROT_NUMBER); packet->AddHeader (udpHeader);