{% include 'imports/pipeline' %}

{% if framework == 'tensorflow' %}
def build_model(encoders):
    """Builds and compiles the model from scratch.

    # Arguments
        encoders: dict of encoders (used to set size of text/categorical inputs)

    # Returns
        model: A compiled model which can be used to train or predict.
    """

{% if has_text_input %}
{% include 'models/' ~ text_framework ~ '/text' %}
{% endif %}

{% for field, field_raw, field_type in nontarget_fields %}
{% if field_type != 'text' %}
{% include 'models/' ~ framework ~ '/' ~ field_type %}

{% endif %}
{% endfor %}
    # Combine all the inputs into a single layer
    concat = concatenate([
        {% for field, _, field_type in nontarget_fields %}
        {% if field_type == 'text' %}
        {{ field }}_enc{{ ", " if not loop.last }}
        {% elif field_type != 'datetime' %}
        input_{{ field }}{{ ", " if not loop.last }}
        {% else %}
        input_dayofweeks_{{ field }},
        input_hours_{{ field }}{{ ", " if not loop.last and not (params['datetime_month'] or params['datetime_year']) }}
        {% if params['datetime_month'] %}
        ,input_month_{{ field }}{{ ", " if not loop.last and not params['datetime_year'] }}
        {% endif %}
        {% if params['datetime_year'] %}
        ,input_year_{{ field }}{{ ", " if not loop.last }}
        {% endif %}
        {% endif %}
        {% endfor %}
    ], name="concat")

{% include 'models/' ~ framework ~ '/mlp' %}

    # Build and compile the model.
    model = Model(inputs=[
        {% for field, _, field_type in nontarget_fields %}
        {% if field_type != 'datetime' and field != target_field %}
        input_{{ field }}{{ ", " if not loop.last }}
        {% elif field != target_field %}
        input_dayofweeks_{{ field }},
        input_hours_{{ field }}{{ ", " if not loop.last and not (params['datetime_month'] or params['datetime_year'])}}
        {% if params['datetime_month'] %}
        ,input_month_{{ field }}{{ ", " if not loop.last and not params['datetime_year']}}
        {% endif %}
        {% if params['datetime_year'] %}
        ,input_year_{{ field }}{{ ", " if not loop.last }}
        {% endif %}
        {% endif %}
        {% endfor %}
                ],
                      outputs=[output])
    {% if not tpu_address %}
    model.compile(loss={% include 'models/' ~ framework ~ '/loss' %},
              optimizer=AdamWOptimizer(learning_rate = {{ params['base_lr'] }},
                                        weight_decay = {{ params['weight_decay'] }}))
    {% else %}
    model.compile(loss={% include 'models/' ~ framework ~ '/loss' %},
              optimizer=AdamOptimizer(learning_rate = {{ params['base_lr'] }}))
    {% endif %}

    return model
{% endif %}

def build_encoders(df):
    """Builds encoders for fields to be used when
    processing data for the model.

    All encoder specifications are stored in locally
    in /encoders as .json files.

    # Arguments
        df: A pandas DataFrame containing the data.
    """

{% if has_text_input %}
{% include 'encoders/' ~ text_framework ~ '-text' %}
{% endif %}

{% for field, field_raw, field_type in nontarget_fields %}
{% if field_type != 'text' %}
{% include 'encoders/' ~ field_type %}

{% endif %}

{% endfor %}
{% include 'encoders/target' %}

def load_encoders():
    """Loads the encoders built during `build_encoders`.

    # Returns
        encoders: A dict of encoder objects/specs.
    """

    encoders = {}

{% if has_text_input %}
{% include 'loaders/' ~ text_framework ~ '-text' %}
{% endif %}

{% for field, field_raw, field_type in nontarget_fields %}
{% if field_type != 'text' %}
{% include 'loaders/' ~ field_type %}

{% endif %}

{% endfor %}
{% include 'loaders/target' %}

    return encoders

def process_data(df, encoders, process_target=True):
    """Processes an input DataFrame into a format
    sutable for model prediction.

    This function loads the encoder specifications created in
    `build_encoders`.

    # Arguments
        df: a DataFrame containing the source data
        encoders: a dict of encoders to process the data.
        process_target: boolean to determine if the target should be encoded.

    # Returns
        A tuple: A list containing all the processed fields to be fed
        into the model, and the processed target field.
    """

{% if has_text_input %}
{% include 'processors/' ~ text_framework ~ '-text' %}
{% endif %}

{% for field, field_raw, field_type in nontarget_fields %}
{% if field_type != 'text' %}
{% include 'processors/' ~ field_type %}
{% endif %}

{% endfor %}

    data_enc = [{% for field, _, field_type in nontarget_fields %}
        {% if field_type != 'datetime' %}
        {{ field }}_enc{{ ", " if not loop.last }}
        {% else %}
        {{ field }}_dayofweeks_enc,
        {{ field }}_hour_enc{{ ", " if not loop.last and not (params['datetime_month'] or params['datetime_year'])}}
        {% if params['datetime_month'] %}
        ,{{ field }}_month_enc{{ ", " if not loop.last and not params['datetime_year']}}
        {% endif %}
        {% if params['datetime_year'] %}
        ,{{ field }}_year_enc{{ ", " if not loop.last }}
        {% endif %}
        {% endif %}
        {% endfor %}
        ]

    if process_target:
{% include 'processors/target' %}
        return (data_enc, {{ target_field }}_enc)

    return data_enc


def model_predict(df, model, encoders):
    """Generates predictions for a trained model.

    # Arguments
        df: A pandas DataFrame containing the source data.
        model: A compiled model.
        encoders: a dict of encoders to process the data.

    # Returns
        A numpy array of predictions.
    """

    data_enc = process_data(df, encoders, process_target=False)

    {% if framework == 'xgboost' %}
    data_enc = xgb.DMatrix(np.hstack(data_enc))
    {% endif %}

    {% if problem_type == 'classification' %}
    headers = encoders['{{ target_field }}_encoder'].classes_
    {% elif problem_type == 'binary_classification' %}
    headers = ['probability']
    {% else %}
    headers = ['{{ target_field }}']
    {% endif %}
    predictions =  pd.DataFrame(model.predict(data_enc), columns=headers)

    return predictions
    

def model_train(df, encoders, args, model=None):
    """Trains a model, and saves the data locally.

    # Arguments
        df: A pandas DataFrame containing the source data.
        encoders: a dict of encoders to process the data.
        args: a dict of arguments passed through the command line
        model: A compiled model (for TensorFlow, None otherwise).
    """
    {% if framework == 'tensorflow' %}
    X, y = process_data(df, encoders)
    {% endif %}

    {% if framework == 'xgboost' %}
    X, y_enc = process_data(df, encoders)
    X = np.hstack(X)
    y = df['{{ target_field_raw }}'].values
    {% endif %}

    {% if problem_type == 'regression' %}
    split = ShuffleSplit(n_splits=1, train_size=args.split, test_size=None, random_state=123)
    {% else %}
    split = StratifiedShuffleSplit(n_splits=1, train_size=args.split, test_size=None, random_state=123)
    {% endif %}

    {% if framework == 'tensorflow' %}
    for train_indices, val_indices in split.split(np.zeros(y.shape[0]), y):
        X_train = [field[train_indices,] for field in X]
        X_val = [field[val_indices,] for field in X]
        y_train = y[train_indices,]
        y_val = y[val_indices,]

    meta = meta_callback(args, X_val, y_val)

    {% if tpu_address %}
    if args.context == 'automl-gs':
        model = tf.contrib.tpu.keras_to_tpu_model(model,
               strategy=tf.contrib.tpu.TPUDistributionStrategy(
               tf.contrib.cluster_resolver.TPUClusterResolver('{{ tpu_address }}')))
    {% endif %}

    model.fit(X_train, y_train, validation_data=(X_val, y_val),
                epochs=args.epochs,
                callbacks=[meta],
                batch_size={{ params['batch_size'] }}{% if tpu_address %}* 8{% endif %})
    {% endif %}

    {% if framework == 'xgboost' %}
    for train_indices, val_indices in split.split(np.zeros(y.shape[0]), y):
        train = xgb.DMatrix(X[train_indices,], y[train_indices,])
        val = xgb.DMatrix(X[val_indices,], y[val_indices,])

    params = {
        'eta': {{ params['base_lr'] }},
        'max_depth': {{ params['max_depth'] }},
        'gamma': {{ params['gamma'] }},
        'min_child_weight': {{ params['min_child_weight'] }},
        'subsample': {{ params['subsample'] }},
        'colsample_bytree': {{ params['colsample_bytree'] }},
        'max_bin': {{ params['max_bin'] }},
        'objective': {% include 'models/' ~ framework ~ '/loss' %}
        'tree_method': {% if gpu %}'gpu_hist'{% else %}'hist'{% endif %},
        'silent': 1
    }

    f = open(os.path.join('metadata', 'results.csv'), 'w')
    w = csv.writer(f)
    w.writerow(['epoch', 'time_completed'] + {{ metrics }})

    y_true = y_enc[val_indices, ]
    for epoch in range(args.epochs):
        model = xgb.train(params, train, 1,
                          xgb_model=model if epoch > 0 else None)
        y_pred = model.predict(val)

        {% include 'callbacks/problem_types/' ~ problem_type %}

        time_completed = "{:%Y-%m-%d %H:%M:%S}".format(datetime.utcnow())
        w.writerow([epoch+1, time_completed] + metrics)

        if args.context == 'automl-gs':
            sys.stdout.flush()
            print("\nEPOCH_END")

    f.close()
    model.save_model('model.bin')
    {% endif %}

{% if framework == 'tensorflow' %}
{% include 'callbacks/tensorflow' %}
{% endif %}

