tslearn.metrics.SoftDTWLossPyTorch¶
- tslearn.metrics.SoftDTWLossPyTorch(gamma=1.0, normalize=False, dist_func=None)[source]¶
Soft-DTW loss function in PyTorch.
Soft-DTW was originally presented in [1] and is discussed in more details in our user-guide page on DTW and its variants.
Soft-DTW is computed as:
\[\text{soft-DTW}_{\gamma}(X, Y) = \min_{\pi}{}^\gamma \sum_{(i, j) \in \pi} d \left( X_i, Y_j \right)\]where \(d\) is a distance function or a dissimilarity measure supporting PyTorch automatic differentiation and \(\min^\gamma\) is the soft-min operator of parameter \(\gamma\) defined as:
\[\min{}^\gamma \left( a_{1}, ..., a_{n} \right) = - \gamma \log \sum_{i=1}^{n} e^{- a_{i} / \gamma}\]In the limit case \(\gamma = 0\), \(\min^\gamma\) reduces to a hard-min operator. The soft-DTW is then defined as the square of the DTW dissimilarity measure when \(d\) is the squared Euclidean distance.
Contrary to DTW, soft-DTW is not bounded below by zero, and we even have:
\[\text{soft-DTW}_{\gamma}(X, Y) \rightarrow - \infty \text{ when } \gamma \rightarrow + \infty\]In [2], new dissimilarity measures are defined, that rely on soft-DTW. In particular, soft-DTW divergence is introduced to counteract the non-positivity of soft-DTW:
\[D_{\gamma} \left( X, Y \right) = \text{soft-DTW}_{\gamma}(X, Y) - \frac{1}{2} \left( \text{soft-DTW}_{\gamma}(X, X) + \text{soft-DTW}_{\gamma}(Y, Y) \right)\]This divergence has the advantage of being minimized for \(X = Y\) and being exactly 0 in that case.
- Parameters:
- gammafloat
Regularization parameter. It should be strictly positive. Lower is less smoothed (closer to true DTW).
- normalizebool
If True, the Soft-DTW divergence is used. The Soft-DTW divergence is always positive. Optional, default: False.
- dist_funccallable
Distance function or dissimilarity measure. It takes two input arguments of shape (batch_size, ts_length, dim). It should support PyTorch automatic differentiation. Optional, default: None If None, the squared Euclidean distance is used.
See also
soft_dtw
Compute Soft-DTW metric between two time series.
cdist_soft_dtw
Compute cross-similarity matrix using Soft-DTW metric.
cdist_soft_dtw_normalized
Compute cross-similarity matrix using a normalized version of the Soft-DTW metric.
References
[1]Marco Cuturi & Mathieu Blondel. “Soft-DTW: a Differentiable Loss Function for Time-Series”, ICML 2017.
[2]Mathieu Blondel, Arthur Mensch & Jean-Philippe Vert. “Differentiable divergences between time series”, International Conference on Artificial Intelligence and Statistics, 2021.
Examples
>>> import torch >>> from tslearn.metrics import SoftDTWLossPyTorch >>> soft_dtw_loss = SoftDTWLossPyTorch(gamma=0.1) >>> x = torch.zeros((4, 3, 2), requires_grad=True) >>> y = torch.arange(0, 24).reshape(4, 3, 2) >>> soft_dtw_loss_mean_value = soft_dtw_loss(x, y).mean() >>> print(soft_dtw_loss_mean_value) tensor(1081., grad_fn=<MeanBackward0>) >>> soft_dtw_loss_mean_value.backward() >>> print(x.grad.shape) torch.Size([4, 3, 2]) >>> print(x.grad) tensor([[[ 0.0000, -0.5000], [ -1.0000, -1.5000], [ -2.0000, -2.5000]], [[ -3.0000, -3.5000], [ -4.0000, -4.5000], [ -5.0000, -5.5000]], [[ -6.0000, -6.5000], [ -7.0000, -7.5000], [ -8.0000, -8.5000]], [[ -9.0000, -9.5000], [-10.0000, -10.5000], [-11.0000, -11.5000]]])
Examples using tslearn.metrics.SoftDTWLossPyTorch
¶
Soft-DTW loss for PyTorch neural network