/* server.c */

#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <sys/select.h>
#include <sys/time.h>
#include <unistd.h>

#include <netinet/in.h>
#include <sys/types.h>
#include <sys/socket.h>

#include <netdb.h>

#include "message.h"
#include "spite.h"

#define MAX_SOCKS 32
#define BUF_SIZE 4096

/* Information pertaining to a single client connection. */
typedef struct sock_t {
	int fd;
	queue_t* send_buf;
	queue_t* recv_buf;
} sock_t;

/* Creates and configures a listener socket and returns its fd. */
int create_server_socket(void)
{
	int sock;
	struct sockaddr_in addr;
	int yes = 1;

	/* Create a new socket for listening */
	if ((sock = socket(AF_INET, SOCK_STREAM, 0)) == -1)
		printerr_exit(errno, EXIT_FAILURE);
	printmsg(stdout, "creating listener socket.\n");

	/* Allow the socket to be reused */
	if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(int)) == -1)
		printerr_exit(errno, EXIT_FAILURE);

	/* Set up connection parameters */
	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = htons(INADDR_ANY);
	addr.sin_port = htons(7999);
	memset(&addr.sin_zero, 0, 8);
	
	/* Bind the socket to a port */
	if (bind(sock, (struct sockaddr*)&addr, sizeof(struct sockaddr)) == -1)
		printerr_exit(errno, EXIT_FAILURE);
	printmsg(stdout, "binding listener socket to port.\n");

	/* Listen for connection requests on the port */
	if (listen(sock, 20) == -1)
		printerr_exit(errno, EXIT_FAILURE);
	printmsg(stdout, "listening for connections.\n");
	
	return sock;
}

/* Creates and configures a listener socket and returns its fd. */
int create_client_socket(const char* server)
{
	int sock;
	struct sockaddr_in addr;
	struct hostent* hp;

	/* Set up the address to connect to */
	if ((hp = gethostbyname(server)) == NULL)
		printerr_exit(errno, EXIT_FAILURE);
	addr.sin_family = hp->h_addrtype;
	if (hp->h_length > sizeof(addr.sin_addr))
		hp->h_length = sizeof(addr.sin_addr);
	memcpy(&addr.sin_addr, hp->h_addr, hp->h_length);
	addr.sin_port = htons(7999);
	memset(&addr.sin_zero, 0, 8);
	
	/* Create the socket */
	if ((sock = socket(AF_INET, SOCK_STREAM, 0)) == -1)
		printerr_exit(errno, EXIT_FAILURE);
	printmsg(stdout, "creating client socket.\n");

	/* Connect the socket to the foreign host */
	if (connect(sock, (struct sockaddr*)&addr,
	sizeof(struct sockaddr)) == -1)
		printerr_exit(errno, EXIT_FAILURE);
	printmsg(stdout, "connecting to server.\n");

	return sock;
}

/* Add a socket whose file descriptor is 'fd' to 'socks' and increment
 * 'num_socks' and adjust 'max_fd'.  Returns the index of the new socket in
 * the socks table. */
int add_sock(sock_t* socks, int* num_socks, int* max_fd, int fd, int readable,
int writable)
{
	if (*num_socks == MAX_SOCKS) printmsg_exit(0, stderr, "out of sockets");

	socks[*num_socks].fd = fd;

	socks[*num_socks].recv_buf = readable ? queue_new(1, 4096) : NULL;
	socks[*num_socks].send_buf = writable ? queue_new(1, 4096) : NULL;

	(*num_socks)++;

	if (fd > *max_fd) *max_fd = fd;

	return *num_socks - 1;
}

void remove_sock(sock_t* socks, int* num_socks, int* max_fd, int fd)
{
	int i;

	/* Find fd's index in socks */
	for (i = 0; i < *num_socks; i++) if (socks[i].fd == fd) break;
	if (i == *num_socks) return;

	if (socks[i].send_buf != NULL) queue_del(socks[i].send_buf);
	if (socks[i].recv_buf != NULL) queue_del(socks[i].recv_buf);

	/* Swap the highest-indexed socket down */
	socks[i].fd = socks[*num_socks - 1].fd;
	socks[i].send_buf = socks[*num_socks - 1].send_buf;
	socks[i].recv_buf = socks[*num_socks - 1].recv_buf;

	(*num_socks)--;

	/* Adjust max_fd if necessary */
	if (fd == *max_fd) {
		*max_fd = 0;
		for (i = 0; i < *num_socks; i++)
			if (socks[i].fd > *max_fd) *max_fd = socks[i].fd;
	}
}

int main(int argc, char** argv)
{
	int i, j;
	struct sockaddr_in addr;
	socklen_t addr_len;
	int server_fd = -1;
	int client_fd = -1;
	int num_socks = 0, max_fd = 0;
	sock_t socks[MAX_SOCKS];
	fd_set read_set, read_set_master, write_set, write_set_master;
	int select_count;
	int byte_count;
	char buf[BUF_SIZE];
	int count;

	/* Initialize the monosodium glutamate library :^) */
	PROGRAM_NAME = argv[0];
	errno = 0;

	FD_ZERO(&read_set_master);
	FD_ZERO(&write_set_master);

	/* Running as client or server? */
	if (argc == 1) {
		/* Create the server (listener) socket */
		server_fd = create_server_socket();
		add_sock(socks, &num_socks, &max_fd, server_fd, 0, 0);
		FD_SET(server_fd, &read_set_master);
	} else {
		/* Create the client socket */
		client_fd = create_client_socket(argv[1]);
		add_sock(socks, &num_socks, &max_fd, client_fd, 1, 1);
		FD_SET(client_fd, &read_set_master);
	}

	/* Treat stdin like a socket */
	add_sock(socks, &num_socks, &max_fd, fileno(stdin), 1, 0);
	FD_SET(fileno(stdin), &read_set_master);

	/* We can do this because UNIX is cool :^) */
	add_sock(socks, &num_socks, &max_fd, fileno(stdout), 0, 1);

	/* Main event loop */
	while (1) {
repeat_select:
		/* Wait for activity on an fd */
		read_set = read_set_master;
		write_set = write_set_master;
		select_count = select(max_fd + 1, &read_set, &write_set,
				NULL, NULL);
		if (select_count == 0) goto repeat_select;
		if (select_count < 0) printerr_exit(errno, EXIT_FAILURE);

		/* Any sockets ready to be read? */
		for (i = 0; i < num_socks; i++) {
			if (!FD_ISSET(socks[i].fd, &read_set)) continue;

			if (socks[i].fd == server_fd) {
				/* Activity on a server socket means that
				 * there's a pending connection; accept it */
				int fd = accept(server_fd, (struct sockaddr*)
						&addr, &addr_len);
				if (fd == -1)
					printerr_exit(errno, EXIT_FAILURE);
				add_sock(socks, &num_socks, &max_fd, fd, 1, 1);
				FD_SET(fd, &read_set_master);

				/* If there are no more sockets with activity,
				 * run select again to wait for a new event */
				if (--select_count == 0) goto repeat_select;

				/* This socket's completely processed, so go to
				 * the next one */
				continue;
			}

			/* Read from socket */
			byte_count = read(socks[i].fd, buf, BUF_SIZE);
			if (byte_count == -1) {
				/* Error */
				printerr_exit(errno, EXIT_FAILURE);
			} else if (byte_count == 0) {
				/* Socket closed */
				close(socks[i].fd);
				remove_sock(socks, &num_socks, &max_fd,
						socks[i].fd);
				i--;
			} else {
				/* Write data to socket's buffer to be read
				 * back again in a moment */
				queue_push(socks[i].recv_buf, buf, byte_count);

				/* Look for the newline; has an entire line has
				 * been read yet? */
				while ((count = queue_char_index(
				socks[i].recv_buf, '\n') + 1) != 0) {
					/* Read the line into buf */
					queue_pop(socks[i].recv_buf, buf,
							count);

					/* Queue for delivery to all
					 * writeable sockets except for the
					 * source of the message*/
					for (j = 0; j < num_socks; j++) {
						/* Skip some sockets */
						if (socks[j].send_buf == NULL ||
						j == i ||
						(socks[i].fd == fileno(stdin) &&
						socks[j].fd == fileno(stdout)))
							continue;

						/* Queue data to be written */
						queue_push(socks[j].send_buf,
								buf, count);

						/* Add to the write set */
						FD_SET(socks[j].fd,
							&write_set_master);
					}
				}
			}

			/* If there are no more sockets with activity, run
			 * select again to wait for a new event */
			if (--select_count == 0) goto repeat_select;
		}

		/* Any sockets ready to be written to? */
		for (i = 0; i < num_socks; i++) {
			if (!FD_ISSET(socks[i].fd, &write_set)) continue;

			/* Get the data to be written */
			queue_peek(socks[i].send_buf, buf,
					socks[i].send_buf->count);

			/* Write it to the socket */
			byte_count = write(socks[i].fd, buf,
					socks[i].send_buf->count);
			if (byte_count == -1)
				printerr_exit(errno, EXIT_FAILURE);

			/* Skip past the data actually written */
			queue_discard(socks[i].send_buf, byte_count);

			/* If send_buf is empty, we're done writing */
			if (socks[i].send_buf->count == 0)
				FD_CLR(socks[i].fd, &write_set_master);

			/* If there are no more sockets with activity, run
			 * select again to wait for a new event */
			if (--select_count == 0) goto repeat_select;
		}
	}
	
	/* Close all connections */
	printmsg(stdout, "closing connections.\n");
	for (i = 0; i < socks[i].fd; i++) {
		if (socks[i].fd == fileno(stdin)) continue;
		if (socks[i].fd == fileno(stdout)) continue;
		if (close(socks[i].fd) == -1)
			printerr_exit(errno, EXIT_FAILURE);
	}
	
	return EXIT_SUCCESS;
}

/* EOF */

