1- /*  Copyright © 2017-2020  ABBYY Production LLC 
1+ /*  Copyright © 2017-2024  ABBYY
22
33Licensed under the Apache License, Version 2.0 (the "License"); 
44you may not use this file except in compliance with the License. 
@@ -27,31 +27,46 @@ class CSourceLayer;
2727class  CSinkLayer ;
2828class  CDnnTrainingModelWrapper ;
2929
30+ struct  IShuffledBatchGenerator  {
31+ 	virtual  ~IShuffledBatchGenerator () = default ;
32+ 
33+ 	virtual  const  CArray<int >& GenerateBatchIndexes ( int  batchSize, bool  batchShuffled ) = 0;
34+ 	virtual  bool  HasUnseenElements () const  = 0;
35+ 	virtual  void  DeleteUnseenElement ( int  index ) = 0;
36+ };
37+ 
38+ // ---------------------------------------------------------------------------------------------------------------------
39+ 
3040//  CProblemSourceLayer is a wrapper over the IProblem interface. 
3141//  On each iteration, it passes BatchSize vectors into the network for processing.
3242class  NEOML_API  CProblemSourceLayer : public CBaseLayer {
3343	NEOML_DNN_LAYER ( CProblemSourceLayer )
3444public: 
35- 	explicit  CProblemSourceLayer ( IMathEngine& mathEngine );
45+ 	explicit  CProblemSourceLayer ( IMathEngine& mathEngine ) :
46+ 		CBaseLayer( mathEngine, " CCnnProblemSourceLayer"  , /* isLearnable*/ false  ) {}
3647
3748	void  Serialize ( CArchive& archive ) override ;
3849
3950	int  GetBatchSize () const  { return  batchSize; }
40- 	void  SetBatchSize (int  _batchSize );
51+ 	void  SetBatchSize (  int  batchSize  );
4152
4253	//  The filler for empty values that are not present in a sparse vector
4354	float  GetEmptyFill () const  { return  emptyFill; }
44- 	void  SetEmptyFill (float  _emptyFill) { NeoAssert (GetDnn () == 0 ); emptyFill = _emptyFill; }
55+ 	void  SetEmptyFill (  float  _emptyFill  ) { NeoAssert (  GetDnn () == nullptr   ); emptyFill = _emptyFill; }
4556
4657	//  You may only change the problem for the layer that is connected to a network
4758	//  if the number of classes and the number of input vectors stay the same
4859	CPtr<const  IProblem> GetProblem () const  { return  problem; }
49- 	void  SetProblem (const  CPtr<const  IProblem>& _problem );
60+ 	void  SetProblem (  const  CPtr<const  IProblem>& problem,  bool  shuffle =  false ,  unsigned  seed =  42   );
5061
5162	//  Retrieves and sets the data type for class labels
5263	TBlobType GetLabelType () const  { return  labelType; }
5364	void  SetLabelType ( TBlobType newLabelType );
5465
66+ 	//  Still not the end of an epoch
67+ 	bool  HasUnseenElements () const 
68+ 		{ return  ( shuffled && shuffled->HasUnseenElements () ) || nextProblemIndex < problem->GetVectorCount (); }
69+ 
5570protected: 
5671	~CProblemSourceLayer () override  = default ;
5772
@@ -60,34 +75,43 @@ class NEOML_API CProblemSourceLayer : public CBaseLayer {
6075	void  BackwardOnce () override ;
6176
6277private: 
63- 	float  emptyFill;		//  the empty values filler (for values not represented in a sparse vector)
64- 	int  batchSize;			//  the size of the batch passed to the network
65- 	int  nextProblemIndex;	//  the index of the next element in the problem to be passed
66- 	CPtr<const  IProblem> problem;	//  the classification problem the network is solving
67- 	TBlobType labelType;		//  the data type for labels
68- 	CArray<float > exchangeBufs[3 ];
78+ 	float  emptyFill = 0 ; //  the empty values filler (for values not represented in a sparse vector)
79+ 	int  batchSize = 1 ; //  the size of the batch passed to the network
80+ 	int  nextProblemIndex = NotFound; //  the index of the next element in the problem to be passed
81+ 	TBlobType labelType = CT_Float; //  the data type for labels
82+ 	CPtr<const  IProblem> problem; //  the classification problem the network is solving
83+ 	CPtrOwner<IShuffledBatchGenerator> shuffled; //  if a shuffled batch input
84+ 
85+ 	enum  { EB_Data, EB_Label, EB_Weight, EB_Count_ };
86+ 	CArray<float > exchangeBufs[EB_Count_]{};
87+ 
88+ 	void  fillExchangeBuffers ( int  shift, int  index );
6989};
7090
71- // /////////////////////////////////////////////////////////////////////////////////////////////////////
91+ //  Creates CProblemSourceLayer with the name
92+ NEOML_API CProblemSourceLayer* ProblemSource ( CDnn& dnn, const  char * name,
93+ 	TBlobType labelType, int  batchSize, const  CPtr<const  IProblem>& problem, bool  shuffle = false , unsigned  seed = 42  );
94+ 
95+ // ---------------------------------------------------------------------------------------------------------------------
7296
7397//  CDnnModelWrapper is the base class wrapping the trained neural network into the IModel interface
7498class  NEOML_API  CDnnModelWrapper : public IModel {
7599public: 
76100	explicit  CDnnModelWrapper (IMathEngine& mathEngine, unsigned  int  seed = 0xDEADFACE );
77101
78- 	int  GetClassCount () const  override ; 
102+ 	int  GetClassCount () const  override  {  return  ClassCount; } 
79103	bool  Classify (const  CFloatVectorDesc& data, CClassificationResult& result) const  override ;
80104	void  Serialize (CArchive& archive) override ;
81105
82106protected: 
83107	int  ClassCount;
84108	float  SourceEmptyFill;
85109	mutable  CRandom Random;
86- 	mutable  CDnn Dnn;	 //  the network
87- 	CPtr<CSourceLayer> SourceLayer;	 //  the reference to the source layer
88- 	CPtr<CSinkLayer> SinkLayer;		 //  the reference to the terminator layer
89- 	CPtr<CDnnBlob> SourceBlob;			 //  the source data blob
90- 	mutable  CArray<float > tempExp;		 //  the temporary array for exponent values to calculate softmax
110+ 	mutable  CDnn Dnn;  //  the network
111+ 	CPtr<CSourceLayer> SourceLayer;  //  the reference to the source layer
112+ 	CPtr<CSinkLayer> SinkLayer;  //  the reference to the terminator layer
113+ 	CPtr<CDnnBlob> SourceBlob;  //  the source data blob
114+ 	mutable  CArray<float > tempExp;  //  the temporary array for exponent values to calculate softmax
91115
92116	static  const  char * const  SourceLayerName;
93117	static  const  char * const  SinkLayerName;
@@ -101,7 +125,7 @@ class NEOML_API CDnnModelWrapper : public IModel {
101125	bool  classify ( CClassificationResult& result ) const ;
102126};
103127
104- // ///////////////////////////////////////////////////////////////////////////////////////////////////// 
128+ // --------------------------------------------------------------------------------------------------------------------- 
105129
106130//  CDnnTrainingModelWrapper is the base class wrapping the neural network 
107131//  into an ITrainingModel interface so the network can be trained using the Train method
0 commit comments