r"""
.. _example_segmentation_hmm_training:

SegmentationModel Training
==========================

This example illustrates how a Hidden Markov Model (HMM) implemented by the
:class:`~gaitmap.stride_segmentation.hmm.RothSegmentationHmm` can be trained from IMU data and presegmented stride lists.
The used implementation is based on the work of Roth et al [1]_

.. [1] Roth, N., Küderle, A., Ullrich, M., Gladow, T., Marxreiter F., Klucken, J., Eskofier, B. & Kluge F. (2021).
   Hidden Markov Model based Stride Segmentation on Unsupervised Free-living Gait Data in Parkinson’s Disease Patients.
   Journal of NeuroEngineering and Rehabilitation, (JNER).
"""

import numpy as np
from matplotlib import pyplot as plt

np.random.seed(1)

# %%
# Getting some example data
# --------------------------
#
# For this we take some example data that contains the regular walking movement during a 2x20m walk test of a healthy
# subject. The IMU signals are already rotated so that they align with the gaitmap SF coordinate system.
# The data contains information from two sensors - one from the right and one from the left foot.
from gaitmap.example_data import get_healthy_example_imu_data

data = get_healthy_example_imu_data()
sampling_rate_hz = 204.8
data.sort_index(axis=1).head(1)

# %%
# Preparing the data
# ------------------
# The HMM only makes use of the gyro information.
# Further, if you use this model, your data is expected to be in the gaitmap body-frame to be able to use the
# same model for the left and the right foot.
# Therefore, we need to transform the dataset into the body frame.
from gaitmap.utils.coordinate_conversion import convert_to_fbf

# We use the `..._like` parameters to identify the data of the left and the right foot based on the name of the sensor.
bf_data = convert_to_fbf(data, left_like="left_", right_like="right_")

# %%
# Getting the example stride list
# -------------------------------
#
# For this we take the ground truth stride list provided with the example data.
# For new data this stride list can be generated by running the algorithms provided in the
# :class:`~gaitmap.stride_segmentation` module and then manually corrected, or by creating a stride list using ground
# truth data.
from gaitmap.example_data import get_healthy_example_stride_borders

stride_list = get_healthy_example_stride_borders()

from gaitmap.data_transform import ButterworthFilter

# %%
# Initialize Model Parameters - Feature Transformation
# ----------------------------------------------------
#
# Here we define the feature space in which model training and later prediction will take place. You can choose
# different axis and or feature combinations as well as downsampling, filter and standardization steps. The following
# example has proved to work well in most cases.
from gaitmap.stride_segmentation.hmm import RothHmmFeatureTransformer

feature_transform = RothHmmFeatureTransformer(
    sampling_rate_feature_space_hz=51.2,
    axes=["gyr_ml"],
    features=["raw", "gradient"],
    low_pass_filter=ButterworthFilter(order=4, cutoff_freq_hz=10),
    window_size_s=0.2,
    standardization=True,
)

# %%
# Initialize Model Parameters - Sub HMMs
# --------------------------------------
#
# The segmentation process is defined as a two-class problem, namely "strides" and "transitions/null".
# For each class we define a separate HMM and define all its components. Notice that the stride and transition model are
# different in architecture, number of states or number of gaussian mixture model (GMM) components.
# In this example all configurable parameters are exposed.
# These parameters might require optimization for your specific type of dataset!
from gaitmap.stride_segmentation.hmm import SimpleHmm

stride_model = SimpleHmm(
    n_states=20,
    n_gmm_components=6,
    algo_train="baum-welch",
    stop_threshold=1e-9,
    max_iterations=5,
    architecture="left-right-strict",
    verbose=True,
    name="stride_model",
)

transition_model = SimpleHmm(
    n_states=5,
    n_gmm_components=3,
    algo_train="baum-welch",
    stop_threshold=1e-9,
    max_iterations=5,
    architecture="left-right-loose",
    verbose=True,
    name="transition_model",
)

# %%
# Initialize Model Parameters - Segmentation Model
# ------------------------------------------------
#
# Finally we can combine the feature extraction and our defined sub-HMMs to the actual segmentation model were we can
# invoke the training process.
# Again, all configurable parameters are exposed for demonstration purpose.
# These parameters should again work for most usecases.
from gaitmap.stride_segmentation.hmm import RothSegmentationHmm

segmentation_model = RothSegmentationHmm(
    stride_model=stride_model,
    transition_model=transition_model,
    feature_transform=feature_transform,
    algo_predict="viterbi",
    algo_train="baum-welch",
    stop_threshold=1e-9,
    max_iterations=1,
    initialization="labels",
    verbose=True,
    name="segmentation_model",
)

# %%
# Prepare Data for Training
# --------------------------------------
#
# The HMM does not differentiate between left or right strides, (this is why we must have our data in the body-frame
# convention!).
# The main input format for the training process are gait sequences which include transitions as well as valid strides.
# To train on multiple sequences, we can just feed a list of gaitsequences into the model for training.
# For each gait sequence we also need to have a valid stride list. In this example we handle the data from the left and
# right foot as separate gait sequences and add them to a simple list.
# We have to do the same for the stride lists.

data_train_sequence = [bf_data["left_sensor"], bf_data["right_sensor"]]
stride_list_sequence = [stride_list["left_sensor"], stride_list["right_sensor"]]

# %%
# Training
# --------------------------------------
#
# Finally! Sit back relax and let the magic happen (depending on the number of input sequences this can take up to
# >30min).
# However, this small example runs quite fast!
# The model will internally perform the feature transformation of the dataset, train the individual sub models and
# finally combine them to a flatted segmentation model.

segmentation_model = segmentation_model.self_optimize(
    data_train_sequence, stride_list_sequence, sampling_rate_hz=sampling_rate_hz
)

# %%
# Inspecting the Results
# --------------------------------------
#
# Now all internal models which were initialized as "None" should be populated by pomegranate models.
# We can now have a look at the final transition matrix or the trained distributions (GMMs).
# You could now either use the model to predict stride borders on an unseen sequence or save it to a json file for later
# use.

np.set_printoptions(precision=3, linewidth=180, suppress=True)

print(segmentation_model.model.dense_transition_matrix()[0:-2, 0:-2])

print(segmentation_model.model.states[10])

# %%
# Applying the Model to a Sequence
# --------------------------------
# in the follwoing we will apply the model to the same sequence we used for training, just to show that the model
# "learned" something.
# We will also plot the results to see how well the model performs.
from gaitmap.stride_segmentation.hmm import HmmStrideSegmentation

hmm = HmmStrideSegmentation(segmentation_model).segment(bf_data, sampling_rate_hz=sampling_rate_hz)
hmm.stride_list_

# %%
# Plotting the Results
# --------------------
sensor = "left_sensor"

fig, axs = plt.subplots(nrows=2, sharex=True, figsize=(10, 5))
axs[0].set_title("gaitmap Body Frame Dataset")
axs[0].plot(bf_data.reset_index(drop=True)[sensor]["gyr_ml"])
for start, end in hmm.stride_list_["left_sensor"].to_numpy():
    axs[0].axvline(start, c="r")
    axs[0].axvline(end, c="r")
    axs[0].axvspan(start, end, alpha=0.2)
axs[0].set_ylabel("gyr-ml [deg/s]")

axs[1].set_title("Predicted Hidden State Sequence")
axs[1].plot(hmm.hidden_state_sequence_[sensor])
for start, end in hmm.matches_start_end_original_[sensor]:
    axs[1].axvline(start, c="g")
    axs[1].axvline(end, c="g")
    axs[1].axvspan(start, end, alpha=0.2)
axs[1].set_ylabel("Hidden State [N]")

axs[1].set_xlabel("Samples @ %d Hz" % sampling_rate_hz)
plt.xlim([6000, 7200])
fig.tight_layout()
plt.show()
