diff --git a/Makefile b/Makefile index 2c0fc2b72..d457937b4 100644 --- a/Makefile +++ b/Makefile @@ -305,6 +305,7 @@ install: build $(INSTALL_DATA) ${srcdir}/MQTTReasonCodes.h $(DESTDIR)${includedir} $(INSTALL_DATA) ${srcdir}/MQTTSubscribeOpts.h $(DESTDIR)${includedir} $(INSTALL_DATA) ${srcdir}/MQTTExportDeclarations.h $(DESTDIR)${includedir} + $(INSTALL_DATA) ${srcdir}/SSLPluginInterface.h $(DESTDIR)${includedir} - $(INSTALL_DATA) doc/man/man1/paho_c_pub.1 $(DESTDIR)${man1dir} - $(INSTALL_DATA) doc/man/man1/paho_c_sub.1 $(DESTDIR)${man1dir} - $(INSTALL_DATA) doc/man/man1/paho_cs_pub.1 $(DESTDIR)${man1dir} diff --git a/src/SSLPluginInterface.h b/src/SSLPluginInterface.h new file mode 100644 index 000000000..e504721cb --- /dev/null +++ b/src/SSLPluginInterface.h @@ -0,0 +1,11 @@ +#if defined(__cplusplus) +extern "C" { +#endif +#include "MQTTExportDeclarations.h" +#include + +LIBMQTT_API void SSLPluginInterface_setcallback(int (*callback)(SSL*)); + +#if defined(__cplusplus) +} +#endif \ No newline at end of file diff --git a/src/SSLSocket.c b/src/SSLSocket.c index 8cb090c47..3b3064892 100644 --- a/src/SSLSocket.c +++ b/src/SSLSocket.c @@ -27,7 +27,7 @@ */ #if defined(OPENSSL) - +#include "SSLPluginInterface.h" #include "SocketBuffer.h" #include "MQTTClient.h" #include "MQTTProtocolOut.h" @@ -61,6 +61,7 @@ int SSL_create_mutex(ssl_mutex_type* mutex); int SSL_lock_mutex(ssl_mutex_type* mutex); int SSL_unlock_mutex(ssl_mutex_type* mutex); int SSL_destroy_mutex(ssl_mutex_type* mutex); +static int (*handleSSLObjectAtConnect)(SSL*) = NULL; #if (OPENSSL_VERSION_NUMBER >= 0x010000000) extern void SSLThread_id(CRYPTO_THREADID *id); #else @@ -766,6 +767,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u) { int rc = 0; + int sslCallbackReturnedValue = 1; FUNC_ENTRY; @@ -781,50 +783,63 @@ int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, i rc = TCPSOCKET_INTERRUPTED; } #if (OPENSSL_VERSION_NUMBER >= 0x010002000) /* 1.0.2 and later */ - else if (verify) + else { - char* peername = NULL; - int port; - size_t hostname_len; + if(handleSSLObjectAtConnect) + { + sslCallbackReturnedValue = handleSSLObjectAtConnect(ssl); + } + + if (verify) + { + char* peername = NULL; + int port; + size_t hostname_len; - X509* cert = SSL_get_peer_certificate(ssl); - hostname_len = MQTTProtocol_addressPort(hostname, &port, NULL, MQTT_DEFAULT_PORT); + X509* cert = SSL_get_peer_certificate(ssl); + hostname_len = MQTTProtocol_addressPort(hostname, &port, NULL, MQTT_DEFAULT_PORT); - rc = X509_check_host(cert, hostname, hostname_len, 0, &peername); - if (rc == 1) - Log(TRACE_PROTOCOL, -1, "peername from X509_check_host is %s", peername); - else - Log(TRACE_PROTOCOL, -1, "X509_check_host for hostname %.*s failed, rc %d", - (int)hostname_len, hostname, rc); + rc = X509_check_host(cert, hostname, hostname_len, 0, &peername); + if (rc == 1) + Log(TRACE_PROTOCOL, -1, "peername from X509_check_host is %s", peername); + else + Log(TRACE_PROTOCOL, -1, "X509_check_host for hostname %.*s failed, rc %d", + (int)hostname_len, hostname, rc); - if (peername != NULL) - OPENSSL_free(peername); + if (peername != NULL) + OPENSSL_free(peername); - /* 0 == fail, -1 == SSL internal error, -2 == malformed input */ - if (rc == 0 || rc == -1 || rc == -2) - { - char* ip_addr = malloc(hostname_len + 1); - /* cannot use = strndup(hostname, hostname_len); here because of custom Heap */ - if (ip_addr) + /* 0 == fail, -1 == SSL internal error, -2 == malformed input */ + if (rc == 0 || rc == -1 || rc == -2) { - strncpy(ip_addr, hostname, hostname_len); - ip_addr[hostname_len] = '\0'; + char* ip_addr = malloc(hostname_len + 1); + /* cannot use = strndup(hostname, hostname_len); here because of custom Heap */ + if (ip_addr) + { + strncpy(ip_addr, hostname, hostname_len); + ip_addr[hostname_len] = '\0'; - rc = X509_check_ip_asc(cert, ip_addr, 0); - Log(TRACE_MIN, -1, "rc from X509_check_ip_asc is %d", rc); + rc = X509_check_ip_asc(cert, ip_addr, 0); + Log(TRACE_MIN, -1, "rc from X509_check_ip_asc is %d", rc); - free(ip_addr); + free(ip_addr); + } + + if (rc == 0 || rc == -1 || rc == -2) + rc = SSL_FATAL; } - if (rc == 0 || rc == -1 || rc == -2) - rc = SSL_FATAL; + if (cert) + X509_free(cert); } - - if (cert) - X509_free(cert); } #endif + if (sslCallbackReturnedValue != 1) + { + FUNC_EXIT_RC(sslCallbackReturnedValue); + return sslCallbackReturnedValue; + } FUNC_EXIT_RC(rc); return rc; } @@ -1114,4 +1129,9 @@ int SSLSocket_abortWrite(pending_writes* pw) FUNC_EXIT_RC(rc); return rc; } + +void SSLPluginInterface_setcallback(int(*callback)(SSL*)) +{ + handleSSLObjectAtConnect = callback; +} #endif