Skip to content

Commit 8d51eee

Browse files
authored
Image classification using DNNs and Transfer Learning. (#4057)
* Image classification using DNNs and Transfer Learning. * Undo sample changes. * Disable ONNX conversion test because of protobuf version conflict. * Disable ONNX conversion test because of protobuf version conflict. * Upgrade to latest TF.NEt and enable ONNX test. * stop buffer resue. comment out buggy tests. * enable buggy tests. * enable buggy tests. * pass string directly as tensor instead of converting to utf8. * pass string directly as tensor instead of converting to utf8. * fix session dispose call in DNNTransfomer. * fix session dispose call in DNNTransfomer. * fix crash issue when tesing tf.net. * fix for premature deallocation by TF.NET. * fix for destructor. * fix for destructor. * Upgrade to latest TF.NEt that contains fix for default graph. * TDV-TF mapping. * Add test for transfer learning. * Add Inception V3 model. * Add Inception V3 model and unit-test * Add Inception V3 model and unit-test * Add Inception V3 model and unit-test * Add Inception V3 model and unit-test * Add Inception V3 model and unit-test * Update transfer learning test. * Remove hard coded paths from test. * tsts * Add more samples and minor refactoring. * Take dependency of scisharp tensorflow redist. * PR feedback. * PR feedback. * Cleanup.
1 parent 5d90c11 commit 8d51eee

File tree

30 files changed

+3957
-4938
lines changed

30 files changed

+3957
-4938
lines changed

Microsoft.ML.sln

Lines changed: 657 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
5+
using Microsoft.ML;
6+
using Microsoft.ML.Data;
7+
using Microsoft.ML.Transforms;
8+
9+
namespace Samples.Dynamic
10+
{
11+
public static class InceptionV3TransferLearning
12+
{
13+
/// <summary>
14+
/// Example use of Image classification API in a ML.NET pipeline.
15+
/// </summary>
16+
public static void Example()
17+
{
18+
var mlContext = new MLContext(seed: 1);
19+
20+
var imagesDataFile = Path.GetDirectoryName(
21+
Microsoft.ML.SamplesUtils.DatasetUtils.DownloadImages());
22+
23+
var data = mlContext.Data.LoadFromEnumerable(
24+
ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4));
25+
26+
data = mlContext.Data.ShuffleRows(data, 5);
27+
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
28+
.Append(mlContext.Transforms.LoadImages("ImageObject", null,
29+
"ImagePath"))
30+
.Append(mlContext.Transforms.ResizeImages("Image",
31+
inputColumnName: "ImageObject", imageWidth: 299,
32+
imageHeight: 299))
33+
.Append(mlContext.Transforms.ExtractPixels("Image",
34+
interleavePixelColors: true))
35+
.Append(mlContext.Model.ImageClassification("Image",
36+
"Label", arch: DnnEstimator.Architecture.InceptionV3, epoch: 4,
37+
batchSize: 4));
38+
39+
var trainedModel = pipeline.Fit(data);
40+
var predicted = trainedModel.Transform(data);
41+
var metrics = mlContext.MulticlassClassification.Evaluate(predicted);
42+
43+
Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," +
44+
$"macro-accuracy = {metrics.MacroAccuracy}");
45+
46+
// Create prediction function and test prediction
47+
var predictFunction = mlContext.Model
48+
.CreatePredictionEngine<ImageNetData, ImagePrediction>(trainedModel);
49+
50+
var prediction = predictFunction
51+
.Predict(ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4)
52+
.First());
53+
54+
Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
55+
$"Predicted Label : {prediction.PredictedLabel}");
56+
57+
}
58+
}
59+
60+
public class ImageNetData
61+
{
62+
[LoadColumn(0)]
63+
public string ImagePath;
64+
65+
[LoadColumn(1)]
66+
public string Label;
67+
68+
public static IEnumerable<ImageNetData> LoadImagesFromDirectory(
69+
string folder, int repeat = 1, bool useFolderNameasLabel = false)
70+
{
71+
var files = Directory.GetFiles(folder, "*",
72+
searchOption: SearchOption.AllDirectories);
73+
74+
foreach (var file in files)
75+
{
76+
if (Path.GetExtension(file) != ".jpg")
77+
continue;
78+
79+
var label = Path.GetFileName(file);
80+
if (useFolderNameasLabel)
81+
label = Directory.GetParent(file).Name;
82+
else
83+
{
84+
for (int index = 0; index < label.Length; index++)
85+
{
86+
if (!char.IsLetter(label[index]))
87+
{
88+
label = label.Substring(0, index);
89+
break;
90+
}
91+
}
92+
}
93+
94+
for (int index = 0; index < repeat; index++)
95+
yield return new ImageNetData() {
96+
ImagePath = file,Label = label };
97+
}
98+
}
99+
}
100+
101+
public class ImagePrediction
102+
{
103+
[ColumnName("Score")]
104+
public float[] Score;
105+
106+
[ColumnName("PredictedLabel")]
107+
public Int64 PredictedLabel;
108+
}
109+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.IO;
5+
using System.Linq;
6+
using Microsoft.ML;
7+
using Microsoft.ML.Data;
8+
using Microsoft.ML.Transforms;
9+
10+
namespace Samples.Dynamic
11+
{
12+
public static class ResnetV2101TransferLearning
13+
{
14+
/// <summary>
15+
/// Example use of Image classification API in a ML.NET pipeline.
16+
/// </summary>
17+
public static void Example()
18+
{
19+
var mlContext = new MLContext(seed: 1);
20+
21+
var imagesDataFile = Path.GetDirectoryName(
22+
Microsoft.ML.SamplesUtils.DatasetUtils.DownloadImages());
23+
24+
var data = mlContext.Data.LoadFromEnumerable(
25+
ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4));
26+
27+
data = mlContext.Data.ShuffleRows(data, 5);
28+
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
29+
.Append(mlContext.Transforms.LoadImages("ImageObject", null,
30+
"ImagePath"))
31+
.Append(mlContext.Transforms.ResizeImages("Image",
32+
inputColumnName: "ImageObject", imageWidth: 299,
33+
imageHeight: 299))
34+
.Append(mlContext.Transforms.ExtractPixels("Image",
35+
interleavePixelColors: true))
36+
.Append(mlContext.Model.ImageClassification("Image",
37+
"Label", arch: DnnEstimator.Architecture.ResnetV2101, epoch: 4,
38+
batchSize: 4));
39+
40+
var trainedModel = pipeline.Fit(data);
41+
var predicted = trainedModel.Transform(data);
42+
var metrics = mlContext.MulticlassClassification.Evaluate(predicted);
43+
44+
Console.WriteLine($"Micro-accuracy: {metrics.MicroAccuracy}," +
45+
$"macro-accuracy = {metrics.MacroAccuracy}");
46+
47+
mlContext.Model.Save(trainedModel, data.Schema, "model.zip");
48+
49+
ITransformer loadedModel;
50+
using (var file = File.OpenRead("model.zip"))
51+
loadedModel = mlContext.Model.Load(file, out DataViewSchema schema);
52+
53+
// Create prediction function and test prediction
54+
var predictFunction = mlContext.Model
55+
.CreatePredictionEngine<ImageNetData, ImagePrediction>(loadedModel);
56+
57+
var prediction = predictFunction
58+
.Predict(ImageNetData.LoadImagesFromDirectory(imagesDataFile, 4)
59+
.First());
60+
61+
Console.WriteLine($"Scores : [{string.Join(",", prediction.Score)}], " +
62+
$"Predicted Label : {prediction.PredictedLabel}");
63+
}
64+
65+
private const int imageHeight = 224;
66+
private const int imageWidth = 224;
67+
private const int numChannels = 3;
68+
private const int inputSize = imageHeight * imageWidth * numChannels;
69+
70+
public class ImageNetData
71+
{
72+
[LoadColumn(0)]
73+
public string ImagePath;
74+
75+
[LoadColumn(1)]
76+
public string Label;
77+
78+
public static IEnumerable<ImageNetData> LoadImagesFromDirectory(
79+
string folder, int repeat = 1, bool useFolderNameasLabel = false)
80+
{
81+
var files = Directory.GetFiles(folder, "*",
82+
searchOption: SearchOption.AllDirectories);
83+
84+
foreach (var file in files)
85+
{
86+
if (Path.GetExtension(file) != ".jpg")
87+
continue;
88+
89+
var label = Path.GetFileName(file);
90+
if (useFolderNameasLabel)
91+
label = Directory.GetParent(file).Name;
92+
else
93+
{
94+
for (int index = 0; index < label.Length; index++)
95+
{
96+
if (!char.IsLetter(label[index]))
97+
{
98+
label = label.Substring(0, index);
99+
break;
100+
}
101+
}
102+
}
103+
104+
for (int index = 0; index < repeat; index++)
105+
yield return new ImageNetData()
106+
{
107+
ImagePath = file,
108+
Label = label
109+
};
110+
}
111+
}
112+
}
113+
114+
public class ImagePrediction
115+
{
116+
[ColumnName("Score")]
117+
public float[] Score;
118+
119+
[ColumnName("PredictedLabel")]
120+
public Int64 PredictedLabel;
121+
}
122+
}
123+
}

docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
</PropertyGroup>
1111

1212
<ItemGroup>
13+
<ProjectReference Include="..\..\..\src\Microsoft.ML.Dnn\Microsoft.ML.Dnn.csproj" />
1314
<ProjectReference Include="..\..\..\src\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj" />
1415
<ProjectReference Include="..\..\..\src\Microsoft.ML.Mkl.Components\Microsoft.ML.Mkl.Components.csproj" />
1516
<ProjectReference Include="..\..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard2.0</TargetFramework>
5+
<PackageDescription>Microsoft.ML.Dnn contains APIs to do high level DNN training such as image classification.</PackageDescription>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
10+
</ItemGroup>
11+
12+
</Project>
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<Project DefaultTargets="Pack">
2+
3+
<Import Project="Microsoft.ML.Dnn.nupkgproj" />
4+
5+
</Project>

src/Microsoft.ML.Core/Properties/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardTrainers" + PublicKey.Value)]
3838
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)]
3939
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow" + PublicKey.Value)]
40+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Dnn" + PublicKey.Value)]
4041
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]
4142
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)]
4243

src/Microsoft.ML.Data/Properties/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StandardTrainers" + PublicKey.Value)]
3737
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Sweeper" + PublicKey.Value)]
3838
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow" + PublicKey.Value)]
39+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Dnn" + PublicKey.Value)]
3940
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]
4041
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)]
4142
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.AlexNet" + PublicKey.Value)]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Runtime.CompilerServices;
6+
using Microsoft.ML;
7+
8+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
9+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tensorflow" + PublicKey.Value)]

0 commit comments

Comments
 (0)