Skip to content

Commit faec111

Browse files
committed
server: block server access via non-local domains
Signed-off-by: Jared Van Bortel <[email protected]>
1 parent b666d16 commit faec111

File tree

6 files changed

+154
-13
lines changed

6 files changed

+154
-13
lines changed

gpt4all-chat/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ qt_add_executable(chat
244244
src/localdocsmodel.cpp src/localdocsmodel.h
245245
src/logger.cpp src/logger.h
246246
src/modellist.cpp src/modellist.h
247+
src/mwhttpserver.cpp src/mwhttpserver.h
247248
src/mysettings.cpp src/mysettings.h
248249
src/network.cpp src/network.h
249250
src/server.cpp src/server.h

gpt4all-chat/src/mwhttpserver.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include <QTcpServer>
2+
3+
#include "mwhttpserver.h"
4+
5+
6+
namespace gpt4all::ui {
7+
8+
9+
MwHttpServer::MwHttpServer()
10+
: m_httpServer()
11+
, m_tcpServer (new QTcpServer(&m_httpServer))
12+
{}
13+
14+
15+
} // namespace gpt4all::ui

gpt4all-chat/src/mwhttpserver.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#pragma once
2+
3+
#include <QHttpServer>
4+
#include <QHttpServerRequest>
5+
6+
#include <functional>
7+
#include <optional>
8+
#include <utility>
9+
#include <vector>
10+
11+
class QHttpServerResponse;
12+
class QHttpServerRouterRule;
13+
class QString;
14+
15+
16+
namespace gpt4all::ui {
17+
18+
19+
/// @brief QHttpServer wrapper with middleware support.
20+
///
21+
/// This class wraps QHttpServer and provides addBeforeRequestHandler() to add middleware.
22+
class MwHttpServer
23+
{
24+
using BeforeRequestHandler = std::function<std::optional<QHttpServerResponse>(const QHttpServerRequest &)>;
25+
26+
public:
27+
explicit MwHttpServer();
28+
29+
bool bind() { return m_httpServer.bind(m_tcpServer); }
30+
31+
void addBeforeRequestHandler(BeforeRequestHandler handler)
32+
{ m_beforeRequestHandlers.push_back(std::move(handler)); }
33+
34+
template <typename Handler>
35+
void addAfterRequestHandler(
36+
const typename QtPrivate::ContextTypeForFunctor<Handler>::ContextType *context, Handler &&handler
37+
) {
38+
return m_httpServer.addAfterRequestHandler(context, std::forward<Handler>(handler));
39+
}
40+
41+
template <typename... Args>
42+
QHttpServerRouterRule *route(
43+
const QString &pathPattern,
44+
QHttpServerRequest::Methods method,
45+
std::function<QHttpServerResponse(Args..., const QHttpServerRequest &)> viewHandler
46+
);
47+
48+
QTcpServer *tcpServer() { return m_tcpServer; }
49+
50+
private:
51+
QHttpServer m_httpServer;
52+
QTcpServer *m_tcpServer;
53+
std::vector<BeforeRequestHandler> m_beforeRequestHandlers;
54+
};
55+
56+
57+
} // namespace gpt4all::ui
58+
59+
60+
#include "mwhttpserver.inl" // IWYU pragma: export

gpt4all-chat/src/mwhttpserver.inl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
namespace gpt4all::ui {
2+
3+
4+
template <typename... Args>
5+
QHttpServerRouterRule *MwHttpServer::route(
6+
const QString &pathPattern,
7+
QHttpServerRequest::Methods method,
8+
std::function<QHttpServerResponse(Args..., const QHttpServerRequest &)> viewHandler
9+
) {
10+
auto wrapped = [this, vh = std::move(viewHandler)](Args ...args, const QHttpServerRequest &req) {
11+
for (auto &handler : m_beforeRequestHandlers)
12+
if (auto resp = handler(req))
13+
return *std::move(resp);
14+
return vh(std::forward<Args>(args)..., req);
15+
};
16+
return m_httpServer.route(pathPattern, method, std::move(wrapped));
17+
}
18+
19+
20+
} // namespace gpt4all::ui

gpt4all-chat/src/server.cpp

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
#include "chat.h"
44
#include "chatmodel.h"
55
#include "modellist.h"
6+
#include "mwhttpserver.h"
67
#include "mysettings.h"
78
#include "utils.h" // IWYU pragma: keep
89

910
#include <fmt/format.h>
1011
#include <gpt4all-backend/llmodel.h>
1112

13+
#include <QAbstractSocket>
1214
#include <QByteArray>
1315
#include <QCborArray>
1416
#include <QCborMap>
@@ -51,6 +53,7 @@
5153

5254
using namespace std::string_literals;
5355
using namespace Qt::Literals::StringLiterals;
56+
using namespace gpt4all::ui;
5457

5558
//#define DEBUG
5659

@@ -443,6 +446,8 @@ Server::Server(Chat *chat)
443446
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
444447
}
445448

449+
Server::~Server() = default;
450+
446451
static QJsonObject requestFromJson(const QByteArray &request)
447452
{
448453
QJsonParseError err;
@@ -455,17 +460,57 @@ static QJsonObject requestFromJson(const QByteArray &request)
455460
return document.object();
456461
}
457462

463+
/// @brief Check if a host is safe to use to connect to the server.
464+
///
465+
/// GPT4All's local server is not safe to expose to the internet, as it does not provide
466+
/// any form of authentication. DNS rebind attacks bypass CORS and without additional host
467+
/// header validation, malicious websites can access the server in client-side js.
468+
///
469+
/// @param host The value of the "Host" header or ":authority" pseudo-header
470+
/// @return true if the host is unsafe, false otherwise
471+
static bool isHostUnsafe(const QString &host)
472+
{
473+
QHostAddress addr;
474+
if (addr.setAddress(host) && addr.protocol() == QAbstractSocket::IPv4Protocol)
475+
return false; // ipv4
476+
477+
// ipv6 host is wrapped in square brackets
478+
static const QRegularExpression ipv6Re(uR"(^\[(.+)\]$)"_s);
479+
if (auto match = ipv6Re.match(host); match.hasMatch()) {
480+
auto ipv6 = match.captured(1);
481+
if (addr.setAddress(ipv6) && addr.protocol() == QAbstractSocket::IPv6Protocol)
482+
return false; // ipv6
483+
}
484+
485+
if (!host.contains('.'))
486+
return false; // dotless hostname
487+
488+
static const QStringList allowedTlds { u".local"_s, u".test"_s, u".internal"_s };
489+
for (auto &tld : allowedTlds)
490+
if (host.endsWith(tld, Qt::CaseInsensitive))
491+
return false; // local TLD
492+
493+
return true; // unsafe
494+
}
495+
458496
void Server::start()
459497
{
460-
m_server = std::make_unique<QHttpServer>(this);
461-
auto *tcpServer = new QTcpServer(m_server.get());
498+
m_server = std::make_unique<MwHttpServer>();
499+
500+
m_server->addBeforeRequestHandler([](const QHttpServerRequest &req) -> std::optional<QHttpServerResponse> {
501+
// this works for HTTP/1.1 "Host" header and HTTP/2 ":authority" pseudo-header
502+
auto host = req.url().host();
503+
if (!host.isEmpty() && isHostUnsafe(host))
504+
return QHttpServerResponse(QHttpServerResponder::StatusCode::Forbidden);
505+
return std::nullopt;
506+
});
462507

463508
auto port = MySettings::globalInstance()->networkPort();
464-
if (!tcpServer->listen(QHostAddress::LocalHost, port)) {
509+
if (!m_server->tcpServer()->listen(QHostAddress::LocalHost, port)) {
465510
qWarning() << "Server ERROR: Failed to listen on port" << port;
466511
return;
467512
}
468-
if (!m_server->bind(tcpServer)) {
513+
if (!m_server->bind()) {
469514
qWarning() << "Server ERROR: Failed to HTTP server to socket" << port;
470515
return;
471516
}
@@ -490,7 +535,7 @@ void Server::start()
490535
}
491536
);
492537

493-
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Get,
538+
m_server->route<const QString &>("/v1/models/<arg>", QHttpServerRequest::Method::Get,
494539
[](const QString &model, const QHttpServerRequest &) {
495540
if (!MySettings::globalInstance()->serverChat())
496541
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
@@ -562,7 +607,7 @@ void Server::start()
562607

563608
// Respond with code 405 to wrong HTTP methods:
564609
m_server->route("/v1/models", QHttpServerRequest::Method::Post,
565-
[] {
610+
[](const QHttpServerRequest &) {
566611
if (!MySettings::globalInstance()->serverChat())
567612
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
568613
return QHttpServerResponse(
@@ -573,8 +618,8 @@ void Server::start()
573618
}
574619
);
575620

576-
m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Post,
577-
[](const QString &model) {
621+
m_server->route<const QString &>("/v1/models/<arg>", QHttpServerRequest::Method::Post,
622+
[](const QString &model, const QHttpServerRequest &) {
578623
(void)model;
579624
if (!MySettings::globalInstance()->serverChat())
580625
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
@@ -587,7 +632,7 @@ void Server::start()
587632
);
588633

589634
m_server->route("/v1/completions", QHttpServerRequest::Method::Get,
590-
[] {
635+
[](const QHttpServerRequest &) {
591636
if (!MySettings::globalInstance()->serverChat())
592637
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
593638
return QHttpServerResponse(
@@ -598,7 +643,7 @@ void Server::start()
598643
);
599644

600645
m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get,
601-
[] {
646+
[](const QHttpServerRequest &) {
602647
if (!MySettings::globalInstance()->serverChat())
603648
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
604649
return QHttpServerResponse(

gpt4all-chat/src/server.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "chatllm.h"
55
#include "database.h"
66

7-
#include <QHttpServer>
87
#include <QHttpServerResponse>
98
#include <QJsonObject>
109
#include <QList>
@@ -18,6 +17,7 @@
1817
class Chat;
1918
class ChatRequest;
2019
class CompletionRequest;
20+
namespace gpt4all::ui { class MwHttpServer; }
2121

2222

2323
class Server : public ChatLLM
@@ -26,7 +26,7 @@ class Server : public ChatLLM
2626

2727
public:
2828
explicit Server(Chat *chat);
29-
~Server() override = default;
29+
~Server() override;
3030

3131
public Q_SLOTS:
3232
void start();
@@ -44,7 +44,7 @@ private Q_SLOTS:
4444

4545
private:
4646
Chat *m_chat;
47-
std::unique_ptr<QHttpServer> m_server;
47+
std::unique_ptr<gpt4all::ui::MwHttpServer> m_server;
4848
QList<ResultInfo> m_databaseResults;
4949
QList<QString> m_collections;
5050
};

0 commit comments

Comments
 (0)