/*
 * tt.c -- a souped-up tee(1).
 *
 * Copyright (C) 2004 Matt Harang <matt@ioioio.net>
 *                    Andy Goth   <unununium@openverse.com>
 * This code is available under the GNU General Public License.
 *
 * Compile with:
 * gcc -o tt tt.c
 *
 * To learn how to use it:
 * ./tt -h
 */

#include <assert.h>
#include <errno.h>
#include <fcntl.h>	/* flags for open() */
#include <signal.h>	/* signal() */
#include <stdarg.h>	/* va_list macros */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>	/* strerror() */
#include <sys/poll.h>	/* poll() */
#include <sys/stat.h>	/* flags for open() */
#include <sys/time.h>	/* setitimer() */
#include <sys/types.h>	/* pid_t */
#include <sys/wait.h>	/* WIFEXITED(), ... */
#include <unistd.h>	/* fork(), getopt(), open(), execve() */


/* flags for open() */
#define OVERWRITE	(O_WRONLY | O_CREAT | O_TRUNC)
#define APPEND		(O_WRONLY | O_CREAT | O_APPEND)


/*
 * Global variables.
 */

extern char** environ;
extern const char* const sys_siglist[];

static char* prog_name;

/* Output file/pipe info. */
typedef struct stream {
	int   fd;	/* File descriptor for writing to the stream.	*/
	char* name;	/* Name of program or file.         		*/
	pid_t pid;	/* If >=0, PID of child; otherwise, ignored.	*/
	int   dead;	/* If true, this stream needs to die.		*/
} stream_t;

static int num_streams;
static stream_t* streams;

static int use_stdout = 1;
static int reap_timed_out;
static int verbose = 1;
static int stdin_fl;

static void cleanup(int e);
static void print_(const char* fmt, va_list args);
static void print(int level, const char* fmt, ...);
static void error_exit(const char* fmt, ...);
static stream_t* add_stream(int fd);
static void del_stream(int idx);
static void open_file(char* file, int flags);
static void open_child(char* command);
static void close_child(int idx);
static pid_t reap_child(pid_t pid, struct timeval* timeout);
static void reap_sigalrm(int sig, siginfo_t* info, void* moo);
static void show_help(void);
static void scan_params(int argc, char* argv[]);
static void sigchld(int sig);

/*
 * Functions.
 */

static void cleanup(int e)
{
	/* Close all streams and child processes. */
	del_stream(-1);

	/* Fix stdin. */
	fcntl(STDIN_FILENO, F_SETFL, stdin_fl);

	/* And quit. */
	exit(e);
}


static void print_(const char* fmt, va_list args)
{
	fprintf(stderr, "%s: ", prog_name);
	vfprintf(stderr, fmt, args);
	fputc('\n', stderr);
}


static void print(int level, const char* fmt, ...)
{
	va_list args;

	if (level <= verbose) {
		va_start(args, fmt);
		print_(fmt, args);
		va_end(args);
	}
}


/* We caught an error; complain and quit. */
static void error_exit(const char* fmt, ...)
{
	va_list args;

	va_start(args, fmt);
	print_(fmt, args);
	va_end(args);

	cleanup(EXIT_FAILURE);
}


static stream_t* add_stream(int fd)
{
	stream_t* s;

	streams = realloc(streams, sizeof(*streams) * (num_streams + 1));
	if (streams == NULL)
		error_exit("realloc(): %m");

	s = &streams[num_streams];
	num_streams++;

	s->fd   = fd;
	s->name = NULL;
	s->pid  = -1;
	s->dead = 0;

	return s;
}


/* Deletes stream(s), closing files and processes.  Meaning of `idx':
 *
 * -1  : Deletes all streams.
 * >= 0: Deletes the stream with the given index. */
static void del_stream(int idx)
{
	int i;

	assert(idx == -1 || (idx >= 0 && idx < num_streams));

	if (idx == -1) {
		/* Close all files. */
		for (i = 0; i < num_streams; ++i) {
			if (streams[idx].pid < 0 && streams[idx].fd > -1)
				close(streams[i].fd);
		}

		/* End all child processes. */
		close_child(-1);

		/* Free all stream names. */
		for (i = 0; i < num_streams; ++i) {
			free(streams[i].name);
		}

		num_streams = 0;
	} else {
		/* Delete one stream. */
		if (streams[idx].pid >= 0) {
			close_child(idx);
		} else {
			close(streams[idx].fd);
		}
		free(streams[idx].name);

		--num_streams;
		if (idx != num_streams) {
			memmove(&streams[idx], &streams[idx + 1],
				sizeof(*streams) * (num_streams - idx));
		}
	}

	streams = realloc(streams, sizeof(*streams) * num_streams);
	if (num_streams != 0 && streams == NULL)
		error_exit("realloc(): %m");
}


static void open_file(char* file, int flags)
{
	int fd;
	stream_t* s;

	fd = open(file, flags, 0644);
	if (fd == -1)
		error_exit("open(`%s'): %m", file);

	s = add_stream(fd);
	s->name = strdup(file);

	print(4, "%s to file `%s'",
			flags == OVERWRITE ? "Writing" : "Appending", file);
}


static void open_child(char* command)
{
	int pipefds[2];
	pid_t pid;
	char* args[] = { "/bin/sh", "-c", command, NULL };
	stream_t* s;

	/* pipe for forwarding data to child */
	if (pipe(pipefds) == -1)
		error_exit("pipe(): %m");

	s = add_stream(pipefds[1]);

	if ((pid = fork()) == -1)
		error_exit("fork(): %m");

	if (pid == 0) {
		/* this is the child */

		/* replace stdin with output end of pipe */
		close(STDIN_FILENO);
		if (dup2(pipefds[0], STDIN_FILENO) == -1)
			error_exit("child: dup2(): %m");
		close(pipefds[0]);
		close(pipefds[1]);

		/* run the user's program in a shell */
		execve(args[0], args, environ);

		/* if exec returns, something broke... */
		error_exit("child: execve(): %m");
	} else {
		/* this is the parent */

		/* the parent doesn't use this FD */
		close(pipefds[0]);

		/* and other children don't use this FD */
		fcntl(pipefds[1], F_SETFD, FD_CLOEXEC);

		/* assign PID now */
		s->name = strdup(command);
		s->pid  = pid;

		print(4, "Child `%s' has pid %d", s->name, s->pid);
	}
}


/* Terminate a child and mark its stream for elimination.  Meaning of `idx':
 *
 * -1  : Terminate all child processes in the streams array.
 * >= 0: Terminate the child process associated with stream number `idx'. */
static void close_child(int idx)
{
	int i, method, end, found;
	struct timeval tv = {1, 0};	/* One second. */

	assert(idx == -1 || (idx >= 0 && idx < num_streams));
	if (idx >= 0 && idx < num_streams) {
		assert(streams[idx].pid >= 0);
	}

	for (method = 0;; ++method) {
		if (idx == -1) {
			/* Terminate all children. */
			i   = 0;
			end = num_streams;
		} else {
			/* Terminate one child. */
			i   = idx;
			end = idx + 1;
		}

		for (; i < end; ++i) {
			if (streams[i].pid < 0 || streams[i].dead)
				continue;
			found = 1;

			switch (method) {
			case 0:
				/* First, try closing the child's stdin. */
				close(streams[i].fd);
				streams[i].fd = -1;
				break;
			case 1:
				/* Next, send SIGTERM. */
				kill(streams[i].pid, SIGTERM);
				break;
			case 2:
				/* Finally, use SIGKILL. */
				kill(streams[i].pid, SIGKILL);
				break;
			case 3:
				/* That didn't work?  Pretend success. :^) */
				streams[i].pid  = -1;
				streams[i].dead =  1;
			}
		}

		/* If there were no more children to kill, quit now. */
		if (!found) break;

		/* If last-resort method has been performed, quit now. */
		if (method == 3) break;

		/* Now reap any terminated children. */
		while (reap_child(-1, &tv) != -2);
	}
}


/* waitpid() for terminated children.  Meaning of `pid':
 *
 * -1  : Wait for any child to terminate.
 * >= 0: Wait for one particular child to terminate.
 *
 * If `timeout' is not NULL, it signifies the maximum amount of time to wait
 * for waitpid() to execute.  If it's NULL, then waitpid() may hang forever.
 * Set `*timeout' to zero to return immediately if the child(ren) indicated by
 * `pid' isn't/aren't terminated.
 *
 * Return value:
 *
 * -2  : There were no children to wait for (for `pid' == -1).
 * -1  : The timeout expired.
 * >= 0: process id of the child; for `pid' == -1, always 0. */
static pid_t reap_child(pid_t pid, struct timeval* timeout)
{
	struct sigaction act, old;
	struct itimerval itv;
	int i, val, retval, status, flags = 0, have_sigalrm = 0;
	pid_t waitret;

	assert(pid >= -1);

	if (pid == -1) {
		/* Make sure there's a child to waitpid() for. */
		for (i = 0; i < num_streams; ++i) {
			if (streams[i].pid > 0) goto found;
		}

		/* Nope. */
		return -2;
	}

found:
	if (timeout != NULL) {
		if (timeout->tv_sec == 0 && timeout->tv_usec == 0) {
			/* Don't hang. */
			flags = WNOHANG;
		} else {
			/* Timeout requested. */
			have_sigalrm = 1;

			act.sa_sigaction = reap_sigalrm;
			sigemptyset(&act.sa_mask);
			act.sa_flags = 0;
			sigaction(SIGALRM, &act, &old);

			reap_timed_out = 0;
			itv.it_interval.tv_sec  = 0;
			itv.it_interval.tv_usec = 0;
			itv.it_value = *timeout;
			setitimer(ITIMER_REAL, &itv, NULL);
		}
	}

	/* waitpid() for an exited child.  Repeat on EINTR. */
	do waitret = waitpid(pid, &status, flags);
	while (waitret == -1 && errno == EINTR && !reap_timed_out);

	if (have_sigalrm) {
		/* Deactivate the SIGALRM handler. */
		itv.it_value.tv_sec  = 0;
		itv.it_value.tv_usec = 0;
		setitimer(ITIMER_REAL, &itv, NULL);

		sigaction(SIGALRM, &old, NULL);
	}

	if (waitret == -1) {
		if (errno == EINTR) {
			/* Timeout expired. */
			return -1;
		} else {
			/* Uh oh. */
			error_exit("waitpid(): %m");
		}
	} else if (waitret == 0) {
		/* Zero timeout and no exited child. */
		if (pid == -1) return -2;
		else           return -1;
	} else {
		/* Found an exited child.  Fall through to handle. */
	}

	/* Identify which stream goes with this pid. */
	for (i = 0; i < num_streams; ++i) {
		if (streams[i].pid != waitret) continue;

		/* Post-mortem. */
		if (WIFEXITED(status)) {
			/* Child exited. */
			val = WEXITSTATUS(status);
			print(val == 0 ? 3 : 2, "`%s' (%d) exited with code %d",
					streams[i].name, streams[i].pid, val);
		} else if (WIFSIGNALED(status)) {
			/* Child signaled. */
			val = WTERMSIG(status);
			print(2, "`%s' (%d) terminated by signal %d",
					streams[i].name, streams[i].pid, val);
		}

		/* Mark this stream as garbage to be cleaned later. */
		if (streams[i].fd != -1) {
			close(streams[i].fd);
			streams[i].fd = -1;
		}
		streams[i].pid  = -1;
		streams[i].dead =  1;
		break;
	}

	return 0;
}


/* reap_child()'s handler for SIGALRM. */
static void reap_sigalrm(int sig, siginfo_t* info, void* moo)
{
	reap_timed_out = 1;
}


/*
 * Tell the user 'sup.
 */
static void show_help(void)
{
	printf(
  "%1$s - copy stdin to files, program pipes, and stdout"
"\nSyntax: `%1$s [OPTIONS]' < data"
"\n"
"\nOptions:"
"\n  -a FILE  append stdin to FILE"
"\n  -f FILE  write stdin to FILE (overwriting any previous contents)"
"\n  -e CMD   write tt's stdin to CMD's stdin"
"\n  -i       inhibit sending to stdout"
"\n  -v       increase verbosity by one notch"
"\n  -q       only display fatal errors"
"\n  -h       show this help and exit"
"\n"
"\n`-a', `-f', `-e', and `-v' may be specified multiple times."
"\n`-v' and `-q' are only valid when given as the first arguments.\n",
			prog_name);

	cleanup(EXIT_SUCCESS);
}


/*
 * Attempt to mind meld with the user.
 */
static void scan_params(int argc, char* argv[])
{
	int c;

	while ((c = getopt(argc, argv, "a:f:e:ivqh")) != -1) {
		switch (c) {
		case 'a': open_file(optarg, APPEND); 	break;
		case 'f': open_file(optarg, OVERWRITE);	break;
		case 'e': open_child(optarg);		break;
		case 'i': use_stdout = 0;		break;
		case 'v': ++verbose;			break;
		case 'q': verbose = 0;			break;
		case 'h': show_help();			break;
		case '?': goto error;
		}
	}

	if (optind != argc) goto error;

	return;

error:
	error_exit("Try `%s -h' for more information.", prog_name);
}

/*
 * Handler for SIGCHLD.
 */
static void sigchld(int sig)
{
	struct timeval tv = {0, 0};

	/* Reap all terminated children. */
	while (reap_child(-1, &tv) != -2);
}


int main(int argc, char* argv[])
{
	char* p;
	char buf[4096];
	int i, j, r, e;
	struct pollfd pollfd = {STDIN_FILENO, POLLIN};

	stdin_fl = fcntl(STDIN_FILENO, F_GETFL);

	/* Remove the path component from argv[0]. */
	p = strrchr(argv[0], '/');
	if (p != NULL)
		argv[0] = p + 1;

	prog_name = argv[0];
	scan_params(argc, argv);
	if (use_stdout) {
		add_stream(STDOUT_FILENO)->name = strdup("stdout");
	}
	signal(SIGCHLD, sigchld);

	/* Enable nonblocking mode for stdin. */
	fcntl(0, F_SETFL, stdin_fl | O_NONBLOCK);

	while (1) {
		/* Wait for readability on stdin. */
		r = poll(&pollfd, 1, -1);
		if (r == -1 && errno == EINTR || r == 0) continue;
		if (r == -1) error_exit("poll(`stdin'): %m");

		/* Get data from stdin. */
		r = read(STDIN_FILENO, buf, sizeof(buf));
		if (r == -1) {
			/* Error of some sort. */
			if (errno == EINTR) {
				/* Interrupted by a signal.  Try again. */
				continue;
			} else {
				/* Some other error... whine and give up. */
				error_exit("read(`stdin'): %m");
			}
		} else if (r == 0) {
			/* stdin closed. */
			break;
		} else {
			/* Got data.  Now fall through to send to streams. */
		}

		/* For each stream... */
		i = 0;
		while (1) {
			if (i >= num_streams) {
				/* Exhausted all streams, so read more data. */
				break;
			} else if (streams[i].dead) {
				/* This stream needs to be eliminated. */
				del_stream(i);
				continue;
			} else {
				/* Write data to this stream.  Fall through. */
			}

			/* Try writing the buffer to the stream. */
			e = write(streams[i].fd, buf, r);

			if (e == -1) {
				if (errno == EINTR) {
					/* Interrupted by signal. */
				} else {
					/* Other error. */
					print(1, "write(`%s'): %m",
							streams[i].name);
					del_stream(i);
				}

				/* Try again.  Note that if del_stream() was
				 * used above, the same `i' value now points to
				 * a new stream. */
				continue;
			}

			/* Will this always be true? */
			assert(e == r);

			/* Next stream. */
			++i;
		}
	}

	signal(SIGCHLD, SIG_DFL);
	cleanup(EXIT_SUCCESS);
}


/* EOF */
