33#include " chat.h"
44#include " chatmodel.h"
55#include " modellist.h"
6+ #include " mwhttpserver.h"
67#include " mysettings.h"
78#include " utils.h" // IWYU pragma: keep
89
910#include < fmt/format.h>
1011#include < gpt4all-backend/llmodel.h>
1112
13+ #include < QAbstractSocket>
1214#include < QByteArray>
1315#include < QCborArray>
1416#include < QCborMap>
5153
5254using namespace std ::string_literals;
5355using namespace Qt ::Literals::StringLiterals;
56+ using namespace gpt4all ::ui;
5457
5558// #define DEBUG
5659
@@ -443,6 +446,8 @@ Server::Server(Chat *chat)
443446 connect (chat, &Chat::collectionListChanged, this , &Server::handleCollectionListChanged, Qt::QueuedConnection);
444447}
445448
449+ Server::~Server () = default ;
450+
446451static QJsonObject requestFromJson (const QByteArray &request)
447452{
448453 QJsonParseError err;
@@ -455,17 +460,57 @@ static QJsonObject requestFromJson(const QByteArray &request)
455460 return document.object ();
456461}
457462
463+ // / @brief Check if a host is safe to use to connect to the server.
464+ // /
465+ // / GPT4All's local server is not safe to expose to the internet, as it does not provide
466+ // / any form of authentication. DNS rebind attacks bypass CORS and without additional host
467+ // / header validation, malicious websites can access the server in client-side js.
468+ // /
469+ // / @param host The value of the "Host" header or ":authority" pseudo-header
470+ // / @return true if the host is unsafe, false otherwise
471+ static bool isHostUnsafe (const QString &host)
472+ {
473+ QHostAddress addr;
474+ if (addr.setAddress (host) && addr.protocol () == QAbstractSocket::IPv4Protocol)
475+ return false ; // ipv4
476+
477+ // ipv6 host is wrapped in square brackets
478+ static const QRegularExpression ipv6Re (uR"( ^\[(.+)\]$)" _s);
479+ if (auto match = ipv6Re.match (host); match.hasMatch ()) {
480+ auto ipv6 = match.captured (1 );
481+ if (addr.setAddress (ipv6) && addr.protocol () == QAbstractSocket::IPv6Protocol)
482+ return false ; // ipv6
483+ }
484+
485+ if (!host.contains (' .' ))
486+ return false ; // dotless hostname
487+
488+ static const QStringList allowedTlds { u" .local" _s, u" .test" _s, u" .internal" _s };
489+ for (auto &tld : allowedTlds)
490+ if (host.endsWith (tld, Qt::CaseInsensitive))
491+ return false ; // local TLD
492+
493+ return true ; // unsafe
494+ }
495+
458496void Server::start ()
459497{
460- m_server = std::make_unique<QHttpServer>(this );
461- auto *tcpServer = new QTcpServer (m_server.get ());
498+ m_server = std::make_unique<MwHttpServer>();
499+
500+ m_server->addBeforeRequestHandler ([](const QHttpServerRequest &req) -> std::optional<QHttpServerResponse> {
501+ // this works for HTTP/1.1 "Host" header and HTTP/2 ":authority" pseudo-header
502+ auto host = req.url ().host ();
503+ if (!host.isEmpty () && isHostUnsafe (host))
504+ return QHttpServerResponse (QHttpServerResponder::StatusCode::Forbidden);
505+ return std::nullopt ;
506+ });
462507
463508 auto port = MySettings::globalInstance ()->networkPort ();
464- if (!tcpServer->listen (QHostAddress::LocalHost, port)) {
509+ if (!m_server-> tcpServer () ->listen (QHostAddress::LocalHost, port)) {
465510 qWarning () << " Server ERROR: Failed to listen on port" << port;
466511 return ;
467512 }
468- if (!m_server->bind (tcpServer )) {
513+ if (!m_server->bind ()) {
469514 qWarning () << " Server ERROR: Failed to HTTP server to socket" << port;
470515 return ;
471516 }
@@ -490,7 +535,7 @@ void Server::start()
490535 }
491536 );
492537
493- m_server->route (" /v1/models/<arg>" , QHttpServerRequest::Method::Get,
538+ m_server->route < const QString &> (" /v1/models/<arg>" , QHttpServerRequest::Method::Get,
494539 [](const QString &model, const QHttpServerRequest &) {
495540 if (!MySettings::globalInstance ()->serverChat ())
496541 return QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
@@ -562,7 +607,7 @@ void Server::start()
562607
563608 // Respond with code 405 to wrong HTTP methods:
564609 m_server->route (" /v1/models" , QHttpServerRequest::Method::Post,
565- [] {
610+ []( const QHttpServerRequest &) {
566611 if (!MySettings::globalInstance ()->serverChat ())
567612 return QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
568613 return QHttpServerResponse (
@@ -573,8 +618,8 @@ void Server::start()
573618 }
574619 );
575620
576- m_server->route (" /v1/models/<arg>" , QHttpServerRequest::Method::Post,
577- [](const QString &model) {
621+ m_server->route < const QString &> (" /v1/models/<arg>" , QHttpServerRequest::Method::Post,
622+ [](const QString &model, const QHttpServerRequest & ) {
578623 (void )model;
579624 if (!MySettings::globalInstance ()->serverChat ())
580625 return QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
@@ -587,7 +632,7 @@ void Server::start()
587632 );
588633
589634 m_server->route (" /v1/completions" , QHttpServerRequest::Method::Get,
590- [] {
635+ []( const QHttpServerRequest &) {
591636 if (!MySettings::globalInstance ()->serverChat ())
592637 return QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
593638 return QHttpServerResponse (
@@ -598,7 +643,7 @@ void Server::start()
598643 );
599644
600645 m_server->route (" /v1/chat/completions" , QHttpServerRequest::Method::Get,
601- [] {
646+ []( const QHttpServerRequest &) {
602647 if (!MySettings::globalInstance ()->serverChat ())
603648 return QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
604649 return QHttpServerResponse (
0 commit comments