#!/usr/bin/env python3
#
# A simple implementation of COMET emulator.
# Copyright (c) 2021, Hiroyuki Ohsaki.
# All rights reserved.
#

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import re
import struct
import sys

from perlcompat import die, warn, getopts
import tbdump

VERSION = 0.3
DEBUG = False

# addresses of IN/OUT/EXI system calls
SYS_IN = 0xfff0
SYS_OUT = 0xfff2
SYS_EXIT = 0xfff4

# values of the flag register
FR_PLUS = 0
FR_ZERO = 1
FR_MINUS = 2

# the top of the stack, which is the upper limit of the stack space
STACK_TOP = 0xff00

# maximum/minimum of signed value
MAX_SIGNED = 32767
MIN_SIGNED = -32768

# COMET instructions
# Each entry is a object code, and its associated value is a hash
# having two keys: ID is the mnemonic of the instruction, and TYPE is
# the type of it.
INSTTBL = {
    0x10: ('op1', 'LD'),
    0x11: ('op1', 'ST'),
    0x12: ('op1', 'LEA'),
    0x20: ('op1', 'ADD'),
    0x21: ('op1', 'SUB'),
    0x30: ('op1', 'AND'),
    0x31: ('op1', 'OR'),
    0x33: ('op1', 'EOR'),
    0x40: ('op1', 'CPA'),
    0x41: ('op1', 'CPL'),
    0x50: ('op1', 'SLA'),
    0x51: ('op1', 'SRA'),
    0x52: ('op1', 'SLL'),
    0x53: ('op1', 'SRL'),
    0x60: ('op2', 'JPZ'),
    0x61: ('op2', 'JMI'),
    0x62: ('op2', 'JNZ'),
    0x63: ('op2', 'JZE'),
    0x64: ('op2', 'JMP'),
    0x70: ('op2', 'PUSH'),
    0x71: ('op3', 'POP'),
    0x80: ('op2', 'CALL'),
    0x81: ('op4', 'RET'),
}

CMDTBL = [
    (r'^de|del', 'cmd_delete', False),
    (r'^du|dump', 'cmd_dump', False),
    (r'^b|break', 'cmd_break', False),
    (r'^di|disasm', 'cmd_disasm', False),
    (r'^f|file', 'cmd_file', True),
    (r'^h|\?|help', 'cmd_help', False),
    (r'^i|info', 'cmd_info', False),
    (r'^j|jump', 'cmd_jump', True),
    (r'^m|memory', 'cmd_memory', True),
    (r'^p|print', 'cmd_print', False),
    (r'^q|quit', 'cmd_quit', False),
    (r'^r|run', 'cmd_run', True),
    (r'^st|stack', 'cmd_stack', False),
    (r'^s|step', 'cmd_step', True),
]

def usage():
    die(f"""\
usage: {sys.argv[0]} [-q] [com-file]
  -q   hide copyright notice at startup
  -d   debug mode
""")

def signed(val):
    return struct.unpack('>h', struct.pack('>H', val))[0]

def unsigned(val):
    return struct.unpack('>H', struct.pack('>h', val))[0]

def parse_number(v):
    if type(v) == int:
        return int(v) & 0xffff
    elif type(v) == str:
        if re.search(r'^[-+]?\d+$', v):
            return int(v) & 0xffff
        else:
            m = re.search(r'^#([\da-zA-Z]+)$', v)
            if m:
                # convert hex to decimal
                return int(m.group(1), base=16) & 0xffff
    return None

# ----------------------------------------------------------------
class State:
    def __init__(self):
        self.pc = 0
        self.fr = FR_ZERO
        self.gr = [0, 0, 0, 0, STACK_TOP]

# ----------------------------------------------------------------
class Comet:
    def __init__(self, debug=False):
        self.memory = [0] * 0x10000
        self.state = State()
        self.breakpoints = []
        self._debug = debug

    def __repr__(self):
        st = self.state
        return f'Comet({st.pc:04x}, {st.gr[0]:02x}, {st.gr[1]:02x}, {st.gr[2]:02x}, {st.gr[3]:02x}, {st.gr[4]:02x} {st.fr:#b})'

    def debug(self, msg):
        if self._debug:
            warn('** ' + msg)

    def copyright(self):
        print(f"""\
This is COMET, version {VERSION}.
Copyright (c) 2021, Hiroyuki Ohsaki.
All rights reserved.""")

    def load(self, file):
        with open(file, 'rb') as f:
            print(f'Reading object from {file}...', end='')

            # parse the file header
            header = f.read(16)
            if header[:4] != b'CASL':
                die(f'{file}: not a COMET object file')
            self.state.pc = struct.unpack('>H', header[4:6])[0]

            # load object into the memory
            addr = 0
            while True:
                buf = f.read(2)
                if not buf:
                    break
                if addr >= STACK_TOP:
                    die('out of memory')
                self.memory[addr] = struct.unpack('>H', buf)[0]
                addr += 1
            print('done.')

    def decode(self, addr=None):
        self.debug(f'decode({self}, {addr})')
        if addr is None:
            addr = self.state.pc
        word = self.memory[addr]
        inst = word >> 8
        gr = (word >> 4) & 0xf
        xr = word & 0xf
        adr = self.memory[addr + 1]
        return word, inst, gr, adr, xr

    def parse(self, addr=None):
        """ Disassemble the object from the PC, and return strings for the
        instruction and the operand.???"""
        self.debug(f'parse({self}, {addr})')
        if addr is None:
            addr = self.state.pc

        # decode the instruction at ADDR
        word, inst, gr, adr, xr = self.decode(addr)

        if inst in INSTTBL:
            categ, nemonic = INSTTBL[inst]
            # instructions with GR, adr, and XR
            if categ == 'op1':
                opr = f'GR{gr}, #{adr:04x}'
                if xr > 0:
                    opr += f', GR{xr}'
                size = 2
            # instructions with adr and XR
            elif categ == 'op2':  # with adr, (XR)
                opr = f'#{adr:04x}'
                if xr > 0:
                    opr += f', GR{xr}'
                size = 2
            # instructions with GR
            elif categ == 'op3':  # only with GR
                opr = f'GR{gr}'
                size = 2
            # instructions without operand
            elif categ == 'op4':  # no operand
                opr = ''
                size = 1
        else:
            # interpret as a binary word by default
            nemonic = 'DC'
            opr = f'#{word:04x}'
            size = 1

        # for IN/OUT/EXIT system calls
        if addr == SYS_IN:
            nemonic, opr = 'IN', 'SYSTEM CALL'
            size = 2
        elif addr == SYS_OUT:
            nemonic, opr = 'OUT', 'SYSTEM CALL'
            size = 2
        elif addr == SYS_EXIT:
            nemonic, opr = 'EXIT', 'SYSTEM CALL'
            size = 2

        return nemonic, opr, size

    def update_fr(self, val):
        self.debug(f'update_fr({self}, {val})')
        if val & 0x8000:
            self.state.fr = FR_MINUS
        elif val == 0:
            self.state.fr = FR_ZERO
        else:
            self.state.fr = FR_PLUS

    def exec_IN(self, inst, gr, adr, xr):
        """Handler of the IN system call --- extract two arguments from the
        stack, read a line from STDIN, store it in specified place."""
        self.debug(f'exec_in({self})')
        gr4 = self.state.gr[4]
        pc = self.memory[gr4]
        len_addr = self.memory[gr4 + 1]
        buf_addr = self.memory[gr4 + 2]
        line = input('IN > ')  # prompt for input
        line = line[:80]  # must be shorter than 80 characters
        self.memory[len_addr] = len(line)
        for c in bytearray(line, encoding='ascii'):
            self.memory[buf_addr] = c
            buf_addr += 1
        self.state.pc = pc  # go back to the caller
        self.state.gr[4] += 1

    def exec_OUT(self, inst, gr, adr, xr):
        """Handler of the OUT system call --- extract two arguments from the
        stack, write a string to STDOUT."""
        self.debug(f'exec_out({self})')
        gr4 = self.state.gr[4]
        pc = self.memory[gr4]
        len_addr = self.memory[gr4 + 1]
        buf_addr = self.memory[gr4 + 2]
        size = self.memory[len_addr]
        print('OUT> ', end='')
        for n in range(size):
            c = self.memory[buf_addr + n] & 0xff
            print(chr(c), end='')
        print()
        self.state.pc = pc  # go back to the caller
        self.state.gr[4] += 1

    def exec_EXIT(self, inst, gr, adr, xr, eadr):
        sys.exit(1)

    def exec_LD(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] = self.memory[eadr]
        self.state.pc += 2

    def exec_ST(self, inst, gr, adr, xr, eadr):
        self.memory[eadr] = self.state.gr[gr]
        self.state.pc += 2

    def exec_LEA(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] = eadr
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_ADD(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] += self.memory[eadr]
        self.state.gr[gr] &= 0xffff
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_SUB(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] -= self.memory[eadr]
        self.state.gr[gr] &= 0xffff
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_AND(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] &= self.memory[eadr]
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_OR(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] |= self.memory[eadr]
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_EOR(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] ^= self.memory[eadr]
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_CPA(self, inst, gr, adr, xr, eadr):
        v = signed(self.state.gr[gr]) - signed(self.memory[eadr])
        v = max(MIN_SIGNED, min(MAX_SIGNED, v))
        self.update_fr(unsigned(v))
        self.state.pc += 2

    def exec_CPL(self, inst, gr, adr, xr, eadr):
        v = self.state.gr[gr] - self.memory[eadr]
        v = max(MIN_SIGNED, min(MAX_SIGNED, v))
        self.update_fr(unsigned(v))
        self.state.pc += 2

    def exec_SLA(self, inst, gr, adr, xr, eadr):
        v = self.state.gr[gr] & 0x8000
        self.state.gr[gr] <<= eadr
        self.state.gr[gr] |= v
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_SRA(self, inst, gr, adr, xr, eadr):
        v = self.state.gr[gr]
        if v & 0x8000:
            v &= 0x7fff
            v >>= eadr
            v += ((0x7fff >> eadr) ^ 0xffff)
        else:
            v >>= eadr
        self.state.gr[gr] = v
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_SLL(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] <<= eadr
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_SRL(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] >>= eadr
        self.update_fr(self.state.gr[gr])
        self.state.pc += 2

    def exec_JPZ(self, inst, gr, adr, xr, eadr):
        if self.state.fr != FR_MINUS:
            self.state.pc = eadr
        else:
            self.state.pc += 2

    def exec_JMI(self, inst, gr, adr, xr, eadr):
        if self.state.fr == FR_MINUS:
            self.state.pc = eadr
        else:
            self.state.pc += 2

    def exec_JNZ(self, inst, gr, adr, xr, eadr):
        if self.state.fr != FR_ZERO:
            self.state.pc = eadr
        else:
            self.state.pc += 2

    def exec_JZE(self, inst, gr, adr, xr, eadr):
        if self.state.fr == FR_ZERO:
            self.state.pc = eadr
        else:
            self.state.pc += 2

    def exec_JMP(self, inst, gr, adr, xr, eadr):
        self.state.pc = eadr

    def exec_PUSH(self, inst, gr, adr, xr, eadr):
        self.state.gr[4] -= 1
        self.memory[self.state.gr[4]] = eadr
        self.state.pc += 2

    def exec_POP(self, inst, gr, adr, xr, eadr):
        self.state.gr[gr] = self.memory[self.state.gr[4]]
        self.state.gr[4] += 1
        self.state.pc += 2

    def exec_CALL(self, inst, gr, adr, xr, eadr):
        self.state.gr[4] -= 1
        self.memory[self.state.gr[4]] = self.state.pc + 2
        self.state.pc = eadr

    def exec_RET(self, inst, gr, adr, xr, eadr):
        self.state.pc = self.memory[self.state.gr[4]]
        self.state.gr[4] += 1

    def exec(self):
        """Execute one instruction from the PC --- evaluate the intruction,
        update registers, and advance the PC by the instruction's size."""
        self.debug(f'exec({self})')
        # calcurate the effective address
        word, inst, gr, adr, xr = self.decode(self.state.pc)
        eadr = adr
        if 1 <= xr <= 4:
            eadr += self.state.gr[xr]
        eadr &= 0xffff

        # obtain the mnemonic and the operand for the current address
        nemonic, opr, size = self.parse()

        subr = eval(f'Comet.exec_{nemonic}')
        self.debug(f'exec_{nemonic}({inst:02x}, {gr}, {adr:04x}, {xr}, {eadr:04x})')
        subr(self, inst, gr, adr, xr, eadr)

        # die(f'illegal instruction {inst:02x} at {pc:04x}')

    # ----------------------------------------------------------------
    def cmd_run(self, *args):
        while True:
            self.exec()
            # check the PC is at one of breakpoints
            for n, addr in enumerate(self.breakpoints):
                if self.state.pc == addr:
                    print(f'Breakpoint {n}, #{addr:04x}')
                    return

    def cmd_step(self, *args):
        try:
            count = parse_number(args[0])
        except IndexError:
            count = 1
        for n in range(count):
            self.exec()

    def cmd_break(self, *args):
        try:
            addr = parse_number(args[0])
            if addr is not None:
                self.breakpoints.append(addr)
            else:
                warn(f'invalid breakpoint address "{args[0]}"')
        except IndexError:
            pass

    def cmd_delete(self, *args):
        try:
            n = parse_number(args[0])
            if n is not None:
                del self.breakpoints[n - 1]
            else:
                resp = input('Delete all breakpoints? (y or n) ')
                if re.search(r'^[yY]', resp):
                    self.breakpoints.clear()
        except IndexError:
            pass

    def cmd_info(self, *args):
        for n, addr in enumerate(self.breakpoints):
            print(f'{n}: #{addr:04x}')

    def cmd_print(self, *args):
        # obtain instruction and operand at the current PC
        inst, opr, size = self.parse()
        gr = self.state.gr
        fr = self.state.fr
        print(""
              f"PC  #{self.state.pc:04x} [ {inst} {opr} ]\n"
              f"GR0 #{gr[0]:04x} ({signed(gr[0]):6}) "
              f"GR1 #{gr[1]:04x} ({signed(gr[1]):6}) "
              f"GR2 #{gr[2]:04x} ({signed(gr[2]):6})\n"
              f"GR3 #{gr[3]:04x} ({signed(gr[3]):6}) "
              f"GR4 #{gr[4]:04x} ({signed(gr[4]):6}) "
              f"FR  {fr:#b} ({fr:6})\n")

    def cmd_dump(self, *args):
        try:
            addr = parse_number(args[0])
        except IndexError:
            addr = self.state.pc
        for row in range(16):
            base = addr + (row << 3)
            print(f'{base:04x}', end='')
            for col in range(8):
                v = self.memory[base + col]
                print(f' {v:04x}', end='')
            print(' ', end='')
            for col in range(8):
                v = self.memory[base + col] & 0xff
                if 0x20 <= v <= 0x7f:
                    c = chr(v)
                else:
                    c = '.'
                print(c, end='')
            print()

    def cmd_stack(self, *args):
        addr = self.state.gr[4]
        self.cmd_dump(addr)

    def cmd_file(self, file):
        self.load(file)

    def cmd_jump(self, *args):
        try:
            addr = parse_number(args[0])
            if addr is not None:
                self.state.pc = parse_number(addr)
            else:
                warn(f'invalid jump address "{args[0]}"')
        except IndexError:
            pass

    def cmd_memory(self, *args):
        try:
            addr = parse_number(args[0])
            val = parse_number(args[1])
            if addr is not None and val is not None:
                self.memory[addr] = val
            else:
                warn('invalid address "{args[0]}" or value "{args[1]}"')
        except IndexError:
            warn('memory command needs address and value')

    def cmd_disasm(self, *args):
        try:
            addr = parse_number(args[0])
        except IndexError:
            addr = self.state.pc
        for n in range(16):
            inst, opr, size = self.parse(addr)
            print(f'#{addr:04x}\t{inst}\t{opr}')
            addr += size

    def cmd_help(self, *args):
        self.debug(f'cmd_help({self}, {args})')
        print("""\
List of commands:

r,  run         Start execution of program.
s,  step        Step execution.  Argument N means do this N times.
b,  break       Set a breakpoint at specified address.
d,  del 	Delete some breakpoints.
i,  info        Print information on breakpoints.
p,  print       Print status of PC/FR/GR0/GR1/GR2/GR3/GR4 registers.
du, dump        Dump 128 words of memory image from specified address.
st, stack       Dump 128 words of stack image.
f,  file        Use FILE as program to be debugged.
j,  jump        Continue program at specifed address.
m,  memory      Change the memory at ADDRESS to VALUE.
di, disasm      Disassemble 32 words from specified address.
h,  help        Print list of commands.
q,  quit        Exit comet.""")

    def cmd_quit(self, *args):
        sys.exit(1)

    def mainloop(self):
        last_line = ''
        self.cmd_print()
        while True:
            # show prompt and input command from STDIN
            line = input('comet> ')
            if line == '':
                line = last_line
            last_line = line
            cmd, *args = re.split(r'\s+', line)
            if not cmd:
                continue

            for regexp, name, need_print in CMDTBL:
                if re.search(regexp, cmd):
                    subr = eval(f'Comet.{name}')
                    subr(self, *args)
                    if need_print:
                        self.cmd_print()
                    break
            else:
                print(f"undefined command: '{cmd}'. Try 'help'")
        pass

def main():
    opt = getopts('qd') or usage()
    comet = Comet(debug=opt.d)
    comet.copyright()
    if len(sys.argv) >= 2:
        file = sys.argv[1]
        comet.load(file)
    comet.mainloop()

if __name__ == "__main__":
    main()
