#!/usr/bin/env python3
#
# A simple implementation of CASL assembler.
# Copyright (c) 1997-2000, Hiroyuki Ohsaki.
# All rights reserved.
#

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import os
import os.path
import re
import struct
import sys

from perlcompat import die, warn, getopts
import tbdump

VERSION = 0.3
DEBUG = False

# addresses of IN/OUT/EXIT system calls --- these MACROs are expanded
# to call this address after pushing its arguments on stack.
SYS_IN = 0xfff0
SYS_OUT = 0xfff2
SYS_EXIT = 0xfff4

# COMET instructions
# All valid instructions must have an entry in this hash.  CODE is the
# object code of the instruction, and TYPE is the type of instruction
# (e.g., ``op1'' indicates it takes two or three operands).
INSTTBL = {
    'LD': ['op1', 0x10],
    'ST': ['op1', 0x11],
    'LEA': ['op1', 0x12],
    'ADD': ['op1', 0x20],
    'SUB': ['op1', 0x21],
    'AND': ['op1', 0x30],
    'OR': ['op1', 0x31],
    'EOR': ['op1', 0x33],
    'CPA': ['op1', 0x40],
    'CPL': ['op1', 0x41],
    'SLA': ['op1', 0x50],
    'SRA': ['op1', 0x51],
    'SLL': ['op1', 0x52],
    'SRL': ['op1', 0x53],
    'JPZ': ['op2', 0x60],
    'JMI': ['op2', 0x61],
    'JNZ': ['op2', 0x62],
    'JZE': ['op2', 0x63],
    'JMP': ['op2', 0x64],
    'PUSH': ['op2', 0x70],
    'POP': ['op3', 0x71],
    'CALL': ['op2', 0x80],
    'RET': ['op4', 0x81],
    # pseudo instructions
    'START': ['start', None],
    'END': ['end', None],
    'DS': ['ds', None],
    'DC': ['dc', None],
    # CASL macros
    'IN': ['io', None],
    'OUT': ['io', None],
    'EXIT': ['exit', None],
}

def usage():
    die(f"""\
usage: {sys.argv[0]} [-av] file...
 -a          turn on verbose listings
 -v          display version and exit
""")

def parse_number(v):
    if type(v) == int:
        return int(v) & 0xffff
    elif type(v) == str:
        if re.search(r'^[-+]?\d+$', v):
            return int(v) & 0xffff
        else:
            m = re.search(r'^#([\da-zA-Z]+)$', v)
            if m:
                # convert hex to decimal
                return int(m.group(1), base=16) & 0xffff
    return None

# ----------------------------------------------------------------
class Instruction:
    def __init__(self, op, opr, err_func=die):
        try:
            categ, code = INSTTBL[op]
        except IndexError:
            err_func(f"illegal instruction '{op}'")

        self.op = op
        self.opr = opr  # raw operands
        if opr is None:
            oprs = []
        else:
            oprs = re.split(r'[,\s]+', opr)
        self.noprs = len(oprs)
        self.opr1 = oprs[0] if self.noprs >= 1 else None
        self.opr2 = oprs[1] if self.noprs >= 2 else None
        self.opr3 = oprs[2] if self.noprs >= 3 else None
        self.code = code  # op code
        self.categ = categ  # instruction type

    def __repr__(self):
        return f'Instruction(op={self.op}, opr={self.opr}, noprs={self.noprs}, categ={self.categ})'

# ----------------------------------------------------------------
class LabelTable:
    def __init__(self, debug_func=warn, err_func=die, assembler=None):
        self.tbl = {}
        self.src = {}
        self.debug_func = debug_func
        self.err_func = err_func
        self.assembler = assembler

    def __repr__(self):
        return 'LabelTable()'

    def __setitem__(self, label, val):
        self.tbl[label] = val
        self.src[label] = self.assembler.file, self.assembler.lineno

    def __getitem__(self, label):
        return self.tbl[label]

    def keys(self):
        return self.tbl.keys()

    def src_for(self, addr):
        return self.src[addr]

    def check_label(self, label):
        """Check the validity of label LABEL.  If not valid, display error and
        exit."""
        self.debug_func(f'check_label({self}, {label})')
        if not re.search(r'^[A-Z][0-9A-Za-z]{0,5}$', label):
            self.err_func(f"invalid label '{label}'")

    def register(self, label, val):
        """Register a label LABEL in the label table with the value
        VAL.  If LABEL is already defined, display error and exit."""
        self.debug_func(f'register({self}, {label}, {val})')
        self.check_label(label)
        if label in self.tbl:
            self.err_func(f"label '{label}' already defined")
        self[label] = val

    def expand(self, val):
        """Expand VAL to the corresponding decimal number.  Any label is
        resolved and hexadecimal number is converted to decimal."""
        self.debug_func(f'expand({self}, {val})')
        if type(val) == str:
            if val in self.tbl:
                val = self[val]
        v = parse_number(val)
        if v is None:
            self.err_func(f"undefined label '{val}'")
        return v

# ----------------------------------------------------------------
class Memory:
    def __init__(self, debug_func=warn, err_func=die, assembler=None):
        self.memory = []
        self.src = []
        self.debug_func = debug_func
        self.err_func = err_func
        self.assembler = assembler

    def __repr__(self):
        return 'Memory()'

    def __setitem__(self, addr, val):
        size = len(self.memory)
        if size < addr + 1:
            self.memory.extend([None] * (addr + 1 - size))
        self.memory[addr] = val

        size = len(self.src)
        if size < addr + 1:
            self.src.extend([None] * (addr + 1 - size))
        self.src[
            addr] = self.assembler.file, self.assembler.lineno, self.assembler.line

    def __getitem__(self, addr):
        return self.memory[addr]

    def src_at(self, addr):
        return self.src[addr]

    def _register_number(self, reg):
        """Check the validity of register REG.  Return the register
        number (0 - 4). If not valid, display error and exit."""
        self.debug_func(f'_register_number({self}, {reg})')
        if reg is None:
            return 0
        m = re.search(r'^GR([0-4])$', reg)
        if not m:
            self.err_func(f"invalid register '{reg}'")
        return int(m.group(1))

    def generate_code(self, addr, inst, gr, adr, xr):
        """Generate two-word code from INST, GR, ADR, and XR at ADDRESS."""
        self.debug_func(
            f'generate_code({self}, {addr:04x}, {inst:02x}, {gr}, {adr}, {xr})'
        )
        gr = self._register_number(gr)
        xr = self._register_number(xr)
        val = (inst << 8) + (gr << 4) + xr
        self[addr] = val
        self[addr + 1] = adr

# ----------------------------------------------------------------
class Assembler:
    def __init__(self, debug=False):
        self.file = None
        self.enable_debug = debug
        self.buf = []
        self.memory = Memory(debug_func=self.debug, assembler=self)
        self.labeltbl = LabelTable(debug_func=self.debug, assembler=self)

        self.start_addr = 0
        self.end_addr = 0
        self.lineno = None
        self.line = None

    def __repr__(self):
        return f'Assembler(start={self.start_addr}, addr={self.addr})'

    def copyright(self):
        return """\
This is CASL, version {VERSION}.
Copyright (c) 2021, Hiroyuki Ohsaki.
All rights reserved."""

    def load(self, file):
        with open(file) as f:
            self.file = file
            for lineno, line in enumerate(f):
                line = line.rstrip()
                self.buf.append(line)

    def debug(self, msg):
        if self.enable_debug:
            warn(f'{self.file}:{self.lineno}: {msg}')

    def error(self, msg):
        die(f'{self.file}:{self.lineno}: {msg}')

    def check_decimal(self, number):
        """Check the validity of decimal number NUMBER.  If not valid, display
        error and exit."""
        self.debug(f'check_decimal({number})')
        if not re.search(r'^[+-]?\d+$', number):
            self.error(f"'{number}' must be decimal")
        return int(number)

    def _remove_comment(self, line):
        # remove comment --- take care of ``;'' between single quotes.
        m = re.search(r"(^[^;]*'[^']*'.*?);.*$", line)
        if m:
            line = m.group(1)
        line = re.sub(r';.*$', '', line)
        # remove trailing spaces
        line = re.sub(r'\s+$', '', line)
        return line

    def _parse_line(self, line):
        # extract fields
        m = re.search(r'^(\S+)?\s+([A-Z]+)(\s+(.*)?)?$', line)
        if m:
            label, op, opr = m.group(1), m.group(2), m.group(4)
            self.debug(f'label/op/opr = {label}/{op}/{opr}')
        else:
            self.error('syntax error')
        if op:
            return label, Instruction(op, opr)
        else:
            return label, None

    # instructions with GR, adr, and optional XR
    def generate_op1(self, label, inst):
        if not (2 <= inst.noprs <= 3):
            self.error(f"missing operands '{inst}'")
        self.memory.generate_code(self.addr, inst.code, inst.opr1, inst.opr2,
                                  inst.opr3)
        self.addr += 2

    # instructions with adr, and optional XR
    def generate_op2(self, label, inst):
        if not (1 <= inst.noprs <= 2):
            self.error(f"too much/too many operand '{inst}'")
        self.memory.generate_code(self.addr, inst.code, None, inst.opr1,
                                  inst.opr2)
        self.addr += 2

    # instructions only with optional GR
    def generate_op3(self, label, inst):
        if inst.noprs != 1:
            self.error(f"expects just one operand '{inst}'")
        self.memory.generate_code(self.addr, inst.code, inst.opr1, None, None)
        self.addr += 2

    # instructions without operand
    def generate_op4(self, label, inst):
        if inst.noprs > 0:
            self.error(f'invalid operand "{inst}"')
        self.memory.generate_code(self.addr, inst.code, inst.opr1, None, None)
        self.addr += 2

    # START instruction
    def generate_start(self, label, inst):
        self.debug(f'generate_start({self}, {label}, {inst})')
        if not label:
            self.error("no label found at START")
        if inst.opr1 is not None:
            self.start_addr = parse_number(inst.opr1)
            self.addr = self.start_addr
        self.in_block = True

    # END instruction
    def generate_end(self, label, inst):
        if label:
            self.error("can't use label '{label}' at END")
        if inst.noprs > 0:
            self.error(f"END accepts no operand '{inst}'")
        self.end_addr = self.addr - 1
        self.in_block = False

    # DS instruction
    def generate_ds(self, label, inst):
        if inst.noprs != 1:
            self.error(f"DS exepects a single operand '{inst}'")
        v = self.check_decimal(inst.opr1)
        for n in range(v):
            self.memory[self.addr] = 0
            self.addr += 1

    # DC instruction
    def generate_dc(self, label, inst):
        m = re.search(r"^'([^\']+)'$", inst.opr)
        if m:
            vals = bytearray(m.group(1), encoding='ascii')
            for c in vals:
                self.memory[self.addr] = c
                self.addr += 1
        elif inst.noprs == 1:  # number or label
            self.memory[self.addr] = inst.opr1
            self.addr += 1
        else:
            self.error(f"DC accepts a word count or a string '{inst}'")

    # IN/OUT macro
    def generate_io(self, label, inst):
        if inst.noprs != 2:
            self.error(f"IN/OUT requires two operands '{inst}'")
        # two operands must be labels, not numbers
        self.labeltbl.check_label(inst.opr1)
        self.labeltbl.check_label(inst.opr2)
        # IN/OUT macro is expanded to push two operands onto the
        # stack, call SYS_IN / SYS_OUT, and restore stack.
        if inst.op == 'IN':
            entry = SYS_IN
        else:
            entry = SYS_OUT
        self.memory.generate_code(self.addr, INSTTBL['PUSH'][1], None,
                                  inst.opr1, None)
        self.memory.generate_code(self.addr + 2, INSTTBL['PUSH'][1], None,
                                  inst.opr2, None)
        self.memory.generate_code(self.addr + 4, INSTTBL['CALL'][1], None,
                                  entry, None)
        self.memory.generate_code(self.addr + 6, INSTTBL['LEA'][1], 'GR4', 2,
                                  'GR4')
        self.addr += 8

    # EXIT macro
    def generate_exit(self, label, inst):
        if inst.noprs > 0:
            self.error(f"EXIT does not accept operand '{inst}'")
        # EXIT macro is replaced with 'JMP SYS_EXIT'
        self.memory.generate_code(self.addr, INSTTBL['JMP'][1], None, SYS_EXIT,
                                  None)
        self.addr += 2

    def generate(self, label, inst):
        self.debug(f'generate({self}, {label}, {inst})')
        # START must be the first instruction
        if not self.in_block and inst.categ != 'start':
            self.error(f"no START directive found")
        # GR0 cannot be used as an index register.
        if inst.opr2 == 'GR0':
            self.error("can't use GR0 as an index register")

        subr = eval(f'Assembler.generate_{inst.categ}')
        subr(self, label, inst)

# FIXME
#            self.error(f"instruction type '{inst.categ}' not implemented")

    def pass1(self):
        """Parse the source file FILE, register all symbols to LABELTBL, and
        generate code in MEMORY."""
        self.addr = 0
        self.in_block = False
        for self.lineno, self.line in enumerate(self.buf):
            self.debug(self.line)
            line = self._remove_comment(self.line)
            if line == '':
                continue
            label, inst = self._parse_line(line)

            # register label to the symbol table
            if label is not None:
                self.labeltbl.register(label, self.addr)
            # generate object code according the type of instruction
            if inst:
                self.generate(label, inst)
        if self.in_block:
            self.error("No 'END' instruction found")

    def pass2(self):
        """Open the output file, and dump the assembled object code."""
        self.debug(f'pass2({self})')
        base = re.sub(r'\.cas', '', self.file, flags=re.I)
        outfile = base + '.com'
        with open(outfile, 'wb') as f:
            # print object header
            f.write(struct.pack('>4sH10x', b'CASL', self.start_addr))
            # dump memory image
            for addr in range(self.start_addr, self.end_addr + 1):
                val = self.memory[addr]
                v = self.labeltbl.expand(val)
                f.write(struct.pack('>H', v))

    def dump_memory(self):
        self.debug(f'dump_memory({self})')
        print('CASL LISTING file')
        last_lineno = None
        for addr in range(self.start_addr, self.end_addr + 1):
            val = self.memory[addr]
            v = self.labeltbl.expand(val)
            file, lineno, line = self.memory.src_at(addr)
            if lineno != last_lineno:
                print(f'{lineno:4} {addr:04x} {v:04x}\t', end='')
                print(line)
                last_lineno = lineno
            else:
                print(f'{lineno:4}      {v:04x}')

    def dump_labels(self):
        self.debug(f'dump_label({self})')
        print('\nDEFINED LABELS')
        for label in self.labeltbl.keys():
            file, lineno = self.labeltbl.src_for(label)
            val = self.labeltbl[label]
            print(f'\t{file}:{lineno}:\t{val:04x}\t{label}')

def main():
    opt = getopts('avd') or usage()
    asm = Assembler(debug=opt.d)
    for file in sys.argv[1:]:
        asm.load(file)
        asm.pass1()
        asm.pass2()
        if opt.a:
            asm.dump_memory()
            asm.dump_labels()

if __name__ == "__main__":
    main()
