#!/bin/bash

## Usage ################################
# ./bigdl-nano-init <command to run under openmp and jemalloc environment variables>
# Example:
# bigdl-nano-init python pytorch-lenet.py --device ipex
#########################################


# Get options
function disable-openmp-var {
	echo "Option: Disable opemMP and unset related variables"
	export DISABLE_OPENMP_VAR=1
}

function enable-jemalloc-var {
	echo "Option: Enable jemalloc and unset related variables"
	export ENABLE_JEMALLOC_VAR=1
}

function disable-tcmalloc-var {
	echo "Option: Disable tcmalloc and unset related variables"
	export DISABLE_TCMALLOC_VAR=1
}

function enable-tensorflow-var {
	echo "Option: Enable tensorflow optimaztion"
	export ENABLE_TF_OPTS=1
}

function display-error {
	echo "Invalid Option: -$1" 1>&2
        echo "Try to run 'bigdl-nano-run -h' for detailed usage." 1>&2
        exit 1
}

function display-help {
	echo "Usage: bigdl-nano-run [-o] [-j] <subcommand>"
        echo ""
        echo "bigdl-nano-run is a tool to automatically configure and run the subcommand under"
        echo "environment variables for accelerating pytorch."
        echo ""
        echo "Optional options:"
        echo "    -h, --help                Display this help message and exit."
        echo "    -o, --disable-openmp      Disable openMP and unset related variables"
        echo "    -j, --enable-jemalloc     Enable jemalloc and unset related variables"
		echo "    -c, --disable-tcmalloc   Disable tcmalloc and unset related variables"
}

while getopts ":ojhc-:" opt; do
	case ${opt} in
		- )
			case "${OPTARG}" in
				disable-openmp)
					disable-openmp-var
					;;
				enable-jemalloc)
					enable-jemalloc-var
					;;
				disable-tcmalloc)
					disable-tcmalloc-var
					;;
				enable-tensorflow)
					enable-tensorflow-var
					;;
				help)
					display-help
					exit 0
					;;
				*)
					display-error "-$OPTARG"
					;;
			esac
			;;

		o )
			disable-openmp-var
			;;
		j )
			enable-jemalloc-var
			;;
		c )
			disable-tcmalloc-var
			;;
		h )
			display-help
			exit 0
			;;	
		\? )
			display-error $OPTARG
			;;
	esac
done

shift $((OPTIND -1))

# Init
OPENMP=0
JEMALLOC=0
TCMALLOC=0

# Find conda dir
if [ ! -z $BASH_SOURCE ]; then
	# using bash
	BASE_DIR="$(dirname $BASH_SOURCE)/.."
else
	# using zsh
	BASE_DIR="$(dirname ${(%):-%N})/.."
fi
echo "conda dir found: $BASE_DIR"
LIB_DIR=$BASE_DIR/lib

PYTHON_VERSION=$(python3 -c "import platform; major, mnior, patch = platform.python_version_tuple();print(major+'.'+mnior)")

NANO_DIR="$(dirname "$(python3 -c "import bigdl; print(bigdl.__file__)")")/nano/"


# Detect Intel openMP library
if [ -f "${LIB_DIR}/libiomp5.so" ]; then
	echo "OpenMP library found..."
	OPENMP=1

	# detect number of physical cores
	cpu_infos=($(lscpu -p=CPU,Socket,Core | grep -P '^(\d*),(\d*),(\d*)$'))
	max_cpu_info=($(echo ${cpu_infos[-1]} | sed 's/,/\ /g'))
	# bash's array index starts from 0, while zsh's array index starts from 1,
	# so we use -1, -2, -3 as index here for consistency
	let cpu_=${max_cpu_info[-3]}+1
	let sockets_=${max_cpu_info[-2]}+1
	let core_=${max_cpu_info[-1]}+1
	let threads_per_core=$cpu_/$core_
	let cores_per_socket=$core_/$sockets_


	# set env variables
	echo "Setting OMP_NUM_THREADS..."

	if [ -z "${ENABLE_TF_OPTS:-}" ]; then
		echo "Setting OMP_NUM_THREADS specified for pytorch..."
		export OMP_NUM_THREADS=$((cores_per_socket*sockets_))
		export NANO_TF_INTER_OP=1
	else
		export OMP_NUM_THREADS=${cores_per_socket}
		export NANO_TF_INTER_OP=$sockets_
	fi

	echo "Setting KMP_AFFINITY..."
	export KMP_AFFINITY=granularity=fine

	echo "Setting KMP_BLOCKTIME..."
	export KMP_BLOCKTIME=1

else
	echo "No openMP library found in ${LIB_DIR}."
fi

# Detect jemalloc library
JEMALLOC=1

# Detect tcmalloc library
TCMALLOC=1

# set env variables
echo "Setting MALLOC_CONF..."
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"

# Set LD_PRELOAD
if [ -z "${LD_PRELOAD:-}" ]; then
	echo "Setting LD_PRELOAD..."
	if [[ $OPENMP -eq 1 && -z "${DISABLE_OPENMP_VAR:-}" ]]; then
		export LD_PRELOAD="${LIB_DIR}/libiomp5.so"
	fi
	if [[ $JEMALLOC -eq 1 && ! -z "${ENABLE_JEMALLOC_VAR:-}" ]]; then
	    DISABLE_TCMALLOC_VAR=1
		if [ -f "${NANO_DIR}/libs/libjemalloc.so" ]; then
			JEMALLOC_PATH="${NANO_DIR}/libs/libjemalloc.so"
		elif [ -f "${LIB_DIR}/libjemalloc.so" ]; then
			JEMALLOC_PATH="${LIB_DIR}/libjemalloc.so"
		fi
		# if `LD_PRELOAD` or `JEMLLOC_PATH` is empty, there will be
		# extra space on the left or right sides, use echo to remove them
		export LD_PRELOAD=$(echo ${LD_PRELOAD} ${JEMALLOC_PATH})
	fi

	# Set TCMALLOC lib path
	if [[ $TCMALLOC -eq 1 && -z "${DISABLE_TCMALLOC_VAR:-}" ]]; then
		if [ -f "${NANO_DIR}/libs/libtcmalloc.so" ]; then
			TCMALLOC_PATH="${NANO_DIR}/libs/libtcmalloc.so"
		elif [ -f "${LIB_DIR}/libtcmalloc.so" ]; then
			TCMALLOC_PATH="${LIB_DIR}/libtcmalloc.so"
		fi
		# the same as above
		export LD_PRELOAD=$(echo ${LD_PRELOAD} ${TCMALLOC_PATH})
	fi
fi

# Set TF_ENABLE_ONEDNN_OPTS
export TF_ENABLE_ONEDNN_OPTS=1
export ENABLE_TF_OPTS=1

# Disable openmp or jemalloc according to options
if [ ! -z "${DISABLE_OPENMP_VAR:-}" ]; then
	unset OMP_NUM_THREADS
	unset KMP_AFFINITY
	unset KMP_BLOCKTIME
	unset DISABLE_OPENMP_VAR
	export LD_PRELOAD=`echo $LD_PRELOAD | sed "s/\s.*libiomp5\.so//g" | sed "s/.*libiomp5\.so\s*//g"`
fi
if [ -z "${ENABLE_JEMALLOC_VAR:-}" ]; then
	unset MALLOC_CONF
	export LD_PRELOAD=`echo $LD_PRELOAD | sed "s/\s.*libjemalloc\.so//g" | sed "s/.*libjemalloc\.so\s*//g"`
fi

if [ ! -z "${DISABLE_TCMALLOC_VAR:-}" ]; then
	unset DISABLE_TCMALLOC_VAR
	export LD_PRELOAD=`echo $LD_PRELOAD | sed "s/\s.*libtcmalloc\.so//g" | sed "s/.*libtcmalloc\.so\s*//g"`
fi

if [ -z "${CONDA_DEFAULT_ENV:-}" ]; then
	echo "Not in a conda env"
else
    if [ -f $CONDA_PREFIX/etc/conda/activate.d/nano_vars.sh ];then
        echo "nano_vars.sh already exists"
    elif [ ! -f $CONDA_PREFIX/bin/bigdl-nano-init ]; then
        echo "It seems that you are using bidl-nano installed by system's pip, "
        echo "you may need to 'source bgidl-nano-init' before using bigdl-nano "
        echo "and 'source bigdl-nano-unset-env' after using bigdl-nano yourself"
    else
        echo "Setting environment variables in current conda env"
        ACTIVATED_PATH=$CONDA_PREFIX/etc/conda/activate.d
        DEACTIVATED_PATH=$CONDA_PREFIX/etc/conda/deactivate.d
        mkdir -p $ACTIVATED_PATH
        mkdir -p $DEACTIVATED_PATH

        # bigdl-nano-init
        echo "if [ -f '${CONDA_PREFIX}/bin/bigdl-nano-init' ]; then" > $ACTIVATED_PATH/nano_vars.sh
        echo "    source ${CONDA_PREFIX}/bin/bigdl-nano-init" >> $ACTIVATED_PATH/nano_vars.sh
        echo "else" >> $ACTIVATED_PATH/nano_vars.sh
        echo "    echo 'Cannot find bigdl-nano-init, if you have uninstalled bigdl-nano, you may want to delete $ACTIVATED_PATH/nano_vars.sh and $DEACTIVATED_PATH/nano_vars.sh'" >> $ACTIVATED_PATH/nano_vars.sh
        echo "fi" >> $ACTIVATED_PATH/nano_vars.sh

        #bigdl-nano-unset-env
        echo "if [ -f '${CONDA_PREFIX}/bin/bigdl-nano-unset-env' ]; then" > $DEACTIVATED_PATH/nano_vars.sh
        echo "    source ${CONDA_PREFIX}/bin/bigdl-nano-unset-env" >> $DEACTIVATED_PATH/nano_vars.sh
        echo "else" >> $DEACTIVATED_PATH/nano_vars.sh
        echo "    echo 'Cannot find bigdl-nano-init, if you have uninstalled bigdl-nano, you may want to delete $ACTIVATED_PATH/nano_vars.sh and $DEACTIVATED_PATH/nano_vars.sh'" >> $DEACTIVATED_PATH/nano_vars.sh
        echo "fi" >> $DEACTIVATED_PATH/nano_vars.sh

        # warning
        echo "Added nano_vars.sh script to $ACTIVATED_PATH and $DEACTIVATED_PATH. You may want to delete them if you want to uninstall bigdl-nano."
    fi
fi

echo "+++++ Env Variables +++++"
echo "LD_PRELOAD=${LD_PRELOAD}"
echo "MALLOC_CONF=${MALLOC_CONF}"
echo "OMP_NUM_THREADS=${OMP_NUM_THREADS}"
echo "KMP_AFFINITY=${KMP_AFFINITY}"
echo "KMP_BLOCKTIME=${KMP_BLOCKTIME}"
echo "TF_ENABLE_ONEDNN_OPTS=${TF_ENABLE_ONEDNN_OPTS}"
echo "ENABLE_TF_OPTS=${ENABLE_TF_OPTS}"
echo "NANO_TF_INTER_OP"=${NANO_TF_INTER_OP}
echo "+++++++++++++++++++++++++"
# Run the commands after bigdl-nano-init
echo "Complete."
if [ ! $1 = "activate" ] && [ ! $1 = "deactivate" ];then
    ${@:1}
fi
