#!/usr/bin/env python

""" Runs the Australian Senate Election Audit. """

from os import unlink

from aus_senate_audit.audit_recorder import AuditRecorder
from aus_senate_audit.audits.bayesian_audit import audit
from aus_senate_audit.audit_validator import AuditValidator
from aus_senate_audit.cli import parse_command_line_args
from aus_senate_audit.constants import REAL_MODE
from aus_senate_audit.constants import SIMULATION_MODE
from aus_senate_audit.sampler.sampler_wrapper import SamplerWrapper
from aus_senate_audit.senate_election.real_senate_election import RealSenateElection
from aus_senate_audit.senate_election.simulated_senate_election import SimulatedSenateElection


def main():
    """ Runs the Australian senate election audit. """
    args = parse_command_line_args()
    if args.mode == SIMULATION_MODE:
        election = SimulatedSenateElection(args.seed, args.num_ballots, args.num_candidates, args.sample_increment_size)
        audit(election, args.seed, args.unpopular_frequency_threshold, quick=True)
    elif args.mode == REAL_MODE:
        audit_recorder = AuditRecorder(args.state)
        if args.selected_ballots is None:
            SamplerWrapper(args.seed, args.state, args.sample_increment_size, args.data, audit_recorder)
        else:
            AuditValidator(args.selected_ballots, audit_recorder).compare()
            election = RealSenateElection(args.seed, args.state, args.data)
            audit(
                election,
                args.seed,
                args.unpopular_frequency_threshold,
                stage_counter=audit_recorder.get_current_audit_stage() - 1,
            )
    else:
        done = False
        audit_recorder = AuditRecorder(args.state)
        while not done:
            SamplerWrapper(args.seed, args.state, args.sample_increment_size, args.data, audit_recorder, quick=True)
            AuditValidator('selected_ballots.csv', audit_recorder).compare()
            election = RealSenateElection(args.seed, args.state, args.data, max_ballots=args.max_ballots)
            done = audit(
                election,
                args.seed,
                args.unpopular_frequency_threshold,
                stage_counter=audit_recorder.get_current_audit_stage() - 1,
            )
        unlink('selected_ballots.csv')


if __name__ == '__main__':
    main()
