/* ZNetPacketChannel.cpp Author: Patrick Baggett Created: 6/14/2013 Purpose: ** NOT PART OF PUBLIC SDK ** This class is not part of the public SDK; its fields and methods are not present in the documentation and cannot be guaranteed in future revisions. ** NOT PART OF PUBLIC SDK ** Queue of incoming and outgoing packets. License: Copyright 2013, 762 Studios */ /*************************************************************************/ #include #include #include #include #include #include #include "ZNetPrivate.hpp" /*************************************************************************/ bool ZNetPacketChannel::QueueForSending(ZNetPacket* packet, uint32_t command) { printf("QueueForSending(): This sequence = %u\n", this->nextSequenceNumber); //Check for a packet overflow if(packetCount+1 > overflowLimit) { printf("ZNetPacketChannel::QueueForSending(): Packet channel overflow!\n"); return false; } //Since we're keeping this packet in this channel, increase its reference count packet->AddReference(); //Wrap this packet with some metadata ZNetQueuedPacket qp; qp.timeSent = 0; qp.packet = packet; qp.command = command; qp.subsequence = 0; qp.sequence = this->nextSequenceNumber; //Next sequence this->nextSequenceNumber++; packets.PushBack(qp); return true; } /*************************************************************************/ bool ZNetPacketChannel::FillBuffer(uint8_t* buffer, uint32_t bufferSize, uint32_t packetStartIndex, uint32_t* restartIndexReturn, uint32_t* bytesWritten) { ZList::Iterator it = packets.Begin(); uint32_t bufferSpaceRemaining = bufferSize; uint32_t nrWritten = 0; //Skip to appropriate point in the list while(packetStartIndex > 0) { ++it; --packetStartIndex; } ZBinaryBufferWriter writer(buffer, bufferSize, ZNET_BYTEORDER); //Write packets uint64_t now = SST_OS_GetMilliTime(); do { //Nothing to do? if(it == packets.End()) break; ZNetQueuedPacket& qp = it.Get(); uint32_t payloadSpace = bufferSpaceRemaining - ZNET_WIRESIZE_MESSAGE_HEADER; //This is how much payload we can handle. //No space for a payload? if(payloadSpace == 0) break; /* All packets other than DATA are actually quite small, and as such they are not fragmented -- either they fit into the buffer or they don't. The data packets can easily exceed an MTU. Because of that, they are fragmented at the byte level using the subsequence value. This allows us to fully fill MTU- sized packets as often as possible, however, it requires some special handling because fields are inserted into the front that vary in size. */ uint8_t* packetData; uint32_t sendSize = 0; //MSVC 2012 complains about uninitialized //Data packets are somewhat special if(qp.command == ZNETCMD_DATA) { SST_OS_DebugAssert(sendSize > 0, "Data packet has been fully confirmed, but still attempting to send?"); const uint32_t leftToSend = qp.packet->dataSize - qp.subsequence; const uint32_t dataOverhead = 3*ZNetPrivate::PackedIntegerSize(qp.packet->dataSize); //3 fields, see ZNetPrivate.hpp //Not enough space? if(dataOverhead >= payloadSpace) break; //Subtract overhead from payload to get how much payload we can really deliver. payloadSpace -= dataOverhead; //Can we fit all of the remaining data in this packet? if(leftToSend < payloadSpace) sendSize = leftToSend; else sendSize = payloadSpace; //No, so just send as much as we can printf("Sending %u bytes of ZNETCMD_DATA payload\n", sendSize); packetData = (qp.packet->GetData() + qp.subsequence); } else { packetData = qp.packet->GetData(); sendSize = qp.packet->dataSize; printf("Sending %u bytes of regular (non-ZNETCMD_DATA)\n", sendSize); } qp.timeSent = now; SST_OS_DebugAssert(sendSize != 0, "Should not have 0-sized value"); //Write the header and then the data ZNetPrivate::WriteWireMessageHeader(&writer, this->channelId, qp.packet->flags, qp.command, qp.sequence); writer.WriteU8Array(packetData, sendSize); //Record stats and advance to next packet nrWritten += (sendSize + ZNET_WIRESIZE_MESSAGE_HEADER); //i.e. payload + header packetStartIndex++; bufferSpaceRemaining -= (sendSize + ZNET_WIRESIZE_MESSAGE_HEADER); ++it; } while(bufferSpaceRemaining > 0); bool allDone = (it == packets.End()); //If a restart is needed, record where we left off. if(!allDone) *restartIndexReturn = packetStartIndex; //Save the number of bytes written *bytesWritten = nrWritten; return allDone; } /*************************************************************************/ void ZNetPacketChannel::Deinitialize() { //Unreference all packets and remove them from queues for(ZList::Iterator it=packets.Begin(); it.HasCurrent(); it.Next()) it.Get().packet->ReleaseReference(); packets.Clear(); for(ZList::Iterator it=reassembly.Begin(); it.HasCurrent(); it.Next()) it.Get().packet->ReleaseReference(); reassembly.Clear(); for(ZList::Iterator it=assembled.Begin(); it.HasCurrent(); it.Next()) it.Get().packet->ReleaseReference(); assembled.Clear(); } /*************************************************************************/ void ZNetPacketChannel::UpdateRemoteAck(uint16_t newHighest, int32_t* pingAdjust) { printf("ZNetPacketChannel::UpdateRemoteAck(): current ack = %u, newHighest = %u\n", remoteAck, newHighest); //Nothing new if(newHighest == remoteAck) return; uint16_t ackCount; //Wrap around if(newHighest < remoteAck) { /* In a wrap around, we exceed the the fixed point notation of the next sequence number. To handle it, we count from where we are at to the highest possible value, then add the new local number -- this is the number of packets to dequeue. For example, if our window was 100 and we were at 97, then "ack 4" means that 98, 99, 100, 1, 2, 3, 4 were received. We compute it as 100 - 97 + 4 == 7. */ ackCount = UINT16_MAX - remoteAck + newHighest; } else ackCount = newHighest - remoteAck; printf("AckCount = %u\n", ackCount); if(ackCount > 0) { //Freeze time value now uint64_t now = SST_OS_GetMilliTime(); int32_t currentPing = *pingAdjust; while(ackCount > 0 && !packets.Empty()) { ZNetQueuedPacket& qp = packets.Front(); //Does this packet have a valid timeSent? if(qp.timeSent != 0) { //Find out how much the RTT differs from the ping int32_t delta = currentPing - ((int32_t)(now - qp.timeSent)); //Adjust ping by a fraction of RTT to smooth it out. currentPing += delta / 10; } qp.packet->ReleaseReference(); packets.PopFront(); printf("Removed a packet!\n"); ackCount--; this->packetCount--; } //Save updated ping *pingAdjust = currentPing; } //Because malicious clients can send anything, don't explode if this packet suggests //that we remove more than exist. //TODO: maybe log it? probably can help with debugging printf("Possibly wrong sequencing - this tells us to ACK more than we've sent\n"); } /*************************************************************************/ bool ZNetPacketChannel::QueueLocally(const ZNetPacketChannel::ZNetQueuedPacket* toQueue) { uint16_t seq = toQueue->sequence; ZList::Iterator it = assembled.Begin(); while(it.HasCurrent()) { ZNetQueuedPacket& qp = it.Get(); if(ZNetPrivate::SequenceAfter(seq, qp.sequence)) //Incoming packet sequence number is after this packet it.Next(); else if(ZNetPrivate::SequenceAfter(qp.sequence, seq)) //Incoming packet sequence number is before this packet { ZNetQueuedPacket newEntry; newEntry.command = toQueue->command; newEntry.sequence = toQueue->sequence; newEntry.subsequence = toQueue->subsequence; newEntry.packet = toQueue->packet; newEntry.timeSent = toQueue->timeSent; //Add it in assembled.Insert(it, newEntry); break; } else if(qp.sequence == seq) //Overwrite old data { if(qp.packet != NULL) { qp.packet->ReleaseReference(); qp.packet = toQueue->packet; } break; } } //end //Doesn't appear to be valid return false; } /*************************************************************************/ bool ZNetPacketChannel::QueueData(ZNetPrivate::ZNetMessageContainer* data) { ZList::Iterator it = reassembly.Begin(); /* Check for invalid offsets into packet. The second and third checks look redundant, but for extremely large values, summing them can wrap to a smaller value, e.g. 0xFFFFFFFF + 0x0003 will give 0x0002, and crash later when copying / allocating memory. */ if(data->parsed.data.offset + data->parsed.data.length > data->parsed.data.maxSize || data->parsed.data.offset >= data->parsed.data.maxSize || data->parsed.data.length >= data->parsed.data.maxSize) return false; //If whole data logical packet was received if(data->parsed.data.length == data->parsed.data.maxSize) { if(data->parsed.data.offset == 0) { ZNetQueuedPacket qp; ZNetPacket* packet = GetHost()->CreatePacket(data->parsed.data.data, data->parsed.data.length, 0); if(packet == NULL)//Out of memory return false; qp.command = ZNETCMD_DATA; qp.packet = packet; qp.sequence = data->sequence; qp.timeSent = SST_OS_GetMilliTime(); qp.subsequence = 0; //field not used return this->QueueLocally(&qp); } else return false; //invalid packet } //Required reassembly while(it.HasNext()) { ZNetDataReassemblyPacket& reasm = it.Get(); //If the new packet comes after this packet, try next if(ZNetPrivate::SequenceAfter(data->sequence, reasm.sequence)) { it.Next(); continue; } if(reasm.sequence == data->sequence) { SST_OS_DebugAssert(reasm.packet != NULL, "Should not have NULL packet here"); const uint32_t maxSize = reasm.packet->dataSize; //TODO: when done debugging, remove this assert as untrusted input should not cause a crash SST_OS_DebugAssert(maxSize == data->parsed.data.maxSize, "Does not match."); if(maxSize != data->parsed.data.maxSize) return false; const uint32_t begin = data->parsed.data.offset; const uint32_t end = begin + data->parsed.data.length; //Does this packet contain data that can be appended directly? if(begin <= reasm.subsequence && end > reasm.subsequence) { uint8_t* rawbits = reasm.packet->GetData(); //Get raw logical packet /* Copy from the subsequence point to the end point. This can be less than the incoming amount: reasm.subsequence = 3 filled xxxxxxxxx [0][1][2][3][4][5][6][7][8] ^-----^ new packet. offset = 2, length = 3. In this case, we copy only `end - subsequence` == 5 - 3 == 2 bytes. */ memcpy(rawbits+reasm.subsequence, data->parsed.data.data, end - reasm.subsequence); reasm.subsequence = end; //Packet is fully assembled if(end == maxSize) { ZNetQueuedPacket qp; qp.command = ZNETCMD_DATA; qp.packet = reasm.packet; qp.sequence = data->sequence; qp.timeSent = SST_OS_GetMilliTime(); qp.subsequence = 0; //field not used return this->QueueLocally(&qp); } } } else //This packet comes before it.Get() { ZNetDataReassemblyPacket newAssembly; ZNetPacket* packet = GetHost()->CreatePacket(NULL, data->parsed.data.maxSize, 0); if(packet == NULL) return false; //Copy when offset is zero, otherwise, the other side is going to do a resend //anyways. if(data->parsed.data.offset == 0) { memcpy(packet->GetData(), data->parsed.data.data, data->parsed.data.length); newAssembly.subsequence = data->parsed.data.length; } else newAssembly.subsequence = 0; newAssembly.packet = packet; newAssembly.sequence = data->sequence; //Insert before 'it' reassembly.Insert(it, newAssembly); } } return true; } /*************************************************************************/ void ZNetPacketChannel::UpdateLocalAck(uint16_t seq) { sequencesFound.PushBack(seq); } /*************************************************************************/ void ZNetPacketChannel::ProcessLocalAcks() { ZListAlgo::Sort(sequencesFound); ZList::Iterator it = sequencesFound.Begin(); int32_t firstFound = -1; //This is the first value found before wraparounds /* We want to find consecutive values starting at "localAck" and continuing. This takes two passes because there can be a wrap around and the data is sorted ascending, but values lower than "localAck" logically come _after_ it due to the wraparound. In the first pass, we ignore wraparounds (i.e. localAck > seq). We check to see if the next is expected value of localAck+1. If it is, we bump the counter and check the next number. Stop when the list is reached or a discontinuity found. In the second pass, we handle wraparounds, stopping when we reach the first value accepted in pass 1. */ //PASS 1: handling pre-wrap arounds while(it.HasCurrent()) { uint16_t thisSeq = it.Get(); //Skip smaller values as they are wraparound values if(thisSeq > localAck) { if(localAck+1 == thisSeq) //because > operator succeeded, we know that localAck+1 cannot overflow { if(firstFound == -1) firstFound = (int32_t)thisSeq; printf("Acked a packet!\n"); localAck++; } else { //discontinuity, stop here sequencesFound.Clear(); return; } } //Next value it.Next(); } //If we made it here, we didn't hit any discontinuities (or the list was empty) //PASS 2: handling wrap arounds it = sequencesFound.Begin(); while(it.HasCurrent() && it.Get() != (uint16_t)firstFound) { uint16_t thisSeq = it.Get(); if(thisSeq != (uint16_t)firstFound && //stop when we hit the first value we processed in pass 1 thisSeq == localAck+1)//localAck+1 may wraparound now { printf("(WRAP) Acked a packet!\n"); localAck++; } else //discontinuity hit break; } //Now, remove all acks sequencesFound.Clear(); }