@@ -33,10 +33,8 @@ public class IntentClassifier
33
33
private bool _isModelReady ;
34
34
public bool isModelReady => _isModelReady ;
35
35
private ClassifierSetting _settings ;
36
-
37
36
private string [ ] _labels ;
38
37
public string [ ] Labels => GetLabels ( ) ;
39
-
40
38
private int _numLabels
41
39
{
42
40
get
@@ -111,7 +109,6 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
111
109
batch_size : trainingParams . BatchSize ,
112
110
epochs : trainingParams . Epochs ,
113
111
callbacks : callbacks ,
114
- // validation_split: 0.1f,
115
112
shuffle : true ) ;
116
113
117
114
_model . save_weights ( weights ) ;
@@ -155,9 +152,18 @@ public NDArray GetTextEmbedding(string text)
155
152
156
153
public ( NDArray , NDArray ) PrepareLoadData ( )
157
154
{
158
- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
159
- string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . RAW_DATA_DIR ) ;
160
- string saveLabelDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , _settings . LABEL_FILE_NAME ) ;
155
+ var agentService = _services . CreateScope ( )
156
+ . ServiceProvider
157
+ . GetRequiredService < IAgentService > ( ) ;
158
+ string rootDirectory = Path . Combine (
159
+ agentService . GetDataDir ( ) ,
160
+ _settings . RAW_DATA_DIR
161
+ ) ;
162
+ string saveLabelDirectory = Path . Combine (
163
+ agentService . GetDataDir ( ) ,
164
+ _settings . MODEL_DIR ,
165
+ _settings . LABEL_FILE_NAME
166
+ ) ;
161
167
162
168
if ( ! Directory . Exists ( rootDirectory ) )
163
169
{
@@ -171,7 +177,10 @@ public NDArray GetTextEmbedding(string text)
171
177
172
178
foreach ( var filePath in GetFiles ( ) )
173
179
{
174
- var texts = File . ReadAllLines ( filePath , Encoding . UTF8 ) . Select ( x => TextClean ( x ) ) . ToList ( ) ;
180
+ var texts = File . ReadAllLines ( filePath , Encoding . UTF8 )
181
+ . Select ( x => TextClean ( x ) )
182
+ . ToList ( ) ;
183
+
175
184
vectorList . AddRange ( vector . GetVectors ( texts ) ) ;
176
185
string fileName = Path . GetFileNameWithoutExtension ( filePath ) ;
177
186
labelList . AddRange ( Enumerable . Repeat ( fileName , texts . Count ) . ToList ( ) ) ;
@@ -187,16 +196,19 @@ public NDArray GetTextEmbedding(string text)
187
196
for ( int i = 0 ; i < vectorList . Count ; i ++ )
188
197
{
189
198
x [ i ] = vectorList [ i ] ;
190
- // y[i] = (float)uniqueLabelList.IndexOf(labelList[i]);
191
199
y [ i ] = ( float ) Array . IndexOf ( uniqueLabelList , labelList [ i ] ) ;
192
200
}
201
+
193
202
return ( x , y ) ;
194
203
}
195
204
196
205
public string [ ] GetFiles ( string prefix = "intent" )
197
206
{
198
- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
207
+ var agentService = _services . CreateScope ( )
208
+ . ServiceProvider
209
+ . GetRequiredService < IAgentService > ( ) ;
199
210
string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . RAW_DATA_DIR ) ;
211
+
200
212
return Directory . GetFiles ( rootDirectory )
201
213
. Where ( x => Path . GetFileNameWithoutExtension ( x )
202
214
. StartsWith ( prefix ) )
@@ -208,8 +220,15 @@ public string[] GetLabels()
208
220
{
209
221
if ( _labels == null )
210
222
{
211
- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
212
- string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , _settings . LABEL_FILE_NAME ) ;
223
+ var agentService = _services . CreateScope ( )
224
+ . ServiceProvider
225
+ . GetRequiredService < IAgentService > ( ) ;
226
+ string rootDirectory = Path . Combine (
227
+ agentService . GetDataDir ( ) ,
228
+ _settings . MODEL_DIR ,
229
+ _settings . LABEL_FILE_NAME
230
+ ) ;
231
+
213
232
var labelText = File . ReadAllLines ( rootDirectory ) ;
214
233
_labels = labelText . OrderBy ( x => x ) . ToArray ( ) ;
215
234
}
@@ -223,9 +242,11 @@ public string TextClean(string text)
223
242
// Remove digits
224
243
// To lowercase
225
244
var processedText = Regex . Replace ( text , "[AB0-9]" , " " ) ;
226
- processedText = string . Join ( "" , processedText . Select ( c => char . IsPunctuation ( c ) ? ' ' : c ) . ToList ( ) ) ;
227
- processedText = processedText . Replace ( " " , " " ) . ToLower ( ) ;
228
- return processedText ;
245
+ var replacedTextList = processedText . Select ( c => char . IsPunctuation ( c ) ? ' ' : c ) . ToList ( ) ;
246
+
247
+ return string . Join ( "" , replacedTextList )
248
+ . Replace ( " " , " " )
249
+ . ToLower ( ) ;
229
250
}
230
251
231
252
public string Predict ( NDArray vector , float confidenceScore = 0.9f )
@@ -235,8 +256,8 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
235
256
InitClassifer ( ) ;
236
257
}
237
258
259
+ // Generate and post-process prediction
238
260
var prob = _model . predict ( vector ) . numpy ( ) ;
239
-
240
261
var probLabel = tf . arg_max ( prob , - 1 ) . numpy ( ) . ToArray < long > ( ) ;
241
262
prob = np . squeeze ( prob , axis : 0 ) ;
242
263
@@ -245,9 +266,9 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
245
266
return string . Empty ;
246
267
}
247
268
248
- var prediction = _labels [ probLabel [ 0 ] ] ;
269
+ var labelIndex = probLabel [ 0 ] ;
249
270
250
- return prediction ;
271
+ return _labels [ labelIndex ] ;
251
272
}
252
273
public void InitClassifer ( bool inference = true )
253
274
{
0 commit comments