@@ -28,14 +28,13 @@ class CDnn;
2828class NEOML_API CDnnSolver : virtual public IObject {
2929public:
3030 // Stores the calculated values of layer parameters gradients for further use in Train method
31- // forSharedWeightsLayer =true should only be used within layers that share weights with other layers.
31+ // sharedWeights =true should only be used within layers that share weights with other layers
3232 void AddDiff ( CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramDiffBlobs,
3333 bool sharedWeights = false );
3434
3535 // Modifies the trainable parameters of the network layers,
3636 // using the accumulated gradients and previous steps' history (moment, etc.)
3737 void Train ( float distributedCoeff = 1 .f );
38-
3938 // Resets to the initial state
4039 void Reset ();
4140
@@ -62,11 +61,17 @@ class NEOML_API CDnnSolver : virtual public IObject {
6261
6362 // Gets the reference to the math engine
6463 IMathEngine& MathEngine () const { return mathEngine; }
64+ // Get the intermediate result storing blob
65+ const CDnnBlob& TempBlob () const { return *temporaryBlob; }
66+ // Intermediate result storing blob
67+ // hide it to private, its allocated size may > actual
68+ CFloatHandle TempData ();
69+ // Reinitialize the intermediate result storing blob
70+ bool ReInitTempBlob ( int dataSize );
6571
6672 // Called once on Reset method call
6773 // Resets the stats in the inheriting instances to the initial state
6874 virtual void OnReset () {}
69-
7075 // On each training step the method is called once, before the call to TrainLayer for all layers
7176 virtual void OnTrain () {}
7277
@@ -78,13 +83,20 @@ class NEOML_API CDnnSolver : virtual public IObject {
7883
7984private:
8085 IMathEngine& mathEngine;
86+ CPtr<CDnnBlob> gradParams;
87+
88+ // MathEngine memory stored variables for calculations
8189 float learningRate;
8290 float regularizationL2;
8391 float regularizationL1;
8492 float maxGradientNorm;
8593 float clipGradientMin;
8694 float clipGradientMax;
8795
96+ // Intermediate result storing
97+ // hide it to private, its allocated size may > actual
98+ CPtr<CDnnBlob> temporaryBlob;
99+
88100 // The blobs sum
89101 struct CDiffBlobSum final {
90102 const CBaseLayer* LayerOwner{}; // for the given layer
@@ -141,7 +153,7 @@ void NEOML_API SerializeSolver( CArchive& archive, CDnn& dnn, CPtr<CDnnSolver>&
141153// ---------------------------------------------------------------------------------------------------------------------
142154
143155template <class T >
144- class CSolverClassRegistrar {
156+ class CSolverClassRegistrar final {
145157public:
146158 explicit CSolverClassRegistrar ( const char * solverName );
147159 ~CSolverClassRegistrar ();
@@ -168,40 +180,27 @@ inline CSolverClassRegistrar<T>::~CSolverClassRegistrar()
168180class NEOML_API CDnnSimpleGradientSolver : public CDnnSolver {
169181 NEOML_DNN_SOLVER ( CDnnSimpleGradientSolver )
170182public:
171- CDnnSimpleGradientSolver ( IMathEngine& mathEngine );
183+ explicit CDnnSimpleGradientSolver ( IMathEngine& mathEngine );
172184
173185 // Moment decay rate (moment is a weighted sum of previous gradients)
174186 float GetMomentDecayRate () const { return momentDecayRate; }
175187 void SetMomentDecayRate (float decayRate) { momentDecayRate = decayRate; }
176-
188+ // Backward compatibility mode
177189 bool IsInCompatibilityMode () const { return isInCompatibilityMode; }
178190 void SetCompatibilityMode ( bool compatibilityMode ) { isInCompatibilityMode = compatibilityMode; }
179191
180192 void Serialize ( CArchive& archive, const CDnn& dnn ) override ;
181193
182194protected:
195+ // Updates the trainable weights of the layer
183196 void TrainLayer ( const CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramBlobs,
184197 const CObjectArray<CDnnBlob>& paramDiffBlobs, CObjectArray<CDnnBlob>& gradientHistory ) override ;
185198
186199private:
187200 // Moment decay rate (moment is a weighted sum of previous gradients)
188201 float momentDecayRate;
189-
190202 // Backward compatibility mode
191203 bool isInCompatibilityMode;
192-
193- // Temporary variables of Handle type, used for calculations
194- enum TTempVariable {
195- TV_MomentDecayRateVar = 0 ,
196- TV_OpMomentDecayRateVar,
197- TV_OpRegL2MomentDecayRateVar,
198- TV_RateVar,
199- TV_L1Threshold,
200- TV_L1Mult,
201- TV_Count
202- };
203-
204- CPtr<CDnnBlob> tempVariables;
205204};
206205
207206// ---------------------------------------------------------------------------------------------------------------------
@@ -210,7 +209,7 @@ class NEOML_API CDnnSimpleGradientSolver : public CDnnSolver {
210209class NEOML_API CDnnAdaptiveGradientSolver : public CDnnSolver {
211210 NEOML_DNN_SOLVER ( CDnnAdaptiveGradientSolver )
212211public:
213- CDnnAdaptiveGradientSolver ( IMathEngine& mathEngine );
212+ explicit CDnnAdaptiveGradientSolver ( IMathEngine& mathEngine );
214213
215214 // Retrieves and sets the moment decay rate (moment is a weighted sum of previous gradients)
216215 float GetMomentDecayRate () const { return momentDecayRate; }
@@ -222,7 +221,7 @@ class NEOML_API CDnnAdaptiveGradientSolver : public CDnnSolver {
222221 // Retrieves and sets the espilon used to avoid division by zero when calculating second moment
223222 float GetEpsilon () const { return epsilon; }
224223 void SetEpsilon ( float newEpsilon ) { epsilon = newEpsilon; }
225-
224+ // Backward compatibility mode
226225 bool IsInCompatibilityMode () const { return isInCompatibilityMode; }
227226 void SetCompatibilityMode ( bool compatibilityMode ) { isInCompatibilityMode = compatibilityMode; }
228227
@@ -249,7 +248,7 @@ class NEOML_API CDnnAdaptiveGradientSolver : public CDnnSolver {
249248 // Prepares for the next training step
250249 void OnTrain () override ;
251250 // Updates the trainable weights of the layer
252- virtual void TrainLayer ( const CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramBlobs,
251+ void TrainLayer ( const CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramBlobs,
253252 const CObjectArray<CDnnBlob>& paramDiffBlobs, CObjectArray<CDnnBlob>& gradientHistory ) override ;
254253
255254private:
@@ -284,27 +283,8 @@ class NEOML_API CDnnAdaptiveGradientSolver : public CDnnSolver {
284283 bool isAmsGradEnabled;
285284 // Perform weight decay after calculating the moving averages
286285 bool isDecoupledWeightDecay;
287-
288286 // Backward compatibility mode
289287 bool isInCompatibilityMode;
290-
291- enum TTempVariable {
292- TV_MomentDecayRateVar = 0 ,
293- TV_SecondMomentDecayRateVar,
294- TV_RegL2Var,
295- TV_OpMomentDecayRateVar,
296- TV_OpSecondMomentDecayRateVar,
297- TV_RateVar,
298- TV_L1Threshold,
299- TV_L1Mult,
300- TV_EpsilonVar,
301- TV_Count
302- };
303-
304- // Temporary Handle variables for calculations
305- CPtr<CDnnBlob> tempVariables;
306-
307- CPtr<CDnnBlob> temporaryBlob;
308288};
309289
310290// ---------------------------------------------------------------------------------------------------------------------
@@ -389,26 +369,6 @@ class NEOML_API CDnnNesterovGradientSolver : public CDnnSolver {
389369 float muTPlusOne; // the mu coefficient for the next step
390370 float productMuT; // the product of mu coefficient over all steps including the current one
391371
392- enum TTempVariable {
393- TV_MomentDecayRateVar = 0 ,
394- TV_SecondMomentDecayRateVar,
395- TV_RegL2Var,
396- TV_OpMomentDecayRateVar,
397- TV_OpSecondMomentDecayRateVar,
398- TV_RateVar,
399- TV_L1Threshold,
400- TV_L1Mult,
401- TV_EpsilonVar,
402- TV_InvOpSecondMomentDecayRateNVar, // 1 / (1 - secondMomentDecay ^ N)
403- TV_MBarGradMultVar, // the gradient coefficient in the total sum
404- TV_MBarMomentMultVar, // the moment coefficient in the total sum
405- TV_Count
406- };
407-
408- // Temporary blobs for calculations
409- CPtr<CDnnBlob> tempVariables;
410-
411- CPtr<CDnnBlob> temporaryBlob;
412372 // m with a stroke (from the paper referred to)
413373 // It is a weighted sum of the gradient and the first moment
414374 CPtr<CDnnBlob> mBarBlob ;
@@ -492,11 +452,12 @@ class NEOML_API CDnnLambGradientSolver : public CDnnSolver {
492452 void Serialize ( CArchive& archive, const CDnn& dnn ) override ;
493453
494454protected:
455+ // Prepares for the next training step
456+ void OnTrain () override ;
457+ // Updates the trainable weights of the layer
495458 void TrainLayer ( const CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramBlobs,
496459 const CObjectArray<CDnnBlob>& paramDiffBlobs, CObjectArray<CDnnBlob>& gradientHistory ) override ;
497460
498- void OnTrain () override ;
499-
500461private:
501462 // The gradientHistory array stores the previous values of gradients of different types
502463 enum TGradientHistoryType {
@@ -519,48 +480,28 @@ class NEOML_API CDnnLambGradientSolver : public CDnnSolver {
519480 // Is NVLamb modification used
520481 bool useNvLamb;
521482
522- enum TTempVariable {
523- TV_MomentDecayRateVar,
524- TV_SecondMomentDecayRateVar,
525- TV_OpMomentDecayRateVar,
526- TV_OpSecondMomentDecayRateVar,
527- TV_RateVar,
528- TV_EpsilonVar,
529- TV_WeightDecayVar,
530- TV_ClipMultiplierVar,
531- TV_LayerNormVar,
532- TV_TrustRatioVar,
533- TV_L2NormVar,
534-
535- TV_Count
536- };
537-
538- CPtr<CDnnBlob> tempVariables;
539-
540- CPtr<CDnnBlob> tempBlob;
541-
483+ CPtr<CDnnBlob> normL2Var;
542484 CArray<float > layersGradientNormSquare;
543485 float totalGradientNorm;
544486
545487 // Layer excluded from optimization
546- struct CExcludedLayer {
488+ struct CExcludedLayer final {
547489 // Layer name (or substring)
548490 CString LayerName;
549491 // Match type (exact or substring)
550- TExcludeLayerNameMatchType MatchType;
492+ TExcludeLayerNameMatchType MatchType{ ELNMT_Exact } ;
551493 // Parameter number
552494 // -1 if all parameters
553- int ParamIndex;
554-
555- CExcludedLayer () : MatchType( ELNMT_Exact ), ParamIndex( NotFound ) {}
495+ int ParamIndex{ NotFound };
556496 };
557497 // Layers excluded from weight decay
558498 CArray<CExcludedLayer> excludedLayers;
499+ mutable CPtr<CDnnBlob> tempNormBlob;
559500
560501 float calcL2NormAverage ( const CConstFloatHandle& data, int dataSize ) const ;
561502 void getWeightDecayIndices ( const CBaseLayer& layer, int paramsCount, CHashTable<int >& indexes ) const ;
562503
563- void calcNormalizeMultiplier ( const CDnnBlob& weights, const CDnnBlob& update, const CFloatHandle& multiplier ) const ;
504+ float calcNormalizeMultiplier ( const CDnnBlob& weights, const CDnnBlob& update ) const ;
564505};
565506
566507template <typename TLayer>
0 commit comments