/* vm.c */

#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "project.h"

static int normal_insn_count = 0;	/* Normal instructions executed */
static int memory_insn_count = 0;	/* Memory instructions executed */
static char* user_input = NULL;		/* Line of user input */
static int user_input_len = 0;		/* Length of user input buffer */
static int trace_mode = 1;		/* Pause before each instruction? */
static char** watch_list = NULL;	/* List of label names */
static int watch_list_len = 0;		/* Length of watch list */
static char* image_filename = NULL;	/* Name of the image file */
static char* map_filename = NULL;	/* Name of the map file */

/* Cool colorized prompts */
static char* read_prompt =
"\e[30;1m[\e[33;1mread\e[30;1m]\e[0;35m--\e[1;35m>\e[0m ";
static char* trace_prompt =
"\e[30;1m[\e[36;1mtrace\e[30;1m]\e[0;35m--\e[1;35m>\e[0m ";
static char* read_help =
"in array 1 \e[36mnum_string\e[30;1m:\e[0m     set the first array to "
	"\e[36mnum_string\e[0m\n"
"in array 2 \e[36mnum_string\e[30;1m:\e[0m     set the second array to "
	"\e[36mnum_string\e[0m\n"
"add arrays\e[30;1m:\e[0m                add the two array values and display "
	"the result\n";
static char* divider = 
"   \e[1;30m-  -- ---\e[0m-\e[1;30m--\e[0m--\e[1;30m-\e[0m---\e[1m-\e[0m--"
"\e[1m--\e[0m-\e[1m------\e[0m-\e[1m--\e[0m--\e[1m-\e[0m---\e[1;30m-\e[0m--"
"\e[1;30m--\e[0m-\e[1;30m--- --  -   \e[0m\n";
static char* common_help =
"exit\e[30;1m:\e[0m                      terminate the program\n"
"help\e[30;1m:\e[0m                      display this page\n"
"run\e[30;1m:\e[0m                       enable continuous run mode\n"
"trace\e[30;1m:\e[0m                     enable execution trace mode\n"
"reg\e[30;1m:\e[0m                       displays register contents\n"
"dump \e[36label\e[30;1m:\e[0m                 show data at \e[36mlabel\e[0m\n"
"watch \e[36mlabel\e[30;1m:\e[0m               watch data at \e[36mlabel\e[0m\n"
"unwatch \e[36mlabel\e[30;1m:\e[0m             remove \e[36mlabel\e[0m from "
	"the watch list\n"
"performance\e[30;1m:\e[0m               display performance statistics\n"
"reboot\e[30;1m:\e[0m                    reload program and restart execution\n";

/* Static function prototypes */
static void cleanup(void);
static void prompt(reg_t* reg1, reg_t* reg2, unsigned imm16);
static void print_argument(insn_t* insn, unsigned code, unsigned arg_idx);
static void watch_clear(void);
static void print_performance(void);
static void reboot(void);

/* This executes machine code. */
void vm_simulate(char* image, char* map)
{
	unsigned code;
	insn_t* insn;
	unsigned arg1, arg2, arg3;
	reg_t* pc = reg_by_name("pc");
	unsigned pc_val;
	int i;

	/* Initialize the system */
	image_filename = image;
	map_filename = map;
	reboot();

	/* Run! */
	while (1) {
		/* Read instruction and increment $pc */
		pc_val = reg_read(pc);

		dump_memory(pc_val, pc_val + 1);
		if (trace_mode) {
			dump_regs();
			for (i = 0; i < watch_list_len; i++) {
				label_t* start = label_by_name(watch_list[i]);
				label_t* end = label_after_addr(start->addr);
				dump_memory(start->addr, end->addr);
			}
			prompt(NULL, NULL, 0);
		}

		code = mem_read(pc_val);
		reg_write(pc, pc_val + 1);
		insn_decode(code, &insn, &arg1, &arg2, &arg3);
		if (insn == NULL) die("Invalid instruction '0x%08x'\n", code);
		if (insn->type == I_MEMORY) memory_insn_count++;
		else if (insn->type == I_NORMAL) normal_insn_count++;

		insn->func(arg1, arg2, arg3);
	}

	exit(EXIT_SUCCESS);
}

/* Enables or disables trace mode. */
void set_trace_mode(int flag)
{
	trace_mode = flag;
	printf(flag ?
	"Trace mode \e[1;32menabled\e[0m; will pause before "
	"each instruction.\n" :
	"Trace mode \e[1;31mdisabled\e[0m; will execute without pause until "
	"\e[1mread\e[0m is encountered.\n");
}

/* Displays a memory region. */
void dump_memory(unsigned start, unsigned end)
{
	char blank = 0;
	int i;
	unsigned code;
	label_t* label;
	insn_t* insn;
	for (i = start; i < end; i++) {
		/* Print address and raw value */
		code = mem_read(i);
		printf("\e[0;36m%04x\e[1;30m:\e[1;32m%08x\e[0m ", i, code);

		/* Print label */
		label = label_by_addr(i);
		if (label != NULL) {
			int len = 10 - strlen(label->name);
			printf("\e[1;36m%s\e[1;30m:\e[0m", label->name);
			if (len > 0) printf("%*s", len, &blank);
		} else printf("%11s", &blank);

		/* Print instruction */
		insn = insn_by_code(code);
		if (insn == NULL) {
			printf("??\n");
			continue;
		}
		printf("\e[%sm", insn->type == I_NORMAL ? "0" : "1;31");
		printf("%-5s\e[1;30m(\e[0m", insn->name);
		print_argument(insn, code, 1);
		printf("\e[1;30m,\e[0m ");
		print_argument(insn, code, 2);
		printf("\e[1;30m,\e[0m ");
		print_argument(insn, code, 3);
		printf("\e[1;30m);\e[0m");

		/* Print possible destination label */
		if (insn->arg3.type == A_IMM16) {
			unsigned addr = (code >> insn->arg3.pos) & 0xffff;
			label_t* dest = label_before_addr(addr);
			if (dest == NULL) {
				printf("\n");
				continue;
			}

			printf(" \e[1;30m[\e[0;36m@\e[1m%s", dest->name);
			if (dest->addr != addr)
				printf("\e[0m + 0x\e[1m%x", addr - dest->addr);
			printf("\e[1;30m]\e[0m");
		}

		printf("\n");
	}
}

/* Displays all registers. */
void dump_regs(void)
{
	int i, j;
	char blank = 0;

	/* Print normal registers */
	for (i = 0; i < 16; i++) {
		reg_t* reg = reg_by_code(i, R_NORMAL);
		int len = 4 - strlen(reg->name);
		int val = reg_read(reg);
		printf("%*s\e[0;35m$\e[1m%s\e[0m = "
		"\e[1;30m[\e[1;37m%08x\e[1;30m]", len, &blank, reg->name, val);
		printf("%s", ((i & 3) == 3) ? "\n" : "  ");
	}

	/* Print BCD registers */
	for (j = 0; j < 2; j++) {
		char pre = j == 0 ? 'u' : 'v';
		printf("   \e[0;33m$\e[1m%c\e[0m = ", pre);
		printf("\e[1;30m[\e[1;37m");
		for (i = 7; i >= 0; i--) {
			char name[3] = {pre, '0' + i, 0};
			printf("%01x", reg_read(reg_by_name(name)) & 0xf);
		}
		printf("\e[1;30m]    (\e[0;33m$\e[1m%c\e[0m = "
				"\e[34m$\e[1m%c7\e[0m...\e[34m$"
				"\e[1m%c0\e[30m)\e[0m", pre, pre, pre);
		printf("%s", j == 0 ? "  " : "\n");
	}
}

/* Handles the read instruction. */
void vm_read(unsigned nreg1, unsigned nreg2, unsigned imm16)
{
	reg_t* reg1 = reg_by_code(nreg1, R_NORMAL);
	reg_t* reg2 = reg_by_code(nreg2, R_NORMAL);
	prompt(reg1, reg2, imm16);
}

/* Displays a BCD array and reboots the system. */
void vm_disp(unsigned nreg1, unsigned nreg2, unsigned imm16)
{
	int i, j;
	int seen_nonzero = 0;
	reg_t* reg1 = reg_by_code(nreg1, R_NORMAL);
	reg_t* reg2 = reg_by_code(nreg2, R_NORMAL);
	reg_t* cf = reg_by_name("cf");
	unsigned sig = reg_read(reg1);

	printf("%s", divider);
	printf("Sum: \e[1;36m");

	/* Print a leading one to indicate carry */
	if (reg_read(cf) != 0) {
		seen_nonzero = 1;
		printf("1");
	}

	/* Print remaining digits */
	for (i = 0; i < sig; i++) {
		unsigned word = mem_read(imm16 + reg_read(reg2) + sig - i - 1);
		for (j = 0; j < 32; j += 4) {
			int digit = (word >> j) & 0xf;

			/* Skip leading zeroes */
			if (!seen_nonzero) {
				if (digit == 0) continue;
				seen_nonzero = 1;
			}
			printf("%01x", (word >> j) & 0xf);
		}
	}

	/* But print 0 if the whole number is zero */
	if (!seen_nonzero) printf("0");

	printf("\e[0m\n");
	printf("%s", divider);

	/* Show the final performance values */
	print_performance();
	printf("%s", divider);

	/* Reboot system */
	reboot();
}

/* Cleanup function. */
static void cleanup(void)
{
	int i;
	if (watch_list != NULL) {
		for (i = 0; watch_list[i] != NULL; i++)
			free(watch_list[i]);
		free(watch_list);
	}
	free(user_input);
}

/* Provides the user interface. */
static void prompt(reg_t* reg1, reg_t* reg2, unsigned imm16)
{
	char delims[] = " \t\n";
	char* result = NULL;
	int in_read = reg1 != NULL;

	/* Register cleanup handler if necessary */
	if (user_input == NULL && atexit(cleanup) != 0) {
		perror("atexit");
		exit(EXIT_FAILURE);
	}

	/* Get input until a go-to-next-instruction command is issued */
	while (1) {
		/* Print prompt */
		printf("%s", in_read ? read_prompt : trace_prompt);
		if (getline(&user_input, &user_input_len, stdin) == -1) {
			printf("\n");
			if (ferror(stdin)) {
				perror("getline");
				exit(EXIT_FAILURE);
			} 
			exit(EXIT_FAILURE);
		}

		/* Expect: first word of command */
		result = strtok(user_input, delims);

		/* Blank line? */
		if (result == NULL) {
			if (in_read) continue;	/* Repeat prompt */
			else break;		/* Go to next instruction */
		}

		/* Determine which command to execute */
		if (strcasecmp(result, "help") == 0) {
			/* Expect: end of line */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			if (in_read) printf("%s%s", divider, read_help);
			printf("%s", common_help);
		} else if (strcasecmp(result, "run") == 0) {
			/* Expect: end of line */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			if (!in_read) {
				set_trace_mode(0);
				break;
			} else set_trace_mode(0);
		} else if (strcasecmp(result,"trace") == 0) {
			/* Expect: end of line */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			set_trace_mode(1);
		} else if (strcasecmp(result,"exit") == 0) {
			/* Expect: end of line */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			exit(EXIT_SUCCESS);
		} else if (strcasecmp(result, "reg") == 0) {
			/* Expect: end of line */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			dump_regs();
		} else if (strcasecmp(result, "dump") == 0) {
			label_t* label;
			label_t* end;

			/* Expect: label name */
			result = strtok(NULL, delims);
			if (result == NULL) goto bad_input;

			/* Expect: end of string */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			/* Get label entry */
			label = label_by_name(result);
			if (label == NULL) {
				printf("Label \e[1;30m'\e[36m%s\e[30m'\e[0m "
						"does not exist.\n", result);
				continue;
			}

			/* Find next label and print region between */
			end = label_after_addr(label->addr);
			if (end != NULL) dump_memory(label->addr, end->addr);
		} else if (strcasecmp(result, "watch") == 0) {
			label_t* label;
			int i;

			/* Expect: label name */
			result = strtok(NULL, delims);
			if (result == NULL) goto bad_input;

			/* Expect: end of string */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			/* Get label entry */
			label = label_by_name(result);
			if (label == NULL) {
				printf("Label \e[1;30m'\e[36m%s\e[30m'\e[0m "
						"does not exist.\n", result);
				continue;
			}

			/* Ensure that the label is not in the watch list */
			for (i = 0; i < watch_list_len; i++)
				if (strcmp(watch_list[i], result) == 0) {
					printf("Label \e[1;30m'\e[36m%s\e[30m' "
							"\e[0mis already being "
							"watched.\n", result);
					break;
				}

			/* Append to the watch list */
			if (i == watch_list_len) {
				watch_list = realloc(watch_list,
					(watch_list_len + 1) * sizeof(char*));
				watch_list[watch_list_len] = strdup(result);
				watch_list_len++;
			}
		} else if (strcasecmp(result, "unwatch") == 0) {
			int i;

			/* Expect: label name */
			result = strtok(NULL, delims);
			if (result == NULL) goto bad_input;

			/* Expect: end of string */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			for (i = 0; i < watch_list_len; i++) {
				if (strcmp(watch_list[i], result) != 0)
					continue;	/* Try next label */

				/* Shift elements down */
				free(watch_list[i]);
				memmove(watch_list + i * sizeof(char*),
				watch_list + (i + 1) * sizeof(char*),
						(watch_list_len - i - 1) *
						sizeof(char*));
				watch_list_len--;
				watch_list = realloc(watch_list,
						watch_list_len);
				break;
			}
			if (i == watch_list_len) printf("Not watching "
				"\e[1;30m'\e[0;36m@\e[1m%s\e[30m'\e[0m.\n",
				result);
		} else if (strcasecmp(result, "performance") == 0) {
			/* Expect: end of line */
			result = strtok(NULL, delims);
			if (result != NULL) goto bad_input;

			print_performance();
		} else if (strcasecmp(result, "reboot") == 0) {
			/* Expect: end of line */
			result = strtok(NULL, delims);
			if (result != NULL) goto bad_input;

			reboot();
		} else if (strcasecmp(result, "add") == 0 && in_read) {
			/* Expect: "arrays" */
			result = strtok(NULL, delims);
			if (result == NULL) goto bad_input;
			if (strcasecmp(result, "arrays") != 0) goto bad_input;

			/* Expect: end of line */
			result = strtok(NULL, delims);
			if (result != NULL) goto bad_input;

			reg_write(reg1, 2);
			break;
		} else if (strcasecmp(result, "in") == 0 && in_read) {
			unsigned char expanded[128];
			int i, j;
			int dest_array;
			int num_digits = 0;

			/* Expect: "array" */
			result = strtok(NULL, delims);
			if (result == NULL ||
			strcasecmp(result, "array") != 0) goto bad_input;

			/* Expect: "1" or "2" */
			result = strtok(NULL, delims);
			if (result == NULL) goto bad_input;
			else if (strcmp(result, "1") == 0) dest_array = 1;
			else if (strcmp(result, "2") == 0) dest_array = 2;
			else goto bad_input;

			/* Expect: numeric string */
			result = strtok(NULL, delims);
			if (result == NULL) goto bad_input;
			for (; *result == '0' && *result != 0; result++);
			for (i = 0; result[i] != 0; i++) {
				if (result[i] < '0' || result[i] > '9') {
					printf("Not a digit: \e[1;30m'\e[37m%c"
						"\e[30m'\e[0m.\n", result[i]);
					goto bad_input;
				}
				num_digits++;
				if (num_digits > 128) {
					printf("Array cannot be longer than "
						"128 digits.\n");
					goto bad_input;
				}
			}

			/* Expect: end of line */
			if (strtok(NULL, delims) != NULL) goto bad_input;

			/* Prepend leading zeroes */
			memset(expanded, '0', 128 - num_digits);
			memcpy(expanded + 128 - num_digits, result, num_digits);

			/* Write to destination in memory file */
			for (i = 0; i < 32; i++) {
				int word = 0;
				for (j = 0; j < 4; j++)
					word |= (expanded[i * 4 + j] &
					0xff) << ((3 - j) * 8);
				mem_write(i + imm16, word);
			}

			/* Update registers */
			reg_write(reg1, dest_array - 1);
			reg_write(reg2, num_digits);
			break;
		} else goto bad_input;
		continue;

bad_input:
		printf("Invalid command.  Try \e[1;30m'\e[1;37mhelp"
					"\e[1;30m'\e[0m for help.\n");
	}
}

/* Displays a single argument */
static void print_argument(insn_t* insn, unsigned code, unsigned arg_idx)
{
	reg_t* reg;
	int len;
	char blank = 0;
	char* name;
	arg_t* arg;

	/* Get the argument description */
	switch (arg_idx) {
	case 1: arg = &insn->arg1; break;
	case 2: arg = &insn->arg2; break;
	case 3: arg = &insn->arg3;
	}

	/* Pare code down to the correct bits */
	code = (code >> arg->pos) & ((1 << arg->width) - 1);

	/* Print it */
	switch (arg->type) {
	case A_IMM4:
		printf("\e[0m   0x\e[1m%01x\e[0m", code);
		break;
	case A_IMM16:
		printf("\e[0m0x\e[1m%04x\e[0m", code);
		break;
	case A_NREG:
		reg = reg_by_code(code, R_NORMAL);
		len = 5 - strlen(reg->name);
		printf("%*s\e[35m$\e[1m%s\e[0m", len, &blank, reg->name);
		break;
	case A_BREG:
		reg = reg_by_code(code, R_BCD);
		len = 5 - strlen(reg->name);
		printf("%*s\e[34m$\e[1m%s\e[0m", len, &blank, reg->name);
		break;
	case A_PREG:
		reg = reg_by_code(code, R_PSEUDO);
		if (reg == NULL) name = "??";
		else name = reg->name;
		len = 5 - strlen(name);
		printf("%*s\e[33m$\e[1m%s\e[0m", len, &blank, name);
	}
}

/* Empty the watch list */
static void watch_clear(void)
{
	int i;
	printf("Emptying the watch list...\n");
	if (watch_list == NULL) return;
	for (i = 0; watch_list[i] != NULL; i++) free(watch_list[i]);
	free(watch_list);
}

/* Displays the performance counters */
static void print_performance(void)
{
	printf("Normal instructions: \e[1m%d\e[0m\n", normal_insn_count);
	printf("Memory instructions: \e[1;31m%d\e[0m\n", memory_insn_count);
	printf("Weighted total:      \e[1;33m%d\e[0m\n",
			memory_insn_count * 4 + normal_insn_count);
}

/* Reload memory image and map from disk and reinitialize registers. */
static void reboot(void)
{
	watch_clear();

	/* Reset performance counters */
	printf("Resetting performance counters...\n");
	normal_insn_count = 0;
	memory_insn_count = 0;

	/* Process input files */
	printf("Loading memory image and map files...\n");
	mem_load_image(image_filename);
	label_load_map(map_filename);

	/* Initialize registers */
	printf("Initializing registers...\n");
	reg_initialize();

	/* Start in trace mode */
	set_trace_mode(1);
}

/* EOF */

