{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "abd06395",
   "metadata": {},
   "source": [
    "# Making predictions for TOIs"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c88a7b1a",
   "metadata": {},
   "source": [
    "\n",
    "Load required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e91f799a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING (theano.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "import pandas as pd\n",
    "import lightkurve as lk\n",
    "import numpy as np\n",
    "import astropy.units as u\n",
    "import astropy.constants as c\n",
    "from scipy.constants import G\n",
    "from IPython.display import display\n",
    "from ldtk import LDPSetCreator, BoxcarFilter, TabulatedFilter\n",
    "from exoInfoMatrixTOI import exoInfoMatrix\n",
    "import ldtk.filters as filters\n",
    "import exoplanet as xo\n",
    "import torch.multiprocessing as mp"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d58a624f",
   "metadata": {},
   "source": [
    "\n",
    "Select only planet candidates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfcb1d0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Read TOIs table from the Nasa Exoplanet Archive (NEA)\n",
    "nea_tois = pd.read_csv(\"nea_tois.csv\", header=90)\n",
    "\n",
    "nea_tois = nea_tois[nea_tois[\"tfopwg_disp\"] == \"PC\"] # Only want planet candidates\n",
    "nea_tois = nea_tois[nea_tois[\"pl_pnum\"] == 1] # And only with a single planet\n",
    "\n",
    "print(f\"Initial number of planet candidates is {len(nea_tois)}\\n\")\n",
    "\n",
    "# Only keep if there are values for stellar logg and stellar radius which we will need later on\n",
    "nea_tois.dropna(axis=0, subset=[\"st_logg\", \"st_rad\", \"st_teff\"], inplace=True)\n",
    "\n",
    "# We also need errors, wherever two values (lower and upper boundaries) for the error are reported or only one is given, we will keep the largest\n",
    "nea_tois[\"st_rad_err\"] = np.nanmax(nea_tois[[\"st_raderr1\", \"st_raderr2\"]], axis=-1) \n",
    "nea_tois[\"st_logg_err\"] = np.nanmax(nea_tois[[\"st_loggerr1\", \"st_loggerr2\"]], axis=-1) \n",
    "nea_tois[\"st_teff_err\"] = np.nanmax(nea_tois[[\"st_tefferr1\", \"st_tefferr2\"]], axis=-1) \n",
    "\n",
    "# And since we also need the errors later on, only keep columns with errors included\n",
    "nea_tois.dropna(axis=0, subset=[\"st_rad_err\", \"st_logg_err\", \"st_teff_err\"], inplace=True)\n",
    "\n",
    "print(f\"{len(nea_tois)} planet candidates with values for logg, R_star and T_eff along with errors\\n\")\n",
    "\n",
    "# Reset indices\n",
    "nea_tois.reset_index(inplace=True, drop=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c4eb9d81",
   "metadata": {},
   "source": [
    "\n",
    "Out of these, we want only those which were observed __only__ with 1800s cadence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a7afb4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# We will need to search for the available lightcurves for each of the candidates. Then, if they are only observed with 1800s we add them to a new dataframe\n",
    "\n",
    "# This can take long\n",
    "for i, row in nea_tois.iterrows():\n",
    "    print(f\"\\n{i} out of {len(nea_tois) - 1}\")\n",
    "\n",
    "    TID = f\"TIC {row['tid']}\"\n",
    "\n",
    "    # Results of the lightcurve search\n",
    "    search = lk.search_lightcurve(TID, mission=\"TESS\")\n",
    "\n",
    "    # If there were no matches, a KeyError will be raised\n",
    "    try:\n",
    "        exptimes = set(search.exptime.value)\n",
    "    except KeyError:\n",
    "        # We add a note letting us know this PC wasn't found\n",
    "        nea_tois.at[i, \"notes\"] = \"NOT FOUND\"\n",
    "        print(\"NOT FOUND\")\n",
    "        continue\n",
    "\n",
    "    # We check if the only cadence is 1800s. If it's not we do not flag this candidate as accepted. Otherwise we do\n",
    "    if not exptimes.issubset(set([1800])):\n",
    "        nea_tois.at[i, \"notes\"] = \"NOT ONLY 1800s\"\n",
    "        print(\"NOT ONLY 1800s\")\n",
    "    else:\n",
    "        nea_tois.at[i, \"accepted\"] = True\n",
    "        print(\"ACCEPTED\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69b6038d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Now we save the PCs with observations only in 1800s\n",
    "nea_tois[nea_tois[\"accepted\"] == True].to_csv(\"tois_with_only_1800s.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8499a440",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# And read these results into a new dataframe\n",
    "only1800 = pd.read_csv(\"tois_with_only_1800s.csv\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b73eab03",
   "metadata": {},
   "source": [
    "\n",
    "Add columns with some values needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "6f9a805f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We use astropy units to not have to deal with conversion between units.\n",
    "\n",
    "\n",
    "logg = only1800[\"st_logg\"] # Log(g)\n",
    "g = 10**logg.to_numpy() * u.cm * u.s ** (-2) # g in cm/s^2\n",
    "R = only1800[\"st_rad\"].to_numpy() * c.R_sun # Stellar radius in solar radii\n",
    "P = only1800[\"pl_orbper\"].to_numpy() * u.day # Period in days\n",
    "T = only1800[\"pl_trandurh\"].to_numpy() * u.hour # Transit duration in hours\n",
    "\n",
    "# We store the values only, no units\n",
    "only1800[\"st_rho\"] = (3/(4 * np.pi * c.G) * g / R).to(u.g * (u.cm)**(-3)).value # Stellar density\n",
    "\n",
    "a = (((g * R ** 2 * P ** 2) / (4 * np.pi ** 2)) ** (1/3) / R).decompose() # Semi-major axis\n",
    "\n",
    "only1800[\"a\"] = a.value\n",
    "\n",
    "b = np.sqrt(1 - ((a * np.pi * T) / P) ** 2).decompose()\n",
    "\n",
    "only1800[\"b\"] = b.value"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "082d05fc",
   "metadata": {},
   "source": [
    "\n",
    "To make stimates, we need fiducial values for the limb-darkening parameters. We obtain approximate values using 'PyLDTk'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "32fc761e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Row 0 out of 193\n",
      "Row 1 out of 193\n",
      "Row 2 out of 193\n",
      "Row 3 out of 193\n",
      "Row 4 out of 193\n",
      "Row 5 out of 193\n",
      "Row 6 out of 193\n",
      "Row 7 out of 193\n",
      "Row 8 out of 193\n",
      "Row 9 out of 193\n",
      "Row 10 out of 193\n",
      "Row 11 out of 193\n",
      "Row 12 out of 193\n",
      "Row 13 out of 193\n",
      "Row 13 (81831095) did not converge\n",
      "Row 14 out of 193\n",
      "Row 15 out of 193\n",
      "Row 16 out of 193\n",
      "Row 17 out of 193\n",
      "Row 18 out of 193\n",
      "Row 19 out of 193\n",
      "Row 20 out of 193\n",
      "Row 21 out of 193\n",
      "Row 22 out of 193\n",
      "Row 23 out of 193\n",
      "Row 24 out of 193\n",
      "Row 25 out of 193\n",
      "Row 26 out of 193\n",
      "Row 27 out of 193\n",
      "Row 28 out of 193\n",
      "Row 29 out of 193\n",
      "Row 30 out of 193\n",
      "Row 31 out of 193\n",
      "Row 32 out of 193\n",
      "Row 33 out of 193\n",
      "Row 34 out of 193\n",
      "Row 35 out of 193\n",
      "Row 36 out of 193\n",
      "Row 37 out of 193\n",
      "Row 38 out of 193\n",
      "Row 39 out of 193\n",
      "Row 39 (160930264) did not converge\n",
      "Row 40 out of 193\n",
      "Row 41 out of 193\n",
      "Row 42 out of 193\n",
      "Row 43 out of 193\n",
      "Row 43 (174426662) did not converge\n",
      "Row 44 out of 193\n",
      "Row 45 out of 193\n",
      "Row 46 out of 193\n",
      "Row 47 out of 193\n",
      "Row 48 out of 193\n",
      "Row 49 out of 193\n",
      "Row 50 out of 193\n",
      "Row 51 out of 193\n",
      "Row 52 out of 193\n",
      "Row 53 out of 193\n",
      "Row 54 out of 193\n",
      "Row 55 out of 193\n",
      "Row 56 out of 193\n",
      "Row 56 (65440953) did not converge\n",
      "Row 57 out of 193\n",
      "Row 58 out of 193\n",
      "Row 59 out of 193\n",
      "Row 60 out of 193\n",
      "Row 61 out of 193\n",
      "Row 62 out of 193\n",
      "Row 63 out of 193\n",
      "Row 64 out of 193\n",
      "Row 65 out of 193\n",
      "Row 66 out of 193\n",
      "Row 67 out of 193\n",
      "Row 68 out of 193\n",
      "Row 69 out of 193\n",
      "Row 70 out of 193\n",
      "Row 71 out of 193\n",
      "Row 72 out of 193\n",
      "Row 73 out of 193\n",
      "Row 74 out of 193\n",
      "Row 75 out of 193\n",
      "Row 76 out of 193\n",
      "Row 77 out of 193\n",
      "Row 78 out of 193\n",
      "Row 79 out of 193\n",
      "Row 80 out of 193\n",
      "Row 81 out of 193\n",
      "Row 82 out of 193\n",
      "Row 83 out of 193\n",
      "Row 84 out of 193\n",
      "Row 85 out of 193\n",
      "Row 86 out of 193\n",
      "Row 87 out of 193\n",
      "Row 88 out of 193\n",
      "Row 88 (468777766) did not converge\n",
      "Row 89 out of 193\n",
      "Row 90 out of 193\n",
      "Row 91 out of 193\n",
      "Row 92 out of 193\n",
      "Row 93 out of 193\n",
      "Row 94 out of 193\n",
      "Row 95 out of 193\n",
      "Row 96 out of 193\n",
      "Row 97 out of 193\n",
      "Row 98 out of 193\n",
      "Row 99 out of 193\n",
      "Row 100 out of 193\n",
      "Row 101 out of 193\n",
      "Row 102 out of 193\n",
      "Row 103 out of 193\n",
      "Row 104 out of 193\n",
      "Row 105 out of 193\n",
      "Row 106 out of 193\n",
      "Row 107 out of 193\n",
      "Row 108 out of 193\n",
      "Row 109 out of 193\n",
      "Row 110 out of 193\n",
      "Row 111 out of 193\n",
      "Row 112 out of 193\n",
      "Row 113 out of 193\n",
      "Row 114 out of 193\n",
      "Row 115 out of 193\n",
      "Row 116 out of 193\n",
      "Row 117 out of 193\n",
      "Row 118 out of 193\n",
      "Row 119 out of 193\n",
      "Row 120 out of 193\n",
      "Row 121 out of 193\n",
      "Row 122 out of 193\n",
      "Row 123 out of 193\n",
      "Row 124 out of 193\n",
      "Row 125 out of 193\n",
      "Row 126 out of 193\n",
      "Row 127 out of 193\n",
      "Row 128 out of 193\n",
      "Row 129 out of 193\n",
      "Row 129 (190986054) did not converge\n",
      "Row 130 out of 193\n",
      "Row 131 out of 193\n",
      "Row 132 out of 193\n",
      "Row 133 out of 193\n",
      "Row 134 out of 193\n",
      "Row 135 out of 193\n",
      "Row 136 out of 193\n",
      "Row 137 out of 193\n",
      "Row 138 out of 193\n",
      "Row 139 out of 193\n",
      "Row 140 out of 193\n",
      "Row 140 (356472238) did not converge\n",
      "Row 141 out of 193\n",
      "Row 142 out of 193\n",
      "Row 143 out of 193\n",
      "Row 144 out of 193\n",
      "Row 145 out of 193\n",
      "Row 146 out of 193\n",
      "Row 147 out of 193\n",
      "Row 148 out of 193\n",
      "Row 149 out of 193\n",
      "Row 150 out of 193\n",
      "Row 151 out of 193\n",
      "Row 152 out of 193\n",
      "Row 153 out of 193\n",
      "Row 154 out of 193\n",
      "Row 155 out of 193\n",
      "Row 156 out of 193\n",
      "Row 157 out of 193\n",
      "Row 158 out of 193\n",
      "Row 159 out of 193\n",
      "Row 160 out of 193\n",
      "Row 161 out of 193\n",
      "Row 162 out of 193\n",
      "Row 163 out of 193\n",
      "Row 164 out of 193\n",
      "Row 165 out of 193\n",
      "Row 166 out of 193\n",
      "Row 167 out of 193\n",
      "Row 168 out of 193\n",
      "Row 169 out of 193\n",
      "Row 170 out of 193\n",
      "Row 171 out of 193\n",
      "Row 172 out of 193\n",
      "Row 173 out of 193\n",
      "Row 174 out of 193\n",
      "Row 175 out of 193\n",
      "Row 176 out of 193\n",
      "Row 177 out of 193\n",
      "Row 178 out of 193\n",
      "Row 179 out of 193\n",
      "Row 180 out of 193\n",
      "Row 181 out of 193\n",
      "Row 182 out of 193\n",
      "Row 183 out of 193\n",
      "Row 183 (407495930) did not converge\n",
      "Row 184 out of 193\n",
      "Row 185 out of 193\n",
      "Row 186 out of 193\n",
      "Row 187 out of 193\n",
      "Row 188 out of 193\n",
      "Row 189 out of 193\n",
      "Row 190 out of 193\n",
      "Row 191 out of 193\n",
      "Row 192 out of 193\n",
      "Row 193 out of 193\n"
     ]
    }
   ],
   "source": [
    "\n",
    "filt = filters.create_tess() # Create TESS filters profiles\n",
    "\n",
    "copy = only1800.copy() # Copy df to iterate through rows safely\n",
    "\n",
    "# Iterate through all rows\n",
    "for i, row in copy.iterrows():\n",
    "    print(f\"Row {i} out of {len(copy) - 1}\")\n",
    "\n",
    "    # Read effective temperature and logg values\n",
    "    teff = row[\"st_teff\"]\n",
    "    teff_err = row[\"st_logg_err\"]\n",
    "\n",
    "    logg = row[\"st_logg\"]\n",
    "    logg_err = row[\"st_logg_err\"]\n",
    "\n",
    "    # Just to be sure, we check there are no NaN values\n",
    "    names = np.array([\"teff\", \"teff_err\", \"logg\", \"logg_err\"])\n",
    "    anynan = np.isnan(np.array([teff, teff_err, logg, logg_err]))\n",
    "\n",
    "    if anynan.any():\n",
    "        print(f\"{row['tid']} has NaN value in {names[anynan]}\")\n",
    "\n",
    "    # Create profiles. Because we have no z value from the table we use 0.25 with error 0.125\n",
    "    sc = LDPSetCreator(teff=(teff, teff_err), logg=(logg, logg_err), z=(0.25, 0.125), filters=[filt])\n",
    "\n",
    "    ps = sc.create_profiles(nsamples=1000)\n",
    "\n",
    "    # Do a mcmc to get the values, if it can't converge print message\n",
    "    try:\n",
    "        qc, qe = ps.coeffs_qd(do_mc=True)\n",
    "    except np.linalg.LinAlgError:\n",
    "        print(f\"Row {i} ({row['tid']}) did not converge\")\n",
    "        only1800.at[i, \"u_star1\"] = None\n",
    "        only1800.at[i, \"u_star2\"] = None\n",
    "        only1800.at[i, \"u_star1_sd\"] = None\n",
    "        only1800.at[i, \"u_star2_sd\"] = None\n",
    "        continue\n",
    "\n",
    "    # Check no NaN values in results\n",
    "    if np.isnan([qc,qe]).any():\n",
    "        print(f\"Row {i} ({row['tid']}) calculated values are nan somewhere\")\n",
    "\n",
    "    only1800.at[i, \"u_star1\"] = qc[0][0]\n",
    "    only1800.at[i, \"u_star2\"] = qc[0][1]\n",
    "    only1800.at[i, \"u_star1_sd\"] = qe[0][0]\n",
    "    only1800.at[i, \"u_star2_sd\"] = qe[0][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "e907726b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# And we save only those that did converge\n",
    "\n",
    "only1800[np.invert(np.isnan(only1800[\"u_star1\"].to_numpy()))].to_csv(\"tois_with_only_1800s_limbdark.csv\", index=False)\n",
    "\n",
    "limbdarkened = pd.read_csv(\"tois_with_only_1800s_limbdark.csv\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6f6b5d1b",
   "metadata": {},
   "source": [
    "\n",
    "Now that we have limb-darkening values we can get an approximate value for the radius ratio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "1a3bf766",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# We calculate for each row\n",
    "for i, row in limbdarkened.copy().iterrows():\n",
    "    # We create a limb-darkened star from the exoplanet package\n",
    "    star = xo.LimbDarkLightCurve(row[\"u_star1\"], row[\"u_star2\"])\n",
    "\n",
    "    # And use the 'get_ror_from_approx_transit_depth' utility to obtain an approximate value for the radius ratio\n",
    "    ror = star.get_ror_from_approx_transit_depth(row[\"pl_trandep\"]*1e-6, row[\"b\"]).eval()\n",
    "\n",
    "    limbdarkened.at[i, \"ror\"] = ror"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f4df5422",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "157 final planet candidates to be passed onto prediction calculation\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# We can now save this as our final dataframe\n",
    "\n",
    "limbdarkened = limbdarkened[np.invert(np.isnan(limbdarkened[\"ror\"]))]\n",
    "\n",
    "print(f\"{len(limbdarkened)} final planet candidates to be passed onto prediction calculation\")\n",
    "\n",
    "limbdarkened.to_csv(\"final_dataframe.csv\", index=False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "78d59607",
   "metadata": {},
   "source": [
    "\n",
    "Now we make the actual predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "3ce24c18",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# To make it faster, we will parallelize the calculations\n",
    "\n",
    "# CHANGE THIS TO THE NUMBER OF CORES YOU WISH TO USE\n",
    "NCORES = 12\n",
    "\n",
    "# Read the final input table\n",
    "table = pd.read_csv(\"final_dataframe.csv\")\n",
    "\n",
    "# Now we split them into NCORES tables\n",
    "tables = np.array_split(table, NCORES)\n",
    "\n",
    "# We will calculate predicted radius ratio for the following exposure times\n",
    "calc_expt = {20, 120, 600, 1800}\n",
    "\n",
    "indices = np.arange(0, NCORES, 1) # Just to keep track of how each thread is doing"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "90520ad4",
   "metadata": {},
   "source": [
    "\n",
    "This function will calculate predictions for each of the tables. We need to include it in a function so as to be able to do multiprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "0b7d7b34",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def calculate_prediction(args):\n",
    "    df, index = args\n",
    "\n",
    "    copy = df.copy()\n",
    "    copy.reset_index(inplace=True, drop=True)\n",
    "\n",
    "    # Loop through all rows\n",
    "    for idx, row in copy.iterrows():\n",
    "        print(f\"THREAD {index}: {idx+1} out of {len(copy)}\\n\")\n",
    "\n",
    "        # Read the hostname\n",
    "        host = f\"TIC {row['tid']}\"\n",
    "\n",
    "        ref_exptime = 1800 # Our reference exposure time is 1800s, to download a reference lightcurve later one\n",
    "\n",
    "        # Search the lightcurve\n",
    "        search = lk.search_lightcurve(host, mission=\"TESS\", exptime=ref_exptime)\n",
    "\n",
    "        # We give priority to SPOC lightcurves, then QLP and then CDIPS. No reason beyond keeping lightcurves as homogeneous as possible.\n",
    "        if len(search[[\"SPOC\" in author for author in search.author]]) != 0:\n",
    "            search = search[[\"SPOC\" in author for author in search.author]]\n",
    "        elif len(search[[\"QLP\" in author for author in search.author]]) != 0:\n",
    "            search = search[[\"QLP\" in author for author in search.author]]\n",
    "        elif len(search[[\"CDIPS\" in author for author in search.author]]):\n",
    "            search = search[[\"CDIPS\" in author for author in search.author]]\n",
    "\n",
    "        # Download the lightcurve\n",
    "        try:\n",
    "            lc = search[-1].download_all().stitch().remove_nans().remove_outliers(sigma_lower=float('inf'))\n",
    "        except lk.LightkurveError:\n",
    "            print(f\"{host} lightcurve can't be downloaded ({search.author})\")\n",
    "\n",
    "\n",
    "        # Set the reference mean error of measurements as the mean error for the measurements in the 1800s lightcurve\n",
    "        ref_sigma = np.mean(np.array(lc.flux_err.value))\n",
    "\n",
    "        # And the reference timestamps array is also obtained from the lightcurve\n",
    "        ref_t = np.array(lc.time.value)\n",
    "\n",
    "        # We also keep track of these values\n",
    "        copy.at[idx, \"ref_exptime\"] = ref_exptime\n",
    "        copy.at[idx, \"ref_sigma\"] = ref_sigma\n",
    "\n",
    "        # Now we make the actual predictions for each exposure time\n",
    "        for exptime in calc_expt:\n",
    "            # New array of timestamps with points spaced by one exposure time and with a total observation time equal to one sector\n",
    "            t = np.arange(min(ref_t), max(ref_t), exptime / (3600 * 24))\n",
    "\n",
    "            # Calculate the new mean error for this exposure time\n",
    "            sigma = ref_sigma * np.sqrt(ref_exptime)/np.sqrt(exptime)\n",
    "\n",
    "            # Initialize the information matrix. Oversample of ~100 should be fine but can also do 1000, it will just take longer\n",
    "            infomatrix = exoInfoMatrix(exptime, oversample=100)\n",
    "\n",
    "            # This is just to make sure there are no nan values\n",
    "            anynan = np.isnan(np.array([\n",
    "                row[\"pl_orbper\"],\n",
    "                row[\"pl_tranmid\"],\n",
    "                row[\"ror\"],\n",
    "                row[\"b\"],\n",
    "                row[\"u_star1\"],\n",
    "                row[\"u_star2\"],\n",
    "                row[\"st_rho\"],\n",
    "                row[\"st_rad\"]]))\n",
    "\n",
    "            names = np.array([\"pl_orbper\", \"pl_tranmid\", \"ror\", \"b\", \"u_star1\", \"u_star2\", \"st_rho\", \"st_rad\"])\n",
    "\n",
    "            if np.isnan(t).any():\n",
    "                print(f\"{host} has NaN values for t\")\n",
    "                continue\n",
    "            if anynan.any():\n",
    "                print(f\"{host} has NaN values for {names[anynan]}\")\n",
    "                continue\n",
    "\n",
    "            # If there are no NaNs then we set the data\n",
    "            infomatrix.set_data(\n",
    "                time_array = t,\n",
    "                period_val = row[\"pl_orbper\"],\n",
    "                t0_val     = row[\"pl_tranmid\"],\n",
    "                ror_val    = row[\"ror\"],\n",
    "                b_val      = row[\"b\"],\n",
    "                u1_val     = row[\"u_star1\"],\n",
    "                u2_val     = row[\"u_star2\"],\n",
    "                rho_star_val = row[\"st_rho\"],\n",
    "                r_star_val = row[\"st_rad\"],\n",
    "            )\n",
    "\n",
    "            # Then we set the priors. We do not use a prior on stellar density\n",
    "            infomatrix.set_priors(\n",
    "                period_prior = np.nanmax(np.abs(row[[\"pl_orbpererr1\", \"pl_orbpererr2\"]])),\n",
    "                t0_prior = np.nanmax(np.abs(row[[\"pl_tranmiderr1\", \"pl_tranmiderr2\"]])),\n",
    "                r_star_prior = np.nanmax(np.abs(row[[\"st_raderr1\", \"st_raderr2\"]])),\n",
    "                b_prior = 1/np.sqrt(12),\n",
    "                u1_prior = 0.4713,\n",
    "                u2_prior = 0.4084,\n",
    "            )\n",
    "\n",
    "            # And we calculate the information matrix\n",
    "            try:\n",
    "                matrix = infomatrix.eval_cov(sigma = np.mean(sigma))\n",
    "            except ValueError:\n",
    "                print(f\"{host} inversion of matrix failed\")\n",
    "                continue\n",
    "\n",
    "\n",
    "            # Now we loop through the rows and columns of the matrix to extract the values\n",
    "            for i, value1 in enumerate([\"period\", \"t0\", \"ror\", \"b\", \"u_star1\", \"u_star2\", \"rho_star\", \"r_star\"]):\n",
    "                for j, value2 in enumerate([\"period\", \"t0\", \"ror\", \"b\", \"u_star1\", \"u_star2\", \"rho_star\", \"r_star\"]):\n",
    "\n",
    "                    # Diagonal gives the standard deviation or predicted precision\n",
    "                    if value1 == value2:\n",
    "                        std = np.sqrt(np.abs(matrix[i,j]))\n",
    "                        col = f\"{value1}_{exptime}_sd\"\n",
    "                        copy.at[idx, col] = std\n",
    "\n",
    "    return copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "1f331c7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "THREAD 0: 1 out of 14\n",
      "THREAD 3: 1 out of 13\n",
      "THREAD 2: 1 out of 13\n",
      "THREAD 1: 1 out of 13\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "THREAD 6: 1 out of 13\n",
      "THREAD 5: 1 out of 13\n",
      "THREAD 8: 1 out of 13\n",
      "THREAD 11: 1 out of 13\n",
      "THREAD 4: 1 out of 13\n",
      "THREAD 7: 1 out of 13\n",
      "THREAD 10: 1 out of 13\n",
      "THREAD 9: 1 out of 13\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "THREAD 3: 2 out of 13\n",
      "\n",
      "THREAD 0: 2 out of 14\n",
      "\n",
      "THREAD 4: 2 out of 13\n",
      "\n",
      "THREAD 7: 2 out of 13\n",
      "\n",
      "THREAD 1: 2 out of 13\n",
      "\n",
      "THREAD 5: 2 out of 13\n",
      "\n",
      "THREAD 9: 2 out of 13\n",
      "\n",
      "THREAD 8: 2 out of 13\n",
      "\n",
      "THREAD 10: 2 out of 13\n",
      "\n",
      "THREAD 2: 2 out of 13\n",
      "\n",
      "THREAD 3: 3 out of 13\n",
      "\n",
      "THREAD 11: 2 out of 13\n",
      "\n",
      "THREAD 6: 2 out of 13\n",
      "\n",
      "THREAD 0: 3 out of 14\n",
      "\n",
      "THREAD 9: 3 out of 13\n",
      "\n",
      "TIC 252928337 lightcurve can't be downloaded (['DIAMANTE'])\n",
      "THREAD 1: 3 out of 13\n",
      "\n",
      "THREAD 10: 3 out of 13\n",
      "\n",
      "THREAD 3: 4 out of 13\n",
      "\n",
      "THREAD 6: 3 out of 13\n",
      "\n",
      "THREAD 4: 3 out of 13\n",
      "\n",
      "THREAD 7: 3 out of 13\n",
      "\n",
      "THREAD 5: 3 out of 13\n",
      "\n",
      "THREAD 9: 4 out of 13\n",
      "\n",
      "THREAD 2: 3 out of 13\n",
      "\n",
      "THREAD 8: 3 out of 13\n",
      "\n",
      "THREAD 10: 4 out of 13\n",
      "\n",
      "THREAD 0: 4 out of 14\n",
      "\n",
      "THREAD 11: 3 out of 13\n",
      "\n",
      "THREAD 3: 5 out of 13\n",
      "\n",
      "THREAD 4: 4 out of 13\n",
      "\n",
      "THREAD 8: 4 out of 13\n",
      "\n",
      "THREAD 0: 5 out of 14\n",
      "\n",
      "THREAD 1: 4 out of 13\n",
      "\n",
      "THREAD 6: 4 out of 13\n",
      "\n",
      "THREAD 5: 4 out of 13\n",
      "\n",
      "THREAD 7: 4 out of 13\n",
      "\n",
      "THREAD 9: 5 out of 13\n",
      "\n",
      "THREAD 11: 4 out of 13\n",
      "\n",
      "THREAD 10: 5 out of 13\n",
      "\n",
      "THREAD 4: 5 out of 13\n",
      "\n",
      "THREAD 0: 6 out of 14\n",
      "\n",
      "THREAD 2: 4 out of 13\n",
      "\n",
      "THREAD 5: 5 out of 13\n",
      "\n",
      "THREAD 3: 6 out of 13\n",
      "\n",
      "THREAD 6: 5 out of 13\n",
      "\n",
      "THREAD 9: 6 out of 13\n",
      "\n",
      "THREAD 10: 6 out of 13\n",
      "\n",
      "THREAD 4: 6 out of 13\n",
      "\n",
      "THREAD 8: 5 out of 13\n",
      "\n",
      "THREAD 5: 6 out of 13\n",
      "\n",
      "THREAD 7: 5 out of 13\n",
      "\n",
      "THREAD 0: 7 out of 14\n",
      "\n",
      "THREAD 1: 5 out of 13\n",
      "\n",
      "THREAD 11: 5 out of 13\n",
      "\n",
      "THREAD 6: 6 out of 13\n",
      "\n",
      "THREAD 9: 7 out of 13\n",
      "\n",
      "THREAD 10: 7 out of 13\n",
      "\n",
      "THREAD 3: 7 out of 13\n",
      "\n",
      "THREAD 0: 8 out of 14\n",
      "\n",
      "THREAD 7: 6 out of 13\n",
      "\n",
      "THREAD 2: 5 out of 13\n",
      "\n",
      "THREAD 11: 6 out of 13\n",
      "\n",
      "THREAD 8: 6 out of 13\n",
      "\n",
      "THREAD 4: 7 out of 13\n",
      "\n",
      "THREAD 1: 6 out of 13\n",
      "\n",
      "THREAD 5: 7 out of 13\n",
      "\n",
      "THREAD 6: 7 out of 13\n",
      "\n",
      "THREAD 7: 7 out of 13\n",
      "\n",
      "THREAD 9: 8 out of 13\n",
      "\n",
      "THREAD 0: 9 out of 14\n",
      "\n",
      "THREAD 8: 7 out of 13\n",
      "\n",
      "THREAD 10: 8 out of 13\n",
      "\n",
      "THREAD 4: 8 out of 13\n",
      "\n",
      "THREAD 11: 7 out of 13\n",
      "\n",
      "THREAD 3: 8 out of 13\n",
      "\n",
      "THREAD 1: 7 out of 13\n",
      "\n",
      "THREAD 6: 8 out of 13\n",
      "\n",
      "THREAD 0: 10 out of 14\n",
      "\n",
      "THREAD 2: 6 out of 13\n",
      "\n",
      "THREAD 8: 8 out of 13\n",
      "\n",
      "THREAD 7: 8 out of 13\n",
      "\n",
      "THREAD 4: 9 out of 13\n",
      "\n",
      "THREAD 11: 8 out of 13\n",
      "\n",
      "THREAD 9: 9 out of 13\n",
      "\n",
      "THREAD 5: 8 out of 13\n",
      "\n",
      "THREAD 0: 11 out of 14\n",
      "\n",
      "THREAD 6: 9 out of 13\n",
      "\n",
      "THREAD 10: 9 out of 13\n",
      "\n",
      "THREAD 7: 9 out of 13\n",
      "\n",
      "THREAD 4: 10 out of 13\n",
      "\n",
      "THREAD 3: 9 out of 13\n",
      "\n",
      "THREAD 1: 8 out of 13\n",
      "\n",
      "THREAD 9: 10 out of 13\n",
      "\n",
      "THREAD 11: 9 out of 13\n",
      "\n",
      "THREAD 2: 7 out of 13\n",
      "\n",
      "THREAD 8: 9 out of 13\n",
      "\n",
      "THREAD 0: 12 out of 14\n",
      "\n",
      "THREAD 7: 10 out of 13\n",
      "\n",
      "THREAD 6: 10 out of 13\n",
      "\n",
      "THREAD 4: 11 out of 13\n",
      "\n",
      "THREAD 5: 9 out of 13\n",
      "\n",
      "THREAD 10: 10 out of 13\n",
      "\n",
      "THREAD 11: 10 out of 13\n",
      "\n",
      "THREAD 1: 9 out of 13\n",
      "\n",
      "THREAD 3: 10 out of 13\n",
      "\n",
      "THREAD 2: 8 out of 13\n",
      "\n",
      "THREAD 7: 11 out of 13\n",
      "\n",
      "THREAD 9: 11 out of 13\n",
      "\n",
      "THREAD 5: 10 out of 13\n",
      "\n",
      "THREAD 11: 11 out of 13\n",
      "\n",
      "THREAD 8: 10 out of 13\n",
      "\n",
      "THREAD 1: 10 out of 13\n",
      "\n",
      "THREAD 0: 13 out of 14\n",
      "\n",
      "THREAD 4: 12 out of 13\n",
      "\n",
      "THREAD 10: 11 out of 13\n",
      "\n",
      "THREAD 6: 11 out of 13\n",
      "\n",
      "THREAD 7: 12 out of 13\n",
      "\n",
      "THREAD 2: 9 out of 13\n",
      "\n",
      "THREAD 3: 11 out of 13\n",
      "\n",
      "THREAD 8: 11 out of 13\n",
      "\n",
      "THREAD 6: 12 out of 13\n",
      "\n",
      "THREAD 11: 12 out of 13\n",
      "\n",
      "THREAD 5: 11 out of 13\n",
      "\n",
      "THREAD 1: 11 out of 13\n",
      "\n",
      "THREAD 2: 10 out of 13\n",
      "\n",
      "THREAD 9: 12 out of 13\n",
      "\n",
      "THREAD 0: 14 out of 14\n",
      "\n",
      "THREAD 7: 13 out of 13\n",
      "\n",
      "THREAD 4: 13 out of 13\n",
      "\n",
      "THREAD 3: 12 out of 13\n",
      "\n",
      "THREAD 10: 12 out of 13\n",
      "\n",
      "THREAD 11: 13 out of 13\n",
      "\n",
      "THREAD 5: 12 out of 13\n",
      "\n",
      "THREAD 1: 12 out of 13\n",
      "\n",
      "THREAD 2: 11 out of 13\n",
      "\n",
      "THREAD 6: 13 out of 13\n",
      "\n",
      "THREAD 8: 12 out of 13\n",
      "\n",
      "THREAD 9: 13 out of 13\n",
      "\n",
      "THREAD 3: 13 out of 13\n",
      "\n",
      "THREAD 10: 13 out of 13\n",
      "\n",
      "THREAD 2: 12 out of 13\n",
      "\n",
      "THREAD 8: 13 out of 13\n",
      "\n",
      "THREAD 5: 13 out of 13\n",
      "\n",
      "THREAD 1: 13 out of 13\n",
      "\n",
      "THREAD 2: 13 out of 13\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Now we parallelize the calculation and execute it\n",
    "# May have problems downloading lightcurves authored by DIAMANTE\n",
    "\n",
    "arguments = [(df, index) for df, index in zip(tables, indices)]\n",
    "\n",
    "p = mp.Pool(NCORES)\n",
    "\n",
    "result = list(p.imap(calculate_prediction, arguments))\n",
    "\n",
    "p.close()\n",
    "p.join()\n",
    "\n",
    "final_df = pd.DataFrame()\n",
    "\n",
    "\n",
    "for df in result:\n",
    "    final_df = pd.concat([final_df, df])\n",
    "\n",
    "final_df.to_csv(\"tois_with_predictions.csv\", index=False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f9e80d3b",
   "metadata": {},
   "source": [
    "\n",
    "Now we calculate the actual improvements in precision by using the predicted precisions and make it into a nice table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f65367d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>toi</th>\n",
       "      <th>tid</th>\n",
       "      <th>20_improv</th>\n",
       "      <th>120_improv</th>\n",
       "      <th>ror_sd_20</th>\n",
       "      <th>ror_sd_120</th>\n",
       "      <th>ror_sd_1800</th>\n",
       "      <th>ror_sd_600</th>\n",
       "      <th>600_improv</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>1677.01</td>\n",
       "      <td>87090944</td>\n",
       "      <td>77.932100</td>\n",
       "      <td>77.709268</td>\n",
       "      <td>0.010058</td>\n",
       "      <td>0.010160</td>\n",
       "      <td>0.045579</td>\n",
       "      <td>0.012324</td>\n",
       "      <td>72.961336</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>2784.01</td>\n",
       "      <td>302766000</td>\n",
       "      <td>70.941629</td>\n",
       "      <td>70.354258</td>\n",
       "      <td>0.023336</td>\n",
       "      <td>0.023808</td>\n",
       "      <td>0.080307</td>\n",
       "      <td>0.036841</td>\n",
       "      <td>54.125414</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>3786.01</td>\n",
       "      <td>321250206</td>\n",
       "      <td>66.841361</td>\n",
       "      <td>66.794015</td>\n",
       "      <td>0.009582</td>\n",
       "      <td>0.009595</td>\n",
       "      <td>0.028897</td>\n",
       "      <td>0.009857</td>\n",
       "      <td>65.889581</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>1701.01</td>\n",
       "      <td>274215536</td>\n",
       "      <td>66.328226</td>\n",
       "      <td>65.003455</td>\n",
       "      <td>0.002721</td>\n",
       "      <td>0.002828</td>\n",
       "      <td>0.008080</td>\n",
       "      <td>0.004805</td>\n",
       "      <td>40.533684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>2578.01</td>\n",
       "      <td>104986789</td>\n",
       "      <td>65.538223</td>\n",
       "      <td>65.110207</td>\n",
       "      <td>0.001560</td>\n",
       "      <td>0.001580</td>\n",
       "      <td>0.004528</td>\n",
       "      <td>0.001907</td>\n",
       "      <td>57.888268</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>3244.01</td>\n",
       "      <td>208091447</td>\n",
       "      <td>4.763420</td>\n",
       "      <td>4.733792</td>\n",
       "      <td>0.007618</td>\n",
       "      <td>0.007620</td>\n",
       "      <td>0.007999</td>\n",
       "      <td>0.007662</td>\n",
       "      <td>4.211289</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>139</th>\n",
       "      <td>5664.01</td>\n",
       "      <td>158022899</td>\n",
       "      <td>3.407098</td>\n",
       "      <td>3.350451</td>\n",
       "      <td>0.005839</td>\n",
       "      <td>0.005842</td>\n",
       "      <td>0.006045</td>\n",
       "      <td>0.005887</td>\n",
       "      <td>2.607795</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>5599.01</td>\n",
       "      <td>159160230</td>\n",
       "      <td>2.995921</td>\n",
       "      <td>2.898366</td>\n",
       "      <td>0.002671</td>\n",
       "      <td>0.002674</td>\n",
       "      <td>0.002753</td>\n",
       "      <td>0.002705</td>\n",
       "      <td>1.761623</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>5699.01</td>\n",
       "      <td>224328450</td>\n",
       "      <td>2.988077</td>\n",
       "      <td>2.958728</td>\n",
       "      <td>0.006361</td>\n",
       "      <td>0.006363</td>\n",
       "      <td>0.006557</td>\n",
       "      <td>0.006392</td>\n",
       "      <td>2.518649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>5637.01</td>\n",
       "      <td>136992839</td>\n",
       "      <td>2.825291</td>\n",
       "      <td>2.802809</td>\n",
       "      <td>0.003070</td>\n",
       "      <td>0.003071</td>\n",
       "      <td>0.003159</td>\n",
       "      <td>0.003091</td>\n",
       "      <td>2.167626</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>157 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         toi        tid  20_improv  120_improv  ror_sd_20  ror_sd_120  \\\n",
       "11   1677.01   87090944  77.932100   77.709268   0.010058    0.010160   \n",
       "22   2784.01  302766000  70.941629   70.354258   0.023336    0.023808   \n",
       "78   3786.01  321250206  66.841361   66.794015   0.009582    0.009595   \n",
       "12   1701.01  274215536  66.328226   65.003455   0.002721    0.002828   \n",
       "21   2578.01  104986789  65.538223   65.110207   0.001560    0.001580   \n",
       "..       ...        ...        ...         ...        ...         ...   \n",
       "29   3244.01  208091447   4.763420    4.733792   0.007618    0.007620   \n",
       "139  5664.01  158022899   3.407098    3.350451   0.005839    0.005842   \n",
       "127  5599.01  159160230   2.995921    2.898366   0.002671    0.002674   \n",
       "145  5699.01  224328450   2.988077    2.958728   0.006361    0.006363   \n",
       "131  5637.01  136992839   2.825291    2.802809   0.003070    0.003071   \n",
       "\n",
       "     ror_sd_1800  ror_sd_600  600_improv  \n",
       "11      0.045579    0.012324   72.961336  \n",
       "22      0.080307    0.036841   54.125414  \n",
       "78      0.028897    0.009857   65.889581  \n",
       "12      0.008080    0.004805   40.533684  \n",
       "21      0.004528    0.001907   57.888268  \n",
       "..           ...         ...         ...  \n",
       "29      0.007999    0.007662    4.211289  \n",
       "139     0.006045    0.005887    2.607795  \n",
       "127     0.002753    0.002705    1.761623  \n",
       "145     0.006557    0.006392    2.518649  \n",
       "131     0.003159    0.003091    2.167626  \n",
       "\n",
       "[157 rows x 9 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "final_df = pd.read_csv(\"tois_with_predictions.csv\")\n",
    "\n",
    "improvements = pd.DataFrame(columns=['toi', 'tid', '20_improv', '120_improv', 'ror_sd_20', 'ror_sd_120', 'ror_sd_1800'])\n",
    "\n",
    "improvements['toi'] = final_df['toi']\n",
    "improvements['tid'] = final_df['tid']\n",
    "improvements['ror_sd_20'] = final_df['ror_20_sd']\n",
    "improvements['ror_sd_120'] = final_df['ror_120_sd']\n",
    "improvements['ror_sd_600'] = final_df['ror_600_sd']\n",
    "improvements['ror_sd_1800'] = final_df['ror_1800_sd']\n",
    "\n",
    "improvements['20_improv'] = (1 - improvements['ror_sd_20'] / improvements['ror_sd_1800']) * 100\n",
    "improvements['120_improv'] = (1 - improvements['ror_sd_120'] / improvements['ror_sd_1800']) * 100\n",
    "improvements['600_improv'] = (1 - improvements['ror_sd_600'] / improvements['ror_sd_1800']) * 100\n",
    "\n",
    "improvements.sort_values(by=['20_improv'], ascending=False, inplace=True)\n",
    "\n",
    "# And this is our nice table with all predictions\n",
    "improvements"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e77da2fb",
   "metadata": {},
   "source": [
    "\n",
    "Can also convert the table to a latex table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3a50c7f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# We select the top 10\n",
    "\n",
    "table = improvements.head(10)\n",
    "\n",
    "table.drop(labels=['tid', 'ror_sd_20', 'ror_sd_120', 'ror_sd_600', 'ror_sd_1800'], inplace=True, axis=1)\n",
    "\n",
    "table.rename(columns={\n",
    "    'toi': 'TOI',\n",
    "    '20_improv': '20s Improv. [%]',\n",
    "    '120_improv': '120s Improv [%]',\n",
    "    '600_improv': '600s Improv [%]'\n",
    "}, inplace=True)\n",
    "\n",
    "table.to_latex('improvements_table.tex', index=False, float_format=\"%.2f\")\n",
    "\n",
    "# Which is the table used in the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9fe97a51",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# We also generate a longer table for the appendix\n",
    "\n",
    "N = 100\n",
    "\n",
    "table = improvements\n",
    "\n",
    "table.drop(labels=['tid', 'ror_sd_20', 'ror_sd_120', 'ror_sd_600', 'ror_sd_1800'], inplace=True, axis=1)\n",
    "\n",
    "table.rename(columns={\n",
    "    'toi': 'TOI',\n",
    "    '20_improv': '20s Improv. [%]',\n",
    "    '120_improv': '120s Improv [%]',\n",
    "    '600_improv': '600s Improv [%]'\n",
    "}, inplace=True)\n",
    "\n",
    "table.to_latex('long_improvements_table.tex', index=False, float_format=\"%.2f\", caption=f\"Predictions for all the planet candidates considered in order of decreasing improvements to the radius ratio precision.\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7dfe43da",
   "metadata": {},
   "source": [
    "\n",
    "And save a formated csv file of all results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "540ef117",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "formatted_table = improvements.drop(labels=['tid', 'ror_sd_20', 'ror_sd_120', 'ror_sd_600', 'ror_sd_1800'], axis=1)\n",
    "\n",
    "formatted_table.rename(columns={\n",
    "    'toi': 'TOI',\n",
    "    '20_improv': '20s_improv',\n",
    "    '120_improv': '120s_improv',\n",
    "    '600_improv': '600s_improv'\n",
    "}, inplace=True)\n",
    "\n",
    "formatted_table.to_csv(\"formatted_toi_predictions.csv\", index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
