4
4
5
5
using System ;
6
6
using System . Collections . Generic ;
7
+ using System . Data ;
7
8
using System . IO ;
8
9
using System . Linq ;
9
10
using Microsoft . ML . Data ;
@@ -586,17 +587,20 @@ public void TestSrCnnBatchAnomalyDetector(
586
587
{
587
588
var ml = new MLContext ( 1 ) ;
588
589
IDataView dataView ;
590
+ List < TimeSeriesDataDouble > data ;
591
+
589
592
if ( loadDataFromFile )
590
593
{
591
594
var dataPath = GetDataPath ( "Timeseries" , "anomaly_detection.csv" ) ;
592
595
593
596
// Load data from file into the dataView
594
597
dataView = ml . Data . LoadFromTextFile < TimeSeriesDataDouble > ( dataPath , hasHeader : true ) ;
598
+ data = ml . Data . CreateEnumerable < TimeSeriesDataDouble > ( dataView , reuseRowObject : false ) . ToList ( ) ;
595
599
}
596
600
else
597
601
{
602
+ data = new List < TimeSeriesDataDouble > ( ) ;
598
603
// Generate sample series data with an anomaly
599
- var data = new List < TimeSeriesDataDouble > ( ) ;
600
604
for ( int index = 0 ; index < 20 ; index ++ )
601
605
{
602
606
data . Add ( new TimeSeriesDataDouble { Value = 5 } ) ;
@@ -624,6 +628,7 @@ public void TestSrCnnBatchAnomalyDetector(
624
628
outputDataView , reuseRowObject : false ) ;
625
629
626
630
int k = 0 ;
631
+
627
632
foreach ( var prediction in predictionColumn )
628
633
{
629
634
switch ( mode )
@@ -654,9 +659,12 @@ public void TestSrCnnBatchAnomalyDetector(
654
659
Assert . Equal ( 5.00 , prediction . Prediction [ 4 ] , 2 ) ;
655
660
Assert . Equal ( 5.01 , prediction . Prediction [ 5 ] , 2 ) ;
656
661
Assert . Equal ( 4.99 , prediction . Prediction [ 6 ] , 2 ) ;
662
+ Assert . True ( prediction . Prediction [ 6 ] > data [ k ] . Value || data [ k ] . Value > prediction . Prediction [ 5 ] ) ;
657
663
}
658
664
else
665
+ {
659
666
Assert . Equal ( 0 , prediction . Prediction [ 0 ] ) ;
667
+ }
660
668
break ;
661
669
}
662
670
k += 1 ;
@@ -669,10 +677,13 @@ public void TestSrCnnAnomalyDetectorWithSeasonalData(
669
677
{
670
678
var ml = new MLContext ( 1 ) ;
671
679
IDataView dataView ;
680
+ List < TimeSeriesDataDouble > data ;
681
+
672
682
var dataPath = GetDataPath ( "Timeseries" , "period_no_anomaly.csv" ) ;
673
683
674
684
// Load data from file into the dataView
675
685
dataView = ml . Data . LoadFromTextFile < TimeSeriesDataDouble > ( dataPath , hasHeader : true ) ;
686
+ data = ml . Data . CreateEnumerable < TimeSeriesDataDouble > ( dataView , reuseRowObject : false ) . ToList ( ) ;
676
687
677
688
// Setup the detection arguments
678
689
string outputColumnName = nameof ( SrCnnAnomalyDetection . Prediction ) ;
@@ -695,10 +706,14 @@ public void TestSrCnnAnomalyDetectorWithSeasonalData(
695
706
var predictionColumn = ml . Data . CreateEnumerable < SrCnnAnomalyDetection > (
696
707
outputDataView , reuseRowObject : false ) ;
697
708
709
+ var index = 0 ;
698
710
foreach ( var prediction in predictionColumn )
699
711
{
700
712
Assert . Equal ( 7 , prediction . Prediction . Length ) ;
701
713
Assert . Equal ( 0 , prediction . Prediction [ 0 ] ) ;
714
+ Assert . True ( prediction . Prediction [ 6 ] <= data [ index ] . Value ) ;
715
+ Assert . True ( data [ index ] . Value <= prediction . Prediction [ 5 ] ) ;
716
+ ++ index ;
702
717
}
703
718
}
704
719
@@ -709,10 +724,13 @@ public void TestSrCnnAnomalyDetectorWithSeasonalAnomalyData(
709
724
{
710
725
var ml = new MLContext ( 1 ) ;
711
726
IDataView dataView ;
727
+ List < TimeSeriesDataDouble > data ;
728
+
712
729
var dataPath = GetDataPath ( "Timeseries" , "period_anomaly.csv" ) ;
713
730
714
731
// Load data from file into the dataView
715
732
dataView = ml . Data . LoadFromTextFile < TimeSeriesDataDouble > ( dataPath , hasHeader : true ) ;
733
+ data = ml . Data . CreateEnumerable < TimeSeriesDataDouble > ( dataView , reuseRowObject : false ) . ToList ( ) ;
716
734
717
735
// Setup the detection arguments
718
736
string outputColumnName = nameof ( SrCnnAnomalyDetection . Prediction ) ;
@@ -745,10 +763,13 @@ public void TestSrCnnAnomalyDetectorWithSeasonalAnomalyData(
745
763
if ( anomalyStartIndex <= k && k <= anomalyEndIndex )
746
764
{
747
765
Assert . Equal ( 1 , prediction . Prediction [ 0 ] ) ;
766
+ Assert . True ( prediction . Prediction [ 6 ] > data [ k ] . Value || data [ k ] . Value > prediction . Prediction [ 5 ] ) ;
748
767
}
749
768
else
750
769
{
751
770
Assert . Equal ( 0 , prediction . Prediction [ 0 ] ) ;
771
+ Assert . True ( prediction . Prediction [ 6 ] <= data [ k ] . Value ) ;
772
+ Assert . True ( data [ k ] . Value <= prediction . Prediction [ 5 ] ) ;
752
773
}
753
774
754
775
++ k ;
0 commit comments