#!/usr/bin/env python3
#
# A simple implementation of CASL II assembler.
# Copyright (c) 2021, Hiroyuki Ohsaki.
# All rights reserved.
#

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import re
import struct
import sys

from perlcompat import die, warn, getopts
import tbdump

VERSION = 0.7

# COMET II instructions: the key for each entry is op code, and its corresponding
# value is a tuple of the instruction type and the mnemonic.
INSTTBL = {
    # op code : (type, mnemonic)
    0x00: ('nopr', 'NOP'),
    0x10: ('r_adr_x', 'LD'),
    0x11: ('r_adr_x', 'ST'),
    0x12: ('r_adr_x', 'LAD'),
    0x14: ('r1_r2', 'LD'),
    0x20: ('r_adr_x', 'ADDA'),
    0x21: ('r_adr_x', 'SUBA'),
    0x22: ('r_adr_x', 'ADDL'),
    0x23: ('r_adr_x', 'SUBL'),
    0x24: ('r1_r2', 'ADDA'),
    0x25: ('r1_r2', 'SUBA'),
    0x26: ('r1_r2', 'ADDL'),
    0x27: ('r1_r2', 'SUBL'),
    0x30: ('r_adr_x', 'AND'),
    0x31: ('r_adr_x', 'OR'),
    0x32: ('r_adr_x', 'XOR'),
    0x34: ('r1_r2', 'AND'),
    0x35: ('r1_r2', 'OR'),
    0x36: ('r1_r2', 'XOR'),
    0x40: ('r_adr_x', 'CPA'),
    0x41: ('r_adr_x', 'CPL'),
    0x44: ('r1_r2', 'CPA'),
    0x45: ('r1_r2', 'CPL'),
    0x50: ('r_adr_x', 'SLA'),
    0x51: ('r_adr_x', 'SRA'),
    0x52: ('r_adr_x', 'SLL'),
    0x53: ('r_adr_x', 'SRL'),
    0x61: ('adr_x', 'JMI'),
    0x62: ('adr_x', 'JNZ'),
    0x63: ('adr_x', 'JZE'),
    0x64: ('adr_x', 'JUMP'),
    0x65: ('adr_x', 'JPL'),
    0x66: ('adr_x', 'JOV'),
    0x70: ('adr_x', 'PUSH'),
    0x71: ('r', 'POP'),
    0x80: ('adr_x', 'CALL'),
    0x81: ('nopr', 'RET'),
    0xf0: ('adr_x', 'SVC'),
}

def usage():
    die(f"""\
usage: {sys.argv[0]} [-avd] file...
 -a   turn on verbose listing
 -v   display version and exit
 -d   enable debug mode
""")

def parse_number(v):
    """Interpret VAL (either integer or string) as unsigned short.  Both
    decimal and hexadecimal values are recognized.  If the argument cannot be
    recognized, return None."""
    if type(v) == int:
        return int(v) & 0xffff
    elif type(v) == str:
        if re.search(r'^[-+]?\d+$', v):
            return int(v) & 0xffff
        else:
            # hexadecimal number must have four digits (CASL II specification)
            m = re.search(r'^#([\dA-F]{4})$', v)
            if m:
                # convert hex to decimal
                return int(m.group(1), base=16) & 0xffff
    return None

# ----------------------------------------------------------------
class Instruction:
    def __init__(self, op, opr, err_func=die):
        self.op = op
        self._opr = opr  # raw operands
        self.err_func = err_func

        # split and extract (at most three) operands
        if opr is None:
            oprs = []
        else:
            oprs = re.split(r'\s*,\s*', opr)
        self.noprs = len(oprs)  # the number of valid operands
        self.opr1 = oprs[0] if self.noprs >= 1 else None
        self.opr2 = oprs[1] if self.noprs >= 2 else None
        self.opr3 = oprs[2] if self.noprs >= 3 else None
        self.parse_inst()  # interpret the instruction

    def __repr__(self):
        return f'Instruction(op={self.op}, opr={self._opr}, noprs={self.noprs}, categ={self.categ})'

    def guess_type(self):
        """Return a list of strings, representing possible instruction types.
        Instruction types are one of 'nopr', 'r', 'adr_x', 'r1_r2', and
        'r_adr_x'.  See the varible INSTTBL for the description of instruction
        types."""
        if self.noprs == 0:
            return ['nopr']
        elif self.noprs == 1:
            return ['r', 'adr_x']
        elif self.noprs == 2:
            if re.search(r'^GR[0-7]$', self.opr2):
                return ['r1_r2', 'adr_x']
            else:
                return ['r_adr_x']
        else:
            return ['r_adr_x']

    def parse_inst(self):
        """Decode the instruction and update object attributes CATEG
        (instruction type) and CODE (op code)."""
        # assembler instructions and macros
        if self.op in [
                'START', 'END', 'DS', 'DC', 'RPUSH', 'RPOP', 'IN', 'OUT'
        ]:
            self.categ = self.op.lower()
            return
        categs = self.guess_type()
        for code, val in INSTTBL.items():
            categ, mnemonic = val
            if self.op != mnemonic:
                continue
            # if the nemonic matches, find the op code from the number/type
            # of operands
            for c in categs:
                # FIXME: first match may fail if instructions have ambiguity
                if categ == c:
                    self.categ = categ
                    self.code = code
                    return
        self.err_func(f"illegal instruction '{self.op}' {categs}")

# ----------------------------------------------------------------
class LabelTable:
    def __init__(self, debug_func=warn, err_func=die, assembler=None):
        self.tbl = {}
        self.src = {}  # for recording source lines
        self.debug_func = debug_func
        self.err_func = err_func
        self.assembler = assembler

    def __repr__(self):
        return 'LabelTable'

    def __setitem__(self, label, val):
        self.tbl[label] = val
        # record the line at which the label is defined
        self.src[label] = self.assembler.file, self.assembler.lineno

    def __getitem__(self, label):
        return self.tbl[label]

    def keys(self):
        return self.tbl.keys()

    def src_for(self, label):
        """Return the source code where label LABEL is defined."""
        return self.src[label]

    def check_label(self, label):
        """Check the validity of label LABEL.  If not valid, display error and
        abort the program."""
        self.debug_func(f'check_label(self={self}, label={label})')
        # the maximum label length is 8 characters (CASL II specification)
        if not re.search(r'^([A-Z][0-9A-Za-z]{0,7}|=[0-9A-F#+-]+)$', label):
            self.err_func(f"invalid label '{label}'")

    def register(self, label, val):
        """Register label LABEL in the label table with value VAL.  If LABEL
        is already defined, display error and abort the program."""
        self.debug_func(f'register(self={self}, label={label}, val={val})')
        self.check_label(label)
        if label in self.tbl:
            self.err_func(f"label '{label}' already defined")
        self[label] = val  # FIXME: force to unsigned short?

    def expand(self, val):
        """Expand VAL to the corresponding unsigned short.  VAL can be
        integer, label, and hexadecimal number."""
        self.debug_func(f'expand(self={self}, val={val})')
        if type(val) == str:
            if val in self.tbl:
                val = self[val]
        v = parse_number(val)
        if v is None:
            self.err_func(f"invalid label or numeric value '{val}'")
        return v

# ----------------------------------------------------------------
class Memory:
    def __init__(self, debug_func=warn, err_func=die, assembler=None):
        self.memory = []
        self.src = []  # for recording source lines
        self.debug_func = debug_func
        self.err_func = err_func
        self.assembler = assembler

    def __repr__(self):
        return 'Memory'

    def __setitem__(self, addr, val):
        # automatically exapnd the list size if necessary
        size = len(self.memory)
        if size < addr + 1:
            self.memory.extend([None] * (addr + 1 - size))
        self.memory[addr] = val

        # record the source line where the word is generated
        size = len(self.src)
        if size < addr + 1:
            self.src.extend([None] * (addr + 1 - size))
        self.src[
            addr] = self.assembler.file, self.assembler.lineno, self.assembler.line

    def __getitem__(self, addr):
        return self.memory[addr]

    def src_at(self, addr):
        return self.src[addr]

    def register_number(self, reg):
        """Return the general register number (i.e., 0--7) from string
        representation of register REG.  If REG is None, return 0, which
        corresponds to GR0.  If REG is invalid, display error and abort the
        program."""
        self.debug_func(f'register_number(self={self}, reg={reg})')
        if reg is None:
            return 0
        m = re.search(r'^GR([0-7])$', reg)
        if not m:
            self.err_func(f"invalid register '{reg}'")
        return int(m.group(1))

    def generate_code(self, addr, inst, gr, adr, xr, single_word=False):
        """Generate code from address ADDR.  The instruction is specified by
        INST (op code), GR (GR register or R1), ADR (address), and XR (index
        register or R2).  If SINGLE_WORD is True, generate a single word.
        Otherwise, generate two words."""
        self.debug_func(
            f'generate_code(self={self}, addr={addr:04x}, inst={inst:02x}, gr={gr}, adr={adr}, xr={xr})'
        )
        # the first word is composed of op code (8 bits), GR/R1 register
        # (4 bits), and XR/R2 register (4 bits)
        gr = self.register_number(gr)
        xr = self.register_number(xr)
        val = (inst << 8) + (gr << 4) + xr
        self[addr] = val
        if not single_word:
            # the second word is the address
            self[addr + 1] = adr

# ----------------------------------------------------------------
class Assembler:
    def __init__(self, debug=False):
        self.file = None
        self.enable_debug = debug
        self.buf = []
        self.memory = Memory(debug_func=self.debug,
                             err_func=self.error,
                             assembler=self)
        self.labeltbl = LabelTable(debug_func=self.debug,
                                   err_func=self.error,
                                   assembler=self)
        self.literals = {}
        self.start_addr = 0
        self.addr = 0
        self.lineno = None
        self.line = None

    def __repr__(self):
        return f'Assembler(start=#{self.start_addr}, addr=#{self.addr:04x})'

    def copyright(self):
        return """\
This is CASL, version {VERSION}.
Copyright (c) 2021, Hiroyuki Ohsaki.
All rights reserved."""

    def load(self, file):
        """Simply load the source file FILE into the buffer.  This method does
        not perfom any assembly."""
        with open(file) as f:
            self.file = file
            for lineno, line in enumerate(f):
                line = line.rstrip()  # remove the trailing CR+LF
                self.buf.append(line)

    def debug(self, msg):
        if self.enable_debug:
            warn(f'{self.file}:{self.lineno}: {msg}')

    def error(self, msg):
        die(f'{self.file}:{self.lineno}: {msg}\n->{self.line}')

    def parse_line(self, line):
        """Parse line LINE and extract the label (if specified) and the
        instruction.  Return None if label/instruction is not present.  The
        instruction is an instance of the Instruction class."""
        m = re.search(r'^(\S+)?\s+([A-Z]+)(\s+(.*)?)?$', line)
        if m:
            label, op, opr = m.group(1), m.group(2), m.group(4)
            self.debug(f'parse_line: label={label}, op={op}, opr={opr}')
        else:
            self.error(f'syntax error {line}')
        if op:
            return label, Instruction(op, opr, err_func=self.error)
        else:
            return label, None

    def generate_r_adr_x(self, label, inst):
        """Generate two-words code for instruction INST at the current
        address.  The instruction takes at most three operands: GR, ADDRESS,
        and XR."""
        if not (2 <= inst.noprs <= 3):
            self.error(f"missing operands '{inst}'")
        self.memory.generate_code(self.addr, inst.code, inst.opr1, inst.opr2,
                                  inst.opr3)
        self.addr += 2

    def generate_r1_r2(self, label, inst):
        """Generate two-words code for instruction INST at the current
        address.  The instruction takes two operands: R1 and R2."""
        if not (inst.noprs == 2):
            self.error(f"missing operands '{inst}'")
        self.memory.generate_code(self.addr,
                                  inst.code,
                                  inst.opr1,
                                  None,
                                  inst.opr2,
                                  single_word=True)
        self.addr += 1

    def generate_adr_x(self, label, inst):
        """Generate two-words code for instruction INST at the current
        address.  The instruction takes at most two operands: ADDRESS and
        XR."""
        if not (1 <= inst.noprs <= 2):
            self.error(f"too much/too many operand '{inst}'")
        self.memory.generate_code(self.addr, inst.code, None, inst.opr1,
                                  inst.opr2)
        self.addr += 2

    def generate_r(self, label, inst):
        """Generate one-word code for instruction INST at the current
        address.  The instruction takes a single operand: R."""
        if inst.noprs != 1:
            self.error(f"expects just one operand '{inst}'")
        self.memory.generate_code(self.addr,
                                  inst.code,
                                  inst.opr1,
                                  None,
                                  None,
                                  single_word=True)
        self.addr += 1

    def generate_nopr(self, label, inst):
        """Generate one-word code for instruction INST at the current
        address.  The instruction takes no operand."""
        if inst.noprs > 0:
            self.error(f'invalid operand "{inst}"')
        self.memory.generate_code(self.addr,
                                  inst.code,
                                  inst.opr1,
                                  None,
                                  None,
                                  single_word=True)
        self.addr += 1

    def generate_start(self, label, inst):
        """Handler for START assembler instruction.  START assembler
        instruction specifies the start address of the program."""
        if not label:
            self.error("no label found at START")
        if inst.opr1 is not None:
            # FIXME: argument must be label, not numbers (CASL II specification)
            self.start_addr = inst.opr1
        self.in_block = True

    def generate_end(self, label, inst):
        """Handler for END assembler instruction."""
        if label:
            self.error("can't use label '{label}' at END")
        if inst.noprs > 0:
            self.error(f"END accepts no operand '{inst}'")
        self.in_block = False

    def generate_ds(self, label, inst):
        """Handler for DS assembler instruction.  DS assembler instruction
        allocates a memory block of the specified size."""
        if inst.noprs != 1:
            self.error(f"DS exepects a single argument '{inst}'")
        # FIXME: DS accepts only decimal argument (CASL II specification)
        v = parse_number(inst.opr1)
        if v is None:
            self.error(f"DS exepects a single numeric argument '{inst}'")
        for n in range(v):
            self.memory[self.addr] = 0
            self.addr += 1

    def generate_dc(self, label, inst):
        """Handler for DC assembler instruction.  DC assembler instruction
        allocates a word with the specified value for every argument."""
        # parse DC arguments with regexp:
        # This regexp matches either string, numeric argument or label with 
        # its (optional) following comma separator.
        # DC assembler instruction arguments (CASL II specification):
        # - arguments are separated by comma (,) with optional surrounding spaces
        # - string is quoted by single quote (')
        # - quote char in string is escaped by double quotes ('')
        # - string may contain commad
        # e.g., DC 123, #12F0, 'ABC', SYM, ''' = quote', -123, '# = sharp'
        for m in re.finditer(r"\s*('[^']*(''[^']*)*'|[^',]+)(\s*,\s*)?",
                             inst._opr):
            opr = m.group(1)
            if opr.startswith("'"):
                # NOTE: the first and the last chars must be single quote
                astr = opr[1:-1]
                astr = astr.replace("''", "'")
                for c in astr:
                    self.memory[self.addr] = ord(c) & 0xff
                    self.addr += 1
            else:
                val = parse_number(opr)
                if val is not None:
                    self.memory[self.addr] = val
                else:
                    self.memory[self.addr] = opr	# must be label
                self.addr += 1

    def generate_rpush(self, label, inst):
        """Expand RPUSH macro and generate code for storing registers GR1
        through GR7 on the stack."""
        code = """\
        PUSH	0, GR1
        PUSH	0, GR2
        PUSH	0, GR3
        PUSH	0, GR4
        PUSH	0, GR5
        PUSH	0, GR6
        PUSH	0, GR7"""
        self.generate_lines(code.splitlines())

    def generate_rpop(self, label, inst):
        """Expand RPOP macro and generate code for retrieving registers GR7
        through GR1 from the stack."""
        code = """\
	POP	GR7
	POP	GR6
	POP	GR5
	POP	GR4
	POP	GR3
	POP	GR2
	POP	GR1"""
        self.generate_lines(code.splitlines())

    def generate_in(self, label, inst):
        """Expand IN macro and generate code for reading a record from the
        input device.  'IN IBUF, LEN' stores the record and its size at IBUF
        and LEN."""
        if inst.noprs != 2:
            self.error('IN macro requires two labels')
        code = f"""\
        PUSH	0, GR1
        PUSH	0, GR2
        LAD	GR1, {inst.opr1}
        LAD	GR2, {inst.opr2}
        SVC	1
        POP	GR2
        POP	GR1"""
        self.generate_lines(code.splitlines())

    def generate_out(self, label, inst):
        """Expand OUT macro and generate code for writing a record to the
        output device.  'OUT OBUF, LEN' writes LEN words of data starting from
        OBUF."""
        if inst.noprs != 2:
            self.error('OUT macro requires two labels')
        code = f"""\
        PUSH	0, GR1
        PUSH	0, GR2
        LAD	GR1, {inst.opr1}
        LAD	GR2, {inst.opr2}
        SVC	2
        POP	GR2
        POP	GR1"""
        self.generate_lines(code.splitlines())

    def remove_comment(self, line):
        """Remove comment from the line LINE."""
        # regexp for splitting non-comment and comment parts
        # line format (CASL II specification):
        # - comment starts with semicolon ';'
        # - semicolon may appear in strings (e.g., DC assembler instruction)
        # e.g., MSG  DC 'string may contain ; (semicolon)', 123 ; comment on the right
        m = re.search(r"^([^';]*('[^']*'[^';]*)*)\s*;.*$", line)
        if m:
            line = m.group(1)
        return line.rstrip()

    def generate(self, label, inst):
        """Generate code at the current address for the instruction INST."""
        self.debug(f'generate(self={self}, label={label}, ist={inst})')
        # START must be the first instruction
        if not self.in_block and inst.categ != 'start':
            self.error(f"no START directive found")
        # GR0 cannot be used as an index register.
        if inst.opr2 == 'GR0':
            self.error("can't use GR0 as an index register")

        subr = eval(f'Assembler.generate_{inst.categ}')
        if callable(subr):
            self.debug(
                f'generate_{inst.categ}(self={self}, label={label}, ist={inst})'
            )
            subr(self, label, inst)
        else:
            self.error(f"instruction type '{inst.categ}' not implemented")

    def generate_lines(self, lines):
        """Generate codes starting from the current address for lines stored
        in LINES, which can be either the full program or a fragment."""
        self.debug(f'generate_lines(self={self}, lines={lines})')
        for line in lines:
            self.debug(f'generate_lines: {line}')
            line = self.remove_comment(line)
            if line == '':
                continue
            label, inst = self.parse_line(line)

            # register label to the symbol table
            if label is not None:
                self.labeltbl.register(label, self.addr)
            # generate object code according to the instruction
            if inst:
                self.generate(label, inst)
                # register all literals
                if inst.noprs >= 2 and inst.opr2.startswith('='):
                    self.literals[inst.opr2] = True

    def generate_literals(self):
        """Allocate memory for all literals (i.e., adress operand starting
        with '=') used in the program."""
        self.debug(f'generate_literals(self={self})')
        for literal in self.literals.keys():
            # literal always starts with = (e.g., '=-123', '=#32f4')
            val = parse_number(literal[1:])
            if val is not None:
                self.labeltbl.register(literal, self.addr)
                self.memory[self.addr] = val
                self.addr += 1

    def pass1(self):
        """Parse the source code stored in the buffer and generate code in the
        memory with pending (i.e., non-resolved) labels.  The source code must
        have been loaded with load() method."""
        self.debug(f'pass1(self={self})')
        self.addr = 0
        self.in_block = False
        for self.lineno, self.line in enumerate(self.buf):
            # invoke for every line to locate the source code
            self.generate_lines([self.line])
        if self.in_block:
            self.error("no 'END' instruction found")
        self.generate_literals()

    def pass2(self):
        """Open the output file and dump the assembled object code."""
        self.debug(f'pass2(self={self})')
        base = re.sub(r'\.cas', '', self.file, flags=re.I)
        outfile = base + '.com'
        with open(outfile, 'wb') as f:
            # print object header
            addr = self.labeltbl.expand(self.start_addr)
            f.write(struct.pack('>4sH10x', b'CASL', addr))
            # dump memory image
            for addr in range(self.addr):
                val = self.memory[addr]
                # FIXME: undefined labels must be regarded as start addresses
                # of other programs (CASL II specification)
                v = self.labeltbl.expand(val)
                f.write(struct.pack('>H', v))

    def dump_memory(self):
        """Dump the assembled object code to the standard output with
        corresponding source lines."""
        self.debug(f'dump_memory(self={self})')
        print(f'CASL LISTING {self.file}')
        last_lineno = None
        for addr in range(self.addr):
            val = self.memory[addr]
            v = self.labeltbl.expand(val)
            file, lineno, line = self.memory.src_at(addr)
            if lineno != last_lineno:
                print(f'{lineno:4} {addr:04x} {v:04x}\t', end='')
                print(line)
                last_lineno = lineno
            else:
                print(f'{lineno:4}      {v:04x}')

    def dump_labels(self):
        """Dump all defined labels to the standard output."""
        self.debug(f'dump_label(self={self})')
        print('\nDEFINED LABELS')
        for label in self.labeltbl.keys():
            file, lineno = self.labeltbl.src_for(label)
            val = self.labeltbl[label]
            print(f'               {file}:{lineno}\t{val:04x} {label}')

def main():
    # parse command-line options
    opt = getopts('avd') or usage()
    asm = Assembler(debug=opt.d)
    # process all specified files
    for file in sys.argv[1:]:
        asm.load(file)
        asm.pass1()
        asm.pass2()
        if opt.a:
            asm.dump_memory()
            asm.dump_labels()

if __name__ == "__main__":
    main()
