Fix netusage underflow due to race condition.

Also add a unit test.

Change-Id: Ia00df67444f3dcf8dc0de4c3ff131872bea8deef
diff --git a/cmds/Makefile b/cmds/Makefile
index da69aa9..dfbb363 100644
--- a/cmds/Makefile
+++ b/cmds/Makefile
@@ -48,8 +48,9 @@
 LIB_TARGETS=\
 	stdoutline.so
 TEST_TARGETS=\
-	host-utils_test \
-	host-asus_hosts_test
+	host-asus_hosts_test \
+	host-netusage_test \
+	host-utils_test
 SCRIPT_TARGETS=\
 	is-secure-boot
 ARCH_TARGETS=\
@@ -204,6 +205,7 @@
 eddystone: eddystone.o
 host-dir-monitor dir-monitor: LIBS+=-lstdc++
 netusage: CFLAGS += -Wno-sign-compare
+host-netusage_test: host-netusage_test.o
 wifi_files: wifi_files.o
 wifi_files: LIBS+=-lnl-3 -lnl-genl-3
 
diff --git a/cmds/netusage.c b/cmds/netusage.c
index 36fdeb9..8c0e1b5 100644
--- a/cmds/netusage.c
+++ b/cmds/netusage.c
@@ -33,12 +33,17 @@
 #define SAMPLES 8
 
 
-static uint64_t mono_usecs(void)
+#ifndef UNIT_TESTS
+#define CLOCK_GETTIME clock_gettime
+#endif  /* UNIT_TESTS */
+
+
+uint64_t mono_usecs(void)
 {
   struct timespec ts;
   uint64_t usec;
 
-  if (clock_gettime(CLOCK_MONOTONIC, &ts) < 0) {
+  if (CLOCK_GETTIME(CLOCK_MONOTONIC, &ts) < 0) {
     perror("clock_gettime(CLOCK_MONOTONIC)");
     exit(1);
   }
@@ -70,6 +75,7 @@
 }
 
 
+#ifndef UNIT_TESTS
 void sendreq(int s, const char *ifname)
 {
   struct {
@@ -107,10 +113,12 @@
     exit(1);
   }
 }
+#endif  /* UNIT_TESTS */
 
 
+#ifndef UNIT_TESTS
 void recvresp(int s, uint32_t *tx_bytes, uint32_t *rx_bytes,
-    uint32_t *tx_pkts, uint32_t *rx_unipkts, uint32_t *rx_multipkts)
+    uint32_t *tx_pkts, uint32_t *rx_pkts, uint32_t *rx_multipkts)
 {
   ssize_t len;
   unsigned char buf[4096];
@@ -137,7 +145,7 @@
         *rx_bytes = stats->rx_bytes;
         *tx_bytes = stats->tx_bytes;
         *tx_pkts = stats->tx_packets;
-        *rx_unipkts = stats->rx_packets - stats->multicast;
+        *rx_pkts = stats->rx_packets;
         *rx_multipkts = stats->multicast;
       }
 
@@ -147,8 +155,67 @@
     nh = NLMSG_NEXT(nh, len);
   }
 }
+#endif  /* UNIT_TESTS */
 
 
+struct saved_counters {
+  uint32_t tx_bytes;
+  uint32_t rx_bytes;
+  uint32_t tx_pkts;
+  uint32_t rx_unipkts;
+  uint32_t rx_multipkts;
+};
+
+void accumulate_stats(int s, double delta, const char *interface,
+    double *tx_kbps, double *rx_kbps, double *tx_pps,
+    double *rx_uni_pps, double *rx_multi_pps,
+    struct saved_counters *old)
+{
+  uint32_t tx_bytes, rx_bytes, tx_pkts, rx_pkts, rx_multipkts;
+  uint32_t tx_bytes2, rx_bytes2, tx_pkts2, rx_pkts2, rx_multipkts2;
+  uint32_t rx_unipkts;
+
+  sendreq(s, interface);
+  recvresp(s, &tx_bytes, &rx_bytes, &tx_pkts, &rx_pkts, &rx_multipkts);
+
+  /*
+   * Most hardware platforms do not have an RX unicast packet counter, they
+   * have an overall RX counter and they have a multicast counter. We cannot
+   * read the two counters atomically, the hardware does not provide a way
+   * to do so.
+   *
+   * Therefore there is a race condition. Assume no packets have arrived, so
+   * we read an rx_packets count of zero. At that instant a multicast packet
+   * arrives, so rx_packets and rx_multipackets both increment to 1. We've
+   * already read the rx_packets counter so its too late for that, but we
+   * read the rx_multipkts counter of 1.
+   *
+   * rx_unipkts = (rx_packets - rx_multipkts) is 4,294,967,295, which
+   * is very wrong.
+   *
+   * To resolve this, we read the counters twice. We use the rx_multipkts
+   * count from the first read, and the rx_packets count from the second,
+   * to ensure that rx_packets has a chance to update.
+   */
+  sendreq(s, interface);
+  recvresp(s, &tx_bytes2, &rx_bytes2, &tx_pkts2, &rx_pkts2, &rx_multipkts2);
+
+  *tx_kbps = (8.0 * (tx_bytes - old->tx_bytes) / 1000.0) / delta;
+  *rx_kbps = (8.0 * (rx_bytes - old->rx_bytes) / 1000.0) / delta;
+  *tx_pps = (tx_pkts - old->tx_pkts) / delta;
+  rx_unipkts = rx_pkts2 - rx_multipkts;
+  *rx_uni_pps = (rx_unipkts - old->rx_unipkts) / delta;
+  *rx_multi_pps = (rx_multipkts - old->rx_multipkts) / delta;
+
+  old->tx_bytes = tx_bytes;
+  old->rx_bytes = rx_bytes;
+  old->tx_pkts = tx_pkts;
+  old->rx_unipkts = rx_unipkts;
+  old->rx_multipkts = rx_multipkts;
+}
+
+
+#ifndef UNIT_TESTS
 void usage(const char *progname)
 {
   fprintf(stderr, "usage: %s -i foo0\n", progname);
@@ -164,8 +231,8 @@
   const char *interface = NULL;
   int s = netlink_socket();
   uint64_t start;
-  uint32_t old_tx_bytes, old_rx_bytes, old_tx_pkts;
-  uint32_t old_rx_unipkts, old_rx_multipkts;
+  double junk;
+  struct saved_counters old;
   int i = 0;
 
   while ((c = getopt(argc, argv, "i:")) >= 0) {
@@ -186,14 +253,11 @@
 
   setlinebuf(stdout);
   start = mono_usecs();
-  sendreq(s, interface);
-  recvresp(s, &old_tx_bytes, &old_rx_bytes, &old_tx_pkts, &old_rx_unipkts,
-      &old_rx_multipkts);
+  accumulate_stats(s, 1.0, interface, &junk, &junk, &junk, &junk, &junk, &old);
 
   while (1) {
     uint64_t timestamp;
     double delta;
-    uint32_t tx_bytes, rx_bytes, tx_pkts, rx_unipkts, rx_multipkts;
     double tx_kbps[SAMPLES];
     double rx_kbps[SAMPLES];
     double tx_pps[SAMPLES];
@@ -202,17 +266,12 @@
 
     sleep(1);
     timestamp = mono_usecs();
-    sendreq(s, interface);
-    recvresp(s, &tx_bytes, &rx_bytes, &tx_pkts, &rx_unipkts, &rx_multipkts);
-
     delta = (timestamp - start) / 1000000.0;
-    tx_kbps[i] = (8.0 * (tx_bytes - old_tx_bytes) / 1000.0) / delta;
-    rx_kbps[i] = (8.0 * (rx_bytes - old_rx_bytes) / 1000.0) / delta;
-    tx_pps[i] = (tx_pkts - old_tx_pkts) / delta;
-    rx_uni_pps[i] = (rx_unipkts - old_rx_unipkts) / delta;
-    rx_multi_pps[i] = (rx_multipkts - old_rx_multipkts) / delta;
-    i++;
 
+    accumulate_stats(s, delta, interface, &tx_kbps[i], &rx_kbps[i], &tx_pps[i],
+        &rx_uni_pps[i], &rx_multi_pps[i], &old);
+
+    i++;
     if (i == SAMPLES) {
       printf("%s TX Kbps %.0f,%.0f,%.0f,%.0f,%.0f,%.0f,%.0f,%.0f\n",
           interface,
@@ -237,11 +296,7 @@
       i = 0;
     }
 
-    old_tx_bytes = tx_bytes;
-    old_rx_bytes = rx_bytes;
-    old_tx_pkts = tx_pkts;
-    old_rx_unipkts = rx_unipkts;
-    old_rx_multipkts = rx_multipkts;
     start = timestamp;
   }
 }
+#endif  /* UNIT_TESTS */
diff --git a/cmds/netusage_test.c b/cmds/netusage_test.c
new file mode 100644
index 0000000..56a8ca9
--- /dev/null
+++ b/cmds/netusage_test.c
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2015 Google Inc. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Unit tests for netusage.c */
+
+#include <assert.h>
+#include <math.h>
+#include <stdio.h>
+#include <stdint.h>
+#include <time.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+
+
+struct timespec test_clock_gettime_value = {0, 0};
+int test_clock_gettime(clockid_t clk_id, struct timespec *tp)
+{
+  *tp = test_clock_gettime_value;
+  return 0;
+}
+
+
+void sendreq(int s, const char *ifname)
+{
+  return;
+}
+
+
+int recvresp_count = 0;
+void recvresp(int s, uint32_t *tx_bytes, uint32_t *rx_bytes,
+    uint32_t *tx_pkts, uint32_t *rx_pkts, uint32_t *rx_multipkts)
+{
+  /*
+   * Set up the conditions for an overflow in rx_uni_pkts. On the first
+   * read of the hardware there were 5 total packets, but multipackets
+   * incremented to 6 immediately before we read it.
+   *
+   * On the second read, rx_packets will have also incremented to 6.
+   */
+  if (recvresp_count == 0) {
+    recvresp_count = 1;
+    *tx_bytes = 1000;
+    *rx_bytes = 2000;
+    *tx_pkts = 3000;
+    *rx_pkts = 5000;
+    *rx_multipkts = 6000;
+  } else {
+    recvresp_count = 0;
+    *tx_bytes = 1000;
+    *rx_bytes = 2000;
+    *tx_pkts = 3000;
+    *rx_pkts = 6000;
+    *rx_multipkts = 6000;
+  }
+}
+
+
+#define UNIT_TESTS
+#define CLOCK_GETTIME test_clock_gettime
+#include "netusage.c"
+
+
+int almost_equal(double val, double expected)
+{
+  return fabs(val - expected) < 0.0001;
+}
+
+
+void test_counters()
+{
+  double tx_kbps, rx_kbps, tx_pps, rx_uni_pps, rx_multi_pps;
+  struct saved_counters old;
+
+  test_clock_gettime_value.tv_sec = 1;
+  test_clock_gettime_value.tv_nsec = 0;
+  memset(&old, 0, sizeof(old));
+
+  accumulate_stats(0, 1.0, "foo0", &tx_kbps, &rx_kbps, &tx_pps,
+      &rx_uni_pps, &rx_multi_pps, &old);
+
+  assert(almost_equal(tx_kbps, 1.0 * 8));
+  assert(almost_equal(rx_kbps, 2.0 * 8));
+  assert(almost_equal(tx_pps, 3000.0));
+  assert(almost_equal(rx_uni_pps, 0.0));
+  assert(almost_equal(rx_multi_pps, 6000.0));
+}
+
+
+void test_mono_usecs()
+{
+  uint64_t usecs;
+
+  test_clock_gettime_value.tv_sec = 1;
+  test_clock_gettime_value.tv_nsec = 3000;
+  usecs = mono_usecs();
+  assert(usecs == 1000003);
+}
+
+
+int main(int argc, char** argv)
+{
+  test_mono_usecs();
+  test_counters();
+  exit(0);
+}