Skip to content

Commit b4df53a

Browse files
committed
[NeoML] IShuffledBatchGenerator in CProblemSourceLayer (neoml-lib#1104)
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 86bbd5c commit b4df53a

File tree

8 files changed

+511
-186
lines changed

8 files changed

+511
-186
lines changed

NeoML/include/NeoML/Dnn/Layers/FullyConnectedSourceLayer.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2020 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -78,7 +78,9 @@ class NEOML_API CFullyConnectedSourceLayer : public CFullyConnectedLayer {
7878
bool isBatchLoaded( int index ) const;
7979
};
8080

81-
NEOML_API CLayerWrapper<CFullyConnectedSourceLayer> FullyConnectedSource(
82-
TBlobType labelType, int batchSize, int maxBatchCount, IProblem* problem );
81+
// Creates CFullyConnectedSourceLayer with the name
82+
NEOML_API CFullyConnectedSourceLayer* FullyConnectedSource( CDnn& dnn, const char* name,
83+
TBlobType labelType, int batchSize, int maxBatchCount, IProblem* problem,
84+
int numberOfElements = 1, bool isZeroFreeTerm = false );
8385

8486
} // namespace NeoML

NeoML/include/NeoML/Dnn/Layers/ModelWrapperLayer.h

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2020 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -27,31 +27,46 @@ class CSourceLayer;
2727
class CSinkLayer;
2828
class CDnnTrainingModelWrapper;
2929

30+
struct IShuffledBatchGenerator {
31+
virtual ~IShuffledBatchGenerator() = default;
32+
33+
virtual const CArray<int>& GenerateBatchIndexes( int batchSize, bool batchShuffled ) = 0;
34+
virtual bool HasUnseenElements() const = 0;
35+
virtual void DeleteUnseenElement( int index ) = 0;
36+
};
37+
38+
//---------------------------------------------------------------------------------------------------------------------
39+
3040
// CProblemSourceLayer is a wrapper over the IProblem interface.
3141
// On each iteration, it passes BatchSize vectors into the network for processing.
3242
class NEOML_API CProblemSourceLayer : public CBaseLayer {
3343
NEOML_DNN_LAYER( CProblemSourceLayer )
3444
public:
35-
explicit CProblemSourceLayer( IMathEngine& mathEngine );
45+
explicit CProblemSourceLayer( IMathEngine& mathEngine ) :
46+
CBaseLayer( mathEngine, "CCnnProblemSourceLayer", /*isLearnable*/false ) {}
3647

3748
void Serialize( CArchive& archive ) override;
3849

3950
int GetBatchSize() const { return batchSize; }
40-
void SetBatchSize(int _batchSize);
51+
void SetBatchSize( int batchSize );
4152

4253
// The filler for empty values that are not present in a sparse vector
4354
float GetEmptyFill() const { return emptyFill; }
44-
void SetEmptyFill(float _emptyFill) { NeoAssert(GetDnn() == 0); emptyFill = _emptyFill; }
55+
void SetEmptyFill( float _emptyFill ) { NeoAssert( GetDnn() == nullptr ); emptyFill = _emptyFill; }
4556

4657
// You may only change the problem for the layer that is connected to a network
4758
// if the number of classes and the number of input vectors stay the same
4859
CPtr<const IProblem> GetProblem() const { return problem; }
49-
void SetProblem(const CPtr<const IProblem>& _problem);
60+
void SetProblem( const CPtr<const IProblem>& problem, bool shuffle = false, unsigned seed = 42 );
5061

5162
// Retrieves and sets the data type for class labels
5263
TBlobType GetLabelType() const { return labelType; }
5364
void SetLabelType( TBlobType newLabelType );
5465

66+
// Still not the end of an epoch
67+
bool HasUnseenElements() const
68+
{ return ( shuffled && shuffled->HasUnseenElements() ) || nextProblemIndex < problem->GetVectorCount(); }
69+
5570
protected:
5671
~CProblemSourceLayer() override = default;
5772

@@ -60,34 +75,43 @@ class NEOML_API CProblemSourceLayer : public CBaseLayer {
6075
void BackwardOnce() override;
6176

6277
private:
63-
float emptyFill; // the empty values filler (for values not represented in a sparse vector)
64-
int batchSize; // the size of the batch passed to the network
65-
int nextProblemIndex; // the index of the next element in the problem to be passed
66-
CPtr<const IProblem> problem; // the classification problem the network is solving
67-
TBlobType labelType; // the data type for labels
68-
CArray<float> exchangeBufs[3];
78+
float emptyFill = 0; // the empty values filler (for values not represented in a sparse vector)
79+
int batchSize = 1; // the size of the batch passed to the network
80+
int nextProblemIndex = NotFound; // the index of the next element in the problem to be passed
81+
TBlobType labelType = CT_Float; // the data type for labels
82+
CPtr<const IProblem> problem; // the classification problem the network is solving
83+
CPtrOwner<IShuffledBatchGenerator> shuffled; // if a shuffled batch input
84+
85+
enum { EB_Data, EB_Label, EB_Weight, EB_Count_ };
86+
CArray<float> exchangeBufs[EB_Count_]{};
87+
88+
void fillExchangeBuffers( int shift, int index );
6989
};
7090

71-
///////////////////////////////////////////////////////////////////////////////////////////////////////
91+
// Creates CProblemSourceLayer with the name
92+
NEOML_API CProblemSourceLayer* ProblemSource( CDnn& dnn, const char* name,
93+
TBlobType labelType, int batchSize, const CPtr<const IProblem>& problem, bool shuffle = false, unsigned seed = 42 );
94+
95+
//---------------------------------------------------------------------------------------------------------------------
7296

7397
// CDnnModelWrapper is the base class wrapping the trained neural network into the IModel interface
7498
class NEOML_API CDnnModelWrapper : public IModel {
7599
public:
76100
explicit CDnnModelWrapper(IMathEngine& mathEngine, unsigned int seed = 0xDEADFACE);
77101

78-
int GetClassCount() const override;
102+
int GetClassCount() const override { return ClassCount; }
79103
bool Classify(const CFloatVectorDesc& data, CClassificationResult& result) const override;
80104
void Serialize(CArchive& archive) override;
81105

82106
protected:
83107
int ClassCount;
84108
float SourceEmptyFill;
85109
mutable CRandom Random;
86-
mutable CDnn Dnn; // the network
87-
CPtr<CSourceLayer> SourceLayer; // the reference to the source layer
88-
CPtr<CSinkLayer> SinkLayer; // the reference to the terminator layer
89-
CPtr<CDnnBlob> SourceBlob; // the source data blob
90-
mutable CArray<float> tempExp; // the temporary array for exponent values to calculate softmax
110+
mutable CDnn Dnn; // the network
111+
CPtr<CSourceLayer> SourceLayer; // the reference to the source layer
112+
CPtr<CSinkLayer> SinkLayer; // the reference to the terminator layer
113+
CPtr<CDnnBlob> SourceBlob; // the source data blob
114+
mutable CArray<float> tempExp; // the temporary array for exponent values to calculate softmax
91115

92116
static const char* const SourceLayerName;
93117
static const char* const SinkLayerName;
@@ -101,7 +125,7 @@ class NEOML_API CDnnModelWrapper : public IModel {
101125
bool classify( CClassificationResult& result ) const;
102126
};
103127

104-
///////////////////////////////////////////////////////////////////////////////////////////////////////
128+
//---------------------------------------------------------------------------------------------------------------------
105129

106130
// CDnnTrainingModelWrapper is the base class wrapping the neural network
107131
// into an ITrainingModel interface so the network can be trained using the Train method

NeoML/include/NeoML/TraditionalML/Shuffler.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2020 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ namespace NeoML {
2323
// The shuffler class
2424
// Uses the standard shuffling algorithm, not all at once but sequentially; as a result, the first N positions are only shuffled among themselves
2525
// For example, you can use it to get the random indices in an array
26-
class NEOML_API CShuffler {
26+
class NEOML_API CShuffler final {
2727
public:
2828
CShuffler( CRandom& _random, int count );
2929

@@ -34,6 +34,10 @@ class NEOML_API CShuffler {
3434
int SetNext( int index );
3535
// Finishes shuffling and returns all indices
3636
const CArray<int>& GetAllIndices();
37+
// Is shuffling finished
38+
bool IsFinished() const { return nextIndex == indices.Size(); }
39+
// Reset state to use shuffler again for the same array
40+
void Reset() { nextIndex = 0; }
3741

3842
private:
3943
CRandom& random;

NeoML/src/Dnn/Layers/FullyConnectedSourceLayer.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2020 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -310,15 +310,20 @@ bool CFullyConnectedSourceLayer::isBatchLoaded( int index ) const
310310
return ( batchFirstLoadedIndex <= index && index <= batchLastLoadedIndex );
311311
}
312312

313-
CLayerWrapper<CFullyConnectedSourceLayer> FullyConnectedSource( TBlobType labelType,
314-
int batchSize, int maxBatchCount, IProblem* problem )
313+
// Creates CFullyConnectedSourceLayer with the name
314+
CFullyConnectedSourceLayer* FullyConnectedSource( CDnn& dnn, const char* name,
315+
TBlobType labelType, int batchSize, int maxBatchCount, IProblem* problem, int numberOfElements, bool isZeroFreeTerm )
315316
{
316-
return CLayerWrapper<CFullyConnectedSourceLayer>( "FullyConnectedSource", [=, &problem]( CFullyConnectedSourceLayer* result ) {
317-
result->SetLabelType( labelType );
318-
result->SetBatchSize( batchSize );
319-
result->SetMaxBatchCount( maxBatchCount );
320-
result->SetProblem( problem );
321-
} );
317+
CPtr<CFullyConnectedSourceLayer> result = new CFullyConnectedSourceLayer( dnn.GetMathEngine() );
318+
result->SetLabelType( labelType );
319+
result->SetBatchSize( batchSize );
320+
result->SetMaxBatchCount( maxBatchCount );
321+
result->SetProblem( problem );
322+
result->SetNumberOfElements( numberOfElements );
323+
result->SetZeroFreeTerm( isZeroFreeTerm );
324+
result->SetName( name );
325+
dnn.AddLayer( *result );
326+
return result;
322327
}
323328

324329
} // namespace NeoML

0 commit comments

Comments
 (0)