Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ex2/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
CFLAGS += -Wall -Werror
LDLIBS += -ltls -lssl -lcrypto

all: echo client

Expand Down
79 changes: 55 additions & 24 deletions ex2/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <tls.h>
#include <unistd.h>


Expand All @@ -56,15 +57,17 @@ struct server {
int state;
unsigned char *readptr, *writeptr, *nextptr;
unsigned char buf[BUFLEN];
struct tls *ctx;
};

static struct server server;

static void
server_init(struct server *server)
server_init(struct server *server, struct tls *ctx)
{
server->readptr = server->writeptr = server->nextptr = server->buf;
server->state = STATE_NONE;
server->ctx = ctx;
}

static ssize_t
Expand Down Expand Up @@ -134,6 +137,13 @@ server_put(struct server *server, const unsigned char *inbuf, size_t inlen)
static void
closeconn (struct pollfd *pfd)
{
int i;

do {
i = tls_close(server.ctx);
} while (i == TLS_WANT_POLLIN || i == TLS_WANT_POLLOUT);
tls_free(server.ctx);

close(pfd->fd);
pfd->fd = -1;
pfd->revents = 0;
Expand Down Expand Up @@ -166,8 +176,16 @@ handle_server(struct pollfd *pfd, struct server *server)
if (server->state == STATE_READING) {
ssize_t w = 0;
ssize_t written = 0;
len = read(pfd->fd, buf, sizeof(buf));
if (len > 0) {
len = tls_read(server->ctx, buf, sizeof(buf));
if (len == TLS_WANT_POLLIN)
pfd->events = POLLIN | POLLHUP;
else if (len == TLS_WANT_POLLOUT)
pfd->events = POLLOUT | POLLHUP;
else if (len < 0)
err(1, "tls_write: %s", tls_error(server->ctx));
else if (len == 0)
closeconn(pfd);
else {
do {
w = write(STDOUT_FILENO, buf, len);
if (w == -1) {
Expand All @@ -182,26 +200,21 @@ handle_server(struct pollfd *pfd, struct server *server)
pfd->events = POLLHUP;
}
}
else if (len == 0)
closeconn(pfd);
else
pfd->events = POLLIN | POLLHUP;
} else if (server->state == STATE_WRITING) {
ssize_t w = 0;
ssize_t written = 0;
do {
len = server_get(server, buf, sizeof(buf));
w = write(pfd->fd, buf, len);
if (w == -1) {
if (errno != EINTR)
closeconn(pfd);
}
else {
written += w;
server_consume(server, w);
}
} while (written < len);
if (pfd->fd > 0) {
ssize_t ret = 0;
len = server_get(server, buf, sizeof(buf));
if (len) {
ret = tls_write(server->ctx, buf, len);
if (ret == TLS_WANT_POLLIN)
pfd->events = POLLIN | POLLHUP;
else if (ret == TLS_WANT_POLLOUT)
pfd->events = POLLOUT | POLLHUP;
else if (ret < 0)
err(1, "tls_write: %s", tls_error(server->ctx));
else
server_consume(server, ret);
}
if (ret == len) {
server->state = STATE_READING;
pfd->events = POLLIN | POLLHUP;
}
Expand All @@ -210,7 +223,8 @@ handle_server(struct pollfd *pfd, struct server *server)
}

int main(int argc, char **argv) {

struct tls_config *tls_cfg = NULL;
struct tls *tls_ctx = NULL;
struct addrinfo hints, *res;
int serverfd, error;
struct pollfd pollfd;
Expand All @@ -227,14 +241,31 @@ int main(int argc, char **argv) {
usage();
}

/* now set up TLS */
if (tls_init() == -1)
errx(1, "unable to initialize TLS");
if ((tls_cfg = tls_config_new()) == NULL)
errx(1, "unable to allocate TLS config");
if (tls_config_set_ca_file(tls_cfg, "../CA/root.pem") == -1)
errx(1, "unable to set root CA file");

if ((serverfd = socket(AF_INET, SOCK_STREAM, 0)) == -1)
err(1, "socket failed");

if (connect(serverfd, res->ai_addr, res->ai_addrlen) == -1)
err(1, "connect failed");

if ((tls_ctx = tls_client()) == NULL)
errx(1, "tls client creation failed");
if (tls_configure(tls_ctx, tls_cfg) == -1)
errx(1, "tls configuration failed (%s)",
tls_error(tls_ctx));
if (tls_connect_socket(tls_ctx, serverfd, "localhost") == -1)
errx(1, "tls connection failed (%s)",
tls_error(tls_ctx));

newconn(&pollfd, serverfd, 0);
server_init(&server);
server_init(&server, tls_ctx);

while(1) {
if (server.state == STATE_NONE) {
Expand Down
100 changes: 69 additions & 31 deletions ex2/echo.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <tls.h>
#include <unistd.h>

#define MAX_CONNECTIONS 256
Expand All @@ -55,17 +56,19 @@ struct client {
int state;
unsigned char *readptr, *writeptr, *nextptr;
unsigned char buf[BUFLEN];
struct tls *ctx;
};

static struct client clients[MAX_CONNECTIONS];
static struct pollfd pollfds[MAX_CONNECTIONS];
static int throttle = 0;

static void
client_init(struct client *client)
client_init(struct client *client, struct tls *ctx)
{
client->readptr = client->writeptr = client->nextptr = client->buf;
client->state = STATE_READING;
client->ctx = ctx;
}

static ssize_t
Expand Down Expand Up @@ -133,8 +136,15 @@ client_put(struct client *client, const unsigned char *inbuf, size_t inlen)
}

static void
closeconn (struct pollfd *pfd)
closeconn (struct pollfd *pfd, struct client *client)
{
int i;

do {
i = tls_close(client->ctx);
} while (i == TLS_WANT_POLLIN || i == TLS_WANT_POLLOUT);
tls_free(client->ctx);

close(pfd->fd);
pfd->fd = -1;
pfd->revents = 0;
Expand All @@ -159,43 +169,46 @@ handle_client(struct pollfd *pfd, struct client *client)
{
if ((pfd->revents & (POLLERR | POLLNVAL)))
errx(1, "bad fd %d", pfd->fd);
if (pfd->revents & POLLHUP)
closeconn(pfd);
if (pfd->revents & POLLHUP) {
closeconn(pfd, client);
}
else if (pfd->revents & pfd->events) {
char buf[BUFLEN];
ssize_t len = 0;
if (client->state == STATE_READING) {
len = read(pfd->fd, buf, sizeof(buf));
if (len > 0) {
if (client_put(client, buf, len)
!= len) {
len = tls_read(client->ctx, buf, sizeof(buf));
if (len == TLS_WANT_POLLIN)
pfd->events = POLLIN | POLLHUP;
else if (len == TLS_WANT_POLLOUT)
pfd->events = POLLOUT | POLLHUP;
else if (len < 0)
warn("tls_read: %s", tls_error(client->ctx));
else if (len == 0)
closeconn(pfd, client);
else {
if (client_put(client, buf, len) != len) {
warnx("client buffer failed");
closeconn(pfd);
closeconn(pfd, client);
} else {
client->state=STATE_WRITING;
pfd->events = POLLOUT | POLLHUP;
}
}
else if (len == 0)
closeconn(pfd);
else
pfd->events = POLLIN | POLLHUP;
} else if (client->state == STATE_WRITING) {
ssize_t w = 0;
ssize_t written = 0;
do {
len = client_get(client, buf, sizeof(buf));
w = write(pfd->fd, buf, len);
if (w == -1) {
if (errno != EINTR)
closeconn(pfd);
}
else {
written += w;
client_consume(client, w);
}
} while (written < len);
if (pfd->fd > 0) {
ssize_t ret = 0;
len = client_get(client, buf, sizeof(buf));
if (len) {
ret = tls_write(client->ctx, buf, len);
if (ret == TLS_WANT_POLLIN)
pfd->events = POLLIN | POLLHUP;
else if (ret == TLS_WANT_POLLOUT)
pfd->events = POLLOUT | POLLHUP;
else if (ret < 0)
warn("tls_write: %s", tls_error(client->ctx));
else
client_consume(client, ret);
}
if (ret == len) {
client->state = STATE_READING;
pfd->events = POLLIN | POLLHUP;
}
Expand All @@ -204,13 +217,31 @@ handle_client(struct pollfd *pfd, struct client *client)
}

int main(int argc, char **argv) {

struct tls_config *tls_cfg = NULL;
struct tls *tls_ctx = NULL;
struct tls *tls_cctx = NULL;
struct addrinfo hints, *res;
int i, listenfd, error;


if (argc != 3)
usage();

/* now set up TLS */

if ((tls_cfg = tls_config_new()) == NULL)
errx(1, "unable to allocate TLS config");
if (tls_config_set_ca_file(tls_cfg, "../CA/root.pem") == -1)
errx(1, "unable to set root CA filet");
if (tls_config_set_cert_file(tls_cfg, "../CA/server.crt") == -1)
errx(1, "unable to set TLS certificate file");
if (tls_config_set_key_file(tls_cfg, "../CA/server.key") == -1)
errx(1, "unable to set TLS key file");
if ((tls_ctx = tls_server()) == NULL)
errx(1, "tls server creation failed");
if (tls_configure(tls_ctx, tls_cfg) == -1)
errx(1, "tls configuration failed (%s)", tls_error(tls_ctx));

bzero(&hints, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
Expand Down Expand Up @@ -258,9 +289,16 @@ int main(int argc, char **argv) {
throttle = 1;
for (i = 1; fd >= 0 && i < MAX_CONNECTIONS; i++) {
if (pollfds[i].fd == -1) {
newconn(&pollfds[i], fd);
client_init(&clients[i]);
throttle = 0;
if (tls_accept_socket(tls_ctx,
&tls_cctx, fd) == -1) {
warnx("tls accept failed (%s)",
tls_error(tls_ctx));
close(fd);
break;
}
newconn(&pollfds[i], fd);
client_init(&clients[i], tls_cctx);
break;
}
}
Expand Down