Skip to content

Commit 5370692

Browse files
Optimize SR Cnn algorithm (#5374)
* adjust the expected value for not anomaly points * enhance deseasonality code * enrich test cases * update test case * update test case * fix bug * load dataview from file in an elegant manner Co-authored-by: [email protected] <[email protected]>
1 parent 66df428 commit 5370692

File tree

5 files changed

+60
-10
lines changed

5 files changed

+60
-10
lines changed

src/Microsoft.ML.TimeSeries/Deseasonality.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ public void Deseasonality(ref double[] values, int period, ref double[] results)
104104
internal sealed class StlDeseasonality : IDeseasonality
105105
{
106106
private readonly InnerStl _stl;
107+
private readonly IDeseasonality _backupFunc;
107108

108109
public StlDeseasonality()
109110
{
110111
_stl = new InnerStl(true);
112+
_backupFunc = new MedianDeseasonality();
111113
}
112114

113115
public void Deseasonality(ref double[] values, int period, ref double[] results)
@@ -120,12 +122,10 @@ public void Deseasonality(ref double[] values, int period, ref double[] results)
120122
results[i] = _stl.Residual[i];
121123
}
122124
}
125+
// invoke the back up deseasonality method if stl decompose fails.
123126
else
124127
{
125-
for (int i = 0; i < values.Length; ++i)
126-
{
127-
results[i] = values[i];
128-
}
128+
_backupFunc.Deseasonality(ref values, period, ref results);
129129
}
130130
}
131131
}

src/Microsoft.ML.TimeSeries/STL/InnerStl.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ internal class InnerStl
2525
private double[] _c;
2626
private double[] _deseasonSeries;
2727

28+
/// <summary>
29+
/// The minimum length of a valid time series. A time series with length equals 2 is so trivial and meaningless less than 2.
30+
/// </summary>
31+
public const int MinTimeSeriesLength = 3;
32+
2833
/// <summary>
2934
/// The smoothing parameter for the seasonal component.
3035
/// This parameter should be odd, and at least 7.
@@ -114,8 +119,8 @@ public bool Decomposition(IReadOnlyList<double> yValues, int np)
114119
Contracts.CheckValue(yValues, nameof(yValues));
115120
Contracts.CheckParam(np > 0, nameof(np));
116121

117-
if (yValues.Count == 0)
118-
throw Contracts.Except("input data structure cannot be 0-length: innerSTL");
122+
if (yValues.Count < MinTimeSeriesLength)
123+
throw Contracts.Except(string.Format("input time series length for InnerStl is below {0}", MinTimeSeriesLength));
119124

120125
int length = yValues.Count;
121126
Array.Resize(ref _seasonalComponent, length);

src/Microsoft.ML.TimeSeries/STL/Loess.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public Loess(IReadOnlyList<double> xValues, IReadOnlyList<double> yValues, bool
5050
Contracts.CheckValue(yValues, nameof(yValues));
5151

5252
if (xValues.Count < MinTimeSeriesLength || yValues.Count < MinTimeSeriesLength)
53-
throw Contracts.Except("input data structure cannot be 0-length: lowess");
53+
throw Contracts.Except(string.Format("input time series length for Loess is below {0}", MinTimeSeriesLength));
5454

5555
if (xValues.Count != yValues.Count)
5656
throw Contracts.Except("the x-axis length should be equal to y-axis length!: lowess");

src/Microsoft.ML.TimeSeries/SrCnnEntireAnomalyDetector.cs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ internal sealed class SrCnnEntireModeler
349349
private static readonly int _judgementWindowSize = 40;
350350
private static readonly double _eps = 1e-8;
351351
private static readonly double _deanomalyThreshold = 0.35;
352+
private static readonly double _boundSensitivity = 70.0;
352353

353354
// A fixed lookup table which returns factor using sensitivity as index.
354355
// Since Margin = BoundaryUnit * factor, this factor is calculated to make sure Margin == Boundary when sensitivity is 50,
@@ -732,12 +733,20 @@ private void GetMarginPeriod(double[] values, double[][] results, IReadOnlyList<
732733
}
733734
}
734735

735-
//Step 11: Update Anomaly Score
736+
//Step 11: Update Anomaly Score, Expected Value and Boundaries
736737
for (int i = 0; i < results.Length; ++i)
737738
{
738739
results[i][1] = CalculateAnomalyScore(values[i], _ifftRe[i], _units[i], results[i][0] > 0);
739-
}
740740

741+
// adjust the expected value if the point is not anomaly
742+
if (results[i][0] == 0)
743+
{
744+
double margin = results[i][5] - results[i][3];
745+
results[i][3] = AdjustExpectedValueBasedOnBound(values[i], results[i][3], _units[i]);
746+
results[i][5] = results[i][3] + margin;
747+
results[i][6] = results[i][3] - margin;
748+
}
749+
}
741750
}
742751

743752
private void GetMargin(double[] values, double[][] results, double sensitivity)
@@ -763,9 +772,24 @@ private void GetMargin(double[] values, double[][] results, double sensitivity)
763772

764773
//Step 12: Update IsAnomaly
765774
results[i][0] = results[i][0] > 0 && (values[i] < results[i][6] || values[i] > results[i][5]) ? 1 : 0;
775+
776+
//Step 13: Update Expected Value, LowerBound and UpperBound for not anomaly points.
777+
if (results[i][0] == 0)
778+
{
779+
results[i][3] = AdjustExpectedValueBasedOnBound(values[i], results[i][3], _units[i]);
780+
results[i][5] = results[i][3] + margin;
781+
results[i][6] = results[i][3] - margin;
782+
}
766783
}
767784
}
768785

786+
// Adjust the expected value so that it is within the bound margin of value
787+
private double AdjustExpectedValueBasedOnBound(double value, double expectedValue, double unit)
788+
{
789+
var boundMargin = CalculateMargin(unit, _boundSensitivity);
790+
return Math.Max(Math.Min(expectedValue, value + boundMargin), value - boundMargin);
791+
}
792+
769793
private int[] GetAnomalyIndex(double[] scores)
770794
{
771795
List<int> anomalyIdxList = new List<int>();

test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Data;
78
using System.IO;
89
using System.Linq;
910
using Microsoft.ML.Data;
@@ -586,17 +587,20 @@ public void TestSrCnnBatchAnomalyDetector(
586587
{
587588
var ml = new MLContext(1);
588589
IDataView dataView;
590+
List<TimeSeriesDataDouble> data;
591+
589592
if (loadDataFromFile)
590593
{
591594
var dataPath = GetDataPath("Timeseries", "anomaly_detection.csv");
592595

593596
// Load data from file into the dataView
594597
dataView = ml.Data.LoadFromTextFile<TimeSeriesDataDouble>(dataPath, hasHeader: true);
598+
data = ml.Data.CreateEnumerable<TimeSeriesDataDouble>(dataView, reuseRowObject: false).ToList();
595599
}
596600
else
597601
{
602+
data = new List<TimeSeriesDataDouble>();
598603
// Generate sample series data with an anomaly
599-
var data = new List<TimeSeriesDataDouble>();
600604
for (int index = 0; index < 20; index++)
601605
{
602606
data.Add(new TimeSeriesDataDouble { Value = 5 });
@@ -624,6 +628,7 @@ public void TestSrCnnBatchAnomalyDetector(
624628
outputDataView, reuseRowObject: false);
625629

626630
int k = 0;
631+
627632
foreach (var prediction in predictionColumn)
628633
{
629634
switch (mode)
@@ -654,9 +659,12 @@ public void TestSrCnnBatchAnomalyDetector(
654659
Assert.Equal(5.00, prediction.Prediction[4], 2);
655660
Assert.Equal(5.01, prediction.Prediction[5], 2);
656661
Assert.Equal(4.99, prediction.Prediction[6], 2);
662+
Assert.True(prediction.Prediction[6] > data[k].Value || data[k].Value > prediction.Prediction[5]);
657663
}
658664
else
665+
{
659666
Assert.Equal(0, prediction.Prediction[0]);
667+
}
660668
break;
661669
}
662670
k += 1;
@@ -669,10 +677,13 @@ public void TestSrCnnAnomalyDetectorWithSeasonalData(
669677
{
670678
var ml = new MLContext(1);
671679
IDataView dataView;
680+
List<TimeSeriesDataDouble> data;
681+
672682
var dataPath = GetDataPath("Timeseries", "period_no_anomaly.csv");
673683

674684
// Load data from file into the dataView
675685
dataView = ml.Data.LoadFromTextFile<TimeSeriesDataDouble>(dataPath, hasHeader: true);
686+
data = ml.Data.CreateEnumerable<TimeSeriesDataDouble>(dataView, reuseRowObject: false).ToList();
676687

677688
// Setup the detection arguments
678689
string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction);
@@ -695,10 +706,14 @@ public void TestSrCnnAnomalyDetectorWithSeasonalData(
695706
var predictionColumn = ml.Data.CreateEnumerable<SrCnnAnomalyDetection>(
696707
outputDataView, reuseRowObject: false);
697708

709+
var index = 0;
698710
foreach (var prediction in predictionColumn)
699711
{
700712
Assert.Equal(7, prediction.Prediction.Length);
701713
Assert.Equal(0, prediction.Prediction[0]);
714+
Assert.True(prediction.Prediction[6] <= data[index].Value);
715+
Assert.True(data[index].Value <= prediction.Prediction[5]);
716+
++index;
702717
}
703718
}
704719

@@ -709,10 +724,13 @@ public void TestSrCnnAnomalyDetectorWithSeasonalAnomalyData(
709724
{
710725
var ml = new MLContext(1);
711726
IDataView dataView;
727+
List<TimeSeriesDataDouble> data;
728+
712729
var dataPath = GetDataPath("Timeseries", "period_anomaly.csv");
713730

714731
// Load data from file into the dataView
715732
dataView = ml.Data.LoadFromTextFile<TimeSeriesDataDouble>(dataPath, hasHeader: true);
733+
data = ml.Data.CreateEnumerable<TimeSeriesDataDouble>(dataView, reuseRowObject: false).ToList();
716734

717735
// Setup the detection arguments
718736
string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction);
@@ -745,10 +763,13 @@ public void TestSrCnnAnomalyDetectorWithSeasonalAnomalyData(
745763
if (anomalyStartIndex <= k && k <= anomalyEndIndex)
746764
{
747765
Assert.Equal(1, prediction.Prediction[0]);
766+
Assert.True(prediction.Prediction[6] > data[k].Value || data[k].Value > prediction.Prediction[5]);
748767
}
749768
else
750769
{
751770
Assert.Equal(0, prediction.Prediction[0]);
771+
Assert.True(prediction.Prediction[6] <= data[k].Value);
772+
Assert.True(data[k].Value <= prediction.Prediction[5]);
752773
}
753774

754775
++k;

0 commit comments

Comments
 (0)