diff --git a/src/wifi/model/he/constant-obss-pd-algorithm.cc b/src/wifi/model/he/constant-obss-pd-algorithm.cc index a26a5b01c..e3bd8f150 100644 --- a/src/wifi/model/he/constant-obss-pd-algorithm.cc +++ b/src/wifi/model/he/constant-obss-pd-algorithm.cc @@ -20,10 +20,10 @@ #include "constant-obss-pd-algorithm.h" #include "he-configuration.h" -#include "he-phy.h" #include "ns3/config.h" #include "ns3/double.h" +#include "ns3/eht-phy.h" #include "ns3/log.h" #include "ns3/node.h" #include "ns3/sta-wifi-mac.h" @@ -56,7 +56,13 @@ ConstantObssPdAlgorithm::GetTypeId() void ConstantObssPdAlgorithm::ConnectWifiNetDevice(const Ptr device) { - Ptr phy = device->GetPhy(); + auto phy = device->GetPhy(); + if (phy->GetStandard() >= WIFI_STANDARD_80211be) + { + auto ehtPhy = DynamicCast(device->GetPhy()->GetPhyEntity(WIFI_MOD_CLASS_EHT)); + NS_ASSERT(ehtPhy); + ehtPhy->SetEndOfHeSigACallback(MakeCallback(&ConstantObssPdAlgorithm::ReceiveHeSigA, this)); + } auto hePhy = DynamicCast(phy->GetPhyEntity(WIFI_MOD_CLASS_HE)); NS_ASSERT(hePhy); hePhy->SetEndOfHeSigACallback(MakeCallback(&ConstantObssPdAlgorithm::ReceiveHeSigA, this)); diff --git a/src/wifi/model/he/obss-pd-algorithm.cc b/src/wifi/model/he/obss-pd-algorithm.cc index dd6e997e3..9b88a85df 100644 --- a/src/wifi/model/he/obss-pd-algorithm.cc +++ b/src/wifi/model/he/obss-pd-algorithm.cc @@ -19,9 +19,8 @@ #include "obss-pd-algorithm.h" -#include "he-phy.h" - #include "ns3/double.h" +#include "ns3/eht-phy.h" #include "ns3/log.h" #include "ns3/wifi-net-device.h" #include "ns3/wifi-phy.h" @@ -85,6 +84,13 @@ ObssPdAlgorithm::ConnectWifiNetDevice(const Ptr device) { NS_LOG_FUNCTION(this << device); m_device = device; + auto phy = device->GetPhy(); + if (phy->GetStandard() >= WIFI_STANDARD_80211be) + { + auto ehtPhy = DynamicCast(device->GetPhy()->GetPhyEntity(WIFI_MOD_CLASS_EHT)); + NS_ASSERT(ehtPhy); + ehtPhy->SetObssPdAlgorithm(this); + } auto hePhy = DynamicCast(device->GetPhy()->GetPhyEntity(WIFI_MOD_CLASS_HE)); NS_ASSERT(hePhy); hePhy->SetObssPdAlgorithm(this);