Skip to content

Commit a238ae5

Browse files
authored
Merge pull request #132 from evan-cao-wb/re-format-the-code
re format the code
2 parents a2bae74 + f7f9e96 commit a238ae5

File tree

1 file changed

+54
-34
lines changed

1 file changed

+54
-34
lines changed

src/Plugins/BotSharp.Plugin.RoutingSpeeder/Providers/IntentClassifier.cs

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using System.IO;
33
using System.Text;
44
using System.Collections.Generic;
5-
using Tensorflow;
65
using static Tensorflow.KerasApi;
76
using Tensorflow.Keras.Engine;
87
using Tensorflow.NumPy;
@@ -16,11 +15,7 @@
1615
using Microsoft.Extensions.DependencyInjection;
1716
using System.Linq;
1817
using Tensorflow.Keras;
19-
using System.Numerics;
20-
using Newtonsoft.Json;
21-
using Tensorflow.Keras.Layers;
2218
using BotSharp.Abstraction.Agents;
23-
using BotSharp.Abstraction.Knowledges;
2419

2520
namespace BotSharp.Plugin.RoutingSpeeder.Providers;
2621

@@ -33,11 +28,8 @@ public class IntentClassifier
3328
private bool _isModelReady;
3429
public bool isModelReady => _isModelReady;
3530
private ClassifierSetting _settings;
36-
3731
private string[] _labels;
38-
3932
public string[] Labels => GetLabels();
40-
4133
private int _numLabels
4234
{
4335
get
@@ -67,7 +59,7 @@ private void Build()
6759
}
6860

6961
var vector = _services.GetServices<ITextEmbedding>()
70-
.FirstOrDefault(x => x.GetType().FullName.EndsWith(_knowledgeBaseSettings.TextEmbedding));
62+
.FirstOrDefault(x => x.GetType().FullName.EndsWith(_knowledgeBaseSettings.TextEmbedding));
7163

7264
var layers = new List<ILayer>
7365
{
@@ -89,28 +81,29 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
8981
{
9082
_model.compile(optimizer: keras.optimizers.Adam(trainingParams.LearningRate),
9183
loss: keras.losses.SparseCategoricalCrossentropy(),
92-
metrics: new[] { "accuracy" }
93-
);
84+
metrics: new[] { "accuracy" });
9485

95-
CallbackParams callback_parameters = new CallbackParams
86+
var callback_parameters = new CallbackParams
9687
{
9788
Model = _model,
9889
Epochs = trainingParams.Epochs,
9990
Verbose = 1,
10091
Steps = 10
10192
};
10293

103-
ICallback earlyStop = new EarlyStopping(callback_parameters, "accuracy");
94+
var earlyStop = new EarlyStopping(callback_parameters, "accuracy");
10495

105-
var callbacks = new List<ICallback>() { earlyStop };
96+
var callbacks = new List<ICallback>()
97+
{
98+
earlyStop
99+
};
106100

107101
var weights = LoadWeights(trainingParams.Inference);
108102

109103
_model.fit(x, y,
110104
batch_size: trainingParams.BatchSize,
111105
epochs: trainingParams.Epochs,
112106
callbacks: callbacks,
113-
// validation_split: 0.1f,
114107
shuffle: true);
115108

116109
_model.save_weights(weights);
@@ -120,7 +113,9 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
120113

121114
public string LoadWeights(bool inference = true)
122115
{
123-
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
116+
var agentService = _services.CreateScope()
117+
.ServiceProvider
118+
.GetRequiredService<IAgentService>();
124119

125120
var weightsFile = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, $"intent-classifier.h5");
126121

@@ -129,13 +124,13 @@ public string LoadWeights(bool inference = true)
129124
_model.load_weights(weightsFile);
130125
_isModelReady = true;
131126
Console.WriteLine($"Successfully load the weights!");
132-
133127
}
134128
else
135129
{
136130
var logInfo = inference ? "No available weights." : "Will implement model training process and write trained weights into local";
137131
Console.WriteLine(logInfo);
138132
}
133+
139134
return weightsFile;
140135
}
141136

@@ -152,24 +147,33 @@ public NDArray GetTextEmbedding(string text)
152147

153148
public (NDArray, NDArray) PrepareLoadData()
154149
{
155-
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
156-
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.RAW_DATA_DIR);
157-
string saveLabelDirectory = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.LABEL_FILE_NAME);
150+
var agentService = _services.CreateScope()
151+
.ServiceProvider
152+
.GetRequiredService<IAgentService>();
153+
string rootDirectory = Path.Combine(
154+
agentService.GetDataDir(),
155+
_settings.RAW_DATA_DIR);
156+
string saveLabelDirectory = Path.Combine(
157+
agentService.GetDataDir(),
158+
_settings.MODEL_DIR,
159+
_settings.LABEL_FILE_NAME);
158160

159161
if (!Directory.Exists(rootDirectory))
160162
{
161163
throw new Exception($"No training data found! Please put training data in this path: {rootDirectory}");
162164
}
163165

166+
// Do embedding and store results
164167
var vector = _services.GetRequiredService<ITextEmbedding>();
165-
166168
var vectorList = new List<float[]>();
167-
168169
var labelList = new List<string>();
169170

170171
foreach (var filePath in GetFiles())
171172
{
172-
var texts = File.ReadAllLines(filePath, Encoding.UTF8).Select(x => TextClean(x)).ToList();
173+
var texts = File.ReadAllLines(filePath, Encoding.UTF8)
174+
.Select(x => TextClean(x))
175+
.ToList();
176+
173177
vectorList.AddRange(vector.GetVectors(texts));
174178
string fileName = Path.GetFileNameWithoutExtension(filePath);
175179
labelList.AddRange(Enumerable.Repeat(fileName, texts.Count).ToList());
@@ -185,25 +189,39 @@ public NDArray GetTextEmbedding(string text)
185189
for (int i = 0; i < vectorList.Count; i++)
186190
{
187191
x[i] = vectorList[i];
188-
// y[i] = (float)uniqueLabelList.IndexOf(labelList[i]);
189192
y[i] = (float)Array.IndexOf(uniqueLabelList, labelList[i]);
190193
}
194+
191195
return (x, y);
192196
}
193197

194198
public string[] GetFiles(string prefix = "intent")
195199
{
196-
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
200+
var agentService = _services.CreateScope()
201+
.ServiceProvider
202+
.GetRequiredService<IAgentService>();
197203
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.RAW_DATA_DIR);
198-
return Directory.GetFiles(rootDirectory).Where(x => Path.GetFileNameWithoutExtension(x).StartsWith(prefix)).OrderBy(x => x).ToArray();
204+
205+
return Directory.GetFiles(rootDirectory)
206+
.Where(x => Path.GetFileNameWithoutExtension(x)
207+
.StartsWith(prefix))
208+
.OrderBy(x => x)
209+
.ToArray();
199210
}
200211

201212
public string[] GetLabels()
202213
{
203214
if (_labels == null)
204215
{
205-
var agentService = _services.CreateScope().ServiceProvider.GetRequiredService<IAgentService>();
206-
string rootDirectory = Path.Combine(agentService.GetDataDir(), _settings.MODEL_DIR, _settings.LABEL_FILE_NAME);
216+
var agentService = _services.CreateScope()
217+
.ServiceProvider
218+
.GetRequiredService<IAgentService>();
219+
string rootDirectory = Path.Combine(
220+
agentService.GetDataDir(),
221+
_settings.MODEL_DIR,
222+
_settings.LABEL_FILE_NAME
223+
);
224+
207225
var labelText = File.ReadAllLines(rootDirectory);
208226
_labels = labelText.OrderBy(x => x).ToArray();
209227
}
@@ -217,9 +235,11 @@ public string TextClean(string text)
217235
// Remove digits
218236
// To lowercase
219237
var processedText = Regex.Replace(text, "[AB0-9]", " ");
220-
processedText = string.Join("", processedText.Select(c => char.IsPunctuation(c) ? ' ' : c).ToList());
221-
processedText = processedText.Replace(" ", " ").ToLower();
222-
return processedText;
238+
var replacedTextList = processedText.Select(c => char.IsPunctuation(c) ? ' ' : c).ToList();
239+
240+
return string.Join("", replacedTextList)
241+
.Replace(" ", " ")
242+
.ToLower();
223243
}
224244

225245
public string Predict(NDArray vector, float confidenceScore = 0.9f)
@@ -229,8 +249,8 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
229249
InitClassifer();
230250
}
231251

252+
// Generate and post-process prediction
232253
var prob = _model.predict(vector).numpy();
233-
234254
var probLabel = tf.arg_max(prob, -1).numpy().ToArray<long>();
235255
prob = np.squeeze(prob, axis: 0);
236256

@@ -239,9 +259,9 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
239259
return string.Empty;
240260
}
241261

242-
var prediction = _labels[probLabel[0]];
262+
var labelIndex = probLabel[0];
243263

244-
return prediction;
264+
return _labels[labelIndex];
245265
}
246266
public void InitClassifer(bool inference = true)
247267
{

0 commit comments

Comments
 (0)