Note
Click here to download the full example code
Caching algorithm outputs#
Many algorithms implemented in gaitmap have a runtime of multiple seconds on larger datasets. In the context of algorithm evaluation, for example when performing a cross-validation, algorithms are sometimes repeatedly called on the same data and even with the same parameters. In these cases, it can be helpful to cache results to ensure that you do not need to recalculate values.
The joblib Python package makes cashing extremely easy and you should read their
guide first, before continuing with this example.
However, one of the caveats with joblib caching is that it only works on pure functions without side effects and should
not be used with methods.
Unfortunately, gaitmap is mostly object oriented and all the computational expensive things you might want to do are hidden behind a method call.
Therefore, many gaitmap algorithms have caching built-in.
These algorithms support an additional keyword argument called memory in their init-function.
If you pass a joblib.Memory object to these, it will be used to cache the most time consuming function calls.
Note, that this will usually not cache all the calculations in a method, but only the ones that are considered worth
caching by the algorithm developer.
If you really want to cache the full method calls (on your own risk), see the last section of this example.
Example Pipeline#
We will simply copy the stride segmentation example to have some data to work with.
from gaitmap.example_data import get_healthy_example_imu_data
from gaitmap.utils.coordinate_conversion import convert_to_fbf
data = get_healthy_example_imu_data().iloc[:2000]
sampling_rate_hz = 204.8
data = convert_to_fbf(data, left_like="left_", right_like="right_")
Creating the cash#
First we will create a memory instance for our cash. We can use the same cash to cash the output of multiple algorithms. The cash stays valid even after you restart Python, if you didn’t delete the folder.
However, in this example, we will use a temp-directory that will be deleted at the end of the example.
from tempfile import TemporaryDirectory
from joblib import Memory
tmp_dir = TemporaryDirectory()
# We will activate some more debug output for this example
mem = Memory(tmp_dir.name, verbose=2)
Initialize algorithm#
We initialize our algorithm as normal, but pass the memory instance as an additional parameter.
Calling cached methods#
The first time we call segment now, all calculation will run as normal, but the output of certain calculations will
be cached.
They are then reused when we call segment again with the same data and configuration.
Observe the print output to see what happens.
first_call_results = dtw.segment(data=data, sampling_rate_hz=204.8)
first_call_stride_list = first_call_results.stride_list_.copy()
________________________________________________________________________________
[Memory] Calling gaitmap_mad.stride_segmentation.dtw._vendored_tslearn.subsequence_cost_matrix...
subsequence_cost_matrix(array([[-0.23287, ..., -0.97019],
...,
[-0.24131, ..., -0.99925]]), array([[ 0.000225, ..., 0.000064],
...,
[-0.031871, ..., -0.001388]]))
__________________________________________subsequence_cost_matrix - 0.0s, 0.0min
________________________________________________________________________________
[Memory] Calling gaitmap_mad.stride_segmentation.dtw._vendored_tslearn.subsequence_cost_matrix...
subsequence_cost_matrix(array([[-0.23287, ..., -0.97019],
...,
[-0.24131, ..., -0.99925]]), array([[-0.000646, ..., -0.000169],
...,
[-0.001773, ..., -0.090264]]))
__________________________________________subsequence_cost_matrix - 0.0s, 0.0min
According to the debug output, two internal functions of BarthDtw are cached.
Each twice with different value inputs, because our data had two sensors.
It depends on the actual algorithm, which and how internal components are cached.
Independent of that if we call the method again, we can see in the debug output that the results of these methods are now loaded from disk. If we would use a larger dataset, we would see dramatic speed improvements.
second_call_results = dtw.segment(data=data, sampling_rate_hz=204.8)
second_call_stride_list = second_call_results.stride_list_.copy()
[Memory]0.8s, 0.0min : Loading subsequence_cost_matrix...
[Memory]0.8s, 0.0min : Loading subsequence_cost_matrix...
We can verify that the results are actually identical
first_call_stride_list["left_sensor"].equals(second_call_stride_list["left_sensor"])
True
Partially cached calls#
As you have seen before, BarthDtw caches its internal call to subsequence_cost_matrix.
This is only part of the processing.
This ensures that we can change some parameters while still making use of the some cached results.
As the cost-matrix only depends only on the template and the constrains, we can reuse the cash, if we change any other
parameter.
If we change the max_cost for example, only the stride detection part needs to be recalculated.
new_instance = BarthDtw(max_cost=5.0, memory=mem)
new_instance.segment(data=data, sampling_rate_hz=204.8)
[Memory]1.7s, 0.0min : Loading subsequence_cost_matrix...
[Memory]1.8s, 0.0min : Loading subsequence_cost_matrix...
BarthDtw(conflict_resolution=True, find_matches_method='find_peaks', max_cost=5.0, max_match_length_s=3.0, max_signal_stretch_ms=None, max_template_stretch_ms=None, memory=Memory(location=/tmp/tmpma4frxa9/joblib), min_match_length_s=0.6, resample_template=True, snap_to_min_axis='gyr_ml', snap_to_min_win_ms=300, template=BarthOriginalTemplate(scaling=FixedScaler(offset=0, scale=500.0), use_cols=None))
As you can see in the debug output, we loaded the results of subsequence_cost_matrix, but recalculated the second
step.
Some Note#
Caching support will vary from algorithm to algorithm
Caching supports multi-processing
Do not use you cache as permanent storage of results. It is way too easy to delete it.
If you try a lot of things with a lot of data, your cache can become really large.
Clear your cache, before you do your final calculations for a publication!
Make sure you add you cache dir to your “.gitignore” file.
Caching Full method calls#
In some cases it might still be desirable to cache the entire output of an algorithm.
To do this safely you need to be aware of how cashing works under the hood.
The Memory class calculates a hash of all inputs to a function and stores a pickeled version of the results together
with this input-hast.
If the function is called again, the hash of the input is compared with hashes stored on the disk.
Depending on this, a cached result can be selected.
import joblib
We can calculate the hash of our algorithm.
joblib.hash(dtw)
'73207a60d30a296a6e2f2e2922ccf5d8'
If we recreate the object with the same parameters, the hash is identical.
joblib.hash(BarthDtw())
'050b7eb3aa62e4807d8dface0085fb56'
The same is true for cloning
joblib.hash(dtw.clone())
'30473d1158bc00f629a3f4b6401d981d'
However, if we change any parameters the hash of the object changes.
joblib.hash(BarthDtw(max_cost=100))
'fe7b97e7289e8bea9c8d274a200e5692'
It is important to note that the hash always changes, if any of the attributes are modified, not just the ones
accessible through the init.
This means, if e.g. after you call segment and the algorithm object will have all results stored, the hash will
change.
The same will happen, if you add custom attributes to the instance.
The hash will change and the cache would be invalidated.
test_dtw = BarthDtw()
test_dtw.custom_value = 4
joblib.hash(test_dtw)
'575ae6008b1330d505058c59ecf849b4'
This observation becomes an issue when caching class methods.
As python passes the class instance itself as the first argument to this method.
This means the input-hash used for caching will change whenever anything on the class instance changes, even if the
change might not affect the actual output of the method.
In many cases this is less of an issue with gaitmap, as we can reasonably assume that the main action method should
only depend on the params of an algorithm (self.get_params()) and the actual action method.
Therefore, we can cache action methods reliably when cloning the algorithm before hand and using a wrapper method. Cloning the algorithm instance ensures that all instance data, except the params are reset.
def call_segment(algo, data, sampling_rate_hz):
return algo.segment(data=data, sampling_rate_hz=sampling_rate_hz)
# Cache the wrapper:
cached_call_segment = Memory(tmp_dir.name, verbose=2).cache(call_segment)
# Then we need to clone the algorithm every time we call the cached wrapper, to reset the params:
reset_dtw = dtw.clone()
results = cached_call_segment(reset_dtw, data, sampling_rate_hz)
________________________________________________________________________________
[Memory] Calling __main__--home-docs-checkouts-readthedocs.org-user_builds-gaitmap-checkouts-v2.5.2-examples-advanced_features-caching.call_segment...
call_segment(BarthDtw(conflict_resolution=True, find_matches_method='find_peaks', max_cost=4.0, max_match_length_s=3.0, max_signal_stretch_ms=None, max_template_stretch_ms=None, memory=Memory(location=/tmp/tmpma4frxa9/joblib), min_match_length_s=0.6, resample_template=True, snap_to_min_axis='gyr_ml', snap_to_min_win_ms=300, template=BarthOriginalTemplate(scaling=FixedScaler(offset=0, scale=500.0), use_cols=None)),
left_sensor ... right_sensor
acc_pa acc_ml acc_si ... gyr_pa gyr_ml gyr_si
0.000000 0.880811 2.762208 -9.408650 ... -0.323037 -0.084604 -0.025288
0.004883 0.885007 2.746448 -9.465895 ... -0.075961 0.035851 0.152090
0.009766 0.865777 2.686106 -9.436033 ... -0.200378 -0.206538 -0.028626
0.014648 0.876128 2.771787 -9.403943 ... 0.347912 -0.075574 -0.390202
0.019531 0.928267 2.682286 -9.393766 ... -0.260534 -0.025164 0.093895
... ... ... ... ... ... ... ...
9.741211 0.445489 2.638538 -9.353027 ... -327.681622 -627.092147 196.539542
9.746094 0.679787 2.586746 -9.312401 ... -234.217050 -538.949858 78.782630
9.750977 0.828875 2.607384 -9.223829 ... -225.952242 -385.386142 71.534917
9.755859 0.743285 2.789532 -9.297119 ... -123.250826 -209.382027 -25.810329
9.760742 0.505652 2.906063 -9.214244 ... -0.886364 -45.132010 -76.377843
[2000 rows x 12 columns],
204.8)
[Memory]0.0s, 0.0min : Loading subsequence_cost_matrix...
[Memory]0.0s, 0.0min : Loading subsequence_cost_matrix...
_____________________________________________________call_segment - 0.0s, 0.0min
On this first call, we can see that the cached call actually modified the reset_dtw object in place.
id(reset_dtw) == id(results)
True
However, on the second call, it will return a copy (loaded from the cache)
reset_dtw = dtw.clone()
results = cached_call_segment(reset_dtw, data, sampling_rate_hz)
id(reset_dtw) == id(results)
[Memory]0.8s, 0.0min : Loading call_segment...
False
While it is possible to cache methods this way, this might be error prone. The safest option (and remember, we are already in the unsafe territory), is to use a nested wrapper resolve potential user errors.
In the general case you can use the recipe below. It will always ensure that the algo object is cloned and will return a copy of the algorithm in any case.
Warning
While this expected to work, cashing an entire algorithm object as return value can take a lot of storage space as it usually stores a copy of the input data. Whenever possible you should only return the parts of the result you are really interested inside the cached function.
def cached_call_method(_algo, _method_name: str, _memory: Memory, *args, **kwargs):
"""Call a method on the algo object and cache the output.
Repeated calls to this function with the same algorithm and the same args, and kwargs, will return cached results
saved on disk.
.. warning ::
This method will clone the algorithm object before calling the method.
This ensures that the cache is not invalidated because of results stored on the object.
Parameters
----------
_algo
The algorithm instance to use
_method_name
The name of the method to call
_memory
A instance of `joblib.memory` used for caching
args
Positional arguments passed to the called method
kwargs
Keyword arguments passed to the called method.
Returns
-------
method_return
The return value of the called methods either calculated or cached.
See Also
--------
gaitmap.utils.caching.cached_call_action
"""
def _call_method(_algo, _method_name, *args, **kwargs):
return getattr(_algo, _method_name)(*args, **kwargs)
_algo = _algo.clone()
return _memory.cache(_call_method)(_algo, _method_name, *args, **kwargs)
mem = Memory(tmp_dir.name, verbose=2)
cached_result = cached_call_method(
BarthDtw(), _method_name="segment", _memory=mem, data=data, sampling_rate_hz=sampling_rate_hz
)
________________________________________________________________________________
[Memory] Calling __main__--home-docs-checkouts-readthedocs.org-user_builds-gaitmap-checkouts-v2.5.2-examples-advanced_features-caching.cached_call_method.<locals>._call_method...
_call_method(BarthDtw(conflict_resolution=True, find_matches_method='find_peaks', max_cost=4.0, max_match_length_s=3.0, max_signal_stretch_ms=None, max_template_stretch_ms=None, memory=None, min_match_length_s=0.6, resample_template=True, snap_to_min_axis='gyr_ml', snap_to_min_win_ms=300, template=BarthOriginalTemplate(scaling=FixedScaler(offset=0, scale=500.0), use_cols=None)),
'segment', data= left_sensor ... right_sensor
acc_pa acc_ml acc_si ... gyr_pa gyr_ml gyr_si
0.000000 0.880811 2.762208 -9.408650 ... -0.323037 -0.084604 -0.025288
0.004883 0.885007 2.746448 -9.465895 ... -0.075961 0.035851 0.152090
0.009766 0.865777 2.686106 -9.436033 ... -0.200378 -0.206538 -0.028626
0.014648 0.876128 2.771787 -9.403943 ... 0.347912 -0.075574 -0.390202
0.019531 0.928267 2.682286 -9.393766 ... -0.260534 -0.025164 0.093895
... ... ... ... ... ... ... ...
9.741211 0.445489 2.638538 -9.353027 ... -327.681622 -627.092147 196.539542
9.746094 0.679787 2.586746 -9.312401 ... -234.217050 -538.949858 78.782630
9.750977 0.828875 2.607384 -9.223829 ... -225.952242 -385.386142 71.534917
9.755859 0.743285 2.789532 -9.297119 ... -123.250826 -209.382027 -25.810329
9.760742 0.505652 2.906063 -9.214244 ... -0.886364 -45.132010 -76.377843
[2000 rows x 12 columns], sampling_rate_hz=204.8)
______________________________________________________call_method - 0.0s, 0.0min
And the second call will load the results.
cached_result = cached_call_method(
BarthDtw(), _method_name="segment", _memory=mem, data=data, sampling_rate_hz=sampling_rate_hz
)
[Memory]0.4s, 0.0min : Loading _call_method...
Finally remove the tempdir
Total running time of the script: ( 0 minutes 7.181 seconds)
Estimated memory usage: 9 MB