isoping: Perform a handshake on new client connections.

Since isoping communicates over UDP, there's no built-in protections against an
attacker sending us spoofed packets.  Before allocating any memory, reply with a
cookie based on the client's IP address and port, and don't start sending data
until we see that cookie come back from the client.

We also have to resend handshake packets from the client, to avoid getting stuck
if one of the initial packets gets lost.

Change-Id: Icca7d341757279e49d55aa715fbd3f971b615e9f
diff --git a/cmds/Makefile b/cmds/Makefile
index 4bcca65..2697abc 100644
--- a/cmds/Makefile
+++ b/cmds/Makefile
@@ -184,11 +184,11 @@
 	echo "Building .pb.cc"
 	$(HOST_PROTOC) --cpp_out=. $<
 
-host-isoping isoping: LIBS+=$(RT) -lm -lstdc++
+host-isoping isoping: LIBS+=$(RT) -lm -lstdc++ -lcrypto
 host-isoping: host-isoping.o host-isoping_main.o
 host-isoping_test.o: CXXFLAGS += -D WVTEST_CONFIGURED -I ../wvtest/cpp
 host-isoping_test.o: isoping.cc
-host-isoping_test: LIBS+=$(HOST_LIBS) -lm -lstdc++
+host-isoping_test: LIBS+=$(HOST_LIBS) -lm -lstdc++ -lcrypto
 host-isoping_test: host-isoping_test.o host-isoping.o host-wvtestmain.o host-wvtest.o
 host-isostream isostream: LIBS+=$(RT)
 host-diskbench diskbench: LIBS+=-lpthread $(RT)
diff --git a/cmds/isoping.cc b/cmds/isoping.cc
index 05fffd6..d7b9d1c 100644
--- a/cmds/isoping.cc
+++ b/cmds/isoping.cc
@@ -27,6 +27,7 @@
 #include "isoping.h"
 
 #include <arpa/inet.h>
+#include <assert.h>
 #include <errno.h>
 #include <math.h>
 #include <memory.h>
@@ -65,9 +66,13 @@
 #define DIV(x, y) ((y) ? ((double)(x)/(y)) : 0)
 #define _STR(n) #n
 #define STR(n) _STR(n)
+#ifdef DEBUG
+#define DLOG(args...) fprintf(stderr, args)
+#else
+#define DLOG(args...)
+#endif
 
 // Global flag values.
-int is_server = 1;
 int quiet = 0;
 int ttl = DEFAULT_TTL;
 int want_timestamps = 0;
@@ -81,11 +86,47 @@
   want_to_die = 1;
 }
 
+// Render the given sockaddr as a string.  (Uses a static internal buffer
+// which is overwritten each time.)
+static const char *sockaddr_to_str(struct sockaddr *sa) {
+  static char addrbuf[128];
+  void *aptr;
+
+  switch (sa->sa_family) {
+  case AF_INET:
+    aptr = &((struct sockaddr_in *)sa)->sin_addr;
+    break;
+  case AF_INET6:
+    aptr = &((struct sockaddr_in6 *)sa)->sin6_addr;
+    break;
+  default:
+    return "unknown";
+  }
+
+  if (!inet_ntop(sa->sa_family, aptr, addrbuf, sizeof(addrbuf))) {
+    perror("inet_ntop");
+    exit(98);
+  }
+  return addrbuf;
+}
+
+static void debug_print_hex(unsigned char *data, size_t data_len) {
+  for (size_t i = 0; i < data_len; i++) {
+    DLOG("%02x", data[i]);
+    if (i % 8 == 7) {
+      DLOG(" ");
+    }
+  }
+  DLOG("\n");
+}
+
 Session::Session(uint32_t first_send, uint32_t usec_per_pkt,
                  const struct sockaddr_storage &raddr, size_t raddr_len)
     : usec_per_pkt(usec_per_pkt),
       usec_per_print(prints_per_sec > 0 ? 1e6 / prints_per_sec : 0),
       remoteaddr_len(raddr_len),
+      handshake_state(NEW_SESSION),
+      handshake_retry_count(0),
       next_tx_id(1),
       next_rx_id(0),
       next_rxack_id(0),
@@ -105,6 +146,7 @@
   memcpy(&remoteaddr, &raddr, raddr_len);
   memset(&tx, 0, sizeof(tx));
   strcpy(last_ackinfo, "");
+  DLOG("Handshake state: NEW_SESSION\n");
 }
 
 SessionMap::iterator Sessions::NewSession(uint32_t first_send,
@@ -117,6 +159,103 @@
   return p.first;
 }
 
+bool Sessions::CalculateCookie(Packet *p, struct sockaddr_storage *remoteaddr,
+                               size_t remoteaddr_len) {
+  return CalculateCookieWithSecret(p, remoteaddr, remoteaddr_len,
+                                   cookie_secret, sizeof(cookie_secret));
+}
+
+bool Sessions::CalculateCookieWithSecret(Packet *p,
+                                         struct sockaddr_storage *remoteaddr,
+                                         size_t remoteaddr_len,
+                                         unsigned char *secret,
+                                         size_t secret_len) {
+  if (p->packet_type != PACKET_TYPE_HANDSHAKE) {
+    fprintf(stderr, "Tried to create cookie for a non-handshake packet\n");
+    return false;
+  }
+  if (!EVP_DigestInit_ex(&digest_context, md, NULL)) {
+    fprintf(stderr, "Unable to initialize hash digest\n");
+    return false;
+  }
+
+  // Hash the data
+  EVP_DigestUpdate(&digest_context, secret, secret_len);
+  EVP_DigestUpdate(&digest_context, &p->usec_per_pkt, sizeof(p->usec_per_pkt));
+  EVP_DigestUpdate(&digest_context, remoteaddr, remoteaddr_len);
+
+  unsigned int digest_size = 0;
+  EVP_DigestFinal_ex(&digest_context, p->data.handshake.cookie, &digest_size);
+  if (digest_size != COOKIE_SIZE) {
+    fprintf(stderr, "Invalid digest size %d for cookie; expected %d\n",
+            digest_size, COOKIE_SIZE);
+    return false;
+  }
+  p->data.handshake.cookie_epoch = cookie_epoch;
+  return true;
+}
+
+bool Sessions::ValidateCookie(Packet *p, struct sockaddr_storage *addr,
+                              socklen_t addr_len) {
+  if (p->data.handshake.cookie_epoch != cookie_epoch &&
+      p->data.handshake.cookie_epoch != prev_cookie_epoch) {
+    fprintf(stderr, "Obsolete cookie epoch: %d\n",
+            p->data.handshake.cookie_epoch);
+    return false;
+  }
+  Packet golden;
+  golden.packet_type = PACKET_TYPE_HANDSHAKE;
+  golden.usec_per_pkt = p->usec_per_pkt;
+  if (p->data.handshake.cookie_epoch == cookie_epoch) {
+    CalculateCookieWithSecret(&golden, addr, addr_len, cookie_secret,
+                              sizeof(cookie_secret));
+  } else {
+    CalculateCookieWithSecret(&golden, addr, addr_len, prev_cookie_secret,
+                              sizeof(prev_cookie_secret));
+  }
+  DLOG("Handshake: cookie epoch=%d, cookie=0x",
+       p->data.handshake.cookie_epoch);
+  debug_print_hex(p->data.handshake.cookie, sizeof(p->data.handshake.cookie));
+  DLOG("Expected handshake: cookie epoch=%d, cookie=0x",
+       golden.data.handshake.cookie_epoch);
+  debug_print_hex(golden.data.handshake.cookie,
+                  sizeof(golden.data.handshake.cookie));
+  if (memcmp(golden.data.handshake.cookie, p->data.handshake.cookie,
+             COOKIE_SIZE)) {
+    fprintf(stderr, "Invalid cookie in handshake packet from %s\n",
+            sockaddr_to_str((struct sockaddr *)addr));
+    return false;
+  }
+  return true;
+}
+
+void Sessions::MaybeRotateCookieSecrets() {
+  // Round off the unix timestamp to 64 seconds as an epoch, so we don't have to
+  // track which ones we've already used.
+  uint32_t new_epoch = time(NULL) >> 6;
+  if (new_epoch != cookie_epoch) {
+    RotateCookieSecrets(new_epoch);
+  }
+}
+
+void Sessions::RotateCookieSecrets(uint32_t new_epoch) {
+  prev_cookie_epoch = cookie_epoch;
+  memcpy(&prev_cookie_secret[0], &cookie_secret[0],
+         sizeof(prev_cookie_secret));
+  cookie_epoch = new_epoch;
+  NewRandomCookieSecret();
+}
+
+void Sessions::NewRandomCookieSecret() {
+  uint64_t random;
+  for (size_t i = 0; i < sizeof(cookie_secret); i += sizeof(random)) {
+    random = rng();
+    memcpy(&cookie_secret[i], &random,
+           std::min(sizeof(random), sizeof(cookie_secret) - i));
+  }
+  DLOG("Generated new cookie secret.\n");
+}
+
 // Returns the kernel monotonic timestamp in microseconds, truncated to
 // 32 bits.  That will wrap around every ~4000 seconds, which is okay
 // for our purposes.  We use 32 bits to save space in our packets.
@@ -176,30 +315,6 @@
 }
 
 
-// Render the given sockaddr as a string.  (Uses a static internal buffer
-// which is overwritten each time.)
-static const char *sockaddr_to_str(struct sockaddr *sa) {
-  static char addrbuf[128];
-  void *aptr;
-
-  switch (sa->sa_family) {
-  case AF_INET:
-    aptr = &((struct sockaddr_in *)sa)->sin_addr;
-    break;
-  case AF_INET6:
-    aptr = &((struct sockaddr_in6 *)sa)->sin6_addr;
-    break;
-  default:
-    return "unknown";
-  }
-
-  if (!inet_ntop(sa->sa_family, aptr, addrbuf, sizeof(addrbuf))) {
-    perror("inet_ntop");
-    exit(98);
-  }
-  return addrbuf;
-}
-
 bool CompareSockaddr::operator()(const struct sockaddr_storage &lhs,
                                  const struct sockaddr_storage &rhs) {
   if (lhs.ss_family != rhs.ss_family) {
@@ -254,6 +369,26 @@
   return sqrt(DIV(numer, denom));
 }
 
+static void debug_print_packet(Packet *p) {
+  DLOG("Packet contents: magic=0x%x id=%d usec_per_pkt=%d txtime=%u "
+       "clockdiff=%d num_lost=%d first_ack=%d type=%d\n",
+       ntohl(p->magic), ntohl(p->id), ntohl(p->usec_per_pkt),
+       ntohl(p->txtime), ntohl(p->clockdiff), ntohl(p->num_lost),
+       p->first_ack, p->packet_type);
+  if (p->packet_type == PACKET_TYPE_HANDSHAKE) {
+    DLOG("cookie epoch=%u, cookie=0x", p->data.handshake.cookie_epoch);
+    debug_print_hex(p->data.handshake.cookie, sizeof(p->data.handshake.cookie));
+  } else {
+    DLOG("Acks:\n");
+    for (uint32_t i = 0; i < ARRAY_LEN(p->data.acks); i++) {
+      uint32_t acki = (p->first_ack + i) % ARRAY_LEN(p->data.acks);
+      uint32_t ackid = ntohl(p->data.acks[acki].id);
+      if (!ackid) continue;  // empty slot
+      DLOG(" acki=%d id=%d rxtime=%u\n",
+           acki, ackid, ntohl(p->data.acks[acki].rxtime));
+    }
+  }
+}
 
 void prepare_tx_packet(struct Session *s) {
   s->tx.magic = htonl(MAGIC);
@@ -263,33 +398,89 @@
   s->tx.clockdiff = s->start_rtxtime ?
       htonl(s->start_rxtime - s->start_rtxtime) : 0;
   s->tx.num_lost = htonl(s->num_lost);
-  s->tx.first_ack = htonl(s->next_txack_index);
+  s->tx.first_ack = s->next_txack_index;
+  switch (s->handshake_state) {
+    case Session::NEW_SESSION:
+    case Session::HANDSHAKE_REQUESTED:
+    case Session::COOKIE_GENERATED:
+      DLOG("prepare_tx_packet: Sending handshake packet\n");
+      s->tx.packet_type = PACKET_TYPE_HANDSHAKE;
+      break;
+    case Session::ESTABLISHED:
+      s->tx.packet_type = PACKET_TYPE_ACK;
+      break;
+    default:
+      fprintf(stderr, "Unknown handshake state %d\n", s->handshake_state);
+      exit(2);
+  }
+  // note: tx.data.acks[] is filled in incrementally; we just transmit the
+  // current state of it here.  The reason we keep a list of the most recent
+  // acks is in case our packet gets lost, so the receiver will have more
+  // chances to receive the timing information for the packets it sent us.
+  debug_print_packet(&s->tx);
 }
 
 
-static int send_packet(struct Session *s, int sock) {
-  // note: tx.acks[] is filled in incrementally; we just transmit the current
-  // state of it here.  The reason we keep a list of the most recent acks is in
-  // case our packet gets lost, so the receiver will have more chances to
-  // receive the timing information for the packets it sent us.
+void prepare_handshake_reply_packet(Packet *tx, Packet *rx, uint32_t now) {
+  memset(tx, 0, sizeof(*tx));
+  tx->magic = htonl(MAGIC);
+  tx->id = rx->id;
+  // TODO(pmccurdy): Establish limits on the allowed usec_per_pkt values here
+  tx->usec_per_pkt = rx->usec_per_pkt;
+  tx->txtime = now;
+  tx->clockdiff = htonl(now - ntohl(rx->txtime));
+  tx->num_lost = htonl(0);
+  tx->packet_type = PACKET_TYPE_HANDSHAKE;
+}
+
+
+int send_packet(struct Session *s, int sock, int is_server) {
   if (is_server) {
     if (sendto(sock, &s->tx, sizeof(s->tx), 0,
                (struct sockaddr *)&s->remoteaddr, s->remoteaddr_len) < 0) {
       perror("sendto");
     }
   } else {
+    DLOG("Calling send on socket %d, size=%ld\n", sock, sizeof(s->tx));
     if (send(sock, &s->tx, sizeof(s->tx), 0) < 0) {
       int e = errno;
       perror("send");
       if (e == ECONNREFUSED) return 2;
     }
   }
-  s->next_send += s->usec_per_pkt;
+  if (is_server ||
+      s->handshake_state == Session::ESTABLISHED ||
+      s->handshake_state == Session::COOKIE_GENERATED) {
+    DLOG("send_packet: ack packet, next_send in %d (from %d to %d)\n",
+         s->usec_per_pkt, s->next_send, s->next_send + s->usec_per_pkt);
+    s->next_send += s->usec_per_pkt;
+  } else {
+    // Handle resending handshake packets from the client.  If they get lost
+    // before we get a valid cookie from the server, the server won't know about
+    // us, and our normal retry procedure would get us out of sync.
+    if (s->handshake_state == Session::NEW_SESSION) {
+      DLOG("Handshake state: sending handshake packet, moving to "
+           "HANDSHAKE_REQUESTED\n");
+      s->handshake_state = Session::HANDSHAKE_REQUESTED;
+      s->handshake_retry_count = 0;
+    } else {
+      s->handshake_retry_count++;
+    }
+    // Limit the backoff to a factor of 2^10.
+    uint32_t timeout = Session::handshake_timeout_usec *
+                       (1 << std::min(10, s->handshake_retry_count));
+    DLOG("Sending handshake, retries=%d, next_send in %d us (from %u to %u)\n",
+         s->handshake_retry_count, timeout, s->next_send,
+         s->next_send + timeout);
+    s->next_send += timeout;
+    // Don't count the handshake packet as part of the sequence.
+    s->next_tx_id--;
+  }
   return 0;
 }
 
-
-int send_waiting_packets(Sessions *sessions, int sock, uint32_t now) {
+int send_waiting_packets(Sessions *sessions, int sock, uint32_t now,
+                         int is_server) {
   if (sessions == NULL) {
     return -1;
   }
@@ -304,7 +495,7 @@
     sessions->next_sends.pop();
     Session &s = it->second;
     prepare_tx_packet(&s);
-    int err = send_packet(&s, sock);
+    int err = send_packet(&s, sock, is_server);
     if (err != 0) {
       return err;
     }
@@ -313,6 +504,9 @@
     // instead of waiting for timeout.
     // TODO(pmccurdy): Support very low packet-per-second values, e.g. one
     // packet per hour, without constantly disconnecting the client.
+    // TODO(pmccurdy): Instead of a fixed timeout, evict clients if they miss
+    // a certain number of expected transmissions, that number changing based on
+    // the number of packets they've already sent.
     if (is_server && DIFF(now, s.last_rxtime) > 60 * 1000 * 1000) {
       fprintf(stderr, "client %s disconnected.\n",
               sockaddr_to_str((struct sockaddr *)&s.remoteaddr));
@@ -326,10 +520,9 @@
 
 int read_incoming_packet(Sessions *s, int sock, uint32_t now, int is_server) {
   struct sockaddr_storage rxaddr;
-  socklen_t rxaddr_len = 0;
+  socklen_t rxaddr_len = sizeof(rxaddr);
 
   Packet rx;
-  rxaddr_len = sizeof(rxaddr);
   ssize_t got = recvfrom(sock, &rx, sizeof(rx), 0,
                          (struct sockaddr *)&rxaddr, &rxaddr_len);
   if (got < 0) {
@@ -338,40 +531,171 @@
     return e;
   }
   if (got != sizeof(rx) || rx.magic != htonl(MAGIC)) {
-    fprintf(stderr, "got invalid packet of length %ld\n", (long)got);
+    fprintf(stderr, "got invalid packet of length %ld, magic=%d from %s\n",
+            (long)got, ntohl(rx.magic),
+            sockaddr_to_str((struct sockaddr *)&rxaddr));
     return EINVAL;
   }
+  switch (rx.packet_type) {
+    case PACKET_TYPE_HANDSHAKE:
+    case PACKET_TYPE_ACK:
+      break;
+    default:
+      fprintf(stderr, "received unknown packet type %d\n", rx.packet_type);
+      return EINVAL;
+  }
 
-  SessionMap::iterator it;
+  Session *session = NULL;
   if (is_server) {
-    it = s->session_map.find(rxaddr);
-    if (it == s->session_map.end()) {
-      fprintf(stderr, "New client connection: %s\n",
-              sockaddr_to_str((struct sockaddr *)&rxaddr));
-      // TODO(pmccurdy):  This lets clients unconditionally set the usec_per_pkt
-      // values used.  Add some mechanism to let the server override or reject
-      // this value (e.g. limit to a certain range, or reduce per-client pps as
-      // more clients connect).
-      it = s->NewSession(now + 10 * 1000, ntohl(rx.usec_per_pkt), &rxaddr,
-                         rxaddr_len);
+    SessionMap::iterator it = s->session_map.find(rxaddr);
+    if (it != s->session_map.end()) {
+      session = &it->second;
+    } else {
+      // Note: we don't want to allocate any memory here until the client has
+      // completed the handshake.
+      if (rx.packet_type != PACKET_TYPE_HANDSHAKE) {
+        fprintf(stderr, "Received non-handshake packet from unknown client\n");
+        // TODO(pmccurdy): Reply with a new handshake packet, including a
+        // cookie; we may have dropped a legit client and we need to tell them
+        // to renegotiate.
+        return -1;
+      }
     }
   } else {
-    it = s->session_map.begin();
+    SessionMap::iterator it = s->session_map.begin();
     if (it == s->session_map.end()) {
       fprintf(stderr, "No session configured for %s when receiving packet\n",
               sockaddr_to_str((struct sockaddr *)&rxaddr));
       return EINVAL;
     }
+    DLOG("read_incoming_packet: Client received %s packet from server\n",
+         rx.packet_type == PACKET_TYPE_ACK ? "ack" : "handshake");
+    session = &it->second;
   }
-  Session &session = it->second;
-  memcpy(&session.rx, &rx, sizeof(session.rx));
-  handle_packet(&session, now);
+  handle_packet(s, session, &rx, sock, &rxaddr, rxaddr_len, now, is_server);
 
   return 0;
 }
 
+// Checks what kind of packet we've received, and processes it appropriately.
+// Session may be null if we're dealing with a handshake packet for a new
+// connection.
+void handle_packet(struct Sessions *s, struct Session *session, Packet *rx,
+                   int sock, struct sockaddr_storage *rxaddr,
+                   socklen_t rxaddr_len, uint32_t now, int is_server) {
+  switch (rx->packet_type) {
+    case PACKET_TYPE_HANDSHAKE:
+      if (is_server) {
+        handle_new_client_handshake_packet(s, rx, sock, rxaddr, rxaddr_len,
+                                           now);
+        return;
+      } else {
+        DLOG("Client received handshake packet from server\n");
+        handle_server_handshake_packet(s, rx, now);
+        return;
+      }
+      break;
+    case PACKET_TYPE_ACK:
+      if (session != NULL) {
+        memcpy(&session->rx, rx, sizeof(session->rx));
+        if (!is_server &&
+            session->handshake_state == Session::COOKIE_GENERATED) {
+          // Now we know the server has accepted our connection.  Clear out the
+          // handshake data from the send buffer and prepare to track acks.
+          DLOG("Ack from server on new connection; moving to state "
+               "ESTABLISHED.");
+          session->handshake_state = Session::ESTABLISHED;
+          memset(&session->tx.data.acks, 0, sizeof(session->tx.data.acks));
+        }
+      }
+      handle_ack_packet(session, now);
+      break;
+    default:
+      fprintf(stderr, "handle_packet called for unknown packet type %d\n",
+              rx->packet_type);
+      break;
+  }
+}
 
-void handle_packet(struct Session *s, uint32_t now) {
+void handle_new_client_handshake_packet(Sessions *s, Packet *rx, int sock,
+                                       struct sockaddr_storage *remoteaddr,
+                                       size_t remoteaddr_len, uint32_t now) {
+  assert(s != NULL);
+  assert(rx != NULL);
+  assert(remoteaddr != NULL);
+  DLOG("Server received handshake packet from client; cookie epoch=%u\n",
+       rx->data.handshake.cookie_epoch);
+  if (rx->data.handshake.cookie_epoch == 0) {
+    // New connection with no cookie.  Return a cookie to validate the client.
+    s->session_map.erase(*remoteaddr);
+    fprintf(stderr, "New connection from %s, sending cookie\n",
+            sockaddr_to_str((struct sockaddr *)remoteaddr));
+    Packet tx;
+    memset(&tx, 0, sizeof(tx));
+    prepare_handshake_reply_packet(&tx, rx, now);
+    s->CalculateCookie(&tx, remoteaddr, remoteaddr_len);
+    sendto(sock, &tx, sizeof(tx), 0, (struct sockaddr *)remoteaddr,
+           remoteaddr_len);
+    // The handshake_state is conceptually in the COOKIE_GENERATED state now,
+    // but the whole point of the cookie is to avoid saving state in the server,
+    // so we don't store a Session here.
+  } else {
+    // Cookie provided, validate it to accept or reject the connection.
+    if (!s->ValidateCookie(rx, remoteaddr, remoteaddr_len)) {
+      return;
+    }
+    fprintf(stderr, "New client connection: %s\n",
+            sockaddr_to_str((struct sockaddr *)remoteaddr));
+    SessionMap::iterator it = s->NewSession(
+        now + 10 * 1000, ntohl(rx->usec_per_pkt), remoteaddr, remoteaddr_len);
+    Session &session = it->second;
+    session.handshake_state = Session::ESTABLISHED;
+    memcpy(&session.rx, rx, sizeof(session.rx));
+    // This is a new session we haven't sent any timing packets on, so the
+    // client can't possibly have acknowledged any packets.  Replace the
+    // handshake data with a set of empty acks and process as normal.
+    session.rx.packet_type = PACKET_TYPE_ACK;
+    memset(&session.rx.data.acks, 0, sizeof(session.rx.data.acks));
+    assert(sizeof(session.rx.data.acks) > sizeof(void *));
+    handle_ack_packet(&session, now);
+  }
+}
+
+void handle_server_handshake_packet(Sessions *s, Packet *rx, uint32_t now) {
+  assert(s != NULL);
+  assert(rx != NULL);
+  assert(s->session_map.size() == 1);
+  assert(s->next_sends.size() == 1);
+
+  SessionMap::iterator it = s->session_map.begin();
+  Session &session = it->second;
+  // We don't need to resend the handshake packet any more.
+  s->next_sends.pop();
+
+  session.tx.packet_type = PACKET_TYPE_HANDSHAKE;
+  session.tx.data.handshake.cookie_epoch = rx->data.handshake.cookie_epoch;
+  memcpy(&session.tx.data.handshake.cookie, &rx->data.handshake.cookie,
+         COOKIE_SIZE);
+  int usec_per_pkt = ntohl(rx->usec_per_pkt);
+  if (usec_per_pkt != session.usec_per_pkt) {
+    fprintf(stderr, "Server overrode packets per second to %f\n",
+            1000000.0 / usec_per_pkt);
+    session.usec_per_pkt = usec_per_pkt;
+  }
+  DLOG("Handshake state: client received cookie from server, moving to "
+       "COOKIE_GENERATED; next_send=%d (was %d)\n",
+       now, session.next_send);
+  DLOG("Handshake: cookie epoch=%d, cookie=0x",
+       rx->data.handshake.cookie_epoch);
+  debug_print_hex(rx->data.handshake.cookie, sizeof(rx->data.handshake.cookie));
+  session.handshake_state = Session::COOKIE_GENERATED;
+  session.next_send = now;
+  s->next_sends.push(it);
+}
+
+void handle_ack_packet(struct Session *s, uint32_t now) {
+  assert(s != NULL);
+  assert(s->rx.packet_type == PACKET_TYPE_ACK);
   // process the incoming packet header.
   // Most of the complexity here comes from the fact that the remote
   // system's clock will be skewed vs. ours.  (We use CLOCK_MONOTONIC
@@ -434,6 +758,8 @@
     s->start_rxtime = rxtime - id * s->usec_per_pkt;
   }
   int32_t rxdiff = DIFF(rxtime, s->start_rxtime + id * s->usec_per_pkt);
+  DLOG("ack: rxdiff=%d, rxtime=%u, start_rxtime=%u, id=%d, usec_per_pkt=%d\n",
+       rxdiff, rxtime, s->start_rxtime, id, s->usec_per_pkt);
 
   // Figure out the offset between our clock and the remote's clock, so we can
   // calculate the minimum round trip time (rtt). Then, because the consecutive
@@ -486,6 +812,8 @@
     s->lat_rx_sum += s->lat_rx;
     s->lat_rx_var_sum += s->lat_rx * s->lat_rx;
   }
+  DLOG("ack packet: rx id=%d, clockdiff=%d, rtt=%d, offset=%d, rxdiff=%d\n",
+       id, clockdiff, rtt, offset, rxdiff);
 
   // Note: the way ok_to_print is structured, if there is a dropout in the
   // connection for more than usec_per_print, we will statistically end up
@@ -519,24 +847,27 @@
   }
 
   // schedule this for an ack next time we send the packet
-  s->tx.acks[s->next_txack_index].id = htonl(id);
-  s->tx.acks[s->next_txack_index].rxtime = htonl(rxtime);
-  s->next_txack_index = (s->next_txack_index + 1) % ARRAY_LEN(s->tx.acks);
+  s->tx.data.acks[s->next_txack_index].id = htonl(id);
+  s->tx.data.acks[s->next_txack_index].rxtime = htonl(rxtime);
+  s->next_txack_index = (s->next_txack_index + 1) % ARRAY_LEN(s->tx.data.acks);
 
   // see which of our own transmitted packets have been acked
-  uint32_t first_ack = ntohl(s->rx.first_ack);
-  for (uint32_t i = 0; i < ARRAY_LEN(s->rx.acks); i++) {
-    uint32_t acki = (first_ack + i) % ARRAY_LEN(s->rx.acks);
-    uint32_t ackid = ntohl(s->rx.acks[acki].id);
+  uint32_t first_ack = s->rx.first_ack;
+  for (uint32_t i = 0; i < ARRAY_LEN(s->rx.data.acks); i++) {
+    uint32_t acki = (first_ack + i) % ARRAY_LEN(s->rx.data.acks);
+    uint32_t ackid = ntohl(s->rx.data.acks[acki].id);
     if (!ackid) continue;  // empty slot
     if (DIFF(ackid, s->next_rxack_id) >= 0) {
       // an expected ack
       uint32_t start_txtime = s->next_send - s->next_tx_id * s->usec_per_pkt;
       uint32_t txtime = start_txtime + ackid * s->usec_per_pkt;
-      uint32_t rrxtime = ntohl(s->rx.acks[acki].rxtime);
+      uint32_t rrxtime = ntohl(s->rx.data.acks[acki].rxtime);
       uint32_t rxtime = rrxtime + offset;
       // note: already contains 1/2 rtt, unlike rxdiff
       int32_t txdiff = DIFF(rxtime, txtime);
+      DLOG("acki=%d ackid=%d txdiff=%d rxtime=%u txtime=%u offset=%u, "
+           "start_txtime=%u\n",
+           acki, ackid, txdiff, rxtime, txtime, offset, start_txtime);
       if (!quiet && s->usec_per_print <= 0 && s->last_ackinfo[0]) {
         // only print multiple acks per rx if no usec_per_print limit
         if (want_timestamps) print_timestamp(rxtime);
@@ -560,7 +891,6 @@
   s->last_rxtime = rxtime;
 }
 
-
 int isoping_main(int argc, char **argv) {
   struct sockaddr_in6 listenaddr;
   struct addrinfo *ai = NULL;
@@ -613,6 +943,7 @@
     return 1;
   }
 
+  int is_server;
   uint32_t now = ustime();       // current time
   Sessions sessions;
 
@@ -681,6 +1012,8 @@
   act.sa_flags = SA_RESETHAND;
   sigaction(SIGINT, &act, NULL);
 
+  uint32_t last_secret_update_time = 0;
+
   while (!want_to_die) {
     fd_set rfds;
     FD_ZERO(&rfds);
@@ -703,7 +1036,13 @@
       return 1;
     }
 
-    int err = send_waiting_packets(&sessions, sock, now);
+    // Periodically check if the cookie secrets need updating.
+    if (is_server && (now - last_secret_update_time) > 1000000) {
+      sessions.MaybeRotateCookieSecrets();
+      last_secret_update_time = now;
+    }
+
+    int err = send_waiting_packets(&sessions, sock, now, is_server);
     if (err != 0) {
       return err;
     }
diff --git a/cmds/isoping.h b/cmds/isoping.h
index 7100a51..9c2ee85 100644
--- a/cmds/isoping.h
+++ b/cmds/isoping.h
@@ -18,10 +18,23 @@
 
 #include <map>
 #include <netinet/in.h>
+#include <openssl/evp.h>
 #include <queue>
+#include <random>
 #include <stdint.h>
+#include <string.h>
 #include <sys/socket.h>
 
+// Number of bytes required to store the cookie, which is a SHA-256 hash.
+#define COOKIE_SIZE 32
+// Number of bytes used to store the random cookie secret.
+#define COOKIE_SECRET_SIZE 16
+
+enum {
+  PACKET_TYPE_ACK = 0,
+  PACKET_TYPE_HANDSHAKE,
+};
+
 // Layout of the UDP packets exchanged between client and server.
 // All integers are in network byte order.
 // Packets have exactly the same structure in both directions.
@@ -32,12 +45,22 @@
   uint32_t clockdiff; // estimate of (transmitter's clk) - (receiver's clk)
   uint32_t usec_per_pkt; // microseconds of delay between packets
   uint32_t num_lost;  // number of pkts transmitter expected to get but didn't
-  uint32_t first_ack; // starting index in acks[] circular buffer
-  struct {
-    // txtime==0 for empty elements in this array.
-    uint32_t id;      // id field from a received packet
-    uint32_t rxtime;  // receiver's monotonic time when pkt arrived
-  } acks[64];
+  uint8_t packet_type; // 0 for acks, 1 for handshake packet
+  uint8_t first_ack;  // starting index in acks[] circular buffer
+  union {
+    // Data used for handshake packets.
+    struct {
+      uint32_t version; // max version of the isoping protocol supported
+      uint32_t cookie_epoch; // which cookie we're using
+      unsigned char cookie[COOKIE_SIZE]; // actual cookie value
+    } handshake;
+    // Data used for ack packets.
+    struct {
+      // txtime==0 for empty elements in this array.
+      uint32_t id;      // id field from a received packet
+      uint32_t rxtime;  // receiver's monotonic time when pkt arrived
+    } acks[64];
+  } data;
 };
 
 
@@ -52,6 +75,15 @@
   struct sockaddr_storage remoteaddr;
   socklen_t remoteaddr_len;
 
+  enum {
+    NEW_SESSION = 0,     // No packets exchanged yet.
+    HANDSHAKE_REQUESTED, // Client has sent initial packet to server, i.e. SYN.
+    COOKIE_GENERATED,    // Server has replied with cookie, i.e. SYN|ACK.
+    ESTABLISHED          // Client has echoed cookie back, i.e. ACK.
+  } handshake_state;
+  int handshake_retry_count;
+  static const int handshake_timeout_usec = 1000000;
+
   // WARNING: lots of math below relies on well-defined uint32/int32
   // arithmetic overflow behaviour, plus the fact that when we subtract
   // two successive timestamps (for example) they will be less than 2^31
@@ -98,14 +130,34 @@
 
 struct Sessions {
  public:
-  Sessions() {}
+  Sessions()
+      : md(EVP_sha256()),
+        rng(std::random_device()()),
+        cookie_epoch(0) {
+    NewRandomCookieSecret();
+    EVP_MD_CTX_init(&digest_context);
+  }
 
-  // All active sessions, indexed by remote address/port.
-  SessionMap session_map;
-  // A queue of upcoming send times, ordered most recent first, referencing
-  // entries in the session map.
-  std::priority_queue<SessionMap::iterator, std::vector<SessionMap::iterator>,
-      CompareNextSend> next_sends;
+  ~Sessions() {
+    EVP_MD_CTX_cleanup(&digest_context);
+  }
+
+  // Rotates the cookie secrets if they haven't been changed in a while.
+  void MaybeRotateCookieSecrets();
+
+  // Rotate the cookie secrets using the given epoch directly.  Only for use in
+  // unit tests.
+  void RotateCookieSecrets(uint32_t new_epoch);
+
+  // Calculates a handshake cookie based on the provided client IP address and
+  // the relevant parameters in p, using the current cookie secret, and places
+  // the result in p.
+  bool CalculateCookie(Packet *p, struct sockaddr_storage *remoteaddr,
+                       size_t remoteaddr_len);
+
+  // Returns true if the packet contains a handshake packet with a valid cookie.
+  bool ValidateCookie(Packet *p, struct sockaddr_storage *addr,
+                      socklen_t addr_len);
 
   SessionMap::iterator NewSession(uint32_t first_send,
                                   uint32_t usec_per_pkt,
@@ -118,17 +170,58 @@
     }
     return next_sends.top()->second.next_send;
   }
+
+  // All active sessions, indexed by remote address/port.
+  SessionMap session_map;
+  // A queue of upcoming send times, ordered most recent first, referencing
+  // entries in the session map.
+  std::priority_queue<SessionMap::iterator, std::vector<SessionMap::iterator>,
+      CompareNextSend> next_sends;
+
+ private:
+  void NewRandomCookieSecret();
+  bool CalculateCookieWithSecret(Packet *p, struct sockaddr_storage *remoteaddr,
+                                 size_t remoteaddr_len, unsigned char *secret,
+                                 size_t secret_len);
+
+  // Fields required for calculating and verifying cookies.
+  EVP_MD_CTX digest_context;
+  const EVP_MD *md;
+  std::mt19937_64 rng;
+  uint32_t cookie_epoch;
+  unsigned char cookie_secret[COOKIE_SECRET_SIZE];
+  uint32_t prev_cookie_epoch;
+  unsigned char prev_cookie_secret[COOKIE_SECRET_SIZE];
 };
 
-// Process the Session's incoming packet, from s->rx.
-void handle_packet(struct Session *s, uint32_t now);
+// Process an incoming packet from the socket.
+void handle_packet(struct Sessions *s, struct Session *session, Packet *rx,
+                   int sock, struct sockaddr_storage *rxaddr,
+                   socklen_t rxaddr_len, uint32_t now, int is_server);
+
+// Process an established Session's incoming ack packet, from s->rx.
+void handle_ack_packet(struct Session *s, uint32_t now);
+
+// Server-only: processes a handshake packet from a new client in rx. Replies
+// with a cookie if no cookie provided, or validates the provided cookie and
+// establishes a new Session.
+void handle_new_client_handshake_packet(Sessions *s, Packet *rx, int sock,
+                                       struct sockaddr_storage *remoteaddr,
+                                       size_t remoteaddr_len, uint32_t now);
+
+// Client-only: processes a handshake packet received from the server.
+// Configures the Session to echo the provided cookie back to the server.
+void handle_server_handshake_packet(Sessions *s, Packet *rx, uint32_t now);
 
 // Sets all the elements of s->tx to be ready to be sent to the other side.
 void prepare_tx_packet(struct Session *s);
 
 // Sends a packet to all waiting sessions where the appropriate amount of time
 // has passed.
-int send_waiting_packets(Sessions *s, int sock, uint32_t now);
+int send_waiting_packets(Sessions *s, int sock, uint32_t now, int is_server);
+
+// Sends a packet from the given session to the given socket immediately.
+int send_packet(struct Session *s, int sock, int is_server);
 
 // Reads a packet from sock and stores it in s->rx.  Assumes a packet is
 // currently readable.
diff --git a/cmds/isoping_test.cc b/cmds/isoping_test.cc
index a8da304..ca0c440 100644
--- a/cmds/isoping_test.cc
+++ b/cmds/isoping_test.cc
@@ -27,14 +27,14 @@
 
 #include "isoping.h"
 
-uint32_t send_next_packet(Session *from, uint32_t from_base,
+uint32_t send_next_ack_packet(Session *from, uint32_t from_base,
                           Session *to, uint32_t to_base, uint32_t latency) {
   uint32_t t = from->next_send - from_base;
   prepare_tx_packet(from);
   to->rx = from->tx;
   from->next_send += from->usec_per_pkt;
   t += latency;
-  handle_packet(to, to_base + t);
+  handle_ack_packet(to, to_base + t);
   fprintf(stderr,
           "**Sent packet: txtime=%d, start_txtime=%d, rxtime=%d, "
           "start_rxtime=%d, latency=%d, t_from=%d, t_to=%d\n",
@@ -63,6 +63,8 @@
   struct sockaddr_storage empty_sockaddr;
   struct Session c(cbase, usec_per_pkt, empty_sockaddr, sizeof(empty_sockaddr));
   struct Session s(sbase, usec_per_pkt, empty_sockaddr, sizeof(empty_sockaddr));
+  c.handshake_state = Session::ESTABLISHED;
+  s.handshake_state = Session::ESTABLISHED;
 
   // One-way latencies: cs_latency is the latency from client to server;
   // sc_latency is from server to client.
@@ -75,7 +77,7 @@
 
   // Send the initial packet from client to server.  This isn't enough to let us
   // draw any useful latency conclusions.
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency);
   uint32_t rxtime = sbase + t;
   s.next_send = rxtime + 10 * 1000;
 
@@ -84,16 +86,16 @@
   WVPASSEQ(s.rx.clockdiff, 0);
   WVPASSEQ(s.last_rxtime, rxtime);
   WVPASSEQ(s.min_cycle_rxdiff, 0);
-  WVPASSEQ(ntohl(s.tx.acks[0].id), 1);
+  WVPASSEQ(ntohl(s.tx.data.acks[0].id), 1);
   WVPASSEQ(s.next_txack_index, 1);
-  WVPASSEQ(ntohl(s.tx.acks[ntohl(s.tx.first_ack)].id), 1);
-  WVPASSEQ(ntohl(s.tx.acks[ntohl(s.tx.first_ack)].rxtime), rxtime);
+  WVPASSEQ(ntohl(s.tx.data.acks[s.tx.first_ack].id), 1);
+  WVPASSEQ(ntohl(s.tx.data.acks[s.tx.first_ack].rxtime), rxtime);
   WVPASSEQ(s.start_rxtime, rxtime - c.usec_per_pkt);
   WVPASSEQ(s.start_rtxtime, cbase - c.usec_per_pkt);
   WVPASSEQ(s.next_send, rxtime + 10 * 1000);
 
   // Reply to the client.
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency);
 
   // Now we have enough data to figure out latencies on the client.
   rxtime = cbase + t;
@@ -101,8 +103,8 @@
   WVPASSEQ(c.start_rtxtime, sbase + cs_latency + 10 * 1000 - s.usec_per_pkt);
   WVPASSEQ(c.min_cycle_rxdiff, 0);
   WVPASSEQ(ntohl(c.rx.clockdiff), sbase - cbase + cs_latency);
-  WVPASSEQ(ntohl(c.tx.acks[ntohl(c.tx.first_ack)].id), 1);
-  WVPASSEQ(ntohl(c.tx.acks[ntohl(c.tx.first_ack)].rxtime), rxtime);
+  WVPASSEQ(ntohl(c.tx.data.acks[c.tx.first_ack].id), 1);
+  WVPASSEQ(ntohl(c.tx.data.acks[c.tx.first_ack].rxtime), rxtime);
   WVPASSEQ(c.num_lost, 0);
   WVPASSEQ(c.lat_tx_count, 1);
   WVPASSEQ(c.lat_tx, half_rtt);
@@ -111,15 +113,15 @@
   WVPASSEQ(c.num_lost, 0);
 
   // Round 2
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency);
   rxtime = sbase + t;
 
   // Now the server also knows latencies.
   WVPASSEQ(s.start_rxtime, sbase + cs_latency - s.usec_per_pkt);
   WVPASSEQ(s.start_rtxtime, cbase - c.usec_per_pkt);
   WVPASSEQ(ntohl(s.rx.clockdiff), cbase - sbase + sc_latency);
-  WVPASSEQ(ntohl(s.tx.acks[ntohl(s.tx.first_ack)].id), 2);
-  WVPASSEQ(ntohl(s.tx.acks[ntohl(s.tx.first_ack)].rxtime), rxtime);
+  WVPASSEQ(ntohl(s.tx.data.acks[s.tx.first_ack].id), 2);
+  WVPASSEQ(ntohl(s.tx.data.acks[s.tx.first_ack].rxtime), rxtime);
   WVPASSEQ(s.num_lost, 0);
   WVPASSEQ(s.lat_tx_count, 1);
   WVPASSEQ(s.lat_tx, half_rtt);
@@ -129,15 +131,15 @@
 
   // Increase the latencies in both directions, reply to client.
   int32_t latency_diff = 10 * 1000;
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency + latency_diff);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency + latency_diff);
 
   rxtime = cbase + t;
   WVPASSEQ(ntohl(s.tx.clockdiff), real_clockdiff + cs_latency);
   WVPASSEQ(c.start_rxtime,
            rxtime - ntohl(s.tx.id) * s.usec_per_pkt - latency_diff);
   WVPASSEQ(c.start_rtxtime, sbase + cs_latency + 10 * 1000 - s.usec_per_pkt);
-  WVPASSEQ(ntohl(c.tx.acks[ntohl(c.tx.first_ack)].id), 2);
-  WVPASSEQ(ntohl(c.tx.acks[ntohl(c.tx.first_ack)].rxtime), rxtime);
+  WVPASSEQ(ntohl(c.tx.data.acks[c.tx.first_ack].id), 2);
+  WVPASSEQ(ntohl(c.tx.data.acks[c.tx.first_ack].rxtime), rxtime);
   WVPASSEQ(c.num_lost, 0);
   WVPASSEQ(c.lat_tx_count, 2);
   WVPASSEQ(c.lat_tx, half_rtt);
@@ -146,15 +148,15 @@
   WVPASSEQ(c.num_lost, 0);
 
   // Client replies with increased latency, server notices.
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + latency_diff);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency + latency_diff);
 
   rxtime = sbase + t;
   WVPASSEQ(ntohl(c.tx.clockdiff), - real_clockdiff + sc_latency);
   WVPASSEQ(s.start_rxtime, sbase + cs_latency - s.usec_per_pkt);
   WVPASSEQ(s.start_rtxtime, cbase - c.usec_per_pkt);
   WVPASSEQ(ntohl(s.rx.clockdiff), cbase - sbase + sc_latency);
-  WVPASSEQ(ntohl(s.tx.acks[ntohl(s.tx.first_ack)].id), 3);
-  WVPASSEQ(ntohl(s.tx.acks[ntohl(s.tx.first_ack)].rxtime), rxtime);
+  WVPASSEQ(ntohl(s.tx.data.acks[s.tx.first_ack].id), 3);
+  WVPASSEQ(ntohl(s.tx.data.acks[s.tx.first_ack].rxtime), rxtime);
   WVPASSEQ(s.num_lost, 0);
   WVPASSEQ(s.lat_tx_count, 2);
   WVPASSEQ(s.lat_tx, half_rtt + latency_diff);
@@ -167,12 +169,12 @@
   s.next_send += s.usec_per_pkt;
   s.next_tx_id++;
 
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + latency_diff);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency + latency_diff);
 
   rxtime = sbase + t;
   WVPASSEQ(ntohl(s.rx.clockdiff), cbase - sbase + sc_latency);
-  WVPASSEQ(ntohl(s.tx.acks[ntohl(s.tx.first_ack)].id), 3);
-  WVPASSEQ(ntohl(s.tx.acks[ntohl(s.tx.first_ack)].rxtime),
+  WVPASSEQ(ntohl(s.tx.data.acks[s.tx.first_ack].id), 3);
+  WVPASSEQ(ntohl(s.tx.data.acks[s.tx.first_ack].rxtime),
            rxtime - s.usec_per_pkt);
   WVPASSEQ(s.num_lost, 0);
   WVPASSEQ(s.lat_tx_count, 2);
@@ -183,11 +185,11 @@
 
   // Remove the extra latency from server->client, send the next packet, have
   // the client receive it and notice the lost packet and reduced latency.
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency);
 
   rxtime = cbase + t;
-  WVPASSEQ(ntohl(c.tx.acks[ntohl(c.tx.first_ack)].id), 4);
-  WVPASSEQ(ntohl(c.tx.acks[ntohl(c.tx.first_ack)].rxtime), rxtime);
+  WVPASSEQ(ntohl(c.tx.data.acks[c.tx.first_ack].id), 4);
+  WVPASSEQ(ntohl(c.tx.data.acks[c.tx.first_ack].rxtime), rxtime);
   WVPASSEQ(c.num_lost, 1);
   WVPASSEQ(c.lat_tx_count, 4);
   WVPASSEQ(c.lat_tx, half_rtt + latency_diff);
@@ -198,7 +200,8 @@
   // A tiny reduction in latency shows up in min_cycle_rxdiff.
   latency_diff = 0;
   int32_t latency_mini_diff = -15;
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + latency_mini_diff);
+  t = send_next_ack_packet(&c, cbase, &s, sbase,
+                           cs_latency + latency_mini_diff);
 
   WVPASSEQ(ntohl(s.rx.clockdiff), cbase - sbase + sc_latency);
   WVPASSEQ(s.min_cycle_rxdiff, latency_mini_diff);
@@ -206,7 +209,8 @@
   WVPASSEQ(s.lat_tx, half_rtt);
   WVPASSEQ(s.lat_rx, half_rtt + latency_mini_diff);
 
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency + latency_mini_diff);
+  t = send_next_ack_packet(&s, sbase, &c, cbase,
+                           sc_latency + latency_mini_diff);
 
   WVPASSEQ(ntohl(s.rx.clockdiff), cbase - sbase + sc_latency);
   WVPASSEQ(c.min_cycle_rxdiff, latency_mini_diff);
@@ -216,7 +220,7 @@
   // Reduce the latency dramatically, verify that both sides see it, and the
   // start time is modified (not the min_cycle_rxdiff).
   latency_diff = -22 * 1000;
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + latency_diff);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency + latency_diff);
 
   WVPASSEQ(ntohl(s.rx.clockdiff), cbase - sbase + sc_latency);
   WVPASSEQ(s.min_cycle_rxdiff, latency_mini_diff);
@@ -227,7 +231,7 @@
   WVPASSEQ(s.lat_tx, half_rtt + latency_diff/2 + latency_mini_diff);
   WVPASSEQ(s.lat_rx, half_rtt + latency_diff/2);
 
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency + latency_diff);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency + latency_diff);
 
   // Now we see the new latency applied to both sides.
   WVPASSEQ(ntohl(s.rx.clockdiff), cbase - sbase + sc_latency);
@@ -237,7 +241,7 @@
 
   // Restore latency on one side of the connection, verify that we track it on
   // only one side and we've improved our clock sync.
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency);
 
   WVPASSEQ(ntohl(s.rx.clockdiff), cbase - sbase + sc_latency + latency_diff);
   WVPASSEQ(s.lat_tx, half_rtt + latency_diff);
@@ -245,7 +249,7 @@
 
   // And double-check that the other side also sees the improved clock sync and
   // one-sided latency on the correct side.
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency + latency_diff);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency + latency_diff);
 
   WVPASSEQ(ntohl(c.rx.clockdiff), sbase - cbase + cs_latency + latency_diff);
   WVPASSEQ(c.lat_tx, half_rtt);
@@ -262,6 +266,8 @@
   struct sockaddr_storage empty_sockaddr;
   struct Session c(cbase, usec_per_pkt, empty_sockaddr, sizeof(empty_sockaddr));
   struct Session s(sbase, usec_per_pkt, empty_sockaddr, sizeof(empty_sockaddr));
+  c.handshake_state = Session::ESTABLISHED;
+  s.handshake_state = Session::ESTABLISHED;
   // Send packets infrequently, to get new cycles more often.
   s.usec_per_pkt = 1 * 1000 * 1000;
   c.usec_per_pkt = 1 * 1000 * 1000;
@@ -275,7 +281,7 @@
 
   // Perform the initial setup.
   c.next_send = cbase;
-  uint32_t t = send_next_packet(&c, cbase, &s, sbase, cs_latency);
+  uint32_t t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency);
   s.next_send = sbase + t + 10 * 1000;
 
   uint32_t orig_server_start_rxtime = s.start_rxtime;
@@ -286,7 +292,7 @@
   WVPASSEQ(s.lat_tx, 0);
   WVPASSEQ(s.min_cycle_rxdiff, 0);
 
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency);
 
   uint32_t orig_client_start_rxtime = c.start_rxtime;
   WVPASSEQ(c.start_rxtime, cbase + 2 * half_rtt + 10 * 1000 - c.usec_per_pkt);
@@ -298,7 +304,7 @@
 
   // Clock drift shows up as symmetric changes in one-way latency.
   int32_t total_drift = drift_per_round;
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
 
   WVPASSEQ(s.start_rxtime, orig_server_start_rxtime);
   WVPASSEQ(s.start_rtxtime, cbase - c.usec_per_pkt);
@@ -307,7 +313,7 @@
   WVPASSEQ(s.lat_tx, half_rtt);
   WVPASSEQ(s.min_cycle_rxdiff, 0);
 
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
 
   WVPASSEQ(c.start_rxtime, cbase + 2 * half_rtt + 10 * 1000 - c.usec_per_pkt);
   WVPASSEQ(c.start_rtxtime,
@@ -319,7 +325,7 @@
 
   // Once we exceed -20us of drift, we adjust the client's start_rxtime.
   total_drift += drift_per_round;
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
 
   WVPASSEQ(s.start_rxtime, orig_server_start_rxtime);
   WVPASSEQ(s.start_rtxtime, cbase - c.usec_per_pkt);
@@ -328,7 +334,7 @@
   WVPASSEQ(s.lat_tx, half_rtt - drift_per_round);
   WVPASSEQ(s.min_cycle_rxdiff, 0);
 
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
 
   int32_t clock_adj = total_drift;
   WVPASSEQ(c.start_rxtime,
@@ -350,7 +356,7 @@
   total_drift += packets_to_skip * drift_per_round;
 
   // At first we blame the rx latency for most of the drift.
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
 
   // start_rxtime doesn't change here as the first cycle suppresses positive
   // min_cycle_rxdiff values.
@@ -363,7 +369,7 @@
   WVPASSEQ(s.min_cycle_rxdiff, INT_MAX);
 
   // After one round-trip, we divide the blame for the latency diff evenly.
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
 
   WVPASSEQ(c.start_rxtime, orig_client_start_rxtime - total_drift);
   WVPASSEQ(c.start_rtxtime, sbase + cs_latency + 10 * 1000 - c.usec_per_pkt);
@@ -372,7 +378,7 @@
   WVPASSEQ(c.lat_tx, half_rtt + total_drift / 2);
   WVPASSEQ(c.min_cycle_rxdiff, INT_MAX);
 
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
 
   WVPASSEQ(s.start_rxtime, orig_server_start_rxtime);
   WVPASSEQ(s.start_rtxtime, cbase - c.usec_per_pkt);
@@ -383,7 +389,7 @@
   WVPASSEQ(s.min_cycle_rxdiff, total_drift);
 
   total_drift += drift_per_round;
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
   // And on the client.  The client doesn't notice the total_drift rxdiff as it
   // was swallowed by the new cycle.
   WVPASSEQ(c.min_cycle_rxdiff, -drift_per_round);
@@ -398,7 +404,7 @@
   c.next_tx_id += packets_to_skip;
   total_drift += packets_to_skip * drift_per_round;
   int32_t drift_per_cycle = 10 * drift_per_round;
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
+  t = send_next_ack_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
 
   // The clock drift has worked its way into the RTT calculation.
   half_rtt = (cs_latency + sc_latency - drift_per_cycle) / 2;
@@ -412,7 +418,7 @@
   WVPASSEQ(s.lat_tx, half_rtt - drift_per_round);
   WVPASSEQ(s.min_cycle_rxdiff, INT_MAX);
 
-  t = send_next_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
+  t = send_next_ack_packet(&s, sbase, &c, cbase, sc_latency - total_drift);
 
   WVPASSEQ(c.start_rxtime, orig_client_start_rxtime - total_drift);
   WVPASSEQ(c.start_rtxtime, sbase + cs_latency + 10 * 1000 - c.usec_per_pkt);
@@ -495,26 +501,83 @@
   Sessions s;
   uint32_t usec_per_pkt = 100 * 1000;
 
+  s.MaybeRotateCookieSecrets();
+  // TODO(pmccurdy): Remove +1?
   c.NewSession(cbase + 1, usec_per_pkt, &listenaddr, listenaddr_len);
 
-  uint32_t cs_latency = 4000;
-  uint32_t sc_latency = 5000;
+  int is_server = 1;
+  int is_client = 0;
 
+  // Send the initial handshake packet.
   Session &cSession = c.session_map.begin()->second;
-  uint32_t t = cSession.next_send - cbase - 1;
-  WVPASS(!send_waiting_packets(&c, csock, cbase + t));
+  uint32_t t = cSession.next_send - cbase;
+  WVPASS(!send_waiting_packets(&c, csock, cbase + t, is_client));
 
-  // Verify we didn't send a packet before its time.
+  WVPASSEQ(cSession.handshake_retry_count, 0);
+
   fd_set rfds;
   FD_ZERO(&rfds);
   FD_SET(ssock, &rfds);
   struct timeval tv = {0, 0};
   int nfds = select(ssock + 1, &rfds, NULL, NULL, &tv);
+  WVPASSEQ(nfds, 1);
+
+  WVPASS(!read_incoming_packet(&s, ssock, sbase, is_server));
+
+  // The server returns its handshake cookie immediately.
+  FD_ZERO(&rfds);
+  FD_SET(csock, &rfds);
+  nfds = select(csock + 1, &rfds, NULL, NULL, &tv);
+  WVPASSEQ(nfds, 1);
+
+  // Eat the packet before the client can see it.
+  Packet p;
+  WVPASSEQ(recv(csock, &p, sizeof(p), 0), 540);
+
+  // The client doesn't send more packets until the handshake timeout expires.
+  t += Session::handshake_timeout_usec - 1;
+  WVPASS(!send_waiting_packets(&c, csock, cbase + t, is_client));
+  FD_ZERO(&rfds);
+  FD_SET(csock, &rfds);
+  nfds = select(csock + 1, &rfds, NULL, NULL, &tv);
   WVPASSEQ(nfds, 0);
 
-  // Send a packet in each direction.
+  // Wait for the client to time out and resend the initial handshake packet.
   t += 1;
-  WVPASS(!send_waiting_packets(&c, csock, cbase + t));
+  WVPASS(!send_waiting_packets(&c, csock, cbase + t, is_client));
+
+  FD_ZERO(&rfds);
+  FD_SET(ssock, &rfds);
+  nfds = select(ssock + 1, &rfds, NULL, NULL, &tv);
+  WVPASSEQ(nfds, 1);
+
+  // The server resends its cookie immediately.
+  WVPASS(!read_incoming_packet(&s, ssock, sbase, is_server));
+
+  // The server doesn't store any state for unverified clients
+  WVPASSEQ(s.session_map.size(), 0);
+
+  // Let the client read the cookie, establishing the connection.
+  WVPASS(!read_incoming_packet(&c, csock, cbase + t, is_client));
+  WVPASSEQ(cSession.next_tx_id, 1);
+
+  uint32_t cs_latency = 4000;
+  uint32_t sc_latency = 5000;
+
+  WVPASSEQ(cSession.next_send, cbase + t);
+  t = cSession.next_send - cbase - 1;
+  WVPASS(!send_waiting_packets(&c, csock, cbase + t, is_client));
+
+  // Verify we didn't send a packet before its time.
+  FD_ZERO(&rfds);
+  FD_SET(ssock, &rfds);
+  nfds = select(ssock + 1, &rfds, NULL, NULL, &tv);
+  WVPASSEQ(nfds, 0);
+
+  // Send a packet in each direction.  The server can now verify the client.
+  t += 1;
+  WVPASSEQ(cSession.next_tx_id, 1);
+  WVPASS(!send_waiting_packets(&c, csock, cbase + t, is_client));
   WVPASSEQ(cSession.next_tx_id, 2);
 
   FD_ZERO(&rfds);
@@ -523,8 +586,6 @@
   WVPASSEQ(nfds, 1);
 
   t += cs_latency;
-  int is_server = 1;
-  int is_client = 0;
   WVPASS(!read_incoming_packet(&s, ssock, sbase + t, is_server));
   WVPASSEQ(s.session_map.size(), 1);
   WVPASSEQ(s.next_sends.size(), 1);
@@ -534,9 +595,10 @@
   Session &sSession = s.session_map.begin()->second;
   WVPASS(sSession.remoteaddr_len > 0);
   WVPASSEQ(sSession.next_tx_id, 1);
+  WVPASSEQ(ntohl(sSession.rx.id), 1);
 
   t = s.next_send_time() - sbase;
-  WVPASS(!send_waiting_packets(&s, ssock, sbase + t));
+  WVPASS(!send_waiting_packets(&s, ssock, sbase + t, is_server));
   WVPASSEQ(s.next_send_time(), sbase + t + sSession.usec_per_pkt);
   WVPASSEQ(sSession.next_tx_id, 2);
 
@@ -545,8 +607,7 @@
   WVPASSEQ(cSession.lat_rx_count, 1);
 
   // Verify we reject garbage data.
-  Packet p;
-  p.magic = 0;
+  memset(&p, 0, sizeof(p));
   if (!WVPASSEQ(send(csock, &p, sizeof(p), 0), sizeof(p))) {
     perror("sendto");
     return;
@@ -579,8 +640,17 @@
          ntohs(c2addr.sin6_port));
 
   Session &c2Session = c2.session_map.begin()->second;
+  // Perform the handshake dance so the server knows we're legit.
+  prepare_tx_packet(&c2Session);
+  WVPASS(!send_packet(&c2Session, c2sock, is_client));
+  t = cs_latency;
+  WVPASS(!read_incoming_packet(&s, ssock, sbase+t, is_server));
+  t += sc_latency;
+  WVPASS(!read_incoming_packet(&c2, c2sock, cbase+t, is_client));
+
+  // Now we can send a validated packet to the server.
   t = c2Session.next_send - cbase;
-  WVPASS(!send_waiting_packets(&c2, c2sock, cbase + t));
+  WVPASS(!send_waiting_packets(&c2, c2sock, cbase + t, is_client));
 
   t += cs_latency;
 
@@ -597,3 +667,185 @@
   close(c2sock);
   freeaddrinfo(res);
 }
+
+WVTEST_MAIN("Cookie Validation") {
+  struct addrinfo hints, *res;
+
+  // Get local interface information.
+  memset(&hints, 0, sizeof(hints));
+  hints.ai_family = AF_INET6;
+  hints.ai_socktype = SOCK_DGRAM;
+  hints.ai_flags = AI_PASSIVE | AI_V4MAPPED;
+  int err = getaddrinfo(NULL, "0", &hints, &res);
+  if (err != 0) {
+    WVPASSEQ("Error from getaddrinfo: ", gai_strerror(err));
+    return;
+  }
+
+  // Set up a socket
+  struct sockaddr_storage addr;
+  socklen_t addr_len = sizeof(addr);
+  memset(&addr, 0, addr_len);
+  int sock;
+  sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+  if (!WVPASS(sock >= 0)) {
+    perror("socket");
+    return;
+  }
+  if (!WVPASS(!getsockname(sock, (struct sockaddr *)&addr, &addr_len))) {
+    perror("getsockname");
+    return;
+  }
+
+  Sessions s;
+  uint32_t epoch = 1;
+  s.RotateCookieSecrets(epoch);
+  Packet p;
+  memset(&p, 0, sizeof(p));
+  WVFAIL(s.CalculateCookie(&p, &addr, addr_len));
+
+  p.packet_type = PACKET_TYPE_HANDSHAKE;
+  p.usec_per_pkt = 100000;
+  WVPASS(s.CalculateCookie(&p, &addr, addr_len));
+
+  // We validate cookies we generate.
+  WVPASS(s.ValidateCookie(&p, &addr, addr_len));
+
+  // Validation fails after changing the IP port or address.
+  sockaddr_storage changed_addr;
+  memcpy(&changed_addr, &addr, addr_len);
+  ((sockaddr_in6 *)&changed_addr)->sin6_port++;
+  WVFAIL(s.ValidateCookie(&p, &changed_addr, addr_len));
+
+  memcpy(&changed_addr, &addr, addr_len);
+  ((sockaddr_in6 *)&changed_addr)->sin6_addr.s6_addr[0]++;
+  WVFAIL(s.ValidateCookie(&p, &changed_addr, addr_len));
+
+  // Validation fails after changing the usec_per_pkt.
+  p.usec_per_pkt++;
+  WVFAIL(s.ValidateCookie(&p, &addr, addr_len));
+  p.usec_per_pkt--;
+
+  // Validation fails after plain modifying the cookie.
+  p.data.handshake.cookie[0]++;
+  WVFAIL(s.ValidateCookie(&p, &changed_addr, addr_len));
+  p.data.handshake.cookie[0]--;
+
+  // Cookies generated with the previous secret still validate.
+  epoch++;
+  s.RotateCookieSecrets(epoch);
+  WVPASS(s.ValidateCookie(&p, &addr, addr_len));
+
+  // But secrets older than that don't validate.
+  epoch++;
+  s.RotateCookieSecrets(epoch);
+  WVFAIL(s.ValidateCookie(&p, &addr, addr_len));
+
+  // Cleanup
+  close(sock);
+  freeaddrinfo(res);
+}
+
+WVTEST_MAIN("Exponential Handshake Backoff") {
+  struct addrinfo hints, *res;
+
+  // Get local interface information.
+  memset(&hints, 0, sizeof(hints));
+  hints.ai_family = AF_INET6;
+  hints.ai_socktype = SOCK_DGRAM;
+  hints.ai_flags = AI_PASSIVE | AI_V4MAPPED;
+  int err = getaddrinfo(NULL, "0", &hints, &res);
+  if (err != 0) {
+    WVPASSEQ("Error from getaddrinfo: ", gai_strerror(err));
+    return;
+  }
+
+  // Set up a socket
+  struct sockaddr_storage addr;
+  socklen_t addr_len = sizeof(addr);
+  memset(&addr, 0, addr_len);
+  int sock;
+  sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+  if (!WVPASS(sock >= 0)) {
+    perror("socket");
+    return;
+  }
+  if (!WVPASS(!getsockname(sock, (struct sockaddr *)&addr, &addr_len))) {
+    perror("getsockname");
+    return;
+  }
+
+  uint32_t cbase = 400*1000;
+  uint32_t usec_per_pkt = 100 * 1000;
+  Sessions c;
+  c.NewSession(cbase, usec_per_pkt, &addr, addr_len);
+  Session &cSession = c.session_map.begin()->second;
+  WVPASSEQ(cSession.next_send, cbase);
+
+  // Test that we resend handshake packets on an exponential backoff schedule,
+  // up until round 10.
+  int is_client = 0;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_state, Session::HANDSHAKE_REQUESTED);
+  WVPASSEQ(cSession.handshake_retry_count, 0);
+  WVPASSEQ(cSession.next_send, cbase + Session::handshake_timeout_usec);
+
+  uint32_t t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 1);
+  WVPASSEQ(cSession.next_send, t + 2 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 2);
+  WVPASSEQ(cSession.next_send, t + 4 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 3);
+  WVPASSEQ(cSession.next_send, t + 8 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 4);
+  WVPASSEQ(cSession.next_send, t + 16 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 5);
+  WVPASSEQ(cSession.next_send, t + 32 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 6);
+  WVPASSEQ(cSession.next_send, t + 64 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 7);
+  WVPASSEQ(cSession.next_send, t + 128 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 8);
+  WVPASSEQ(cSession.next_send, t + 256 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 9);
+  WVPASSEQ(cSession.next_send, t + 512 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 10);
+  WVPASSEQ(cSession.next_send, t + 1024 * Session::handshake_timeout_usec);
+
+  t = cSession.next_send;
+  send_packet(&cSession, sock, is_client);
+  WVPASSEQ(cSession.handshake_retry_count, 11);
+  WVPASSEQ(cSession.next_send, t + 1024 * Session::handshake_timeout_usec);
+
+  // Cleanup
+  close(sock);
+  freeaddrinfo(res);
+}