SegmentationModel Training#

This example illustrates how a Hidden Markov Model (HMM) implemented by the RothSegmentationHmm can be trained from IMU data and presegmented stride lists. The used implementation is based on the work of Roth et al [1]

import numpy as np
from matplotlib import pyplot as plt

np.random.seed(1)

Getting some example data#

For this we take some example data that contains the regular walking movement during a 2x20m walk test of a healthy subject. The IMU signals are already rotated so that they align with the gaitmap SF coordinate system. The data contains information from two sensors - one from the right and one from the left foot.

from gaitmap.example_data import get_healthy_example_imu_data

data = get_healthy_example_imu_data()
sampling_rate_hz = 204.8
data.sort_index(axis=1).head(1)
sensor left_sensor right_sensor
axis acc_x acc_y acc_z gyr_x gyr_y gyr_z acc_x acc_y acc_z gyr_x gyr_y gyr_z
0.0 0.880811 2.762208 9.40865 -0.112402 -0.032157 -0.062261 0.311553 -2.398646 9.513275 -0.323037 0.084604 -0.025288


Preparing the data#

The HMM only makes use of the gyro information. Further, if you use this model, your data is expected to be in the gaitmap body-frame to be able to use the same model for the left and the right foot. Therefore, we need to transform the dataset into the body frame.

from gaitmap.utils.coordinate_conversion import convert_to_fbf

# We use the `..._like` parameters to identify the data of the left and the right foot based on the name of the sensor.
bf_data = convert_to_fbf(data, left_like="left_", right_like="right_")

Getting the example stride list#

For this we take the ground truth stride list provided with the example data. For new data this stride list can be generated by running the algorithms provided in the stride_segmentation module and then manually corrected, or by creating a stride list using ground truth data.

from gaitmap.example_data import get_healthy_example_stride_borders

stride_list = get_healthy_example_stride_borders()

from gaitmap.data_transform import ButterworthFilter

Initialize Model Parameters - Feature Transformation#

Here we define the feature space in which model training and later prediction will take place. You can choose different axis and or feature combinations as well as downsampling, filter and standardization steps. The following example has proved to work well in most cases.

from gaitmap.stride_segmentation.hmm import RothHmmFeatureTransformer

feature_transform = RothHmmFeatureTransformer(
    sampling_rate_feature_space_hz=51.2,
    axes=["gyr_ml"],
    features=["raw", "gradient"],
    low_pass_filter=ButterworthFilter(order=4, cutoff_freq_hz=10),
    window_size_s=0.2,
    standardization=True,
)

Initialize Model Parameters - Sub HMMs#

The segmentation process is defined as a two-class problem, namely “strides” and “transitions/null”. For each class we define a separate HMM and define all its components. Notice that the stride and transition model are different in architecture, number of states or number of gaussian mixture model (GMM) components. In this example all configurable parameters are exposed. These parameters might require optimization for your specific type of dataset!

from gaitmap.stride_segmentation.hmm import SimpleHmm

stride_model = SimpleHmm(
    n_states=20,
    n_gmm_components=6,
    algo_train="baum-welch",
    stop_threshold=1e-9,
    max_iterations=5,
    architecture="left-right-strict",
    verbose=True,
    name="stride_model",
)

transition_model = SimpleHmm(
    n_states=5,
    n_gmm_components=3,
    algo_train="baum-welch",
    stop_threshold=1e-9,
    max_iterations=5,
    architecture="left-right-loose",
    verbose=True,
    name="transition_model",
)

Initialize Model Parameters - Segmentation Model#

Finally we can combine the feature extraction and our defined sub-HMMs to the actual segmentation model were we can invoke the training process. Again, all configurable parameters are exposed for demonstration purpose. These parameters should again work for most usecases.

from gaitmap.stride_segmentation.hmm import RothSegmentationHmm

segmentation_model = RothSegmentationHmm(
    stride_model=stride_model,
    transition_model=transition_model,
    feature_transform=feature_transform,
    algo_predict="viterbi",
    algo_train="baum-welch",
    stop_threshold=1e-9,
    max_iterations=1,
    initialization="labels",
    verbose=True,
    name="segmentation_model",
)

Prepare Data for Training#

The HMM does not differentiate between left or right strides, (this is why we must have our data in the body-frame convention!). The main input format for the training process are gait sequences which include transitions as well as valid strides. To train on multiple sequences, we can just feed a list of gaitsequences into the model for training. For each gait sequence we also need to have a valid stride list. In this example we handle the data from the left and right foot as separate gait sequences and add them to a simple list. We have to do the same for the stride lists.

data_train_sequence = [bf_data["left_sensor"], bf_data["right_sensor"]]
stride_list_sequence = [stride_list["left_sensor"], stride_list["right_sensor"]]

Training#

Finally! Sit back relax and let the magic happen (depending on the number of input sequences this can take up to >30min). However, this small example runs quite fast! The model will internally perform the feature transformation of the dataset, train the individual sub models and finally combine them to a flatted segmentation model.

segmentation_model = segmentation_model.self_optimize(
    data_train_sequence, stride_list_sequence, sampling_rate_hz=sampling_rate_hz
)
[1] Improvement: 341.2822906421552      Time (s): 0.2922
[2] Improvement: 86.89919843315056      Time (s): 0.2922
[3] Improvement: 47.62707182288341      Time (s): 0.2916
[4] Improvement: 52.6968024905882       Time (s): 0.2918
[5] Improvement: 19.602494214435865     Time (s): 0.2918
Total Training Improvement: 548.1078576032132
Total Training Time (s): 1.7503
[1] Improvement: 732.489856825614       Time (s): 0.01444
[2] Improvement: 63.14542963290478      Time (s): 0.01445
[3] Improvement: 47.583919162248094     Time (s): 0.01439
[4] Improvement: 11.989313699388731     Time (s): 0.01437
[5] Improvement: 9.077419204156286      Time (s): 0.01427
Total Training Improvement: 864.2859385243119
Total Training Time (s): 0.0873
[1] Improvement: 40.37531804520768      Time (s): 0.3007
Total Training Improvement: 40.37531804520768
Total Training Time (s): 0.6494

Inspecting the Results#

Now all internal models which were initialized as “None” should be populated by pomegranate models. We can now have a look at the final transition matrix or the trained distributions (GMMs). You could now either use the model to predict stride borders on an unseen sequence or save it to a json file for later use.

np.set_printoptions(precision=3, linewidth=180, suppress=True)

print(segmentation_model.model.dense_transition_matrix()[0:-2, 0:-2])

print(segmentation_model.model.states[10])
[[0.857 0.139 0.    0.    0.004 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.912 0.088 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.846 0.154 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.939 0.061 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.224 0.    0.    0.    0.725 0.051 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.666 0.334 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.669 0.331 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.626 0.374 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.717 0.283 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.664 0.336 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.665 0.335 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.652 0.348 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.588 0.412 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.515 0.485 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.529 0.471 0.    0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.528 0.472 0.    0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.431 0.569 0.    0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.627 0.373 0.    0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.823 0.177 0.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.496 0.504 0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.615 0.385 0.    0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.668 0.332 0.    0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.647 0.353 0.   ]
 [0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.643 0.357]
 [0.017 0.    0.    0.    0.    0.303 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.681]]
{
    "class" : "State",
    "distribution" : {
        "class" : "GeneralMixtureModel",
        "distributions" : [
            {
                "class" : "Distribution",
                "name" : "MultivariateGaussianDistribution",
                "parameters" : [
                    [
                        1.5144738349768228,
                        -1.234900676677567
                    ],
                    [
                        [
                            0.005669639522824596,
                            0.002265870792019985
                        ],
                        [
                            0.002265870792019985,
                            0.03979676183722518
                        ]
                    ]
                ],
                "frozen" : false
            },
            {
                "class" : "Distribution",
                "name" : "MultivariateGaussianDistribution",
                "parameters" : [
                    [
                        1.248845840066102,
                        -1.5635661215642773
                    ],
                    [
                        [
                            0.0221711234248743,
                            0.009739873876188316
                        ],
                        [
                            0.009739873876188316,
                            0.0442318803893847
                        ]
                    ]
                ],
                "frozen" : false
            },
            {
                "class" : "Distribution",
                "name" : "MultivariateGaussianDistribution",
                "parameters" : [
                    [
                        -0.14468901603846715,
                        -0.32692163041788336
                    ],
                    [
                        [
                            0.004734575782798828,
                            -0.00766969783039948
                        ],
                        [
                            -0.00766969783039948,
                            0.012456943786971592
                        ]
                    ]
                ],
                "frozen" : false
            },
            {
                "class" : "Distribution",
                "name" : "MultivariateGaussianDistribution",
                "parameters" : [
                    [
                        1.7442617711982327,
                        -1.5069847880391927
                    ],
                    [
                        [
                            0.0009567565993238514,
                            0.0032460394973877677
                        ],
                        [
                            0.0032460394973877677,
                            0.03052401621175055
                        ]
                    ]
                ],
                "frozen" : false
            },
            {
                "class" : "Distribution",
                "name" : "MultivariateGaussianDistribution",
                "parameters" : [
                    [
                        1.1419437803205614,
                        -1.7300063358729145
                    ],
                    [
                        [
                            0.015566915065239745,
                            -0.013796768043679145
                        ],
                        [
                            -0.013796768043679145,
                            0.032134311928016
                        ]
                    ]
                ],
                "frozen" : false
            },
            {
                "class" : "Distribution",
                "name" : "MultivariateGaussianDistribution",
                "parameters" : [
                    [
                        0.7192051305085859,
                        -1.8860271366256733
                    ],
                    [
                        [
                            0.024367270673471816,
                            0.009806360499260019
                        ],
                        [
                            0.009806360499260019,
                            0.020510403166483865
                        ]
                    ]
                ],
                "frozen" : false
            }
        ],
        "weights" : [
            0.3078157128141404,
            0.2727190317841552,
            0.0172892699879804,
            0.03890897167337093,
            0.09430039725866776,
            0.26896661648168546
        ]
    },
    "name" : "sa",
    "weight" : 1.0
}

Applying the Model to a Sequence#

in the follwoing we will apply the model to the same sequence we used for training, just to show that the model “learned” something. We will also plot the results to see how well the model performs.

from gaitmap.stride_segmentation.hmm import HmmStrideSegmentation

hmm = HmmStrideSegmentation(segmentation_model).segment(bf_data, sampling_rate_hz=sampling_rate_hz)
hmm.stride_list_
{'left_sensor':       start   end
s_id
0       364   584
1       584   802
2       802  1023
3      1023  1242
4      1242  1458
5      1458  1672
6      1672  1887
7      1887  2104
8      2104  2327
9      2327  2546
10     2546  2773
11     2773  2998
12     2998  3231
13     3231  3466
14     3934  4163
15     4163  4382
16     4382  4603
17     4603  4822
18     4822  5043
19     5043  5267
20     5267  5489
21     5489  5713
22     5713  5936
23     5936  6167
24     6167  6395
25     6395  6628
26     6628  6858
27     6858  7107, 'right_sensor':       start   end
s_id
0       475   691
1       691   913
2       913  1133
3      1133  1350
4      1350  1565
5      1565  1779
6      1779  1995
7      1995  2216
8      2216  2436
9      2436  2659
10     2659  2887
11     2887  3114
12     3114  3351
13     3351  3567
14     3567  3816
15     3816  4049
16     4049  4274
17     4274  4492
18     4492  4712
19     4712  4933
20     4933  5153
21     5153  5381
22     5381  5601
23     5601  5826
24     5826  6051
25     6051  6280
26     6280  6511
27     6511  6742
28     6742  6966
29     6966  7246}

Plotting the Results#

sensor = "left_sensor"

fig, axs = plt.subplots(nrows=2, sharex=True, figsize=(10, 5))
axs[0].set_title("gaitmap Body Frame Dataset")
axs[0].plot(bf_data.reset_index(drop=True)[sensor]["gyr_ml"])
for start, end in hmm.stride_list_["left_sensor"].to_numpy():
    axs[0].axvline(start, c="r")
    axs[0].axvline(end, c="r")
    axs[0].axvspan(start, end, alpha=0.2)
axs[0].set_ylabel("gyr-ml [deg/s]")

axs[1].set_title("Predicted Hidden State Sequence")
axs[1].plot(hmm.hidden_state_sequence_[sensor])
for start, end in hmm.matches_start_end_original_[sensor]:
    axs[1].axvline(start, c="g")
    axs[1].axvline(end, c="g")
    axs[1].axvspan(start, end, alpha=0.2)
axs[1].set_ylabel("Hidden State [N]")

axs[1].set_xlabel("Samples @ %d Hz" % sampling_rate_hz)
plt.xlim([6000, 7200])
fig.tight_layout()
plt.show()
gaitmap Body Frame Dataset, Predicted Hidden State Sequence

Total running time of the script: ( 0 minutes 6.787 seconds)

Estimated memory usage: 9 MB

Gallery generated by Sphinx-Gallery