2
2
using System . IO ;
3
3
using System . Text ;
4
4
using System . Collections . Generic ;
5
- using Tensorflow ;
6
5
using static Tensorflow . KerasApi ;
7
6
using Tensorflow . Keras . Engine ;
8
7
using Tensorflow . NumPy ;
16
15
using Microsoft . Extensions . DependencyInjection ;
17
16
using System . Linq ;
18
17
using Tensorflow . Keras ;
19
- using System . Numerics ;
20
- using Newtonsoft . Json ;
21
- using Tensorflow . Keras . Layers ;
22
18
using BotSharp . Abstraction . Agents ;
23
- using BotSharp . Abstraction . Knowledges ;
24
19
25
20
namespace BotSharp . Plugin . RoutingSpeeder . Providers ;
26
21
@@ -33,11 +28,8 @@ public class IntentClassifier
33
28
private bool _isModelReady ;
34
29
public bool isModelReady => _isModelReady ;
35
30
private ClassifierSetting _settings ;
36
-
37
31
private string [ ] _labels ;
38
-
39
32
public string [ ] Labels => GetLabels ( ) ;
40
-
41
33
private int _numLabels
42
34
{
43
35
get
@@ -67,7 +59,7 @@ private void Build()
67
59
}
68
60
69
61
var vector = _services . GetServices < ITextEmbedding > ( )
70
- . FirstOrDefault ( x => x . GetType ( ) . FullName . EndsWith ( _knowledgeBaseSettings . TextEmbedding ) ) ;
62
+ . FirstOrDefault ( x => x . GetType ( ) . FullName . EndsWith ( _knowledgeBaseSettings . TextEmbedding ) ) ;
71
63
72
64
var layers = new List < ILayer >
73
65
{
@@ -89,28 +81,29 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
89
81
{
90
82
_model . compile ( optimizer : keras . optimizers . Adam ( trainingParams . LearningRate ) ,
91
83
loss : keras . losses . SparseCategoricalCrossentropy ( ) ,
92
- metrics : new [ ] { "accuracy" }
93
- ) ;
84
+ metrics : new [ ] { "accuracy" } ) ;
94
85
95
- CallbackParams callback_parameters = new CallbackParams
86
+ var callback_parameters = new CallbackParams
96
87
{
97
88
Model = _model ,
98
89
Epochs = trainingParams . Epochs ,
99
90
Verbose = 1 ,
100
91
Steps = 10
101
92
} ;
102
93
103
- ICallback earlyStop = new EarlyStopping ( callback_parameters , "accuracy" ) ;
94
+ var earlyStop = new EarlyStopping ( callback_parameters , "accuracy" ) ;
104
95
105
- var callbacks = new List < ICallback > ( ) { earlyStop } ;
96
+ var callbacks = new List < ICallback > ( )
97
+ {
98
+ earlyStop
99
+ } ;
106
100
107
101
var weights = LoadWeights ( trainingParams . Inference ) ;
108
102
109
103
_model . fit ( x , y ,
110
104
batch_size : trainingParams . BatchSize ,
111
105
epochs : trainingParams . Epochs ,
112
106
callbacks : callbacks ,
113
- // validation_split: 0.1f,
114
107
shuffle : true ) ;
115
108
116
109
_model . save_weights ( weights ) ;
@@ -120,7 +113,9 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
120
113
121
114
public string LoadWeights ( bool inference = true )
122
115
{
123
- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
116
+ var agentService = _services . CreateScope ( )
117
+ . ServiceProvider
118
+ . GetRequiredService < IAgentService > ( ) ;
124
119
125
120
var weightsFile = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , $ "intent-classifier.h5") ;
126
121
@@ -129,13 +124,13 @@ public string LoadWeights(bool inference = true)
129
124
_model . load_weights ( weightsFile ) ;
130
125
_isModelReady = true ;
131
126
Console . WriteLine ( $ "Successfully load the weights!") ;
132
-
133
127
}
134
128
else
135
129
{
136
130
var logInfo = inference ? "No available weights." : "Will implement model training process and write trained weights into local" ;
137
131
Console . WriteLine ( logInfo ) ;
138
132
}
133
+
139
134
return weightsFile ;
140
135
}
141
136
@@ -152,24 +147,33 @@ public NDArray GetTextEmbedding(string text)
152
147
153
148
public ( NDArray , NDArray ) PrepareLoadData ( )
154
149
{
155
- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
156
- string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . RAW_DATA_DIR ) ;
157
- string saveLabelDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , _settings . LABEL_FILE_NAME ) ;
150
+ var agentService = _services . CreateScope ( )
151
+ . ServiceProvider
152
+ . GetRequiredService < IAgentService > ( ) ;
153
+ string rootDirectory = Path . Combine (
154
+ agentService . GetDataDir ( ) ,
155
+ _settings . RAW_DATA_DIR ) ;
156
+ string saveLabelDirectory = Path . Combine (
157
+ agentService . GetDataDir ( ) ,
158
+ _settings . MODEL_DIR ,
159
+ _settings . LABEL_FILE_NAME ) ;
158
160
159
161
if ( ! Directory . Exists ( rootDirectory ) )
160
162
{
161
163
throw new Exception ( $ "No training data found! Please put training data in this path: { rootDirectory } ") ;
162
164
}
163
165
166
+ // Do embedding and store results
164
167
var vector = _services . GetRequiredService < ITextEmbedding > ( ) ;
165
-
166
168
var vectorList = new List < float [ ] > ( ) ;
167
-
168
169
var labelList = new List < string > ( ) ;
169
170
170
171
foreach ( var filePath in GetFiles ( ) )
171
172
{
172
- var texts = File . ReadAllLines ( filePath , Encoding . UTF8 ) . Select ( x => TextClean ( x ) ) . ToList ( ) ;
173
+ var texts = File . ReadAllLines ( filePath , Encoding . UTF8 )
174
+ . Select ( x => TextClean ( x ) )
175
+ . ToList ( ) ;
176
+
173
177
vectorList . AddRange ( vector . GetVectors ( texts ) ) ;
174
178
string fileName = Path . GetFileNameWithoutExtension ( filePath ) ;
175
179
labelList . AddRange ( Enumerable . Repeat ( fileName , texts . Count ) . ToList ( ) ) ;
@@ -185,25 +189,39 @@ public NDArray GetTextEmbedding(string text)
185
189
for ( int i = 0 ; i < vectorList . Count ; i ++ )
186
190
{
187
191
x [ i ] = vectorList [ i ] ;
188
- // y[i] = (float)uniqueLabelList.IndexOf(labelList[i]);
189
192
y [ i ] = ( float ) Array . IndexOf ( uniqueLabelList , labelList [ i ] ) ;
190
193
}
194
+
191
195
return ( x , y ) ;
192
196
}
193
197
194
198
public string [ ] GetFiles ( string prefix = "intent" )
195
199
{
196
- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
200
+ var agentService = _services . CreateScope ( )
201
+ . ServiceProvider
202
+ . GetRequiredService < IAgentService > ( ) ;
197
203
string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . RAW_DATA_DIR ) ;
198
- return Directory . GetFiles ( rootDirectory ) . Where ( x => Path . GetFileNameWithoutExtension ( x ) . StartsWith ( prefix ) ) . OrderBy ( x => x ) . ToArray ( ) ;
204
+
205
+ return Directory . GetFiles ( rootDirectory )
206
+ . Where ( x => Path . GetFileNameWithoutExtension ( x )
207
+ . StartsWith ( prefix ) )
208
+ . OrderBy ( x => x )
209
+ . ToArray ( ) ;
199
210
}
200
211
201
212
public string [ ] GetLabels ( )
202
213
{
203
214
if ( _labels == null )
204
215
{
205
- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
206
- string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , _settings . LABEL_FILE_NAME ) ;
216
+ var agentService = _services . CreateScope ( )
217
+ . ServiceProvider
218
+ . GetRequiredService < IAgentService > ( ) ;
219
+ string rootDirectory = Path . Combine (
220
+ agentService . GetDataDir ( ) ,
221
+ _settings . MODEL_DIR ,
222
+ _settings . LABEL_FILE_NAME
223
+ ) ;
224
+
207
225
var labelText = File . ReadAllLines ( rootDirectory ) ;
208
226
_labels = labelText . OrderBy ( x => x ) . ToArray ( ) ;
209
227
}
@@ -217,9 +235,11 @@ public string TextClean(string text)
217
235
// Remove digits
218
236
// To lowercase
219
237
var processedText = Regex . Replace ( text , "[AB0-9]" , " " ) ;
220
- processedText = string . Join ( "" , processedText . Select ( c => char . IsPunctuation ( c ) ? ' ' : c ) . ToList ( ) ) ;
221
- processedText = processedText . Replace ( " " , " " ) . ToLower ( ) ;
222
- return processedText ;
238
+ var replacedTextList = processedText . Select ( c => char . IsPunctuation ( c ) ? ' ' : c ) . ToList ( ) ;
239
+
240
+ return string . Join ( "" , replacedTextList )
241
+ . Replace ( " " , " " )
242
+ . ToLower ( ) ;
223
243
}
224
244
225
245
public string Predict ( NDArray vector , float confidenceScore = 0.9f )
@@ -229,8 +249,8 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
229
249
InitClassifer ( ) ;
230
250
}
231
251
252
+ // Generate and post-process prediction
232
253
var prob = _model . predict ( vector ) . numpy ( ) ;
233
-
234
254
var probLabel = tf . arg_max ( prob , - 1 ) . numpy ( ) . ToArray < long > ( ) ;
235
255
prob = np . squeeze ( prob , axis : 0 ) ;
236
256
@@ -239,9 +259,9 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
239
259
return string . Empty ;
240
260
}
241
261
242
- var prediction = _labels [ probLabel [ 0 ] ] ;
262
+ var labelIndex = probLabel [ 0 ] ;
243
263
244
- return prediction ;
264
+ return _labels [ labelIndex ] ;
245
265
}
246
266
public void InitClassifer ( bool inference = true )
247
267
{
0 commit comments