Add unit tests for isoping socket handling.

Change-Id: I7245394b528b59734d88392d65243ac443ff8584
diff --git a/cmds/isoping.cc b/cmds/isoping.cc
index c6b123e..aac7935 100644
--- a/cmds/isoping.cc
+++ b/cmds/isoping.cc
@@ -84,6 +84,8 @@
 Session::Session(uint32_t now)
     : usec_per_pkt(1e6 / packets_per_sec),
       usec_per_print(prints_per_sec > 0 ? 1e6 / prints_per_sec : 0),
+      remoteaddr(NULL),
+      remoteaddr_len(0),
       next_tx_id(1),
       next_rx_id(0),
       next_rxack_id(0),
@@ -225,17 +227,15 @@
   s->tx.first_ack = htonl(s->next_txack_index);
 }
 
-static int send_packet(struct Session *s,
-                       int sock,
-                       struct sockaddr *remoteaddr,
-                       socklen_t remoteaddr_len) {
+
+static int send_packet(struct Session *s, int sock) {
   // note: tx.acks[] is filled in incrementally; we just transmit the current
   // state of it here.  The reason we keep a list of the most recent acks is in
   // case our packet gets lost, so the receiver will have more chances to
   // receive the timing information for the packets it sent us.
   if (is_server) {
     if (sendto(sock, &s->tx, sizeof(s->tx), 0,
-               remoteaddr, remoteaddr_len) < 0) {
+               s->remoteaddr, s->remoteaddr_len) < 0) {
       perror("sendto");
     }
   } else {
@@ -250,6 +250,63 @@
 }
 
 
+int maybe_send_packet(struct Session *s, int sock, uint32_t now) {
+  if (s->remoteaddr && DIFF(now, s->next_send) >= 0) {
+    prepare_tx_packet(s);
+    int err = send_packet(s, sock);
+    if (err != 0) {
+      return err;
+    }
+  }
+  return 0;
+}
+
+
+int read_incoming_packet(struct Session *s, int sock, uint32_t now) {
+  struct sockaddr_in6 rxaddr;
+  socklen_t rxaddr_len = 0;
+  // TODO(pmccurdy): Temporary until we properly support multiple clients.
+  static struct sockaddr_in6 last_rxaddr;
+
+  rxaddr_len = sizeof(rxaddr);
+  ssize_t got = recvfrom(sock, &s->rx, sizeof(s->rx), 0,
+                         (struct sockaddr *)&rxaddr, &rxaddr_len);
+  if (got < 0) {
+    int e = errno;
+    perror("recvfrom");
+    return e;
+  }
+  if (got != sizeof(s->rx) || s->rx.magic != htonl(MAGIC)) {
+    fprintf(stderr, "got invalid packet of length %ld\n", (long)got);
+    return EINVAL;
+  }
+
+  // is it a new client?
+  if (is_server) {
+    // TODO(pmccurdy): Maintain a hash table of Sessions, look up based
+    // on rxaddr, create a new one if necessary, remove this resetting code.
+    if (!s->remoteaddr ||
+        memcmp(&rxaddr, &last_rxaddr, sizeof(rxaddr)) != 0) {
+      fprintf(stderr, "new client connected: %s\n",
+              sockaddr_to_str((struct sockaddr *)&rxaddr));
+      memcpy(&last_rxaddr, &rxaddr, sizeof(rxaddr));
+      s->remoteaddr = (struct sockaddr *)&last_rxaddr;
+      s->remoteaddr_len = rxaddr_len;
+
+      s->next_send = now + 10*1000;
+      s->next_tx_id = 1;
+      s->next_rx_id = s->next_rxack_id = 0;
+      s->start_rtxtime = s->start_rxtime = 0;
+      s->num_lost = 0;
+      s->next_txack_index = 0;
+      s->usec_per_pkt = ntohl(s->rx.usec_per_pkt);
+      memset(&s->tx, 0, sizeof(s->tx));
+    }
+  }
+  return 0;
+}
+
+
 void handle_packet(struct Session *s, uint32_t now) {
   // process the incoming packet header.
   // Most of the complexity here comes from the fact that the remote
@@ -441,9 +498,7 @@
 
 
 int isoping_main(int argc, char **argv) {
-  struct sockaddr_in6 listenaddr, rxaddr, last_rxaddr;
-  struct sockaddr *remoteaddr = NULL;
-  socklen_t remoteaddr_len = 0, rxaddr_len = 0;
+  struct sockaddr_in6 listenaddr;
   struct addrinfo *ai = NULL;
   int sock = -1;
 
@@ -494,6 +549,10 @@
     return 1;
   }
 
+  uint32_t now = ustime();       // current time
+
+  struct Session s(now);
+
   if (argc - optind == 0) {
     is_server = 1;
     memset(&listenaddr, 0, sizeof(listenaddr));
@@ -529,8 +588,8 @@
       perror("connect");
       return 1;
     }
-    remoteaddr = ai->ai_addr;
-    remoteaddr_len = ai->ai_addrlen;
+    s.remoteaddr = ai->ai_addr;
+    s.remoteaddr_len = ai->ai_addrlen;
   } else {
     usage_and_die(argv[0]);
   }
@@ -553,16 +612,12 @@
     }
   }
 
-  uint32_t now = ustime();       // current time
-
   struct sigaction act;
   memset(&act, 0, sizeof(act));
   act.sa_handler = sighandler;
   act.sa_flags = SA_RESETHAND;
   sigaction(SIGINT, &act, NULL);
 
-  struct Session s(now);
-
   while (!want_to_die) {
     fd_set rfds;
     FD_ZERO(&rfds);
@@ -576,7 +631,7 @@
     } else {
       tv.tv_usec = DIFF(s.next_send, now);
     }
-    int nfds = select(sock + 1, &rfds, NULL, NULL, remoteaddr ? &tv : NULL);
+    int nfds = select(sock + 1, &rfds, NULL, NULL, s.remoteaddr ? &tv : NULL);
     now = ustime();
     if (nfds < 0 && errno != EINTR) {
       perror("select");
@@ -584,59 +639,23 @@
     }
 
     // time to send the next packet?
-    if (remoteaddr && DIFF(now, s.next_send) >= 0) {
-      prepare_tx_packet(&s);
-      int err = send_packet(&s, sock, remoteaddr, remoteaddr_len);
-      if (err != 0) {
-        return err;
-      }
-      // TODO(pmccurdy): Track disconnections across multiple clients.  Use
-      // recvmsg with the MSG_ERRQUEUE flag to detect connection refused.
-      if (is_server && DIFF(now, s.last_rxtime) > 60*1000*1000) {
-        fprintf(stderr, "client disconnected.\n");
-        remoteaddr = NULL;
-      }
+    int err = maybe_send_packet(&s, sock, now);
+    if (err != 0) {
+      return err;
+    }
+    // TODO(pmccurdy): Track disconnections across multiple clients.  Use
+    // recvmsg with the MSG_ERRQUEUE flag to detect connection refused.
+    if (is_server && DIFF(now, s.last_rxtime) > 60 * 1000 * 1000) {
+      fprintf(stderr, "client disconnected.\n");
+      s.remoteaddr = NULL;
     }
 
     if (nfds > 0) {
-      // incoming packet
-      rxaddr_len = sizeof(rxaddr);
-      ssize_t got = recvfrom(sock, &s.rx, sizeof(s.rx), 0,
-                             (struct sockaddr *)&rxaddr, &rxaddr_len);
-      if (got < 0) {
-        int e = errno;
-        perror("recvfrom");
-        if (!is_server && e == ECONNREFUSED) return 2;
+      err = read_incoming_packet(&s, sock, now);
+      if (!is_server && err == ECONNREFUSED) return 2;
+      if (err != 0) {
         continue;
       }
-      if (got != sizeof(s.rx) || s.rx.magic != htonl(MAGIC)) {
-        fprintf(stderr, "got invalid packet of length %ld\n", (long)got);
-        continue;
-      }
-
-      // is it a new client?
-      if (is_server) {
-        // TODO(pmccurdy): Maintain a hash table of Sessions, look up based
-        // on rxaddr, create a new one if necessary, remove this resetting code.
-        if (!remoteaddr ||
-            memcmp(&rxaddr, &last_rxaddr, sizeof(rxaddr)) != 0) {
-          fprintf(stderr, "new client connected: %s\n",
-                  sockaddr_to_str((struct sockaddr *)&rxaddr));
-          memcpy(&last_rxaddr, &rxaddr, sizeof(rxaddr));
-          remoteaddr = (struct sockaddr *)&last_rxaddr;
-          remoteaddr_len = rxaddr_len;
-
-          s.next_send = now + 10*1000;
-          s.next_tx_id = 1;
-          s.next_rx_id = s.next_rxack_id = 0;
-          s.start_rtxtime = s.start_rxtime = 0;
-          s.num_lost = 0;
-          s.next_txack_index = 0;
-          s.usec_per_pkt = ntohl(s.rx.usec_per_pkt);
-          memset(&s.tx, 0, sizeof(s.tx));
-        }
-      }
-
       handle_packet(&s, now);
     }
   }
diff --git a/cmds/isoping.h b/cmds/isoping.h
index 52fcccb..11feae2 100644
--- a/cmds/isoping.h
+++ b/cmds/isoping.h
@@ -17,6 +17,7 @@
 #define ISOPING_H
 
 #include <stdint.h>
+#include <sys/socket.h>
 
 // Layout of the UDP packets exchanged between client and server.
 // All integers are in network byte order.
@@ -43,6 +44,10 @@
   int32_t usec_per_pkt;
   int32_t usec_per_print;
 
+  // The peer's address.
+  struct sockaddr *remoteaddr;
+  socklen_t remoteaddr_len;
+
   // WARNING: lots of math below relies on well-defined uint32/int32
   // arithmetic overflow behaviour, plus the fact that when we subtract
   // two successive timestamps (for example) they will be less than 2^31
@@ -76,6 +81,13 @@
 // Sets all the elements of s->tx to be ready to be sent to the other side.
 void prepare_tx_packet(struct Session *s);
 
+// Sends a packet to the socket if the appropriate amount of time has passed.
+int maybe_send_packet(struct Session *s, int sock, uint32_t now);
+
+// Reads a packet from sock and stores it in s->rx.  Assumes a packet is
+// currently readable.
+int read_incoming_packet(struct Session *s, int sock, uint32_t now);
+
 // Parses arguments and runs the main loop.  Distinct from main() for unit test
 // purposes.
 int isoping_main(int argc, char **argv);
diff --git a/cmds/isoping_test.cc b/cmds/isoping_test.cc
index 84edc46..572490b 100644
--- a/cmds/isoping_test.cc
+++ b/cmds/isoping_test.cc
@@ -16,8 +16,12 @@
 
 #include <arpa/inet.h>
 #include <limits.h>
+#include <memory.h>
 #include <stdio.h>
-
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netdb.h>
+#include <unistd.h>
 #include <wvtest.h>
 
 #include "isoping.h"
@@ -414,7 +418,161 @@
   WVPASSEQ(c.lat_rx, half_rtt + drift_per_round / 2);
   WVPASSEQ(c.lat_tx, half_rtt + total_drift / 2 + 1);
   WVPASSEQ(c.min_cycle_rxdiff, INT_MAX);
+}
 
-  t = send_next_packet(&c, cbase, &s, sbase, cs_latency + total_drift);
+WVTEST_MAIN("Send and receive on sockets") {
+  uint32_t cbase = 1400 * 1000;
+  uint32_t sbase = 1600 * 1000;
 
+  // The states of the client and server.
+  struct Session c(cbase);
+  struct Session s(sbase);
+
+  // Sockets for the client and server.
+  int ssock, csock;
+  struct addrinfo hints, *res;
+
+  // Get local interface information.
+  memset(&hints, 0, sizeof(hints));
+  hints.ai_family = AF_INET6;
+  hints.ai_socktype = SOCK_DGRAM;
+  hints.ai_flags = AI_PASSIVE | AI_V4MAPPED;
+  int err = getaddrinfo(NULL, "0", &hints, &res);
+  if (err != 0) {
+    WVPASSEQ("Error from getaddrinfo: ", gai_strerror(err));
+    return;
+  }
+
+  ssock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+  if (!WVPASS(ssock >= 0)) {
+    perror("server socket");
+    return;
+  }
+
+  csock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+  if (!WVPASS(csock >= 0)) {
+    perror("client socket");
+    return;
+  }
+
+  if (!WVPASS(!bind(ssock, res->ai_addr, res->ai_addrlen))) {
+    perror("bind");
+    return;
+  }
+
+  // Figure out the local port we got.
+  struct sockaddr_in6 listenaddr;
+  socklen_t listenaddr_len = sizeof(listenaddr);
+  memset(&listenaddr, 0, listenaddr_len);
+  if (!WVPASS(!getsockname(ssock, (struct sockaddr *)&listenaddr,
+                           &listenaddr_len))) {
+    perror("getsockname");
+    return;
+  }
+
+  printf("Bound server socket to port=%d\n", listenaddr.sin6_port);
+
+  // Connect the client's socket.
+  if (!WVPASS(
+          !connect(csock, (struct sockaddr *)&listenaddr, listenaddr_len))) {
+    perror("connect");
+    return;
+  }
+
+  c.remoteaddr = (struct sockaddr *)&listenaddr;
+  c.remoteaddr_len = listenaddr_len;
+
+  uint32_t cs_latency = 4000;
+  uint32_t sc_latency = 5000;
+  uint32_t t = c.usec_per_pkt - 1;
+  WVPASS(!maybe_send_packet(&c, csock, cbase + t));
+
+  // Verify we didn't send a packet before its time.
+  fd_set rfds;
+  FD_ZERO(&rfds);
+  FD_SET(ssock, &rfds);
+  struct timeval tv = {0, 0};
+  int nfds = select(ssock + 1, &rfds, NULL, NULL, &tv);
+  WVPASSEQ(nfds, 0);
+
+  // Send a packet in each direction.
+  t += 1;
+  WVPASS(!maybe_send_packet(&c, csock, cbase + t));
+  WVPASSEQ(c.next_tx_id, 2);
+
+  FD_ZERO(&rfds);
+  FD_SET(ssock, &rfds);
+  nfds = select(ssock + 1, &rfds, NULL, NULL, &tv);
+  WVPASSEQ(nfds, 1);
+
+  t += cs_latency;
+  WVPASS(!read_incoming_packet(&s, ssock, sbase + t));
+
+  WVPASS(s.remoteaddr != NULL);
+  WVPASS(s.remoteaddr_len > 0);
+  WVPASSEQ(s.next_tx_id, 1);
+
+  handle_packet(&s, sbase + t);
+
+  t = s.next_send - sbase;
+  WVPASS(!maybe_send_packet(&s, ssock, sbase + t));
+  WVPASSEQ(s.next_send, sbase + t + s.usec_per_pkt);
+  WVPASSEQ(s.next_tx_id, 2);
+
+  t += sc_latency;
+  WVPASS(!read_incoming_packet(&c, csock, cbase + t));
+  handle_packet(&c, cbase + t);
+  WVPASSEQ(c.lat_rx_count, 1);
+
+  // Verify we reject garbage data.
+  Packet p;
+  p.magic = 0;
+  if (!WVPASSEQ(send(csock, &p, sizeof(p), 0), sizeof(p))) {
+    perror("sendto");
+    return;
+  }
+
+  WVPASSEQ(read_incoming_packet(&s, ssock, sbase + t), EINVAL);
+
+  // Make a new client, getting a new source port.
+  struct Session c2(cbase);
+  c2.usec_per_pkt *= 2;
+  c2.remoteaddr = c.remoteaddr;
+  c2.remoteaddr_len = c.remoteaddr_len;
+  int c2sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+  if (!WVPASS(c2sock > 0)) {
+    perror("client socket 2");
+    return;
+  }
+  if (!WVPASS(!connect(c2sock, c.remoteaddr, c.remoteaddr_len))) {
+    perror("connect");
+    return;
+  }
+  struct sockaddr_in6 c2addr;
+  socklen_t c2addr_len = sizeof(c2addr);
+  memset(&c2addr, 0, c2addr_len);
+  if (!WVPASS(!getsockname(c2sock, (struct sockaddr *)&c2addr, &c2addr_len))) {
+    perror("getsockname");
+    return;
+  }
+
+  t = c2.next_send - cbase;
+  WVPASS(!maybe_send_packet(&c2, c2sock, cbase + t));
+
+  t += cs_latency;
+
+  // Check that a new client resets some state.
+  WVPASS(!read_incoming_packet(&s, ssock, sbase + t));
+
+  WVPASSEQ(ntohs(((sockaddr_in6 *)s.remoteaddr)->sin6_port),
+           ntohs(c2addr.sin6_port));
+  WVPASSEQ(s.next_tx_id, 1);
+  WVPASSEQ(s.next_rx_id, 0);
+  WVPASSEQ(s.usec_per_pkt, c2.usec_per_pkt);
+
+  // Cleanup
+  close(ssock);
+  close(csock);
+  close(c2sock);
+  freeaddrinfo(res);
 }