Make tls_connection_get_keyblock_size() internal to tls_*.c

This function exposes internal state of the TLS negotiated parameters
for the sole purpose of being able to implement PRF for EAP-FAST. Since
tls_connection_prf() is now taking care of all TLS-based key derivation
cases, it is cleaner to keep this detail internal to each tls_*.c
wrapper implementation.

Signed-off-by: Jouni Malinen <jouni@qca.qualcomm.com>
diff --git a/src/crypto/tls.h b/src/crypto/tls.h
index 97ed8c1..f9e2e10 100644
--- a/src/crypto/tls.h
+++ b/src/crypto/tls.h
@@ -323,6 +323,7 @@
  * @label: Label (e.g., description of the key) for PRF
  * @server_random_first: seed is 0 = client_random|server_random,
  * 1 = server_random|client_random
+ * @skip_keyblock: Skip TLS key block from the beginning of PRF output
  * @out: Buffer for output data from TLS-PRF
  * @out_len: Length of the output buffer
  * Returns: 0 on success, -1 on failure
@@ -340,6 +341,7 @@
 				     struct tls_connection *conn,
 				     const char *label,
 				     int server_random_first,
+				     int skip_keyblock,
 				     u8 *out, size_t out_len);
 
 /**
@@ -526,16 +528,6 @@
 				    struct tls_connection *conn);
 
 /**
- * tls_connection_get_keyblock_size - Get TLS key_block size
- * @tls_ctx: TLS context data from tls_init()
- * @conn: Connection context data from tls_connection_init()
- * Returns: Size of the key_block for the negotiated cipher suite or -1 on
- * failure
- */
-int tls_connection_get_keyblock_size(void *tls_ctx,
-				     struct tls_connection *conn);
-
-/**
  * tls_capabilities - Get supported TLS capabilities
  * @tls_ctx: TLS context data from tls_init()
  * Returns: Bit field of supported TLS capabilities (TLS_CAPABILITY_*)
diff --git a/src/crypto/tls_gnutls.c b/src/crypto/tls_gnutls.c
index 65db6fc..c7f6464 100644
--- a/src/crypto/tls_gnutls.c
+++ b/src/crypto/tls_gnutls.c
@@ -747,9 +747,9 @@
 
 int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
 		       const char *label, int server_random_first,
-		       u8 *out, size_t out_len)
+		       int skip_keyblock, u8 *out, size_t out_len)
 {
-	if (conn == NULL || conn->session == NULL)
+	if (conn == NULL || conn->session == NULL || skip_keyblock)
 		return -1;
 
 	return gnutls_prf(conn->session, os_strlen(label), label,
@@ -1476,14 +1476,6 @@
 }
 
 
-int tls_connection_get_keyblock_size(void *tls_ctx,
-				     struct tls_connection *conn)
-{
-	/* TODO */
-	return -1;
-}
-
-
 unsigned int tls_capabilities(void *tls_ctx)
 {
 	return 0;
diff --git a/src/crypto/tls_internal.c b/src/crypto/tls_internal.c
index 0c955da..19a2d5a 100644
--- a/src/crypto/tls_internal.c
+++ b/src/crypto/tls_internal.c
@@ -348,25 +348,57 @@
 }
 
 
-int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
-		       const char *label, int server_random_first,
-		       u8 *out, size_t out_len)
+static int tls_get_keyblock_size(struct tls_connection *conn)
 {
 #ifdef CONFIG_TLS_INTERNAL_CLIENT
+	if (conn->client)
+		return tlsv1_client_get_keyblock_size(conn->client);
+#endif /* CONFIG_TLS_INTERNAL_CLIENT */
+#ifdef CONFIG_TLS_INTERNAL_SERVER
+	if (conn->server)
+		return tlsv1_server_get_keyblock_size(conn->server);
+#endif /* CONFIG_TLS_INTERNAL_SERVER */
+	return -1;
+}
+
+
+int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
+		       const char *label, int server_random_first,
+		       int skip_keyblock, u8 *out, size_t out_len)
+{
+	int ret = -1, skip = 0;
+	u8 *tmp_out = NULL;
+	u8 *_out = out;
+
+	if (skip_keyblock) {
+		skip = tls_get_keyblock_size(conn);
+		if (skip < 0)
+			return -1;
+		tmp_out = os_malloc(skip + out_len);
+		if (!tmp_out)
+			return -1;
+		_out = tmp_out;
+	}
+
+#ifdef CONFIG_TLS_INTERNAL_CLIENT
 	if (conn->client) {
-		return tlsv1_client_prf(conn->client, label,
-					server_random_first,
-					out, out_len);
+		ret = tlsv1_client_prf(conn->client, label,
+				       server_random_first,
+				       _out, out_len);
 	}
 #endif /* CONFIG_TLS_INTERNAL_CLIENT */
 #ifdef CONFIG_TLS_INTERNAL_SERVER
 	if (conn->server) {
-		return tlsv1_server_prf(conn->server, label,
-					server_random_first,
-					out, out_len);
+		ret = tlsv1_server_prf(conn->server, label,
+				       server_random_first,
+				       _out, out_len);
 	}
 #endif /* CONFIG_TLS_INTERNAL_SERVER */
-	return -1;
+	if (ret == 0 && skip_keyblock)
+		os_memcpy(out, _out + skip, out_len);
+	bin_clear_free(tmp_out, skip);
+
+	return ret;
 }
 
 
@@ -637,21 +669,6 @@
 }
 
 
-int tls_connection_get_keyblock_size(void *tls_ctx,
-				     struct tls_connection *conn)
-{
-#ifdef CONFIG_TLS_INTERNAL_CLIENT
-	if (conn->client)
-		return tlsv1_client_get_keyblock_size(conn->client);
-#endif /* CONFIG_TLS_INTERNAL_CLIENT */
-#ifdef CONFIG_TLS_INTERNAL_SERVER
-	if (conn->server)
-		return tlsv1_server_get_keyblock_size(conn->server);
-#endif /* CONFIG_TLS_INTERNAL_SERVER */
-	return -1;
-}
-
-
 unsigned int tls_capabilities(void *tls_ctx)
 {
 	return 0;
diff --git a/src/crypto/tls_none.c b/src/crypto/tls_none.c
index a6d210a..1b1ba56 100644
--- a/src/crypto/tls_none.c
+++ b/src/crypto/tls_none.c
@@ -87,7 +87,7 @@
 
 int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
 		       const char *label, int server_random_first,
-		       u8 *out, size_t out_len)
+		       int skip_keyblock, u8 *out, size_t out_len)
 {
 	return -1;
 }
@@ -181,13 +181,6 @@
 }
 
 
-int tls_connection_get_keyblock_size(void *tls_ctx,
-				     struct tls_connection *conn)
-{
-	return -1;
-}
-
-
 unsigned int tls_capabilities(void *tls_ctx)
 {
 	return 0;
diff --git a/src/crypto/tls_openssl.c b/src/crypto/tls_openssl.c
index 00e4479..935add5 100644
--- a/src/crypto/tls_openssl.c
+++ b/src/crypto/tls_openssl.c
@@ -2643,9 +2643,43 @@
 }
 
 
+static int openssl_get_keyblock_size(SSL *ssl)
+{
+	const EVP_CIPHER *c;
+	const EVP_MD *h;
+	int md_size;
+
+	if (ssl->enc_read_ctx == NULL || ssl->enc_read_ctx->cipher == NULL ||
+	    ssl->read_hash == NULL)
+		return -1;
+
+	c = ssl->enc_read_ctx->cipher;
+#if OPENSSL_VERSION_NUMBER >= 0x00909000L
+	h = EVP_MD_CTX_md(ssl->read_hash);
+#else
+	h = conn->ssl->read_hash;
+#endif
+	if (h)
+		md_size = EVP_MD_size(h);
+#if OPENSSL_VERSION_NUMBER >= 0x10000000L
+	else if (ssl->s3)
+		md_size = ssl->s3->tmp.new_mac_secret_size;
+#endif
+	else
+		return -1;
+
+	wpa_printf(MSG_DEBUG, "OpenSSL: keyblock size: key_len=%d MD_size=%d "
+		   "IV_len=%d", EVP_CIPHER_key_length(c), md_size,
+		   EVP_CIPHER_iv_length(c));
+	return 2 * (EVP_CIPHER_key_length(c) +
+		    md_size +
+		    EVP_CIPHER_iv_length(c));
+}
+
+
 static int openssl_tls_prf(void *tls_ctx, struct tls_connection *conn,
 			   const char *label, int server_random_first,
-			   u8 *out, size_t out_len)
+			   int skip_keyblock, u8 *out, size_t out_len)
 {
 #ifdef CONFIG_FIPS
 	wpa_printf(MSG_ERROR, "OpenSSL: TLS keys cannot be exported in FIPS "
@@ -2655,6 +2689,9 @@
 	SSL *ssl;
 	u8 *rnd;
 	int ret = -1;
+	int skip = 0;
+	u8 *tmp_out = NULL;
+	u8 *_out = out;
 
 	/*
 	 * TLS library did not support key generation, so get the needed TLS
@@ -2670,6 +2707,16 @@
 	    ssl->session->master_key == NULL)
 		return -1;
 
+	if (skip_keyblock) {
+		skip = openssl_get_keyblock_size(ssl);
+		if (skip < 0)
+			return -1;
+		tmp_out = os_malloc(skip + out_len);
+		if (!tmp_out)
+			return -1;
+		_out = tmp_out;
+	}
+
 	rnd = os_malloc(2 * SSL3_RANDOM_SIZE);
 	if (rnd == NULL)
 		return -1;
@@ -2688,9 +2735,12 @@
 	if (tls_prf_sha1_md5(ssl->session->master_key,
 			     ssl->session->master_key_length,
 			     label, rnd, 2 * SSL3_RANDOM_SIZE,
-			     out, out_len) == 0)
+			     _out, skip + out_len) == 0)
 		ret = 0;
 	os_free(rnd);
+	if (ret == 0 && skip_keyblock)
+		os_memcpy(out, _out + skip, out_len);
+	bin_clear_free(tmp_out, skip);
 
 	return ret;
 #endif /* CONFIG_FIPS */
@@ -2699,15 +2749,16 @@
 
 int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
 		       const char *label, int server_random_first,
-		       u8 *out, size_t out_len)
+		       int skip_keyblock, u8 *out, size_t out_len)
 {
 #if OPENSSL_VERSION_NUMBER >= 0x10001000L
 	SSL *ssl;
 	if (conn == NULL)
 		return -1;
-	if (server_random_first)
+	if (server_random_first || skip_keyblock)
 		return openssl_tls_prf(tls_ctx, conn, label,
-				       server_random_first, out, out_len);
+				       server_random_first, skip_keyblock,
+				       out, out_len);
 	ssl = conn->ssl;
 	if (SSL_export_keying_material(ssl, out, out_len, label,
 				       os_strlen(label), NULL, 0, 0) == 1) {
@@ -2716,7 +2767,7 @@
 	}
 #endif
 	return openssl_tls_prf(tls_ctx, conn, label, server_random_first,
-			       out, out_len);
+			       skip_keyblock, out, out_len);
 }
 
 
@@ -3569,43 +3620,6 @@
 }
 
 
-int tls_connection_get_keyblock_size(void *tls_ctx,
-				     struct tls_connection *conn)
-{
-	const EVP_CIPHER *c;
-	const EVP_MD *h;
-	int md_size;
-
-	if (conn == NULL || conn->ssl == NULL ||
-	    conn->ssl->enc_read_ctx == NULL ||
-	    conn->ssl->enc_read_ctx->cipher == NULL ||
-	    conn->ssl->read_hash == NULL)
-		return -1;
-
-	c = conn->ssl->enc_read_ctx->cipher;
-#if OPENSSL_VERSION_NUMBER >= 0x00909000L
-	h = EVP_MD_CTX_md(conn->ssl->read_hash);
-#else
-	h = conn->ssl->read_hash;
-#endif
-	if (h)
-		md_size = EVP_MD_size(h);
-#if OPENSSL_VERSION_NUMBER >= 0x10000000L
-	else if (conn->ssl->s3)
-		md_size = conn->ssl->s3->tmp.new_mac_secret_size;
-#endif
-	else
-		return -1;
-
-	wpa_printf(MSG_DEBUG, "OpenSSL: keyblock size: key_len=%d MD_size=%d "
-		   "IV_len=%d", EVP_CIPHER_key_length(c), md_size,
-		   EVP_CIPHER_iv_length(c));
-	return 2 * (EVP_CIPHER_key_length(c) +
-		    md_size +
-		    EVP_CIPHER_iv_length(c));
-}
-
-
 unsigned int tls_capabilities(void *tls_ctx)
 {
 	return 0;
diff --git a/src/eap_common/eap_fast_common.c b/src/eap_common/eap_fast_common.c
index 5b41189..151cc78 100644
--- a/src/eap_common/eap_fast_common.c
+++ b/src/eap_common/eap_fast_common.c
@@ -97,24 +97,16 @@
 			 const char *label, size_t len)
 {
 	u8 *out;
-	int block_size;
 
-	block_size = tls_connection_get_keyblock_size(ssl_ctx, conn);
-	if (block_size < 0)
-		return NULL;
-
-	out = os_malloc(block_size + len);
+	out = os_malloc(len);
 	if (out == NULL)
 		return NULL;
 
-	if (tls_connection_prf(ssl_ctx, conn, label, 1, out, block_size + len))
-	{
+	if (tls_connection_prf(ssl_ctx, conn, label, 1, 1, out, len)) {
 		os_free(out);
 		return NULL;
 	}
 
-	os_memmove(out, out + block_size, len);
-	os_memset(out + len, 0, block_size);
 	return out;
 }
 
diff --git a/src/eap_peer/eap_tls_common.c b/src/eap_peer/eap_tls_common.c
index e5a6ee5..b1180d5 100644
--- a/src/eap_peer/eap_tls_common.c
+++ b/src/eap_peer/eap_tls_common.c
@@ -319,7 +319,8 @@
 	if (out == NULL)
 		return NULL;
 
-	if (tls_connection_prf(data->ssl_ctx, data->conn, label, 0, out, len)) {
+	if (tls_connection_prf(data->ssl_ctx, data->conn, label, 0, 0,
+			       out, len)) {
 		os_free(out);
 		return NULL;
 	}
diff --git a/src/eap_server/eap_server_tls_common.c b/src/eap_server/eap_server_tls_common.c
index 09f2a82..23498c9 100644
--- a/src/eap_server/eap_server_tls_common.c
+++ b/src/eap_server/eap_server_tls_common.c
@@ -106,7 +106,8 @@
 	if (out == NULL)
 		return NULL;
 
-	if (tls_connection_prf(sm->ssl_ctx, data->conn, label, 0, out, len)) {
+	if (tls_connection_prf(sm->ssl_ctx, data->conn, label, 0, 0,
+			       out, len)) {
 		os_free(out);
 		return NULL;
 	}