Caching algorithm outputs#

Many algorithms implemented in gaitmap have a runtime of multiple seconds on larger datasets. In the context of algorithm evaluation, for example when performing a cross-validation, algorithms are sometimes repeatedly called on the same data and even with the same parameters. In these cases, it can be helpful to cache results to ensure that you do not need to recalculate values.

The joblib Python package makes cashing extremely easy and you should read their guide first, before continuing with this example. However, one of the caveats with joblib caching is that it only works on pure functions without side effects and should not be used with methods.

Unfortunately, gaitmap is mostly object oriented and all the computational expensive things you might want to do are hidden behind a method call.

Therefore, many gaitmap algorithms have caching built-in. These algorithms support an additional keyword argument called memory in their init-function. If you pass a joblib.Memory object to these, it will be used to cache the most time consuming function calls. Note, that this will usually not cache all the calculations in a method, but only the ones that are considered worth caching by the algorithm developer.

If you really want to cache the full method calls (on your own risk), see the last section of this example.

Example Pipeline#

We will simply copy the stride segmentation example to have some data to work with.

from gaitmap.example_data import get_healthy_example_imu_data
from gaitmap.utils.coordinate_conversion import convert_to_fbf

data = get_healthy_example_imu_data().iloc[:2000]
sampling_rate_hz = 204.8

data = convert_to_fbf(data, left_like="left_", right_like="right_")

Creating the cash#

First we will create a memory instance for our cash. We can use the same cash to cash the output of multiple algorithms. The cash stays valid even after you restart Python, if you didn’t delete the folder.

However, in this example, we will use a temp-directory that will be deleted at the end of the example.

from tempfile import TemporaryDirectory

from joblib import Memory

tmp_dir = TemporaryDirectory()
# We will activate some more debug output for this example
mem = Memory(tmp_dir.name, verbose=2)

Initialize algorithm#

We initialize our algorithm as normal, but pass the memory instance as an additional parameter.

from gaitmap.stride_segmentation import BarthDtw

dtw = BarthDtw(memory=mem)

Calling cached methods#

The first time we call segment now, all calculation will run as normal, but the output of certain calculations will be cached. They are then reused when we call segment again with the same data and configuration.

Observe the print output to see what happens.

first_call_results = dtw.segment(data=data, sampling_rate_hz=204.8)
first_call_stride_list = first_call_results.stride_list_.copy()
________________________________________________________________________________
[Memory] Calling gaitmap_mad.stride_segmentation.dtw._vendored_tslearn.subsequence_cost_matrix...
subsequence_cost_matrix(array([[-0.23287, ..., -0.97019],
       ...,
       [-0.24131, ..., -0.99925]]), array([[ 0.000225, ...,  0.000064],
       ...,
       [-0.031871, ..., -0.001388]]))
__________________________________________subsequence_cost_matrix - 0.0s, 0.0min
________________________________________________________________________________
[Memory] Calling gaitmap_mad.stride_segmentation.dtw._vendored_tslearn.subsequence_cost_matrix...
subsequence_cost_matrix(array([[-0.23287, ..., -0.97019],
       ...,
       [-0.24131, ..., -0.99925]]), array([[-0.000646, ..., -0.000169],
       ...,
       [-0.001773, ..., -0.090264]]))
__________________________________________subsequence_cost_matrix - 0.0s, 0.0min

According to the debug output, two internal functions of BarthDtw are cached. Each twice with different value inputs, because our data had two sensors. It depends on the actual algorithm, which and how internal components are cached.

Independent of that if we call the method again, we can see in the debug output that the results of these methods are now loaded from disk. If we would use a larger dataset, we would see dramatic speed improvements.

second_call_results = dtw.segment(data=data, sampling_rate_hz=204.8)
second_call_stride_list = second_call_results.stride_list_.copy()
[Memory]0.9s, 0.0min    : Loading subsequence_cost_matrix...
[Memory]0.9s, 0.0min    : Loading subsequence_cost_matrix...

We can verify that the results are actually identical

first_call_stride_list["left_sensor"].equals(second_call_stride_list["left_sensor"])
True

Partially cached calls#

As you have seen before, BarthDtw caches its internal call to subsequence_cost_matrix. This is only part of the processing. This ensures that we can change some parameters while still making use of the some cached results. As the cost-matrix only depends only on the template and the constrains, we can reuse the cash, if we change any other parameter.

If we change the max_cost for example, only the stride detection part needs to be recalculated.

new_instance = BarthDtw(max_cost=5.0, memory=mem)
new_instance.segment(data=data, sampling_rate_hz=204.8)
[Memory]2.0s, 0.0min    : Loading subsequence_cost_matrix...
[Memory]2.0s, 0.0min    : Loading subsequence_cost_matrix...

BarthDtw(conflict_resolution=True, find_matches_method='find_peaks', max_cost=5.0, max_match_length_s=3.0, max_signal_stretch_ms=None, max_template_stretch_ms=None, memory=Memory(location=/tmp/tmpj26mhm6f/joblib), min_match_length_s=0.6, resample_template=True, snap_to_min_axis='gyr_ml', snap_to_min_win_ms=300, template=BarthOriginalTemplate(scaling=FixedScaler(offset=0, scale=500.0), use_cols=None))

As you can see in the debug output, we loaded the results of subsequence_cost_matrix, but recalculated the second step.

Some Note#

  • Caching support will vary from algorithm to algorithm

  • Caching supports multi-processing

  • Do not use you cache as permanent storage of results. It is way too easy to delete it.

  • If you try a lot of things with a lot of data, your cache can become really large.

  • Clear your cache, before you do your final calculations for a publication!

  • Make sure you add you cache dir to your “.gitignore” file.

Caching Full method calls#

In some cases it might still be desirable to cache the entire output of an algorithm. To do this safely you need to be aware of how cashing works under the hood. The Memory class calculates a hash of all inputs to a function and stores a pickeled version of the results together with this input-hast. If the function is called again, the hash of the input is compared with hashes stored on the disk. Depending on this, a cached result can be selected.

import joblib

We can calculate the hash of our algorithm.

joblib.hash(dtw)
'1ca6757c44aade987dc5515b59d43a75'

If we recreate the object with the same parameters, the hash is identical.

joblib.hash(BarthDtw())
'050b7eb3aa62e4807d8dface0085fb56'

The same is true for cloning

joblib.hash(dtw.clone())
'457b73b21546c0919c15d0112e4c22e3'

However, if we change any parameters the hash of the object changes.

joblib.hash(BarthDtw(max_cost=100))
'fe7b97e7289e8bea9c8d274a200e5692'

It is important to note that the hash always changes, if any of the attributes are modified, not just the ones accessible through the init. This means, if e.g. after you call segment and the algorithm object will have all results stored, the hash will change. The same will happen, if you add custom attributes to the instance. The hash will change and the cache would be invalidated.

test_dtw = BarthDtw()
test_dtw.custom_value = 4
joblib.hash(test_dtw)
'575ae6008b1330d505058c59ecf849b4'

This observation becomes an issue when caching class methods. As python passes the class instance itself as the first argument to this method. This means the input-hash used for caching will change whenever anything on the class instance changes, even if the change might not affect the actual output of the method. In many cases this is less of an issue with gaitmap, as we can reasonably assume that the main action method should only depend on the params of an algorithm (self.get_params()) and the actual action method.

Therefore, we can cache action methods reliably when cloning the algorithm before hand and using a wrapper method. Cloning the algorithm instance ensures that all instance data, except the params are reset.

def call_segment(algo, data, sampling_rate_hz):
    return algo.segment(data=data, sampling_rate_hz=sampling_rate_hz)


# Cache the wrapper:
cached_call_segment = Memory(tmp_dir.name, verbose=2).cache(call_segment)

# Then we need to clone the algorithm every time we call the cached wrapper, to reset the params:
reset_dtw = dtw.clone()
results = cached_call_segment(reset_dtw, data, sampling_rate_hz)
________________________________________________________________________________
[Memory] Calling __main__--home-docs-checkouts-readthedocs.org-user_builds-gaitmap-checkouts-v2.6.0-examples-advanced_features-caching.call_segment...
call_segment(BarthDtw(conflict_resolution=True, find_matches_method='find_peaks', max_cost=4.0, max_match_length_s=3.0, max_signal_stretch_ms=None, max_template_stretch_ms=None, memory=Memory(location=/tmp/tmpj26mhm6f/joblib), min_match_length_s=0.6, resample_template=True, snap_to_min_axis='gyr_ml', snap_to_min_win_ms=300, template=BarthOriginalTemplate(scaling=FixedScaler(offset=0, scale=500.0), use_cols=None)),
         left_sensor                      ... right_sensor
              acc_pa    acc_ml    acc_si  ...       gyr_pa      gyr_ml      gyr_si
0.000000    0.880811  2.762208 -9.408650  ...    -0.323037   -0.084604   -0.025288
0.004883    0.885007  2.746448 -9.465895  ...    -0.075961    0.035851    0.152090
0.009766    0.865777  2.686106 -9.436033  ...    -0.200378   -0.206538   -0.028626
0.014648    0.876128  2.771787 -9.403943  ...     0.347912   -0.075574   -0.390202
0.019531    0.928267  2.682286 -9.393766  ...    -0.260534   -0.025164    0.093895
...              ...       ...       ...  ...          ...         ...         ...
9.741211    0.445489  2.638538 -9.353027  ...  -327.681622 -627.092147  196.539542
9.746094    0.679787  2.586746 -9.312401  ...  -234.217050 -538.949858   78.782630
9.750977    0.828875  2.607384 -9.223829  ...  -225.952242 -385.386142   71.534917
9.755859    0.743285  2.789532 -9.297119  ...  -123.250826 -209.382027  -25.810329
9.760742    0.505652  2.906063 -9.214244  ...    -0.886364  -45.132010  -76.377843

[2000 rows x 12 columns],
204.8)
[Memory]0.0s, 0.0min    : Loading subsequence_cost_matrix...
[Memory]0.0s, 0.0min    : Loading subsequence_cost_matrix...
_____________________________________________________call_segment - 0.0s, 0.0min

On this first call, we can see that the cached call actually modified the reset_dtw object in place.

id(reset_dtw) == id(results)
True

However, on the second call, it will return a copy (loaded from the cache)

reset_dtw = dtw.clone()
results = cached_call_segment(reset_dtw, data, sampling_rate_hz)
id(reset_dtw) == id(results)
[Memory]0.9s, 0.0min    : Loading call_segment...

False

While it is possible to cache methods this way, this might be error prone. The safest option (and remember, we are already in the unsafe territory), is to use a nested wrapper resolve potential user errors.

In the general case you can use the recipe below. It will always ensure that the algo object is cloned and will return a copy of the algorithm in any case.

Warning

While this expected to work, cashing an entire algorithm object as return value can take a lot of storage space as it usually stores a copy of the input data. Whenever possible you should only return the parts of the result you are really interested inside the cached function.

def cached_call_method(_algo, _method_name: str, _memory: Memory, *args, **kwargs):
    """Call a method on the algo object and cache the output.

    Repeated calls to this function with the same algorithm and the same args, and kwargs, will return cached results
    saved on disk.

    .. warning ::
        This method will clone the algorithm object before calling the method.
        This ensures that the cache is not invalidated because of results stored on the object.

    Parameters
    ----------
    _algo
        The algorithm instance to use
    _method_name
        The name of the method to call
    _memory
        A instance of `joblib.memory` used for caching
    args
        Positional arguments passed to the called method
    kwargs
        Keyword arguments passed to the called method.

    Returns
    -------
    method_return
        The return value of the called methods either calculated or cached.

    See Also
    --------
    gaitmap.utils.caching.cached_call_action

    """

    def _call_method(_algo, _method_name, *args, **kwargs):
        return getattr(_algo, _method_name)(*args, **kwargs)

    _algo = _algo.clone()
    return _memory.cache(_call_method)(_algo, _method_name, *args, **kwargs)


mem = Memory(tmp_dir.name, verbose=2)

cached_result = cached_call_method(
    BarthDtw(), _method_name="segment", _memory=mem, data=data, sampling_rate_hz=sampling_rate_hz
)
________________________________________________________________________________
[Memory] Calling __main__--home-docs-checkouts-readthedocs.org-user_builds-gaitmap-checkouts-v2.6.0-examples-advanced_features-caching.cached_call_method.<locals>._call_method...
_call_method(BarthDtw(conflict_resolution=True, find_matches_method='find_peaks', max_cost=4.0, max_match_length_s=3.0, max_signal_stretch_ms=None, max_template_stretch_ms=None, memory=None, min_match_length_s=0.6, resample_template=True, snap_to_min_axis='gyr_ml', snap_to_min_win_ms=300, template=BarthOriginalTemplate(scaling=FixedScaler(offset=0, scale=500.0), use_cols=None)),
'segment', data=         left_sensor                      ... right_sensor
              acc_pa    acc_ml    acc_si  ...       gyr_pa      gyr_ml      gyr_si
0.000000    0.880811  2.762208 -9.408650  ...    -0.323037   -0.084604   -0.025288
0.004883    0.885007  2.746448 -9.465895  ...    -0.075961    0.035851    0.152090
0.009766    0.865777  2.686106 -9.436033  ...    -0.200378   -0.206538   -0.028626
0.014648    0.876128  2.771787 -9.403943  ...     0.347912   -0.075574   -0.390202
0.019531    0.928267  2.682286 -9.393766  ...    -0.260534   -0.025164    0.093895
...              ...       ...       ...  ...          ...         ...         ...
9.741211    0.445489  2.638538 -9.353027  ...  -327.681622 -627.092147  196.539542
9.746094    0.679787  2.586746 -9.312401  ...  -234.217050 -538.949858   78.782630
9.750977    0.828875  2.607384 -9.223829  ...  -225.952242 -385.386142   71.534917
9.755859    0.743285  2.789532 -9.297119  ...  -123.250826 -209.382027  -25.810329
9.760742    0.505652  2.906063 -9.214244  ...    -0.886364  -45.132010  -76.377843

[2000 rows x 12 columns], sampling_rate_hz=204.8)
______________________________________________________call_method - 0.0s, 0.0min

And the second call will load the results.

cached_result = cached_call_method(
    BarthDtw(), _method_name="segment", _memory=mem, data=data, sampling_rate_hz=sampling_rate_hz
)
[Memory]0.4s, 0.0min    : Loading _call_method...

Finally remove the tempdir

Total running time of the script: ( 0 minutes 8.256 seconds)

Estimated memory usage: 9 MB

Gallery generated by Sphinx-Gallery