1- /* Copyright © 2017-2020 ABBYY Production LLC
1+ /* Copyright © 2017-2024 ABBYY
22
33Licensed under the Apache License, Version 2.0 (the "License");
44you may not use this file except in compliance with the License.
@@ -27,31 +27,46 @@ class CSourceLayer;
2727class CSinkLayer ;
2828class CDnnTrainingModelWrapper ;
2929
30+ struct IShuffledBatchGenerator {
31+ virtual ~IShuffledBatchGenerator () = default ;
32+
33+ virtual const CArray<int >& GenerateBatchIndexes ( int batchSize, bool batchShuffled ) = 0;
34+ virtual bool HasUnseenElements () const = 0;
35+ virtual void DeleteUnseenElement ( int index ) = 0;
36+ };
37+
38+ // ---------------------------------------------------------------------------------------------------------------------
39+
3040// CProblemSourceLayer is a wrapper over the IProblem interface.
3141// On each iteration, it passes BatchSize vectors into the network for processing.
3242class NEOML_API CProblemSourceLayer : public CBaseLayer {
3343 NEOML_DNN_LAYER ( CProblemSourceLayer )
3444public:
35- explicit CProblemSourceLayer ( IMathEngine& mathEngine );
45+ explicit CProblemSourceLayer ( IMathEngine& mathEngine ) :
46+ CBaseLayer( mathEngine, " CCnnProblemSourceLayer" , /* isLearnable*/ false ) {}
3647
3748 void Serialize ( CArchive& archive ) override ;
3849
3950 int GetBatchSize () const { return batchSize; }
40- void SetBatchSize (int _batchSize );
51+ void SetBatchSize ( int batchSize );
4152
4253 // The filler for empty values that are not present in a sparse vector
4354 float GetEmptyFill () const { return emptyFill; }
44- void SetEmptyFill (float _emptyFill) { NeoAssert (GetDnn () == 0 ); emptyFill = _emptyFill; }
55+ void SetEmptyFill ( float _emptyFill ) { NeoAssert ( GetDnn () == nullptr ); emptyFill = _emptyFill; }
4556
4657 // You may only change the problem for the layer that is connected to a network
4758 // if the number of classes and the number of input vectors stay the same
4859 CPtr<const IProblem> GetProblem () const { return problem; }
49- void SetProblem (const CPtr<const IProblem>& _problem );
60+ void SetProblem ( const CPtr<const IProblem>& problem, bool shuffle = false , unsigned seed = 42 );
5061
5162 // Retrieves and sets the data type for class labels
5263 TBlobType GetLabelType () const { return labelType; }
5364 void SetLabelType ( TBlobType newLabelType );
5465
66+ // Still not the end of an epoch
67+ bool HasUnseenElements () const
68+ { return ( shuffled && shuffled->HasUnseenElements () ) || nextProblemIndex < problem->GetVectorCount (); }
69+
5570protected:
5671 ~CProblemSourceLayer () override = default ;
5772
@@ -60,34 +75,43 @@ class NEOML_API CProblemSourceLayer : public CBaseLayer {
6075 void BackwardOnce () override ;
6176
6277private:
63- float emptyFill; // the empty values filler (for values not represented in a sparse vector)
64- int batchSize; // the size of the batch passed to the network
65- int nextProblemIndex; // the index of the next element in the problem to be passed
66- CPtr<const IProblem> problem; // the classification problem the network is solving
67- TBlobType labelType; // the data type for labels
68- CArray<float > exchangeBufs[3 ];
78+ float emptyFill = 0 ; // the empty values filler (for values not represented in a sparse vector)
79+ int batchSize = 1 ; // the size of the batch passed to the network
80+ int nextProblemIndex = NotFound; // the index of the next element in the problem to be passed
81+ TBlobType labelType = CT_Float; // the data type for labels
82+ CPtr<const IProblem> problem; // the classification problem the network is solving
83+ CPtrOwner<IShuffledBatchGenerator> shuffled; // if a shuffled batch input
84+
85+ enum { EB_Data, EB_Label, EB_Weight, EB_Count_ };
86+ CArray<float > exchangeBufs[EB_Count_]{};
87+
88+ void fillExchangeBuffers ( int shift, int index );
6989};
7090
71- // /////////////////////////////////////////////////////////////////////////////////////////////////////
91+ // Creates CProblemSourceLayer with the name
92+ NEOML_API CProblemSourceLayer* ProblemSource ( CDnn& dnn, const char * name,
93+ TBlobType labelType, int batchSize, const CPtr<const IProblem>& problem, bool shuffle = false , unsigned seed = 42 );
94+
95+ // ---------------------------------------------------------------------------------------------------------------------
7296
7397// CDnnModelWrapper is the base class wrapping the trained neural network into the IModel interface
7498class NEOML_API CDnnModelWrapper : public IModel {
7599public:
76100 explicit CDnnModelWrapper (IMathEngine& mathEngine, unsigned int seed = 0xDEADFACE );
77101
78- int GetClassCount () const override ;
102+ int GetClassCount () const override { return ClassCount; }
79103 bool Classify (const CFloatVectorDesc& data, CClassificationResult& result) const override ;
80104 void Serialize (CArchive& archive) override ;
81105
82106protected:
83107 int ClassCount;
84108 float SourceEmptyFill;
85109 mutable CRandom Random;
86- mutable CDnn Dnn; // the network
87- CPtr<CSourceLayer> SourceLayer; // the reference to the source layer
88- CPtr<CSinkLayer> SinkLayer; // the reference to the terminator layer
89- CPtr<CDnnBlob> SourceBlob; // the source data blob
90- mutable CArray<float > tempExp; // the temporary array for exponent values to calculate softmax
110+ mutable CDnn Dnn; // the network
111+ CPtr<CSourceLayer> SourceLayer; // the reference to the source layer
112+ CPtr<CSinkLayer> SinkLayer; // the reference to the terminator layer
113+ CPtr<CDnnBlob> SourceBlob; // the source data blob
114+ mutable CArray<float > tempExp; // the temporary array for exponent values to calculate softmax
91115
92116 static const char * const SourceLayerName;
93117 static const char * const SinkLayerName;
@@ -101,7 +125,7 @@ class NEOML_API CDnnModelWrapper : public IModel {
101125 bool classify ( CClassificationResult& result ) const ;
102126};
103127
104- // /////////////////////////////////////////////////////////////////////////////////////////////////////
128+ // ---------------------------------------------------------------------------------------------------------------------
105129
106130// CDnnTrainingModelWrapper is the base class wrapping the neural network
107131// into an ITrainingModel interface so the network can be trained using the Train method
0 commit comments