Files
libsst/ZNet/ZNetPacketChannel.cpp
2026-04-03 00:22:39 -05:00

506 lines
14 KiB
C++

/*
ZNetPacketChannel.cpp
Author: Patrick Baggett <ptbaggett@762studios.com>
Created: 6/14/2013
Purpose:
** NOT PART OF PUBLIC SDK **
This class is not part of the public SDK; its fields and methods are not present
in the documentation and cannot be guaranteed in future revisions.
** NOT PART OF PUBLIC SDK **
Queue of incoming and outgoing packets.
License:
Copyright 2013, 762 Studios
*/
/*************************************************************************/
#include <ZNet/ZNetPacketChannel.hpp>
#include <ZNet/ZNetPacket.hpp>
#include <ZNet/ZNetHost.hpp>
#include <ZUtil/ZBinaryBufferWriter.hpp>
#include <ZSTL/ZListAlgo.hpp>
#include <SST/SST_OS.h>
#include "ZNetPrivate.hpp"
/*************************************************************************/
bool ZNetPacketChannel::QueueForSending(ZNetPacket* packet, uint32_t command)
{
printf("QueueForSending(): This sequence = %u\n", this->nextSequenceNumber);
//Check for a packet overflow
if(packetCount+1 > overflowLimit)
{
printf("ZNetPacketChannel::QueueForSending(): Packet channel overflow!\n");
return false;
}
//Since we're keeping this packet in this channel, increase its reference count
packet->AddReference();
//Wrap this packet with some metadata
ZNetQueuedPacket qp;
qp.timeSent = 0;
qp.packet = packet;
qp.command = command;
qp.subsequence = 0;
qp.sequence = this->nextSequenceNumber;
//Next sequence
this->nextSequenceNumber++;
packets.PushBack(qp);
return true;
}
/*************************************************************************/
bool ZNetPacketChannel::FillBuffer(uint8_t* buffer, uint32_t bufferSize, uint32_t packetStartIndex, uint32_t* restartIndexReturn, uint32_t* bytesWritten)
{
ZList<ZNetQueuedPacket>::Iterator it = packets.Begin();
uint32_t bufferSpaceRemaining = bufferSize;
uint32_t nrWritten = 0;
//Skip to appropriate point in the list
while(packetStartIndex > 0)
{
++it;
--packetStartIndex;
}
ZBinaryBufferWriter writer(buffer, bufferSize, ZNET_BYTEORDER);
//Write packets
uint64_t now = SST_OS_GetMilliTime();
do
{
//Nothing to do?
if(it == packets.End())
break;
ZNetQueuedPacket& qp = it.Get();
uint32_t payloadSpace = bufferSpaceRemaining - ZNET_WIRESIZE_MESSAGE_HEADER; //This is how much payload we can handle.
//No space for a payload?
if(payloadSpace == 0)
break;
/*
All packets other than DATA are actually quite small, and as such they are not fragmented -- either
they fit into the buffer or they don't. The data packets can easily exceed an MTU. Because of that,
they are fragmented at the byte level using the subsequence value. This allows us to fully fill MTU-
sized packets as often as possible, however, it requires some special handling because fields are
inserted into the front that vary in size.
*/
uint8_t* packetData;
uint32_t sendSize = 0; //MSVC 2012 complains about uninitialized
//Data packets are somewhat special
if(qp.command == ZNETCMD_DATA)
{
SST_OS_DebugAssert(sendSize > 0, "Data packet has been fully confirmed, but still attempting to send?");
const uint32_t leftToSend = qp.packet->dataSize - qp.subsequence;
const uint32_t dataOverhead = 3*ZNetPrivate::PackedIntegerSize(qp.packet->dataSize); //3 fields, see ZNetPrivate.hpp
//Not enough space?
if(dataOverhead >= payloadSpace)
break;
//Subtract overhead from payload to get how much payload we can really deliver.
payloadSpace -= dataOverhead;
//Can we fit all of the remaining data in this packet?
if(leftToSend < payloadSpace)
sendSize = leftToSend;
else
sendSize = payloadSpace; //No, so just send as much as we can
printf("Sending %u bytes of ZNETCMD_DATA payload\n", sendSize);
packetData = (qp.packet->GetData() + qp.subsequence);
}
else
{
packetData = qp.packet->GetData();
sendSize = qp.packet->dataSize;
printf("Sending %u bytes of regular (non-ZNETCMD_DATA)\n", sendSize);
}
qp.timeSent = now;
SST_OS_DebugAssert(sendSize != 0, "Should not have 0-sized value");
//Write the header and then the data
ZNetPrivate::WriteWireMessageHeader(&writer, this->channelId, qp.packet->flags, qp.command, qp.sequence);
writer.WriteU8Array(packetData, sendSize);
//Record stats and advance to next packet
nrWritten += (sendSize + ZNET_WIRESIZE_MESSAGE_HEADER); //i.e. payload + header
packetStartIndex++;
bufferSpaceRemaining -= (sendSize + ZNET_WIRESIZE_MESSAGE_HEADER);
++it;
} while(bufferSpaceRemaining > 0);
bool allDone = (it == packets.End());
//If a restart is needed, record where we left off.
if(!allDone)
*restartIndexReturn = packetStartIndex;
//Save the number of bytes written
*bytesWritten = nrWritten;
return allDone;
}
/*************************************************************************/
void ZNetPacketChannel::Deinitialize()
{
//Unreference all packets and remove them from queues
for(ZList<ZNetQueuedPacket>::Iterator it=packets.Begin(); it.HasCurrent(); it.Next())
it.Get().packet->ReleaseReference();
packets.Clear();
for(ZList<ZNetDataReassemblyPacket>::Iterator it=reassembly.Begin(); it.HasCurrent(); it.Next())
it.Get().packet->ReleaseReference();
reassembly.Clear();
for(ZList<ZNetQueuedPacket>::Iterator it=assembled.Begin(); it.HasCurrent(); it.Next())
it.Get().packet->ReleaseReference();
assembled.Clear();
}
/*************************************************************************/
void ZNetPacketChannel::UpdateRemoteAck(uint16_t newHighest, int32_t* pingAdjust)
{
printf("ZNetPacketChannel::UpdateRemoteAck(): current ack = %u, newHighest = %u\n", remoteAck, newHighest);
//Nothing new
if(newHighest == remoteAck)
return;
uint16_t ackCount;
//Wrap around
if(newHighest < remoteAck)
{
/*
In a wrap around, we exceed the the fixed point notation of the next sequence number. To
handle it, we count from where we are at to the highest possible value, then add the new
local number -- this is the number of packets to dequeue. For example, if our window was
100 and we were at 97, then "ack 4" means that 98, 99, 100, 1, 2, 3, 4 were received. We
compute it as 100 - 97 + 4 == 7.
*/
ackCount = UINT16_MAX - remoteAck + newHighest;
}
else
ackCount = newHighest - remoteAck;
printf("AckCount = %u\n", ackCount);
if(ackCount > 0)
{
//Freeze time value now
uint64_t now = SST_OS_GetMilliTime();
int32_t currentPing = *pingAdjust;
while(ackCount > 0 && !packets.Empty())
{
ZNetQueuedPacket& qp = packets.Front();
//Does this packet have a valid timeSent?
if(qp.timeSent != 0)
{
//Find out how much the RTT differs from the ping
int32_t delta = currentPing - ((int32_t)(now - qp.timeSent));
//Adjust ping by a fraction of RTT to smooth it out.
currentPing += delta / 10;
}
qp.packet->ReleaseReference();
packets.PopFront();
printf("Removed a packet!\n");
ackCount--;
this->packetCount--;
}
//Save updated ping
*pingAdjust = currentPing;
}
//Because malicious clients can send anything, don't explode if this packet suggests
//that we remove more than exist.
//TODO: maybe log it? probably can help with debugging
printf("Possibly wrong sequencing - this tells us to ACK more than we've sent\n");
}
/*************************************************************************/
bool ZNetPacketChannel::QueueLocally(const ZNetPacketChannel::ZNetQueuedPacket* toQueue)
{
uint16_t seq = toQueue->sequence;
ZList<ZNetQueuedPacket>::Iterator it = assembled.Begin();
while(it.HasCurrent())
{
ZNetQueuedPacket& qp = it.Get();
if(ZNetPrivate::SequenceAfter(seq, qp.sequence)) //Incoming packet sequence number is after this packet
it.Next();
else if(ZNetPrivate::SequenceAfter(qp.sequence, seq)) //Incoming packet sequence number is before this packet
{
ZNetQueuedPacket newEntry;
newEntry.command = toQueue->command;
newEntry.sequence = toQueue->sequence;
newEntry.subsequence = toQueue->subsequence;
newEntry.packet = toQueue->packet;
newEntry.timeSent = toQueue->timeSent;
//Add it in
assembled.Insert(it, newEntry);
break;
}
else if(qp.sequence == seq) //Overwrite old data
{
if(qp.packet != NULL)
{
qp.packet->ReleaseReference();
qp.packet = toQueue->packet;
}
break;
}
} //end
//Doesn't appear to be valid
return false;
}
/*************************************************************************/
bool ZNetPacketChannel::QueueData(ZNetPrivate::ZNetMessageContainer* data)
{
ZList<ZNetDataReassemblyPacket>::Iterator it = reassembly.Begin();
/*
Check for invalid offsets into packet.
The second and third checks look redundant, but for extremely large
values, summing them can wrap to a smaller value, e.g. 0xFFFFFFFF + 0x0003
will give 0x0002, and crash later when copying / allocating memory.
*/
if(data->parsed.data.offset + data->parsed.data.length > data->parsed.data.maxSize ||
data->parsed.data.offset >= data->parsed.data.maxSize ||
data->parsed.data.length >= data->parsed.data.maxSize)
return false;
//If whole data logical packet was received
if(data->parsed.data.length == data->parsed.data.maxSize)
{
if(data->parsed.data.offset == 0)
{
ZNetQueuedPacket qp;
ZNetPacket* packet = GetHost()->CreatePacket(data->parsed.data.data, data->parsed.data.length, 0);
if(packet == NULL)//Out of memory
return false;
qp.command = ZNETCMD_DATA;
qp.packet = packet;
qp.sequence = data->sequence;
qp.timeSent = SST_OS_GetMilliTime();
qp.subsequence = 0; //field not used
return this->QueueLocally(&qp);
}
else
return false; //invalid packet
}
//Required reassembly
while(it.HasNext())
{
ZNetDataReassemblyPacket& reasm = it.Get();
//If the new packet comes after this packet, try next
if(ZNetPrivate::SequenceAfter(data->sequence, reasm.sequence))
{
it.Next();
continue;
}
if(reasm.sequence == data->sequence)
{
SST_OS_DebugAssert(reasm.packet != NULL, "Should not have NULL packet here");
const uint32_t maxSize = reasm.packet->dataSize;
//TODO: when done debugging, remove this assert as untrusted input should not cause a crash
SST_OS_DebugAssert(maxSize == data->parsed.data.maxSize, "Does not match.");
if(maxSize != data->parsed.data.maxSize)
return false;
const uint32_t begin = data->parsed.data.offset;
const uint32_t end = begin + data->parsed.data.length;
//Does this packet contain data that can be appended directly?
if(begin <= reasm.subsequence && end > reasm.subsequence)
{
uint8_t* rawbits = reasm.packet->GetData(); //Get raw logical packet
/* Copy from the subsequence point to the end point. This can be less than the incoming amount:
reasm.subsequence = 3
filled
xxxxxxxxx
[0][1][2][3][4][5][6][7][8]
^-----^
new packet. offset = 2, length = 3.
In this case, we copy only `end - subsequence` == 5 - 3 == 2 bytes.
*/
memcpy(rawbits+reasm.subsequence, data->parsed.data.data, end - reasm.subsequence);
reasm.subsequence = end;
//Packet is fully assembled
if(end == maxSize)
{
ZNetQueuedPacket qp;
qp.command = ZNETCMD_DATA;
qp.packet = reasm.packet;
qp.sequence = data->sequence;
qp.timeSent = SST_OS_GetMilliTime();
qp.subsequence = 0; //field not used
return this->QueueLocally(&qp);
}
}
}
else //This packet comes before it.Get()
{
ZNetDataReassemblyPacket newAssembly;
ZNetPacket* packet = GetHost()->CreatePacket(NULL, data->parsed.data.maxSize, 0);
if(packet == NULL)
return false;
//Copy when offset is zero, otherwise, the other side is going to do a resend
//anyways.
if(data->parsed.data.offset == 0)
{
memcpy(packet->GetData(), data->parsed.data.data, data->parsed.data.length);
newAssembly.subsequence = data->parsed.data.length;
}
else
newAssembly.subsequence = 0;
newAssembly.packet = packet;
newAssembly.sequence = data->sequence;
//Insert before 'it'
reassembly.Insert(it, newAssembly);
}
}
return true;
}
/*************************************************************************/
void ZNetPacketChannel::UpdateLocalAck(uint16_t seq)
{
sequencesFound.PushBack(seq);
}
/*************************************************************************/
void ZNetPacketChannel::ProcessLocalAcks()
{
ZListAlgo::Sort(sequencesFound);
ZList<uint16_t>::Iterator it = sequencesFound.Begin();
int32_t firstFound = -1; //This is the first value found before wraparounds
/*
We want to find consecutive values starting at "localAck" and continuing. This takes two
passes because there can be a wrap around and the data is sorted ascending, but values
lower than "localAck" logically come _after_ it due to the wraparound.
In the first pass, we ignore wraparounds (i.e. localAck > seq). We check to see
if the next is expected value of localAck+1. If it is, we bump the counter and check the
next number. Stop when the list is reached or a discontinuity found.
In the second pass, we handle wraparounds, stopping when we reach the first value accepted
in pass 1.
*/
//PASS 1: handling pre-wrap arounds
while(it.HasCurrent())
{
uint16_t thisSeq = it.Get();
//Skip smaller values as they are wraparound values
if(thisSeq > localAck)
{
if(localAck+1 == thisSeq) //because > operator succeeded, we know that localAck+1 cannot overflow
{
if(firstFound == -1)
firstFound = (int32_t)thisSeq;
printf("Acked a packet!\n");
localAck++;
}
else
{
//discontinuity, stop here
sequencesFound.Clear();
return;
}
}
//Next value
it.Next();
}
//If we made it here, we didn't hit any discontinuities (or the list was empty)
//PASS 2: handling wrap arounds
it = sequencesFound.Begin();
while(it.HasCurrent() && it.Get() != (uint16_t)firstFound)
{
uint16_t thisSeq = it.Get();
if(thisSeq != (uint16_t)firstFound && //stop when we hit the first value we processed in pass 1
thisSeq == localAck+1)//localAck+1 may wraparound now
{
printf("(WRAP) Acked a packet!\n");
localAck++;
}
else //discontinuity hit
break;
}
//Now, remove all acks
sequencesFound.Clear();
}