-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[GenAI] Use BitsAndBytes for 4bit quantization. #7406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
a67fc3a
389e118
f4f15ea
e52cfd6
cd7a6c5
080cf48
849a6c8
62ffa45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,10 +16,11 @@ | |
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="TorchSharp-cuda-windows" Version="0.102.5" Condition="$([MSBuild]::IsOSPlatform('Windows'))" /> | ||
<PackageReference Include="TorchSharp-cuda-windows" Version="0.105.0" Condition="$([MSBuild]::IsOSPlatform('Windows'))" /> | ||
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" /> | ||
<PackageReference Include="AutoGen.SourceGenerator" Version="$(AutoGenVersion)" /> | ||
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" /> | ||
<PackageReference Include="LittleLittleCloud.TorchSharp.BitsAndBytes" Version="0.0.4" /> | ||
|
||
</ItemGroup> | ||
|
||
</Project> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
// See https://aka.ms/new-console-template for more information | ||
using Microsoft.ML.GenAI.Samples.Llama; | ||
using Microsoft.ML.GenAI.Samples.MEAI; | ||
|
||
await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); | ||
await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct"); | ||
//await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); | ||
//await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,13 +90,18 @@ public static void ToInt8QuantizeModule<T>( | |
/// </summary> | ||
/// <typeparam name="T"></typeparam> | ||
/// <param name="model"></param> | ||
public static void ToInt4QuantizeModule<T>( | ||
this T model) | ||
/// <param name="quantizedDType">Quantized data type, can be "fp4" or "nf4".</param> | ||
|
||
/// <param name="blockSize">Block size for quantization, can be [64, 128, 256, 512, 1024]. The larger the size, the faster the speed and the lower the precision.</param> | ||
public static void ToQuantize4BitModule<T>( | ||
this T model, | ||
string quantizedDType = "fp4", | ||
int blockSize = 64) | ||
LittleLittleCloud marked this conversation as resolved.
Show resolved
Hide resolved
|
||
where T : nn.Module | ||
{ | ||
var config = new Quantize4BitConfig(quantizedDType, blockSize); | ||
if (model is IQuantizeModule quantized) | ||
{ | ||
quantized.Int4(); | ||
quantized.Quantize4Bit(config); | ||
|
||
return; | ||
} | ||
|
@@ -105,11 +110,11 @@ public static void ToInt4QuantizeModule<T>( | |
{ | ||
if (value is IQuantizeModule quantizeModule) | ||
{ | ||
quantizeModule.Int4(); | ||
quantizeModule.Quantize4Bit(config); | ||
} | ||
else | ||
{ | ||
value.ToInt4QuantizeModule(); | ||
value.ToQuantize4BitModule(quantizedDType, blockSize); | ||
} | ||
} | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.