506 lines
14 KiB
C++
506 lines
14 KiB
C++
/*
|
|
ZNetPacketChannel.cpp
|
|
Author: Patrick Baggett <ptbaggett@762studios.com>
|
|
Created: 6/14/2013
|
|
|
|
Purpose:
|
|
|
|
** NOT PART OF PUBLIC SDK **
|
|
This class is not part of the public SDK; its fields and methods are not present
|
|
in the documentation and cannot be guaranteed in future revisions.
|
|
** NOT PART OF PUBLIC SDK **
|
|
|
|
Queue of incoming and outgoing packets.
|
|
|
|
License:
|
|
|
|
Copyright 2013, 762 Studios
|
|
*/
|
|
|
|
/*************************************************************************/
|
|
|
|
#include <ZNet/ZNetPacketChannel.hpp>
|
|
#include <ZNet/ZNetPacket.hpp>
|
|
#include <ZNet/ZNetHost.hpp>
|
|
#include <ZUtil/ZBinaryBufferWriter.hpp>
|
|
#include <ZSTL/ZListAlgo.hpp>
|
|
#include <SST/SST_OS.h>
|
|
#include "ZNetPrivate.hpp"
|
|
|
|
/*************************************************************************/
|
|
|
|
bool ZNetPacketChannel::QueueForSending(ZNetPacket* packet, uint32_t command)
|
|
{
|
|
printf("QueueForSending(): This sequence = %u\n", this->nextSequenceNumber);
|
|
|
|
//Check for a packet overflow
|
|
if(packetCount+1 > overflowLimit)
|
|
{
|
|
printf("ZNetPacketChannel::QueueForSending(): Packet channel overflow!\n");
|
|
return false;
|
|
}
|
|
|
|
//Since we're keeping this packet in this channel, increase its reference count
|
|
packet->AddReference();
|
|
|
|
//Wrap this packet with some metadata
|
|
ZNetQueuedPacket qp;
|
|
qp.timeSent = 0;
|
|
qp.packet = packet;
|
|
qp.command = command;
|
|
qp.subsequence = 0;
|
|
qp.sequence = this->nextSequenceNumber;
|
|
|
|
|
|
//Next sequence
|
|
this->nextSequenceNumber++;
|
|
|
|
packets.PushBack(qp);
|
|
|
|
return true;
|
|
}
|
|
|
|
/*************************************************************************/
|
|
|
|
bool ZNetPacketChannel::FillBuffer(uint8_t* buffer, uint32_t bufferSize, uint32_t packetStartIndex, uint32_t* restartIndexReturn, uint32_t* bytesWritten)
|
|
{
|
|
ZList<ZNetQueuedPacket>::Iterator it = packets.Begin();
|
|
uint32_t bufferSpaceRemaining = bufferSize;
|
|
uint32_t nrWritten = 0;
|
|
|
|
//Skip to appropriate point in the list
|
|
while(packetStartIndex > 0)
|
|
{
|
|
++it;
|
|
--packetStartIndex;
|
|
}
|
|
|
|
ZBinaryBufferWriter writer(buffer, bufferSize, ZNET_BYTEORDER);
|
|
|
|
//Write packets
|
|
uint64_t now = SST_OS_GetMilliTime();
|
|
do
|
|
{
|
|
//Nothing to do?
|
|
if(it == packets.End())
|
|
break;
|
|
|
|
ZNetQueuedPacket& qp = it.Get();
|
|
uint32_t payloadSpace = bufferSpaceRemaining - ZNET_WIRESIZE_MESSAGE_HEADER; //This is how much payload we can handle.
|
|
|
|
//No space for a payload?
|
|
if(payloadSpace == 0)
|
|
break;
|
|
|
|
/*
|
|
All packets other than DATA are actually quite small, and as such they are not fragmented -- either
|
|
they fit into the buffer or they don't. The data packets can easily exceed an MTU. Because of that,
|
|
they are fragmented at the byte level using the subsequence value. This allows us to fully fill MTU-
|
|
sized packets as often as possible, however, it requires some special handling because fields are
|
|
inserted into the front that vary in size.
|
|
*/
|
|
|
|
uint8_t* packetData;
|
|
uint32_t sendSize = 0; //MSVC 2012 complains about uninitialized
|
|
|
|
//Data packets are somewhat special
|
|
if(qp.command == ZNETCMD_DATA)
|
|
{
|
|
SST_OS_DebugAssert(sendSize > 0, "Data packet has been fully confirmed, but still attempting to send?");
|
|
const uint32_t leftToSend = qp.packet->dataSize - qp.subsequence;
|
|
const uint32_t dataOverhead = 3*ZNetPrivate::PackedIntegerSize(qp.packet->dataSize); //3 fields, see ZNetPrivate.hpp
|
|
|
|
//Not enough space?
|
|
if(dataOverhead >= payloadSpace)
|
|
break;
|
|
|
|
//Subtract overhead from payload to get how much payload we can really deliver.
|
|
payloadSpace -= dataOverhead;
|
|
|
|
//Can we fit all of the remaining data in this packet?
|
|
if(leftToSend < payloadSpace)
|
|
sendSize = leftToSend;
|
|
else
|
|
sendSize = payloadSpace; //No, so just send as much as we can
|
|
|
|
|
|
printf("Sending %u bytes of ZNETCMD_DATA payload\n", sendSize);
|
|
packetData = (qp.packet->GetData() + qp.subsequence);
|
|
}
|
|
else
|
|
{
|
|
packetData = qp.packet->GetData();
|
|
sendSize = qp.packet->dataSize;
|
|
printf("Sending %u bytes of regular (non-ZNETCMD_DATA)\n", sendSize);
|
|
}
|
|
qp.timeSent = now;
|
|
|
|
SST_OS_DebugAssert(sendSize != 0, "Should not have 0-sized value");
|
|
|
|
//Write the header and then the data
|
|
ZNetPrivate::WriteWireMessageHeader(&writer, this->channelId, qp.packet->flags, qp.command, qp.sequence);
|
|
writer.WriteU8Array(packetData, sendSize);
|
|
|
|
|
|
//Record stats and advance to next packet
|
|
nrWritten += (sendSize + ZNET_WIRESIZE_MESSAGE_HEADER); //i.e. payload + header
|
|
packetStartIndex++;
|
|
bufferSpaceRemaining -= (sendSize + ZNET_WIRESIZE_MESSAGE_HEADER);
|
|
++it;
|
|
} while(bufferSpaceRemaining > 0);
|
|
|
|
bool allDone = (it == packets.End());
|
|
|
|
//If a restart is needed, record where we left off.
|
|
if(!allDone)
|
|
*restartIndexReturn = packetStartIndex;
|
|
|
|
//Save the number of bytes written
|
|
*bytesWritten = nrWritten;
|
|
|
|
return allDone;
|
|
}
|
|
|
|
/*************************************************************************/
|
|
|
|
void ZNetPacketChannel::Deinitialize()
|
|
{
|
|
|
|
//Unreference all packets and remove them from queues
|
|
for(ZList<ZNetQueuedPacket>::Iterator it=packets.Begin(); it.HasCurrent(); it.Next())
|
|
it.Get().packet->ReleaseReference();
|
|
packets.Clear();
|
|
|
|
for(ZList<ZNetDataReassemblyPacket>::Iterator it=reassembly.Begin(); it.HasCurrent(); it.Next())
|
|
it.Get().packet->ReleaseReference();
|
|
reassembly.Clear();
|
|
|
|
for(ZList<ZNetQueuedPacket>::Iterator it=assembled.Begin(); it.HasCurrent(); it.Next())
|
|
it.Get().packet->ReleaseReference();
|
|
assembled.Clear();
|
|
|
|
|
|
|
|
}
|
|
|
|
/*************************************************************************/
|
|
|
|
void ZNetPacketChannel::UpdateRemoteAck(uint16_t newHighest, int32_t* pingAdjust)
|
|
{
|
|
printf("ZNetPacketChannel::UpdateRemoteAck(): current ack = %u, newHighest = %u\n", remoteAck, newHighest);
|
|
//Nothing new
|
|
if(newHighest == remoteAck)
|
|
return;
|
|
|
|
uint16_t ackCount;
|
|
|
|
//Wrap around
|
|
if(newHighest < remoteAck)
|
|
{
|
|
/*
|
|
In a wrap around, we exceed the the fixed point notation of the next sequence number. To
|
|
handle it, we count from where we are at to the highest possible value, then add the new
|
|
local number -- this is the number of packets to dequeue. For example, if our window was
|
|
100 and we were at 97, then "ack 4" means that 98, 99, 100, 1, 2, 3, 4 were received. We
|
|
compute it as 100 - 97 + 4 == 7.
|
|
*/
|
|
ackCount = UINT16_MAX - remoteAck + newHighest;
|
|
}
|
|
else
|
|
ackCount = newHighest - remoteAck;
|
|
|
|
printf("AckCount = %u\n", ackCount);
|
|
if(ackCount > 0)
|
|
{
|
|
//Freeze time value now
|
|
uint64_t now = SST_OS_GetMilliTime();
|
|
int32_t currentPing = *pingAdjust;
|
|
|
|
while(ackCount > 0 && !packets.Empty())
|
|
{
|
|
ZNetQueuedPacket& qp = packets.Front();
|
|
|
|
|
|
//Does this packet have a valid timeSent?
|
|
if(qp.timeSent != 0)
|
|
{
|
|
|
|
//Find out how much the RTT differs from the ping
|
|
int32_t delta = currentPing - ((int32_t)(now - qp.timeSent));
|
|
//Adjust ping by a fraction of RTT to smooth it out.
|
|
currentPing += delta / 10;
|
|
}
|
|
|
|
qp.packet->ReleaseReference();
|
|
packets.PopFront();
|
|
printf("Removed a packet!\n");
|
|
ackCount--;
|
|
this->packetCount--;
|
|
}
|
|
|
|
//Save updated ping
|
|
*pingAdjust = currentPing;
|
|
}
|
|
|
|
//Because malicious clients can send anything, don't explode if this packet suggests
|
|
//that we remove more than exist.
|
|
//TODO: maybe log it? probably can help with debugging
|
|
printf("Possibly wrong sequencing - this tells us to ACK more than we've sent\n");
|
|
}
|
|
|
|
/*************************************************************************/
|
|
|
|
bool ZNetPacketChannel::QueueLocally(const ZNetPacketChannel::ZNetQueuedPacket* toQueue)
|
|
{
|
|
uint16_t seq = toQueue->sequence;
|
|
|
|
ZList<ZNetQueuedPacket>::Iterator it = assembled.Begin();
|
|
|
|
while(it.HasCurrent())
|
|
{
|
|
ZNetQueuedPacket& qp = it.Get();
|
|
|
|
if(ZNetPrivate::SequenceAfter(seq, qp.sequence)) //Incoming packet sequence number is after this packet
|
|
it.Next();
|
|
else if(ZNetPrivate::SequenceAfter(qp.sequence, seq)) //Incoming packet sequence number is before this packet
|
|
{
|
|
ZNetQueuedPacket newEntry;
|
|
|
|
newEntry.command = toQueue->command;
|
|
newEntry.sequence = toQueue->sequence;
|
|
newEntry.subsequence = toQueue->subsequence;
|
|
newEntry.packet = toQueue->packet;
|
|
newEntry.timeSent = toQueue->timeSent;
|
|
|
|
//Add it in
|
|
assembled.Insert(it, newEntry);
|
|
break;
|
|
}
|
|
else if(qp.sequence == seq) //Overwrite old data
|
|
{
|
|
if(qp.packet != NULL)
|
|
{
|
|
qp.packet->ReleaseReference();
|
|
qp.packet = toQueue->packet;
|
|
}
|
|
break;
|
|
}
|
|
} //end
|
|
|
|
|
|
|
|
//Doesn't appear to be valid
|
|
return false;
|
|
}
|
|
|
|
/*************************************************************************/
|
|
|
|
bool ZNetPacketChannel::QueueData(ZNetPrivate::ZNetMessageContainer* data)
|
|
{
|
|
ZList<ZNetDataReassemblyPacket>::Iterator it = reassembly.Begin();
|
|
|
|
/*
|
|
Check for invalid offsets into packet.
|
|
|
|
The second and third checks look redundant, but for extremely large
|
|
values, summing them can wrap to a smaller value, e.g. 0xFFFFFFFF + 0x0003
|
|
will give 0x0002, and crash later when copying / allocating memory.
|
|
*/
|
|
if(data->parsed.data.offset + data->parsed.data.length > data->parsed.data.maxSize ||
|
|
data->parsed.data.offset >= data->parsed.data.maxSize ||
|
|
data->parsed.data.length >= data->parsed.data.maxSize)
|
|
return false;
|
|
|
|
//If whole data logical packet was received
|
|
if(data->parsed.data.length == data->parsed.data.maxSize)
|
|
{
|
|
if(data->parsed.data.offset == 0)
|
|
{
|
|
ZNetQueuedPacket qp;
|
|
ZNetPacket* packet = GetHost()->CreatePacket(data->parsed.data.data, data->parsed.data.length, 0);
|
|
|
|
if(packet == NULL)//Out of memory
|
|
return false;
|
|
|
|
qp.command = ZNETCMD_DATA;
|
|
qp.packet = packet;
|
|
qp.sequence = data->sequence;
|
|
qp.timeSent = SST_OS_GetMilliTime();
|
|
qp.subsequence = 0; //field not used
|
|
return this->QueueLocally(&qp);
|
|
}
|
|
else
|
|
return false; //invalid packet
|
|
|
|
}
|
|
|
|
//Required reassembly
|
|
while(it.HasNext())
|
|
{
|
|
ZNetDataReassemblyPacket& reasm = it.Get();
|
|
|
|
//If the new packet comes after this packet, try next
|
|
if(ZNetPrivate::SequenceAfter(data->sequence, reasm.sequence))
|
|
{
|
|
it.Next();
|
|
continue;
|
|
}
|
|
|
|
|
|
if(reasm.sequence == data->sequence)
|
|
{
|
|
SST_OS_DebugAssert(reasm.packet != NULL, "Should not have NULL packet here");
|
|
const uint32_t maxSize = reasm.packet->dataSize;
|
|
|
|
//TODO: when done debugging, remove this assert as untrusted input should not cause a crash
|
|
SST_OS_DebugAssert(maxSize == data->parsed.data.maxSize, "Does not match.");
|
|
if(maxSize != data->parsed.data.maxSize)
|
|
return false;
|
|
|
|
const uint32_t begin = data->parsed.data.offset;
|
|
const uint32_t end = begin + data->parsed.data.length;
|
|
|
|
//Does this packet contain data that can be appended directly?
|
|
if(begin <= reasm.subsequence && end > reasm.subsequence)
|
|
{
|
|
uint8_t* rawbits = reasm.packet->GetData(); //Get raw logical packet
|
|
|
|
/* Copy from the subsequence point to the end point. This can be less than the incoming amount:
|
|
|
|
reasm.subsequence = 3
|
|
|
|
filled
|
|
xxxxxxxxx
|
|
[0][1][2][3][4][5][6][7][8]
|
|
^-----^
|
|
new packet. offset = 2, length = 3.
|
|
|
|
In this case, we copy only `end - subsequence` == 5 - 3 == 2 bytes.
|
|
*/
|
|
memcpy(rawbits+reasm.subsequence, data->parsed.data.data, end - reasm.subsequence);
|
|
reasm.subsequence = end;
|
|
|
|
//Packet is fully assembled
|
|
if(end == maxSize)
|
|
{
|
|
ZNetQueuedPacket qp;
|
|
qp.command = ZNETCMD_DATA;
|
|
qp.packet = reasm.packet;
|
|
qp.sequence = data->sequence;
|
|
qp.timeSent = SST_OS_GetMilliTime();
|
|
qp.subsequence = 0; //field not used
|
|
return this->QueueLocally(&qp);
|
|
}
|
|
}
|
|
|
|
|
|
}
|
|
else //This packet comes before it.Get()
|
|
{
|
|
ZNetDataReassemblyPacket newAssembly;
|
|
|
|
ZNetPacket* packet = GetHost()->CreatePacket(NULL, data->parsed.data.maxSize, 0);
|
|
if(packet == NULL)
|
|
return false;
|
|
|
|
//Copy when offset is zero, otherwise, the other side is going to do a resend
|
|
//anyways.
|
|
if(data->parsed.data.offset == 0)
|
|
{
|
|
memcpy(packet->GetData(), data->parsed.data.data, data->parsed.data.length);
|
|
newAssembly.subsequence = data->parsed.data.length;
|
|
}
|
|
else
|
|
newAssembly.subsequence = 0;
|
|
newAssembly.packet = packet;
|
|
newAssembly.sequence = data->sequence;
|
|
|
|
//Insert before 'it'
|
|
reassembly.Insert(it, newAssembly);
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/*************************************************************************/
|
|
|
|
void ZNetPacketChannel::UpdateLocalAck(uint16_t seq)
|
|
{
|
|
|
|
sequencesFound.PushBack(seq);
|
|
}
|
|
|
|
/*************************************************************************/
|
|
|
|
void ZNetPacketChannel::ProcessLocalAcks()
|
|
{
|
|
ZListAlgo::Sort(sequencesFound);
|
|
|
|
ZList<uint16_t>::Iterator it = sequencesFound.Begin();
|
|
int32_t firstFound = -1; //This is the first value found before wraparounds
|
|
|
|
/*
|
|
We want to find consecutive values starting at "localAck" and continuing. This takes two
|
|
passes because there can be a wrap around and the data is sorted ascending, but values
|
|
lower than "localAck" logically come _after_ it due to the wraparound.
|
|
|
|
In the first pass, we ignore wraparounds (i.e. localAck > seq). We check to see
|
|
if the next is expected value of localAck+1. If it is, we bump the counter and check the
|
|
next number. Stop when the list is reached or a discontinuity found.
|
|
|
|
In the second pass, we handle wraparounds, stopping when we reach the first value accepted
|
|
in pass 1.
|
|
*/
|
|
|
|
//PASS 1: handling pre-wrap arounds
|
|
while(it.HasCurrent())
|
|
{
|
|
uint16_t thisSeq = it.Get();
|
|
|
|
//Skip smaller values as they are wraparound values
|
|
if(thisSeq > localAck)
|
|
{
|
|
if(localAck+1 == thisSeq) //because > operator succeeded, we know that localAck+1 cannot overflow
|
|
{
|
|
if(firstFound == -1)
|
|
firstFound = (int32_t)thisSeq;
|
|
|
|
printf("Acked a packet!\n");
|
|
localAck++;
|
|
}
|
|
else
|
|
{
|
|
//discontinuity, stop here
|
|
sequencesFound.Clear();
|
|
return;
|
|
}
|
|
|
|
}
|
|
|
|
//Next value
|
|
it.Next();
|
|
}
|
|
|
|
//If we made it here, we didn't hit any discontinuities (or the list was empty)
|
|
|
|
//PASS 2: handling wrap arounds
|
|
it = sequencesFound.Begin();
|
|
while(it.HasCurrent() && it.Get() != (uint16_t)firstFound)
|
|
{
|
|
uint16_t thisSeq = it.Get();
|
|
|
|
if(thisSeq != (uint16_t)firstFound && //stop when we hit the first value we processed in pass 1
|
|
thisSeq == localAck+1)//localAck+1 may wraparound now
|
|
{
|
|
printf("(WRAP) Acked a packet!\n");
|
|
localAck++;
|
|
}
|
|
else //discontinuity hit
|
|
break;
|
|
|
|
}
|
|
|
|
//Now, remove all acks
|
|
sequencesFound.Clear();
|
|
} |