diff --git a/llm/android/LlamaDemo/.gitignore b/llm/android/LlamaDemo/.gitignore new file mode 100644 index 00000000..41853c04 --- /dev/null +++ b/llm/android/LlamaDemo/.gitignore @@ -0,0 +1,12 @@ +*.iml +.gradle +/local.properties +.idea +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties +*.so +*.aar diff --git a/llm/android/LlamaDemo/README.md b/llm/android/LlamaDemo/README.md new file mode 100644 index 00000000..cab10dcf --- /dev/null +++ b/llm/android/LlamaDemo/README.md @@ -0,0 +1,170 @@ +# ExecuTorch Llama Android Demo App + +This app serves as a valuable resource to inspire your creativity and provide foundational code that you can customize and adapt for your particular use case. + +Please dive in and start exploring our demo app today! We look forward to any feedback and are excited to see your innovative ideas. + + +## Key Concepts +From this demo app, you will learn many key concepts such as: +* How to prepare Llama models, build the ExecuTorch library, and model inferencing across delegates +* Expose the ExecuTorch library via JNI layer +* Familiarity with current ExecuTorch app-facing capabilities + +The goal is for you to see the type of support ExecuTorch provides and feel comfortable with leveraging it for your use cases. + +## Supporting Models +As a whole, the models that this app supports are (varies by delegate): +* Llama 3.2 Quantized 1B/3B +* Llama 3.2 1B/3B in BF16 +* Llama Guard 3 1B +* Llama 3.1 8B +* Llama 3 8B +* Llama 2 7B +* LLaVA-1.5 vision model (only XNNPACK) +* Qwen 3 0.6B, 1.7B, and 4B + + +## Building the APK +First it’s important to note that currently ExecuTorch provides support across 3 delegates. Once you identify the delegate of your choice, select the README link to get a complete end-to-end instructions for environment set-up to exporting the models to build ExecuTorch libraries and apps to run on device: + +| Delegate | Resource | +| ------------- | ------------- | +| XNNPACK (CPU-based library) | [link](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md) | +| QNN (Qualcomm AI Accelerators) | [link](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md) | +| MediaTek (MediaTek AI Accelerators) | [link](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/LlamaDemo/docs/delegates/mediatek_README.md) | + + +## How to Use the App + +This section will provide the main steps to use the app, along with a code snippet of the ExecuTorch API. + +For loading the app, development, and running on device we recommend Android Studio: +1. Open Android Studio and select "Open an existing Android Studio project" to open examples/demo-apps/android/LlamaDemo. +2. Run the app (^R). This builds and launches the app on the phone. + +### Opening the App + +Below are the UI features for the app. + +Select the settings widget to get started with picking a model, its parameters and any prompts. +

+ +

+ + + +### Select Models and Parameters + +Once you've selected the model, tokenizer, and model type you are ready to click on "Load Model" to have the app load the model and go back to the main Chat activity. +

+ +

+ + + +Optional Parameters: +* Temperature: Defaulted to 0, you can adjust the temperature for the model as well. The model will reload upon any adjustments. +* System Prompt: Without any formatting, you can enter in a system prompt. For example, "you are a travel assistant" or "give me a response in a few sentences". +* User Prompt: More for the advanced user, if you would like to manually input a prompt then you can do so by modifying the `{{user prompt}}`. You can also modify the special tokens as well. Once changed then go back to the main Chat activity to send. + +#### ExecuTorch App API + +```java +// Upon returning to the Main Chat Activity +mModule = new LlmModule( + ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()), + modelPath, + tokenizerPath, + temperature); +int loadResult = mModule.load(); +``` + +* `modelCategory`: Indicate whether it’s a text-only or vision model +* `modePath`: path to the .pte file +* `tokenizerPath`: path to the tokenizer file +* `temperature`: model parameter to adjust the randomness of the model’s output + + +### User Prompt +Once model is successfully loaded then enter any prompt and click the send (i.e. generate) button to send it to the model. +

+ +

+ +You can provide it more follow-up questions as well. +

+ +

+ +#### ExecuTorch App API + +```java +mModule.generate(prompt,sequence_length, MainActivity.this); +``` +* `prompt`: User formatted prompt +* `sequence_length`: Number of tokens to generate in response to a prompt +* `MainActivity.this`: Indicate that the callback functions (OnResult(), OnStats()) are present in this class. + +[*LLaVA-1.5: Only for XNNPACK delegate*] + +For LLaVA-1.5 implementation, select the exported LLaVA .pte and tokenizer file in the Settings menu and load the model. After this you can send an image from your gallery or take a live picture along with a text prompt to the model. + +

+ +

+ + +### Output Generated +To show completion of the follow-up question, here is the complete detailed response from the model. +

+ +

+ +#### ExecuTorch App API + +Ensure you have the following functions in your callback class that you provided in the `mModule.generate()`. For this example, it is `MainActivity.this`. +```java + @Override + public void onResult(String result) { + //...result contains token from response + //.. onResult will continue to be invoked until response is complete + } + + @Override + public void onStats(String stats) { + //... will be a json. See extension/llm/stats.h for the field definitions + } + +``` + +## Instrumentation Test +You can run the instrumentation test for sanity check. The test loads a model pte file and tokenizer.bin file +under `/data/local/tmp/llama`. + +### Model preparation +Go to ExecuTorch root, +```sh +curl -C - -Ls "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt" --output stories110M.pt +curl -C - -Ls "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model" --output tokenizer.model +# Create params.json file +touch params.json +echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json +python -m extension.llm.export.export_llm base.checkpoint=stories110M.pt base.params=params.json model.dtype_override="fp16" export.output_name=stories110m_h.pte model.use_kv_cache=True +python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin +``` +### Push model +```sh +adb mkdir -p /data/local/tmp/llama +adb push stories110m_h.pte /data/local/tmp/llama +adb push tokenizer.bin /data/local/tmp/llama +``` + +### Run test +Go to `examples/demo-apps/android/LlamaDemo`, +```sh +./gradlew connectedAndroidTest +``` + +## Reporting Issues +If you encountered any bugs or issues following this tutorial please file a bug/issue here on [Github](https://github.com/pytorch/executorch/issues/new), or join our discord [here](https://lnkd.in/gWCM4ViK). diff --git a/llm/android/LlamaDemo/SDK-quick-setup-guide.md b/llm/android/LlamaDemo/SDK-quick-setup-guide.md new file mode 100644 index 00000000..9ae79e96 --- /dev/null +++ b/llm/android/LlamaDemo/SDK-quick-setup-guide.md @@ -0,0 +1,94 @@ +# Guide to set up Java/SDK/NDK for Android + +Follow this doc if you haven't set up Java/SDK/NDK for Android development +already. +This doc provides a CLI tutorial to set them up. Otherwise, you can do the same +thing with Android Studio GUI. + +## Set up Java 17 +1. Download the archive from Oracle website. +Make sure you have read and agree with the terms and conditions from the website before downloading. +```bash +export DEV_HOME= +cd $DEV_HOME +``` +Linux: +```bash +curl https://download.oracle.com/java/17/archive/jdk-17.0.10_linux-x64_bin.tar.gz -o jdk-17.0.10.tar.gz +``` +macOS: +```bash +curl https://download.oracle.com/java/17/archive/jdk-17.0.10_macos-aarch64_bin.tar.gz -o jdk-17.0.10.tar.gz +``` +2. Unzip the archive. The directory named `jdk-17.0.10` is the Java root directory. +```bash +tar xf jdk-17.0.10.tar.gz +``` +3. Set `JAVA_HOME` and update `PATH`. + +Linux: +```bash +export JAVA_HOME="$DEV_HOME"/jdk-17.0.10 +export PATH="$JAVA_HOME/bin:$PATH" +``` +macOS: +```bash +export JAVA_HOME="$DEV_HOME"/jdk-17.0.10.jdk/Contents/Home +export PATH="$JAVA_HOME/bin:$PATH" +``` + +Note: Oracle has tutorials for installing Java on +[Linux](https://docs.oracle.com/en/java/javase/17/install/installation-jdk-linux-platforms.html#GUID-4A6BD592-1840-4BB4-A758-4CD49E9EE88B) +and [macOS](https://docs.oracle.com/en/java/javase/17/install/installation-jdk-macos.html#GUID-E8A251B6-D9A9-4276-ABC8-CC0DAD62EA33). +Some Linux distributions has JDK package in package manager. For example, Debian users can install +openjdk-17-jdk package. + +## Set up Android SDK/NDK +Android has a command line tool [sdkmanager](https://developer.android.com/tools/sdkmanager) which +helps users managing SDK and other tools related to Android development. + +1. Go to https://developer.android.com/studio and download the archive from "Command line tools +only" section. Make sure you have read and agree with the terms and conditions from the website. + +Linux: +```bash +curl https://dl.google.com/android/repository/commandlinetools-linux-11076708_latest.zip -o commandlinetools.zip +``` +macOS: +```bash +curl https://dl.google.com/android/repository/commandlinetools-mac-11076708_latest.zip -o commandlinetools.zip +``` +2. Unzip. +```bash +unzip commandlinetools.zip +``` +3. Specify a root for Android SDK. For example, we can put it under `$DEV_HOME/sdk`. + +``` +mkdir -p $DEV_HOME/sdk +export ANDROID_HOME="$(realpath $DEV_HOME/sdk)" +# Install SDK 34 +./cmdline-tools/bin/sdkmanager --sdk_root="${ANDROID_HOME}" --install "platforms;android-34" +# Install NDK +./cmdline-tools/bin/sdkmanager --sdk_root="${ANDROID_HOME}" --install "ndk;26.3.11579264" +# The NDK root is then under `ndk/`. +export ANDROID_NDK="$ANDROID_HOME/ndk/26.3.11579264" +``` + +### (Optional) Android Studio Setup +If you want to use Android Studio and never set up Java/SDK/NDK before, or if +you use the newly installed ones, follow these steps to set Android Studio to use +them. + +Copy these output paths to be used by Android Studio +```bash +echo $ANDROID_HOME +echo $ANDROID_NDK +echo $JAVA_HOME +``` + +Open a project in Android Studio. In Project Structure (File -> Project +Structure, or `⌘;`) -> SDK Location, +* Set Android SDK Location to the path of $ANDROID_HOME +* Set Android NDK Location to the path of $ANDROID_NDK +* Set JDK location (Click Gradle Settings link) -> Gradle JDK -> Add JDK... to the path of $JAVA_HOME diff --git a/llm/android/LlamaDemo/app/.gitignore b/llm/android/LlamaDemo/app/.gitignore new file mode 100644 index 00000000..796b96d1 --- /dev/null +++ b/llm/android/LlamaDemo/app/.gitignore @@ -0,0 +1 @@ +/build diff --git a/llm/android/LlamaDemo/app/build.gradle.kts b/llm/android/LlamaDemo/app/build.gradle.kts new file mode 100644 index 00000000..53227740 --- /dev/null +++ b/llm/android/LlamaDemo/app/build.gradle.kts @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +plugins { + id("com.android.application") + id("org.jetbrains.kotlin.android") +} + +val qnnVersion: String? = project.findProperty("qnnVersion") as? String + +android { + namespace = "com.example.executorchllamademo" + compileSdk = 34 + + defaultConfig { + applicationId = "com.example.executorchllamademo" + minSdk = 28 + targetSdk = 33 + versionCode = 1 + versionName = "1.0" + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + vectorDrawables { useSupportLibrary = true } + externalNativeBuild { cmake { cppFlags += "" } } + } + + buildTypes { + release { + isMinifyEnabled = false + proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro") + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 + } + kotlinOptions { jvmTarget = "1.8" } + buildFeatures { compose = true } + composeOptions { kotlinCompilerExtensionVersion = "1.4.3" } + packaging { resources { excludes += "/META-INF/{AL2.0,LGPL2.1}" } } +} + +dependencies { + implementation("androidx.core:core-ktx:1.9.0") + implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.1") + implementation("androidx.activity:activity-compose:1.7.0") + implementation(platform("androidx.compose:compose-bom:2023.03.00")) + implementation("androidx.compose.ui:ui") + implementation("androidx.compose.ui:ui-graphics") + implementation("androidx.compose.ui:ui-tooling-preview") + implementation("androidx.compose.material3:material3") + implementation("androidx.appcompat:appcompat:1.6.1") + implementation("androidx.camera:camera-core:1.3.0-rc02") + implementation("androidx.constraintlayout:constraintlayout:2.2.0-alpha12") + implementation("com.facebook.fbjni:fbjni:0.5.1") + implementation("com.google.code.gson:gson:2.8.6") + implementation("org.pytorch:executorch-android:1.0.0") + implementation("com.google.android.material:material:1.12.0") + implementation("androidx.activity:activity:1.9.0") + implementation("org.json:json:20250107") + testImplementation("junit:junit:4.13.2") + androidTestImplementation("androidx.test.ext:junit:1.1.5") + androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") + androidTestImplementation(platform("androidx.compose:compose-bom:2023.03.00")) + androidTestImplementation("androidx.compose.ui:ui-test-junit4") + debugImplementation("androidx.compose.ui:ui-tooling") + debugImplementation("androidx.compose.ui:ui-test-manifest") +} diff --git a/llm/android/LlamaDemo/app/proguard-rules.pro b/llm/android/LlamaDemo/app/proguard-rules.pro new file mode 100644 index 00000000..481bb434 --- /dev/null +++ b/llm/android/LlamaDemo/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/llm/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java b/llm/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java new file mode 100644 index 00000000..32ec24a0 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import android.os.Bundle; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.json.JSONException; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.pytorch.executorch.extension.llm.LlmCallback; +import org.pytorch.executorch.extension.llm.LlmModule; + +@RunWith(AndroidJUnit4.class) +public class PerfTest implements LlmCallback { + + private static final String RESOURCE_PATH = "/data/local/tmp/llama/"; + private static final String TOKENIZER_BIN = "tokenizer.bin"; + + private final List results = new ArrayList<>(); + private final List tokensPerSecond = new ArrayList<>(); + + @Test + public void testTokensPerSecond() { + String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN; + // Find out the model name + File directory = new File(RESOURCE_PATH); + Arrays.stream(directory.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .forEach( + model -> { + LlmModule mModule = new LlmModule(model.getPath(), tokenizerPath, 0.8f); + // Print the model name because there might be more than one of them + report("ModelName", model.getName()); + + int loadResult = mModule.load(); + // Check that the model can be load successfully + assertEquals(0, loadResult); + + // Run a testing prompt + mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); + assertFalse(tokensPerSecond.isEmpty()); + + final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); + report("TPS", tps); + }); + } + + @Override + public void onResult(String result) { + results.add(result); + } + + @Override + public void onStats(String result) { + try { + JSONObject jsonObject = new JSONObject(result); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + tokensPerSecond.add(tps); + } catch (JSONException e) { + } + } + + private void report(final String metric, final Float value) { + Bundle bundle = new Bundle(); + bundle.putFloat(metric, value); + InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); + } + + private void report(final String key, final String value) { + Bundle bundle = new Bundle(); + bundle.putString(key, value); + InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); + } +} diff --git a/llm/android/LlamaDemo/app/src/main/AndroidManifest.xml b/llm/android/LlamaDemo/app/src/main/AndroidManifest.xml new file mode 100644 index 00000000..7096a7d4 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/AndroidManifest.xml @@ -0,0 +1,85 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java new file mode 100644 index 00000000..36d07419 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/AppLog.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; + +public class AppLog { + private final Long timestamp; + private final String message; + + public AppLog(String message) { + this.timestamp = getCurrentTimeStamp(); + this.message = message; + } + + public Long getTimestamp() { + return timestamp; + } + + public String getMessage() { + return message; + } + + public String getFormattedLog() { + return "[" + getFormattedTimeStamp() + "] " + message; + } + + private Long getCurrentTimeStamp() { + return System.currentTimeMillis(); + } + + private String getFormattedTimeStamp() { + return formatDate(timestamp); + } + + private String formatDate(long milliseconds) { + SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault()); + Date date = new Date(milliseconds); + return formatter.format(date); + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/BackendType.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/BackendType.java new file mode 100644 index 00000000..7c847997 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/BackendType.java @@ -0,0 +1,7 @@ +package com.example.executorchllamademo; + +public enum BackendType { + XNNPACK, + QUALCOMM, + MEDIATEK +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java new file mode 100644 index 00000000..99a94c00 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/DemoSharedPreferences.java @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.content.Context; +import android.content.SharedPreferences; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import java.lang.reflect.Type; +import java.util.ArrayList; + +public class DemoSharedPreferences { + Context context; + SharedPreferences sharedPreferences; + + public DemoSharedPreferences(Context context) { + this.context = context; + this.sharedPreferences = getSharedPrefs(); + } + + private SharedPreferences getSharedPrefs() { + return context.getSharedPreferences( + context.getString(R.string.demo_pref_file_key), Context.MODE_PRIVATE); + } + + public String getSavedMessages() { + return sharedPreferences.getString(context.getString(R.string.saved_messages_json_key), ""); + } + + public void addMessages(MessageAdapter messageAdapter) { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String msgJSON = gson.toJson(messageAdapter.getSavedMessages()); + editor.putString(context.getString(R.string.saved_messages_json_key), msgJSON); + editor.apply(); + } + + public void removeExistingMessages() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.remove(context.getString(R.string.saved_messages_json_key)); + editor.apply(); + } + + public void addSettings(SettingsFields settingsFields) { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String settingsJSON = gson.toJson(settingsFields); + editor.putString(context.getString(R.string.settings_json_key), settingsJSON); + editor.apply(); + } + + public String getSettings() { + return sharedPreferences.getString(context.getString(R.string.settings_json_key), ""); + } + + public void saveLogs() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + Gson gson = new Gson(); + String msgJSON = gson.toJson(ETLogging.getInstance().getLogs()); + editor.putString(context.getString(R.string.logs_json_key), msgJSON); + editor.apply(); + } + + public void removeExistingLogs() { + SharedPreferences.Editor editor = sharedPreferences.edit(); + editor.remove(context.getString(R.string.logs_json_key)); + editor.apply(); + } + + public ArrayList getSavedLogs() { + String logsJSONString = + sharedPreferences.getString(context.getString(R.string.logs_json_key), null); + if (logsJSONString == null || logsJSONString.isEmpty()) { + return new ArrayList<>(); + } + Gson gson = new Gson(); + Type type = new TypeToken>() {}.getType(); + ArrayList appLogs = gson.fromJson(logsJSONString, type); + if (appLogs == null) { + return new ArrayList<>(); + } + return appLogs; + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java new file mode 100644 index 00000000..e68c8472 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETImage.java @@ -0,0 +1,126 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.content.ContentResolver; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Color; +import android.net.Uri; +import androidx.annotation.Nullable; +import java.io.FileNotFoundException; +import java.io.InputStream; + +public class ETImage { + private int width; + private int height; + private final byte[] bytes; + private final Uri uri; + private final ContentResolver contentResolver; + + ETImage(ContentResolver contentResolver, Uri uri) { + this.contentResolver = contentResolver; + this.uri = uri; + bytes = getBytesFromImageURI(uri); + } + + public int getWidth() { + return width; + } + + public int getHeight() { + return height; + } + + public Uri getUri() { + return uri; + } + + public byte[] getBytes() { + return bytes; + } + + public int[] getInts() { + // We need to convert the byte array to an int array because + // the runner expects an int array as input. + int[] intArray = new int[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + intArray[i] = (bytes[i++] & 0xFF); + } + return intArray; + } + + private byte[] getBytesFromImageURI(Uri uri) { + try { + int RESIZED_IMAGE_WIDTH = 336; + Bitmap bitmap = resizeImage(uri, RESIZED_IMAGE_WIDTH); + + if (bitmap == null) { + ETLogging.getInstance().log("Unable to get bytes from Image URI. Bitmap is null"); + return new byte[0]; + } + + width = bitmap.getWidth(); + height = bitmap.getHeight(); + + byte[] rgbValues = new byte[width * height * 3]; + + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + // Get the color of the current pixel + int color = bitmap.getPixel(x, y); + + // Extract the RGB values from the color + int red = Color.red(color); + int green = Color.green(color); + int blue = Color.blue(color); + + // Store the RGB values in the byte array + rgbValues[y * width + x] = (byte) red; + rgbValues[(y * width + x) + height * width] = (byte) green; + rgbValues[(y * width + x) + 2 * height * width] = (byte) blue; + } + } + return rgbValues; + } catch (FileNotFoundException e) { + throw new RuntimeException(e); + } + } + + @Nullable + private Bitmap resizeImage(Uri uri, int maxLength) throws FileNotFoundException { + InputStream inputStream = contentResolver.openInputStream(uri); + if (inputStream == null) { + ETLogging.getInstance().log("Unable to resize image, input streams is null"); + return null; + } + Bitmap bitmap = BitmapFactory.decodeStream(inputStream); + if (bitmap == null) { + ETLogging.getInstance().log("Unable to resize image, bitmap during decode stream is null"); + return null; + } + + float aspectRatio; + int finalWidth, finalHeight; + + if (bitmap.getWidth() > bitmap.getHeight()) { + // width > height --> width = maxLength, height scale with aspect ratio + aspectRatio = bitmap.getWidth() / (float) bitmap.getHeight(); + finalWidth = maxLength; + finalHeight = Math.round(maxLength / aspectRatio); + } else { + // height >= width --> height = maxLength, width scale with aspect ratio + aspectRatio = bitmap.getHeight() / (float) bitmap.getWidth(); + finalHeight = maxLength; + finalWidth = Math.round(maxLength / aspectRatio); + } + + return Bitmap.createScaledBitmap(bitmap, finalWidth, finalHeight, false); + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java new file mode 100644 index 00000000..e5953489 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ETLogging.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.Application; +import android.util.Log; +import java.util.ArrayList; + +public class ETLogging extends Application { + private static ETLogging singleton; + + private ArrayList logs; + private DemoSharedPreferences mDemoSharedPreferences; + + @Override + public void onCreate() { + super.onCreate(); + singleton = this; + mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); + logs = mDemoSharedPreferences.getSavedLogs(); + if (logs == null) { // We don't have existing sharedPreference stored + logs = new ArrayList<>(); + } + } + + public static ETLogging getInstance() { + return singleton; + } + + public void log(String message) { + AppLog appLog = new AppLog(message); + logs.add(appLog); + Log.d("ETLogging", appLog.getMessage()); + } + + public ArrayList getLogs() { + return logs; + } + + public void clearLogs() { + logs.clear(); + mDemoSharedPreferences.removeExistingLogs(); + } + + public void saveLogs() { + mDemoSharedPreferences.saveLogs(); + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java new file mode 100644 index 00000000..8c2d6025 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java @@ -0,0 +1,223 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.Activity; +import android.app.ActivityManager; +import android.content.Intent; +import android.os.Build; +import android.os.Bundle; +import android.util.Log; +import android.widget.TextView; +import androidx.annotation.NonNull; +import com.google.gson.Gson; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback { + ModelRunner mModelRunner; + + String mPrompt; + TextView mTextView; + StatsDump mStatsDump; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_benchmarking); + mTextView = findViewById(R.id.log_view); + + Intent intent = getIntent(); + + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + + float temperature = intent.getFloatExtra("temperature", 0.8f); + mPrompt = intent.getStringExtra("prompt"); + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + + mStatsDump = new StatsDump(); + mStatsDump.modelName = model.getName().replace(".pte", ""); + mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); + mStatsDump.loadStart = System.nanoTime(); + } + + @Override + public void onModelLoaded(int status) { + mStatsDump.loadEnd = System.nanoTime(); + mStatsDump.loadStatus = status; + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; + } + mStatsDump.generateStart = System.nanoTime(); + mModelRunner.generate(mPrompt); + } + + @Override + public void onTokenGenerated(String token) { + runOnUiThread( + () -> { + mTextView.append(token); + }); + } + + @Override + public void onStats(String stats) { + mStatsDump.tokens = stats; + } + + @Override + public void onGenerationStopped() { + mStatsDump.generateEnd = System.nanoTime(); + runOnUiThread( + () -> { + mTextView.append(mStatsDump.toString()); + }); + + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(mStatsDump.modelName); + final List results = new ArrayList<>(); + // The list of metrics we have atm includes: + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, + "model_load_time(ms)", + (mStatsDump.loadEnd - mStatsDump.loadStart) * 1e-6, + 0.0f)); + // LLM generate time + results.add( + new BenchmarkMetric( + benchmarkModel, + "generate_time(ms)", + (mStatsDump.generateEnd - mStatsDump.generateStart) * 1e-6, + 0.0f)); + // Token per second + results.add( + new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f)); + + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(results)); + } catch (IOException e) { + e.printStackTrace(); + } + } + + private double extractTPS(final String tokens) { + final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens); + if (m.find()) { + return Double.parseDouble(m.group()); + } else { + return 0.0f; + } + } +} + +class BenchmarkMetric { + public static class BenchmarkModel { + // The model name, i.e. stories110M + String name; + String backend; + String quantization; + + public BenchmarkModel(final String name, final String backend, final String quantization) { + this.name = name; + this.backend = backend; + this.quantization = quantization; + } + } + + BenchmarkModel benchmarkModel; + + // The metric name, i.e. TPS + String metric; + + // The actual value and the option target value + double actualValue; + double targetValue; + + public static class DeviceInfo { + // Let's see which information we want to include here + final String device = Build.BRAND; + // The phone model and Android release version + final String arch = Build.MODEL; + final String os = "Android " + Build.VERSION.RELEASE; + final long totalMem = new ActivityManager.MemoryInfo().totalMem; + final long availMem = new ActivityManager.MemoryInfo().availMem; + } + + DeviceInfo deviceInfo = new DeviceInfo(); + + public BenchmarkMetric( + final BenchmarkModel benchmarkModel, + final String metric, + final double actualValue, + final double targetValue) { + this.benchmarkModel = benchmarkModel; + this.metric = metric; + this.actualValue = actualValue; + this.targetValue = targetValue; + } + + // TODO (huydhn): Figure out a way to extract the backend and quantization information from + // the .pte model itself instead of parsing its name + public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { + final Matcher m = + Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); + if (m.matches()) { + return new BenchmarkMetric.BenchmarkModel( + m.group("name"), m.group("backend"), m.group("quantization")); + } else { + return new BenchmarkMetric.BenchmarkModel(model, "", ""); + } + } +} + +class StatsDump { + int loadStatus; + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + String tokens; + String modelName; + + @NonNull + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tokens; + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java new file mode 100644 index 00000000..7777b275 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsActivity.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.AlertDialog; +import android.content.DialogInterface; +import android.os.Build; +import android.os.Bundle; +import android.widget.ImageButton; +import android.widget.ListView; +import androidx.appcompat.app.AppCompatActivity; +import androidx.core.content.ContextCompat; +import androidx.core.graphics.Insets; +import androidx.core.view.ViewCompat; +import androidx.core.view.WindowInsetsCompat; + +public class LogsActivity extends AppCompatActivity { + + private LogsAdapter mLogsAdapter; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_logs); + if (Build.VERSION.SDK_INT >= 21) { + getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); + getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); + } + ViewCompat.setOnApplyWindowInsetsListener( + requireViewById(R.id.main), + (v, insets) -> { + Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); + v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); + return insets; + }); + + setupLogs(); + setupClearLogsButton(); + } + + @Override + public void onResume() { + super.onResume(); + mLogsAdapter.clear(); + mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); + mLogsAdapter.notifyDataSetChanged(); + } + + private void setupLogs() { + ListView mLogsListView = requireViewById(R.id.logsListView); + mLogsAdapter = new LogsAdapter(this, R.layout.logs_message); + + mLogsListView.setAdapter(mLogsAdapter); + mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); + mLogsAdapter.notifyDataSetChanged(); + } + + private void setupClearLogsButton() { + ImageButton clearLogsButton = requireViewById(R.id.clearLogsButton); + clearLogsButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Delete Logs History") + .setMessage("Do you really want to delete logs history?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + ETLogging.getInstance().clearLogs(); + mLogsAdapter.clear(); + mLogsAdapter.notifyDataSetChanged(); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + @Override + protected void onDestroy() { + super.onDestroy(); + ETLogging.getInstance().saveLogs(); + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java new file mode 100644 index 00000000..76c6a1aa --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LogsAdapter.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.ArrayAdapter; +import android.widget.TextView; +import androidx.annotation.NonNull; +import java.util.Objects; + +public class LogsAdapter extends ArrayAdapter { + public LogsAdapter(android.content.Context context, int resource) { + super(context, resource); + } + + static class ViewHolder { + private TextView logTextView; + } + + @NonNull + @Override + public View getView(int position, View convertView, @NonNull ViewGroup parent) { + ViewHolder mViewHolder = null; + + String logMessage = Objects.requireNonNull(getItem(position)).getFormattedLog(); + + if (convertView == null || convertView.getTag() == null) { + mViewHolder = new ViewHolder(); + convertView = LayoutInflater.from(getContext()).inflate(R.layout.logs_message, parent, false); + mViewHolder.logTextView = convertView.requireViewById(R.id.logsTextView); + } else { + mViewHolder = (ViewHolder) convertView.getTag(); + } + mViewHolder.logTextView.setText(logMessage); + return convertView; + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java new file mode 100644 index 00000000..f995c5bc --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -0,0 +1,847 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.Manifest; +import android.app.ActivityManager; +import android.app.AlertDialog; +import android.content.ContentResolver; +import android.content.ContentValues; +import android.content.Intent; +import android.content.pm.PackageManager; +import android.net.Uri; +import android.os.Build; +import android.os.Bundle; +import android.os.Handler; +import android.os.Looper; +import android.os.Process; +import android.provider.MediaStore; +import android.system.ErrnoException; +import android.system.Os; +import android.util.Log; +import android.view.View; +import android.view.inputmethod.InputMethodManager; +import android.widget.EditText; +import android.widget.ImageButton; +import android.widget.ImageView; +import android.widget.LinearLayout; +import android.widget.ListView; +import android.widget.TextView; +import android.widget.Toast; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.PickVisualMediaRequest; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.annotation.NonNull; +import androidx.appcompat.app.AppCompatActivity; +import androidx.constraintlayout.widget.ConstraintLayout; +import androidx.core.app.ActivityCompat; +import androidx.core.content.ContextCompat; +import androidx.core.content.res.ResourcesCompat; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import org.json.JSONException; +import org.json.JSONObject; +import org.pytorch.executorch.extension.llm.LlmCallback; +import org.pytorch.executorch.extension.llm.LlmModule; + +public class MainActivity extends AppCompatActivity implements Runnable, LlmCallback { + private EditText mEditTextMessage; + private ImageButton mThinkModeButton; + private ImageButton mSendButton; + private ImageButton mGalleryButton; + private ImageButton mCameraButton; + private ListView mMessagesView; + private MessageAdapter mMessageAdapter; + private LlmModule mModule = null; + private Message mResultMessage = null; + private ImageButton mSettingsButton; + private TextView mMemoryView; + private ActivityResultLauncher mPickGallery; + private ActivityResultLauncher mCameraRoll; + private List mSelectedImageUri; + private ConstraintLayout mMediaPreviewConstraintLayout; + private LinearLayout mAddMediaLayout; + private static final int MAX_NUM_OF_IMAGES = 5; + private static final int REQUEST_IMAGE_CAPTURE = 1; + private Uri cameraImageUri; + private DemoSharedPreferences mDemoSharedPreferences; + private SettingsFields mCurrentSettingsFields; + private Handler mMemoryUpdateHandler; + private Runnable memoryUpdater; + private boolean mThinkMode = false; + private int promptID = 0; + private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; + private Executor executor; + + @Override + public void onResult(String result) { + if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) { + return; + } + result = PromptFormat.replaceSpecialToken(mCurrentSettingsFields.getModelType(), result); + if (result.equals("\n\n") || result.equals("\n")) { + if (!mResultMessage.getText().isEmpty()) { + mResultMessage.appendText(result); + run(); + } + } else { + mResultMessage.appendText(result); + run(); + } + } + + @Override + public void onStats(String stats) { + runOnUiThread( + () -> { + if (mResultMessage != null) { + float tps = 0; + try { + JSONObject jsonObject = new JSONObject(stats); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + } catch (JSONException e) { + Log.e("LLM", "Error parsing JSON: " + e.getMessage()); + } + mResultMessage.setTokensPerSecond(tps); + mMessageAdapter.notifyDataSetChanged(); + } + }); + } + + private void setLocalModel(String modelPath, String tokenizerPath, float temperature) { + Message modelLoadingMessage = new Message("Loading model...", false, MessageType.SYSTEM, 0); + ETLogging.getInstance().log("Loading model " + modelPath + " with tokenizer " + tokenizerPath); + runOnUiThread( + () -> { + mSendButton.setEnabled(false); + mMessageAdapter.add(modelLoadingMessage); + mMessageAdapter.notifyDataSetChanged(); + }); + if (mModule != null) { + ETLogging.getInstance().log("Start deallocating existing module instance"); + mModule.resetNative(); + mModule = null; + ETLogging.getInstance().log("Completed deallocating existing module instance"); + } + long runStartTime = System.currentTimeMillis(); + mModule = + new LlmModule( + ModelUtils.getModelCategory( + mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()), + modelPath, + tokenizerPath, + temperature); + int loadResult = mModule.load(); + long loadDuration = System.currentTimeMillis() - runStartTime; + String modelLoadError = ""; + String modelInfo = ""; + if (loadResult != 0) { + // TODO: Map the error code to a reason to let the user know why model loading failed + modelInfo = "*Model could not load (Error Code: " + loadResult + ")*" + "\n"; + loadDuration = 0; + AlertDialog.Builder builder = new AlertDialog.Builder(this); + builder.setTitle("Load failed: " + loadResult); + runOnUiThread( + () -> { + AlertDialog alert = builder.create(); + alert.show(); + }); + } else { + String[] segments = modelPath.split("/"); + String pteName = segments[segments.length - 1]; + segments = tokenizerPath.split("/"); + String tokenizerName = segments[segments.length - 1]; + modelInfo = + "Successfully loaded model. " + + pteName + + " and tokenizer " + + tokenizerName + + " in " + + (float) loadDuration / 1000 + + " sec." + + " You can send text or image for inference"; + + if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { + ETLogging.getInstance().log("Llava start prefill prompt"); + mModule.resetContext(); + mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt()); + ETLogging.getInstance().log("Llava completes prefill prompt"); + } + } + + Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0); + + String modelLoggingInfo = + modelLoadError + + "Model path: " + + modelPath + + "\nTokenizer path: " + + tokenizerPath + + "\nBackend: " + + mCurrentSettingsFields.getBackendType().toString() + + "\nModelType: " + + ModelUtils.getModelCategory( + mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) + + "\nTemperature: " + + temperature + + "\nModel loaded time: " + + loadDuration + + " ms"; + ETLogging.getInstance().log("Load complete. " + modelLoggingInfo); + + runOnUiThread( + () -> { + mSendButton.setEnabled(true); + mMessageAdapter.remove(modelLoadingMessage); + mMessageAdapter.add(modelLoadedMessage); + mMessageAdapter.notifyDataSetChanged(); + }); + } + + private void loadLocalModelAndParameters( + String modelFilePath, String tokenizerFilePath, float temperature) { + Runnable runnable = + new Runnable() { + @Override + public void run() { + setLocalModel(modelFilePath, tokenizerFilePath, temperature); + } + }; + new Thread(runnable).start(); + } + + private void populateExistingMessages(String existingMsgJSON) { + Gson gson = new Gson(); + Type type = new TypeToken>() {}.getType(); + ArrayList savedMessages = gson.fromJson(existingMsgJSON, type); + for (Message msg : savedMessages) { + mMessageAdapter.add(msg); + } + mMessageAdapter.notifyDataSetChanged(); + } + + private int setPromptID() { + + return mMessageAdapter.getMaxPromptID() + 1; + } + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + if (Build.VERSION.SDK_INT >= 21) { + getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); + getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); + } + + try { + Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); + Os.setenv("LD_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); + } catch (ErrnoException e) { + finish(); + } + + mThinkModeButton = requireViewById(R.id.thinkModeButton); + mEditTextMessage = requireViewById(R.id.editTextMessage); + mSendButton = requireViewById(R.id.sendButton); + mSendButton.setEnabled(false); + mMessagesView = requireViewById(R.id.messages_view); + mMessageAdapter = new MessageAdapter(this, R.layout.sent_message, new ArrayList()); + mMessagesView.setAdapter(mMessageAdapter); + mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); + String existingMsgJSON = mDemoSharedPreferences.getSavedMessages(); + if (!existingMsgJSON.isEmpty()) { + populateExistingMessages(existingMsgJSON); + promptID = setPromptID(); + } + mSettingsButton = requireViewById(R.id.settings); + mSettingsButton.setOnClickListener( + view -> { + Intent myIntent = new Intent(MainActivity.this, SettingsActivity.class); + MainActivity.this.startActivity(myIntent); + }); + + mThinkModeButton.setOnClickListener( + view -> { + if (mThinkMode) { + mThinkMode = false; + mThinkModeButton.setImageDrawable( + ResourcesCompat.getDrawable( + getResources(), R.drawable.baseline_lightbulb_24, null)); + } else { + mThinkMode = true; + mThinkModeButton.setImageDrawable( + ResourcesCompat.getDrawable(getResources(), R.drawable.blue_lightbulb_24, null)); + } + runOnUiThread( + () -> { + String thinkingModeText = mThinkMode ? "on" : "off"; + mMessageAdapter.add( + new Message( + "Thinking mode is " + thinkingModeText, false, MessageType.SYSTEM, 0)); + mMessageAdapter.notifyDataSetChanged(); + }); + }); + + mCurrentSettingsFields = new SettingsFields(); + mMemoryUpdateHandler = new Handler(Looper.getMainLooper()); + onModelRunStopped(); + setupMediaButton(); + setupGalleryPicker(); + setupCameraRoll(); + startMemoryUpdate(); + setupShowLogsButton(); + executor = Executors.newSingleThreadExecutor(); + } + + @Override + protected void onPause() { + super.onPause(); + mDemoSharedPreferences.addMessages(mMessageAdapter); + } + + @Override + protected void onResume() { + super.onResume(); + // Check for if settings parameters have changed + Gson gson = new Gson(); + String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); + if (!settingsFieldsJSON.isEmpty()) { + SettingsFields updatedSettingsFields = + gson.fromJson(settingsFieldsJSON, SettingsFields.class); + if (updatedSettingsFields == null) { + // Added this check, because gson.fromJson can return null + askUserToSelectModel(); + return; + } + boolean isUpdated = !mCurrentSettingsFields.equals(updatedSettingsFields); + boolean isLoadModel = updatedSettingsFields.getIsLoadModel(); + setBackendMode(updatedSettingsFields.getBackendType()); + if (isUpdated) { + if (isLoadModel) { + // If users change the model file, but not pressing loadModelButton, we won't load the new + // model + checkForUpdateAndReloadModel(updatedSettingsFields); + } else { + askUserToSelectModel(); + } + + checkForClearChatHistory(updatedSettingsFields); + // Update current to point to the latest + mCurrentSettingsFields = new SettingsFields(updatedSettingsFields); + } + } else { + askUserToSelectModel(); + } + } + + private void setBackendMode(BackendType backendType) { + if (backendType.equals(BackendType.XNNPACK) || backendType.equals(BackendType.QUALCOMM)) { + setXNNPACKMode(); + } else if (backendType.equals(BackendType.MEDIATEK)) { + setMediaTekMode(); + } + } + + private void setXNNPACKMode() { + requireViewById(R.id.addMediaButton).setVisibility(View.VISIBLE); + } + + private void setMediaTekMode() { + requireViewById(R.id.addMediaButton).setVisibility(View.GONE); + } + + private void checkForClearChatHistory(SettingsFields updatedSettingsFields) { + if (updatedSettingsFields.getIsClearChatHistory()) { + mMessageAdapter.clear(); + mMessageAdapter.notifyDataSetChanged(); + mDemoSharedPreferences.removeExistingMessages(); + // changing to false since chat history has been cleared. + updatedSettingsFields.saveIsClearChatHistory(false); + mDemoSharedPreferences.addSettings(updatedSettingsFields); + } + } + + private void checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields) { + // TODO need to add 'load model' in settings and queue loading based on that + String modelPath = updatedSettingsFields.getModelFilePath(); + String tokenizerPath = updatedSettingsFields.getTokenizerFilePath(); + double temperature = updatedSettingsFields.getTemperature(); + if (!modelPath.isEmpty() && !tokenizerPath.isEmpty()) { + if (updatedSettingsFields.getIsLoadModel() + || !modelPath.equals(mCurrentSettingsFields.getModelFilePath()) + || !tokenizerPath.equals(mCurrentSettingsFields.getTokenizerFilePath()) + || temperature != mCurrentSettingsFields.getTemperature()) { + loadLocalModelAndParameters( + updatedSettingsFields.getModelFilePath(), + updatedSettingsFields.getTokenizerFilePath(), + (float) updatedSettingsFields.getTemperature()); + updatedSettingsFields.saveLoadModelAction(false); + mDemoSharedPreferences.addSettings(updatedSettingsFields); + } + } else { + askUserToSelectModel(); + } + } + + private void askUserToSelectModel() { + String askLoadModel = + "To get started, select your desired model and tokenizer " + "from the top right corner"; + Message askLoadModelMessage = new Message(askLoadModel, false, MessageType.SYSTEM, 0); + ETLogging.getInstance().log(askLoadModel); + runOnUiThread( + () -> { + mMessageAdapter.add(askLoadModelMessage); + mMessageAdapter.notifyDataSetChanged(); + }); + } + + private void setupShowLogsButton() { + ImageButton showLogsButton = requireViewById(R.id.showLogsButton); + showLogsButton.setOnClickListener( + view -> { + Intent myIntent = new Intent(MainActivity.this, LogsActivity.class); + MainActivity.this.startActivity(myIntent); + }); + } + + private void setupMediaButton() { + mAddMediaLayout = requireViewById(R.id.addMediaLayout); + mAddMediaLayout.setVisibility(View.GONE); // We hide this initially + + ImageButton addMediaButton = requireViewById(R.id.addMediaButton); + addMediaButton.setOnClickListener( + view -> { + mAddMediaLayout.setVisibility(View.VISIBLE); + }); + + mGalleryButton = requireViewById(R.id.galleryButton); + mGalleryButton.setOnClickListener( + view -> { + // Launch the photo picker and let the user choose only images. + mPickGallery.launch( + new PickVisualMediaRequest.Builder() + .setMediaType(ActivityResultContracts.PickVisualMedia.ImageOnly.INSTANCE) + .build()); + }); + mCameraButton = requireViewById(R.id.cameraButton); + mCameraButton.setOnClickListener( + view -> { + Log.d("CameraRoll", "Check permission"); + if (ContextCompat.checkSelfPermission(MainActivity.this, Manifest.permission.CAMERA) + != PackageManager.PERMISSION_GRANTED) { + ActivityCompat.requestPermissions( + MainActivity.this, + new String[] {Manifest.permission.CAMERA}, + REQUEST_IMAGE_CAPTURE); + } else { + launchCamera(); + } + }); + } + + private void setupCameraRoll() { + // Registers a camera roll activity launcher. + mCameraRoll = + registerForActivityResult( + new ActivityResultContracts.TakePicture(), + result -> { + if (result && cameraImageUri != null) { + Log.d("CameraRoll", "Photo saved to uri: " + cameraImageUri); + mAddMediaLayout.setVisibility(View.GONE); + List uris = new ArrayList<>(); + uris.add(cameraImageUri); + showMediaPreview(uris); + } else { + // Delete the temp image file based on the url since the photo is not successfully + // taken + if (cameraImageUri != null) { + ContentResolver contentResolver = MainActivity.this.getContentResolver(); + contentResolver.delete(cameraImageUri, null, null); + Log.d("CameraRoll", "No photo taken. Delete temp uri"); + } + } + }); + mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); + ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); + mediaPreviewCloseButton.setOnClickListener( + view -> { + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mSelectedImageUri = null; + }); + + ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); + addMoreImageButton.setOnClickListener( + view -> { + Log.d("addMore", "clicked"); + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + // Direct user to select type of input + mCameraButton.callOnClick(); + }); + } + + private String updateMemoryUsage() { + ActivityManager.MemoryInfo memoryInfo = new ActivityManager.MemoryInfo(); + ActivityManager activityManager = (ActivityManager) getSystemService(ACTIVITY_SERVICE); + if (activityManager == null) { + return "---"; + } + activityManager.getMemoryInfo(memoryInfo); + long totalMem = memoryInfo.totalMem / (1024 * 1024); + long availableMem = memoryInfo.availMem / (1024 * 1024); + long usedMem = totalMem - availableMem; + return usedMem + "MB"; + } + + private void startMemoryUpdate() { + mMemoryView = requireViewById(R.id.ram_usage_live); + memoryUpdater = + new Runnable() { + @Override + public void run() { + mMemoryView.setText(updateMemoryUsage()); + mMemoryUpdateHandler.postDelayed(this, 1000); + } + }; + mMemoryUpdateHandler.post(memoryUpdater); + } + + @Override + public void onRequestPermissionsResult( + int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (requestCode == REQUEST_IMAGE_CAPTURE && grantResults.length != 0) { + if (grantResults[0] == PackageManager.PERMISSION_GRANTED) { + launchCamera(); + } else if (grantResults[0] == PackageManager.PERMISSION_DENIED) { + Log.d("CameraRoll", "Permission denied"); + } + } + } + + private void launchCamera() { + ContentValues values = new ContentValues(); + values.put(MediaStore.Images.Media.TITLE, "New Picture"); + values.put(MediaStore.Images.Media.DESCRIPTION, "From Camera"); + values.put(MediaStore.Images.Media.RELATIVE_PATH, "DCIM/Camera/"); + cameraImageUri = + MainActivity.this + .getContentResolver() + .insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values); + mCameraRoll.launch(cameraImageUri); + } + + private void setupGalleryPicker() { + // Registers a photo picker activity launcher in single-select mode. + mPickGallery = + registerForActivityResult( + new ActivityResultContracts.PickMultipleVisualMedia(MAX_NUM_OF_IMAGES), + uris -> { + if (!uris.isEmpty()) { + Log.d("PhotoPicker", "Selected URIs: " + uris); + mAddMediaLayout.setVisibility(View.GONE); + for (Uri uri : uris) { + MainActivity.this + .getContentResolver() + .takePersistableUriPermission(uri, Intent.FLAG_GRANT_READ_URI_PERMISSION); + } + showMediaPreview(uris); + } else { + Log.d("PhotoPicker", "No media selected"); + } + }); + + mMediaPreviewConstraintLayout = requireViewById(R.id.mediaPreviewConstraintLayout); + ImageButton mediaPreviewCloseButton = requireViewById(R.id.mediaPreviewCloseButton); + mediaPreviewCloseButton.setOnClickListener( + view -> { + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mSelectedImageUri = null; + }); + + ImageButton addMoreImageButton = requireViewById(R.id.addMoreImageButton); + addMoreImageButton.setOnClickListener( + view -> { + Log.d("addMore", "clicked"); + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + mGalleryButton.callOnClick(); + }); + } + + private List getProcessedImagesForModel(List uris) { + List imageList = new ArrayList<>(); + if (uris != null) { + uris.forEach( + (uri) -> { + imageList.add(new ETImage(this.getContentResolver(), uri)); + }); + } + return imageList; + } + + private void showMediaPreview(List uris) { + if (mSelectedImageUri == null) { + mSelectedImageUri = uris; + } else { + mSelectedImageUri.addAll(uris); + } + + if (mSelectedImageUri.size() > MAX_NUM_OF_IMAGES) { + mSelectedImageUri = mSelectedImageUri.subList(0, MAX_NUM_OF_IMAGES); + Toast.makeText( + this, "Only max " + MAX_NUM_OF_IMAGES + " images are allowed", Toast.LENGTH_SHORT) + .show(); + } + Log.d("mSelectedImageUri", mSelectedImageUri.size() + " " + mSelectedImageUri); + + mMediaPreviewConstraintLayout.setVisibility(View.VISIBLE); + + List imageViews = new ArrayList(); + + // Pre-populate all the image views that are available from the layout (currently max 5) + imageViews.add(requireViewById(R.id.mediaPreviewImageView1)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView2)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView3)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView4)); + imageViews.add(requireViewById(R.id.mediaPreviewImageView5)); + + // Hide all the image views (reset state) + for (int i = 0; i < imageViews.size(); i++) { + imageViews.get(i).setVisibility(View.GONE); + } + + // Only show/render those that have proper Image URIs + for (int i = 0; i < mSelectedImageUri.size(); i++) { + imageViews.get(i).setVisibility(View.VISIBLE); + imageViews.get(i).setImageURI(mSelectedImageUri.get(i)); + } + + // For LLava, we want to call prefill_image as soon as an image is selected + // Llava only support 1 image for now + if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { + List processedImageList = getProcessedImagesForModel(mSelectedImageUri); + if (!processedImageList.isEmpty()) { + mMessageAdapter.add( + new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0)); + mMessageAdapter.notifyDataSetChanged(); + Runnable runnable = + () -> { + Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); + ETLogging.getInstance().log("Starting runnable prefill image"); + ETImage img = processedImageList.get(0); + ETLogging.getInstance().log("Llava start prefill image"); + mModule.prefillImages( + img.getInts(), + img.getWidth(), + img.getHeight(), + ModelUtils.VISION_MODEL_IMAGE_CHANNELS); + }; + executor.execute(runnable); + } + } + } + + private void addSelectedImagesToChatThread(List selectedImageUri) { + if (selectedImageUri == null) { + return; + } + mMediaPreviewConstraintLayout.setVisibility(View.GONE); + for (int i = 0; i < selectedImageUri.size(); i++) { + Uri imageURI = selectedImageUri.get(i); + Log.d("image uri ", "test " + imageURI.getPath()); + mMessageAdapter.add(new Message(imageURI.toString(), true, MessageType.IMAGE, 0)); + } + mMessageAdapter.notifyDataSetChanged(); + } + + private String getConversationHistory() { + String conversationHistory = ""; + + ArrayList conversations = + mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK); + if (conversations.isEmpty()) { + return conversationHistory; + } + + int prevPromptID = conversations.get(0).getPromptID(); + String conversationFormat = + PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType()); + String format = conversationFormat; + for (int i = 0; i < conversations.size(); i++) { + Message conversation = conversations.get(i); + int currentPromptID = conversation.getPromptID(); + if (currentPromptID != prevPromptID) { + conversationHistory = conversationHistory + format; + format = conversationFormat; + prevPromptID = currentPromptID; + } + if (conversation.getIsSent()) { + format = + format + .replace(PromptFormat.USER_PLACEHOLDER, conversation.getText()) + .replace(PromptFormat.THINKING_MODE_PLACEHOLDER, ""); + } else { + format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText()); + } + } + conversationHistory = conversationHistory + format; + + return conversationHistory; + } + + private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) { + if (conversationHistory.isEmpty()) { + return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode); + } + + return mCurrentSettingsFields.getFormattedSystemPrompt() + + conversationHistory + + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt, mThinkMode); + } + + private void onModelRunStarted() { + mSendButton.setClickable(false); + mSendButton.setImageResource(R.drawable.baseline_stop_24); + mSendButton.setOnClickListener( + view -> { + mModule.stop(); + }); + } + + private void onModelRunStopped() { + mSendButton.setClickable(true); + mSendButton.setImageResource(R.drawable.baseline_send_24); + mSendButton.setOnClickListener( + view -> { + try { + InputMethodManager imm = (InputMethodManager) getSystemService(INPUT_METHOD_SERVICE); + imm.hideSoftInputFromWindow(getCurrentFocus().getWindowToken(), 0); + } catch (Exception e) { + ETLogging.getInstance().log("Keyboard dismissal error: " + e.getMessage()); + } + addSelectedImagesToChatThread(mSelectedImageUri); + String finalPrompt; + String rawPrompt = mEditTextMessage.getText().toString(); + if (ModelUtils.getModelCategory( + mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) + == ModelUtils.VISION_MODEL) { + finalPrompt = + mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode); + } else { + finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt); + } + // We store raw prompt into message adapter, because we don't want to show the extra + // tokens from system prompt + mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID)); + mMessageAdapter.notifyDataSetChanged(); + mEditTextMessage.setText(""); + mResultMessage = new Message("", false, MessageType.TEXT, promptID); + mMessageAdapter.add(mResultMessage); + // Scroll to bottom of the list + mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1); + // After images are added to prompt and chat thread, we clear the imageURI list + // Note: This has to be done after imageURIs are no longer needed by LlmModule + mSelectedImageUri = null; + promptID++; + Runnable runnable = + new Runnable() { + @Override + public void run() { + Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); + ETLogging.getInstance().log("starting runnable generate()"); + runOnUiThread( + new Runnable() { + @Override + public void run() { + onModelRunStarted(); + } + }); + long generateStartTime = System.currentTimeMillis(); + if (ModelUtils.getModelCategory( + mCurrentSettingsFields.getModelType(), + mCurrentSettingsFields.getBackendType()) + == ModelUtils.VISION_MODEL) { + mModule.generate( + finalPrompt, ModelUtils.VISION_MODEL_SEQ_LEN, MainActivity.this, false); + } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) { + String llamaGuardPromptForClassification = + PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt); + ETLogging.getInstance() + .log("Running inference.. prompt=" + llamaGuardPromptForClassification); + mModule.generate( + llamaGuardPromptForClassification, + llamaGuardPromptForClassification.length() + 64, + MainActivity.this, + false); + } else { + ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); + mModule.generate( + finalPrompt, + (int) (finalPrompt.length() * 0.75) + 64, + MainActivity.this, + false); + } + + long generateDuration = System.currentTimeMillis() - generateStartTime; + mResultMessage.setTotalGenerationTime(generateDuration); + runOnUiThread( + new Runnable() { + @Override + public void run() { + onModelRunStopped(); + } + }); + ETLogging.getInstance().log("Inference completed"); + } + }; + executor.execute(runnable); + }); + mMessageAdapter.notifyDataSetChanged(); + } + + @Override + public void run() { + runOnUiThread( + new Runnable() { + @Override + public void run() { + mMessageAdapter.notifyDataSetChanged(); + } + }); + } + + @Override + public void onBackPressed() { + super.onBackPressed(); + if (mAddMediaLayout != null && mAddMediaLayout.getVisibility() == View.VISIBLE) { + mAddMediaLayout.setVisibility(View.GONE); + } else { + // Default behavior of back button + finish(); + } + } + + @Override + protected void onDestroy() { + super.onDestroy(); + mMemoryUpdateHandler.removeCallbacks(memoryUpdater); + // This is to cover the case where the app is shutdown when user is on MainActivity but + // never clicked on the logsActivity + ETLogging.getInstance().saveLogs(); + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java new file mode 100644 index 00000000..b2e5380e --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/Message.java @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; + +public class Message { + private String text; + private final boolean isSent; + private float tokensPerSecond; + private long totalGenerationTime; + private final long timestamp; + private final MessageType messageType; + private String imagePath; + private final int promptID; + + private static final String TIMESTAMP_FORMAT = "hh:mm a"; // example: 2:23 PM + + public Message(String text, boolean isSent, MessageType messageType, int promptID) { + this.isSent = isSent; + this.messageType = messageType; + this.promptID = promptID; + + if (messageType == MessageType.IMAGE) { + this.imagePath = text; + } else { + this.text = text; + } + + if (messageType != MessageType.SYSTEM) { + this.timestamp = System.currentTimeMillis(); + } else { + this.timestamp = (long) 0; + } + } + + public int getPromptID() { + return promptID; + } + + public MessageType getMessageType() { + return messageType; + } + + public String getImagePath() { + return imagePath; + } + + public String getText() { + return text; + } + + public void appendText(String text) { + this.text += text; + } + + public boolean getIsSent() { + return isSent; + } + + public void setTokensPerSecond(float tokensPerSecond) { + this.tokensPerSecond = tokensPerSecond; + } + + public void setTotalGenerationTime(long totalGenerationTime) { + this.totalGenerationTime = totalGenerationTime; + } + + public float getTokensPerSecond() { + return tokensPerSecond; + } + + public long getTotalGenerationTime() { + return totalGenerationTime; + } + + public long getTimestamp() { + return timestamp; + } + + public String getFormattedTimestamp() { + SimpleDateFormat formatter = new SimpleDateFormat(TIMESTAMP_FORMAT, Locale.getDefault()); + Date date = new Date(timestamp); + return formatter.format(date); + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java new file mode 100644 index 00000000..31aaa9a1 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java @@ -0,0 +1,135 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.net.Uri; +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.ArrayAdapter; +import android.widget.ImageView; +import android.widget.TextView; +import java.util.ArrayList; +import java.util.Collections; + +public class MessageAdapter extends ArrayAdapter { + + private final ArrayList savedMessages; + + public MessageAdapter( + android.content.Context context, int resource, ArrayList savedMessages) { + super(context, resource); + this.savedMessages = savedMessages; + } + + @Override + public View getView(int position, View convertView, ViewGroup parent) { + Message currentMessage = getItem(position); + int layoutIdForListItem; + + if (currentMessage.getMessageType() == MessageType.SYSTEM) { + layoutIdForListItem = R.layout.system_message; + } else { + layoutIdForListItem = + currentMessage.getIsSent() ? R.layout.sent_message : R.layout.received_message; + } + View listItemView = + LayoutInflater.from(getContext()).inflate(layoutIdForListItem, parent, false); + if (currentMessage.getMessageType() == MessageType.IMAGE) { + ImageView messageImageView = listItemView.requireViewById(R.id.message_image); + messageImageView.setImageURI(Uri.parse(currentMessage.getImagePath())); + TextView messageTextView = listItemView.requireViewById(R.id.message_text); + messageTextView.setVisibility(View.GONE); + } else { + TextView messageTextView = listItemView.requireViewById(R.id.message_text); + messageTextView.setText(currentMessage.getText()); + } + + String metrics = ""; + TextView tokensView; + if (currentMessage.getTokensPerSecond() > 0) { + metrics = String.format("%.2f", currentMessage.getTokensPerSecond()) + "t/s "; + } + + if (currentMessage.getTotalGenerationTime() > 0) { + metrics = metrics + (float) currentMessage.getTotalGenerationTime() / 1000 + "s "; + } + + if (currentMessage.getTokensPerSecond() > 0 || currentMessage.getTotalGenerationTime() > 0) { + tokensView = listItemView.requireViewById(R.id.generation_metrics); + tokensView.setText(metrics); + TextView separatorView = listItemView.requireViewById(R.id.bar); + separatorView.setVisibility(View.VISIBLE); + } + + if (currentMessage.getTimestamp() > 0) { + TextView timestampView = listItemView.requireViewById(R.id.timestamp); + timestampView.setText(currentMessage.getFormattedTimestamp()); + } + + return listItemView; + } + + @Override + public void add(Message msg) { + super.add(msg); + savedMessages.add(msg); + } + + @Override + public void clear() { + super.clear(); + savedMessages.clear(); + } + + public ArrayList getSavedMessages() { + return savedMessages; + } + + public ArrayList getRecentSavedTextMessages(int numOfLatestPromptMessages) { + ArrayList recentMessages = new ArrayList(); + int lastIndex = savedMessages.size() - 1; + // In most cases lastIndex >=0 . + // A situation where the user clears chat history and enters prompt. Causes lastIndex=-1 . + if (lastIndex >= 0) { + Message messageToAdd = savedMessages.get(lastIndex); + int oldPromptID = messageToAdd.getPromptID(); + + for (int i = 0; i < savedMessages.size(); i++) { + messageToAdd = savedMessages.get(lastIndex - i); + if (messageToAdd.getMessageType() != MessageType.SYSTEM) { + if (messageToAdd.getPromptID() != oldPromptID) { + numOfLatestPromptMessages--; + oldPromptID = messageToAdd.getPromptID(); + } + if (numOfLatestPromptMessages > 0) { + if (messageToAdd.getMessageType() == MessageType.TEXT) { + recentMessages.add(messageToAdd); + } + } else { + break; + } + } + } + // To place the order in [input1, output1, input2, output2...] + Collections.reverse(recentMessages); + } + + return recentMessages; + } + + public int getMaxPromptID() { + int maxPromptID = -1; + for (Message msg : savedMessages) { + + maxPromptID = Math.max(msg.getPromptID(), maxPromptID); + } + return maxPromptID; + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java new file mode 100644 index 00000000..6042acb5 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageType.java @@ -0,0 +1,15 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public enum MessageType { + TEXT, + IMAGE, + SYSTEM +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java new file mode 100644 index 00000000..a1bc205c --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; +import androidx.annotation.NonNull; +import org.json.JSONException; +import org.json.JSONObject; +import org.pytorch.executorch.extension.llm.LlmCallback; +import org.pytorch.executorch.extension.llm.LlmModule; + +/** A helper class to handle all model running logic within this class. */ +public class ModelRunner implements LlmCallback { + LlmModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + ModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + ModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + ModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(String stats) { + float tps = 0; + try { + JSONObject jsonObject = new JSONObject(stats); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + } catch (JSONException e) { + } + mCallback.onStats("tokens/second: " + tps); + } +} + +class ModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final ModelRunner mModelRunner; + + public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { + super(looper); + mModelRunner = modelRunner; + } + + @Override + public void handleMessage(@NonNull android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mModelRunner.mModule.load(); + mModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mModelRunner.mModule.generate((String) msg.obj, mModelRunner); + mModelRunner.mCallback.onGenerationStopped(); + } + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java new file mode 100644 index 00000000..5e8b6f00 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +/** + * A helper interface within the app for MainActivity and Benchmarking to handle callback from + * ModelRunner. + */ +public interface ModelRunnerCallback { + + void onModelLoaded(int status); + + void onTokenGenerated(String token); + + void onStats(String stats); + + void onGenerationStopped(); +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java new file mode 100644 index 00000000..9f813250 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java @@ -0,0 +1,18 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public enum ModelType { + LLAMA_3, + LLAMA_3_1, + LLAMA_3_2, + LLAVA_1_5, + LLAMA_GUARD_3, + QWEN_3, +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java new file mode 100644 index 00000000..cf7ab175 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public class ModelUtils { + // XNNPACK or QNN + static final int TEXT_MODEL = 1; + + // XNNPACK + static final int VISION_MODEL = 2; + static final int VISION_MODEL_IMAGE_CHANNELS = 3; + static final int VISION_MODEL_SEQ_LEN = 768; + static final int TEXT_MODEL_SEQ_LEN = 256; + + // MediaTek + static final int MEDIATEK_TEXT_MODEL = 3; + + // QNN static llama + static final int QNN_TEXT_MODEL = 4; + + public static int getModelCategory(ModelType modelType, BackendType backendType) { + if (backendType.equals(BackendType.XNNPACK)) { + switch (modelType) { + case LLAVA_1_5: + return VISION_MODEL; + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + case QWEN_3: + default: + return TEXT_MODEL; + } + } else if (backendType.equals(BackendType.MEDIATEK)) { + return MEDIATEK_TEXT_MODEL; + } else if (backendType.equals(BackendType.QUALCOMM)) { + return QNN_TEXT_MODEL; + } + + return TEXT_MODEL; // default + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java new file mode 100644 index 00000000..524ad7cb --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public class PromptFormat { + + public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}"; + public static final String USER_PLACEHOLDER = "{{ user_prompt }}"; + public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}"; + public static final String THINKING_MODE_PLACEHOLDER = "{{ thinking_mode }}"; + public static final String DEFAULT_SYSTEM_PROMPT = "Answer the questions in a few sentences"; + + public static String getSystemPromptTemplate(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" + + SYSTEM_PLACEHOLDER + + "<|eot_id|>"; + case LLAVA_1_5: + return "USER: "; + case QWEN_3: + return "<|im_start|>system\n" + "You are a helpful assistant.\n" + "<|im_end|>\n"; + default: + return SYSTEM_PLACEHOLDER; + } + } + + public static String getUserPromptTemplate(ModelType modelType, boolean thinkingMode) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + case LLAMA_GUARD_3: + return "<|start_header_id|>user<|end_header_id|>\n" + + USER_PLACEHOLDER + + "<|eot_id|>" + + "<|start_header_id|>assistant<|end_header_id|>"; + + case QWEN_3: + return "<|im_start|>user\n" + + USER_PLACEHOLDER + + "\n<|im_end|>\n" + + "<|im_start|>assistant\n" + + THINKING_MODE_PLACEHOLDER; + case LLAVA_1_5: + default: + return USER_PLACEHOLDER; + } + } + + public static String getConversationFormat(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + return getUserPromptTemplate(modelType, false) + + "\n" + + ASSISTANT_PLACEHOLDER + + "<|eot_id|>"; + case LLAVA_1_5: + return USER_PLACEHOLDER + " ASSISTANT:"; + case QWEN_3: + return getUserPromptTemplate(modelType, false) + "<|im_end|>\n"; + default: + return USER_PLACEHOLDER; + } + } + + public static String getStopToken(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + case LLAMA_3_2: + case LLAMA_GUARD_3: + return "<|eot_id|>"; + case LLAVA_1_5: + return ""; + case QWEN_3: + return "<|endoftext|>"; + default: + return ""; + } + } + + public static String getThinkingModeToken(ModelType modelType, boolean thinkingMode) { + switch (modelType) { + case QWEN_3: + return thinkingMode ? "" : "\n\n\n\n\n"; + default: + return ""; + } + } + + public static String replaceSpecialToken(ModelType modelType, String token) { + switch (modelType) { + case QWEN_3: + switch (token) { + case "<|im_end|>": + return ""; + case "": + return "Thinking...\n"; + case "": + return "\nDone thinking"; + } + default: + return token; + } + } + + public static String getLlavaPresetPrompt() { + return "A chat between a curious human and an artificial intelligence assistant. The assistant" + + " gives helpful, detailed, and polite answers to the human's questions. USER: "; + } + + public static String getFormattedLlamaGuardPrompt(String userPrompt) { + return getUserPromptTemplate(ModelType.LLAMA_GUARD_3, false) + .replace( + USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt)); + } + + public static String getLlamaGuardPresetPrompt() { + return "Task: Check if there is unsafe content in 'User' messages in conversations" + + " according our safety policy with the below categories.\n" + + "\n" + + "\n" + + "S1: Violent Crimes. \n" + + "S2: Non-Violent Crimes. \n" + + "S3: Sex Crimes. \n" + + "S4: Child Exploitation. \n" + + "S5: Defamation. \n" + + "S6: Specialized Advice. \n" + + "S7: Privacy. \n" + + "S8: Intellectual Property. \n" + + "S9: Indiscriminate Weapons. \n" + + "S10: Hate. \n" + + "S11: Self-Harm. \n" + + "S12: Sexual Content. \n" + + "S13: Elections. \n" + + "S14: Code Interpreter Abuse.\n" + + "\n" + + "\n" + + "\n" + + "\n" + + "User: " + + USER_PLACEHOLDER + + "\n" + + "\n" + + "Provide your safety assessment for ONLY THE LAST User message in the above" + + " conversation:\n" + + " - First line must read 'safe' or 'unsafe'.\n" + + " - If unsafe, a second line must include a comma-separated list of violated" + + " categories."; + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java new file mode 100644 index 00000000..0e388a5b --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java @@ -0,0 +1,463 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +import android.app.AlertDialog; +import android.content.DialogInterface; +import android.os.Build; +import android.os.Bundle; +import android.text.Editable; +import android.text.TextWatcher; +import android.view.View; +import android.widget.Button; +import android.widget.EditText; +import android.widget.ImageButton; +import android.widget.TextView; +import androidx.appcompat.app.AppCompatActivity; +import androidx.core.content.ContextCompat; +import androidx.core.graphics.Insets; +import androidx.core.view.ViewCompat; +import androidx.core.view.WindowInsetsCompat; +import com.google.gson.Gson; +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class SettingsActivity extends AppCompatActivity { + + private String mModelFilePath = ""; + private String mTokenizerFilePath = ""; + private TextView mBackendTextView; + private TextView mModelTextView; + private TextView mTokenizerTextView; + private TextView mModelTypeTextView; + private EditText mSystemPromptEditText; + private EditText mUserPromptEditText; + private Button mLoadModelButton; + private double mSetTemperature; + private String mSystemPrompt; + private String mUserPrompt; + private BackendType mBackendType; + private ModelType mModelType; + public SettingsFields mSettingsFields; + + private DemoSharedPreferences mDemoSharedPreferences; + public static double TEMPERATURE_MIN_VALUE = 0.0; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_settings); + if (Build.VERSION.SDK_INT >= 21) { + getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); + getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); + } + ViewCompat.setOnApplyWindowInsetsListener( + requireViewById(R.id.main), + (v, insets) -> { + Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); + v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); + return insets; + }); + mDemoSharedPreferences = new DemoSharedPreferences(getBaseContext()); + mSettingsFields = new SettingsFields(); + setupSettings(); + } + + private void setupSettings() { + mBackendTextView = requireViewById(R.id.backendTextView); + mModelTextView = requireViewById(R.id.modelTextView); + mTokenizerTextView = requireViewById(R.id.tokenizerTextView); + mModelTypeTextView = requireViewById(R.id.modelTypeTextView); + ImageButton backendImageButton = requireViewById(R.id.backendImageButton); + ImageButton modelImageButton = requireViewById(R.id.modelImageButton); + ImageButton tokenizerImageButton = requireViewById(R.id.tokenizerImageButton); + ImageButton modelTypeImageButton = requireViewById(R.id.modelTypeImageButton); + mSystemPromptEditText = requireViewById(R.id.systemPromptText); + mUserPromptEditText = requireViewById(R.id.userPromptText); + loadSettings(); + + // TODO: The two setOnClickListeners will be removed after file path issue is resolved + backendImageButton.setOnClickListener( + view -> { + setupBackendSelectorDialog(); + }); + modelImageButton.setOnClickListener( + view -> { + setupModelSelectorDialog(); + }); + tokenizerImageButton.setOnClickListener( + view -> { + setupTokenizerSelectorDialog(); + }); + modelTypeImageButton.setOnClickListener( + view -> { + setupModelTypeSelectorDialog(); + }); + mModelFilePath = mSettingsFields.getModelFilePath(); + if (!mModelFilePath.isEmpty()) { + mModelTextView.setText(getFilenameFromPath(mModelFilePath)); + } + mTokenizerFilePath = mSettingsFields.getTokenizerFilePath(); + if (!mTokenizerFilePath.isEmpty()) { + mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); + } + mModelType = mSettingsFields.getModelType(); + ETLogging.getInstance().log("mModelType from settings " + mModelType); + if (mModelType != null) { + mModelTypeTextView.setText(mModelType.toString()); + } + mBackendType = mSettingsFields.getBackendType(); + ETLogging.getInstance().log("mBackendType from settings " + mBackendType); + if (mBackendType != null) { + mBackendTextView.setText(mBackendType.toString()); + setBackendSettingMode(); + } + + setupParameterSettings(); + setupPromptSettings(); + setupClearChatHistoryButton(); + setupLoadModelButton(); + } + + private void setupLoadModelButton() { + mLoadModelButton = requireViewById(R.id.loadModelButton); + mLoadModelButton.setEnabled(true); + mLoadModelButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Load Model") + .setMessage("Do you really want to load the new model?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + mSettingsFields.saveLoadModelAction(true); + mLoadModelButton.setEnabled(false); + onBackPressed(); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupClearChatHistoryButton() { + Button clearChatButton = requireViewById(R.id.clearChatButton); + clearChatButton.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Delete Chat History") + .setMessage("Do you really want to delete chat history?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + mSettingsFields.saveIsClearChatHistory(true); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupParameterSettings() { + setupTemperatureSettings(); + } + + private void setupTemperatureSettings() { + mSetTemperature = mSettingsFields.getTemperature(); + EditText temperatureEditText = requireViewById(R.id.temperatureEditText); + temperatureEditText.setText(String.valueOf(mSetTemperature)); + temperatureEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + mSetTemperature = Double.parseDouble(s.toString()); + // This is needed because temperature is changed together with model loading + // Once temperature is no longer in LlmModule constructor, we can remove this + mSettingsFields.saveLoadModelAction(true); + saveSettings(); + } + }); + } + + private void setupPromptSettings() { + setupSystemPromptSettings(); + setupUserPromptSettings(); + } + + private void setupSystemPromptSettings() { + mSystemPrompt = mSettingsFields.getSystemPrompt(); + mSystemPromptEditText.setText(mSystemPrompt); + mSystemPromptEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + mSystemPrompt = s.toString(); + } + }); + + ImageButton resetSystemPrompt = requireViewById(R.id.resetSystemPrompt); + resetSystemPrompt.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Reset System Prompt") + .setMessage("Do you really want to reset system prompt?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + mSystemPromptEditText.setText(PromptFormat.DEFAULT_SYSTEM_PROMPT); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private void setupUserPromptSettings() { + mUserPrompt = mSettingsFields.getUserPrompt(); + mUserPromptEditText.setText(mUserPrompt); + mUserPromptEditText.addTextChangedListener( + new TextWatcher() { + @Override + public void beforeTextChanged(CharSequence s, int start, int count, int after) {} + + @Override + public void onTextChanged(CharSequence s, int start, int before, int count) {} + + @Override + public void afterTextChanged(Editable s) { + if (isValidUserPrompt(s.toString())) { + mUserPrompt = s.toString(); + } else { + showInvalidPromptDialog(); + } + } + }); + + ImageButton resetUserPrompt = requireViewById(R.id.resetUserPrompt); + resetUserPrompt.setOnClickListener( + view -> { + new AlertDialog.Builder(this) + .setTitle("Reset Prompt Template") + .setMessage("Do you really want to reset the prompt template?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + new DialogInterface.OnClickListener() { + public void onClick(DialogInterface dialog, int whichButton) { + // Clear the messageAdapter and sharedPreference + mUserPromptEditText.setText( + PromptFormat.getUserPromptTemplate(mModelType, false)); + } + }) + .setNegativeButton(android.R.string.no, null) + .show(); + }); + } + + private boolean isValidUserPrompt(String userPrompt) { + return userPrompt.contains(PromptFormat.USER_PLACEHOLDER); + } + + private void showInvalidPromptDialog() { + new AlertDialog.Builder(this) + .setTitle("Invalid Prompt Format") + .setMessage( + "Prompt format must contain " + + PromptFormat.USER_PLACEHOLDER + + ". Do you want to reset prompt format?") + .setIcon(android.R.drawable.ic_dialog_alert) + .setPositiveButton( + android.R.string.yes, + (dialog, whichButton) -> { + mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType, false)); + }) + .setNegativeButton(android.R.string.no, null) + .show(); + } + + private void setupBackendSelectorDialog() { + // Convert enum to list + List backendTypesList = new ArrayList<>(); + for (BackendType backendType : BackendType.values()) { + backendTypesList.add(backendType.toString()); + } + // Alert dialog builder takes in arr of string instead of list + String[] backendTypes = backendTypesList.toArray(new String[0]); + AlertDialog.Builder backendTypeBuilder = new AlertDialog.Builder(this); + backendTypeBuilder.setTitle("Select backend type"); + backendTypeBuilder.setSingleChoiceItems( + backendTypes, + -1, + (dialog, item) -> { + mBackendTextView.setText(backendTypes[item]); + mBackendType = BackendType.valueOf(backendTypes[item]); + setBackendSettingMode(); + dialog.dismiss(); + }); + + backendTypeBuilder.create().show(); + } + + private void setupModelSelectorDialog() { + String[] pteFiles = listLocalFile("/data/local/tmp/llama/", new String[] {".pte"}); + AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this); + modelPathBuilder.setTitle("Select model path"); + + modelPathBuilder.setSingleChoiceItems( + pteFiles, + -1, + (dialog, item) -> { + mModelFilePath = pteFiles[item]; + mModelTextView.setText(getFilenameFromPath(mModelFilePath)); + mLoadModelButton.setEnabled(true); + dialog.dismiss(); + }); + + modelPathBuilder.create().show(); + } + + private static boolean fileHasExtension(String file, String[] suffix) { + return Arrays.stream(suffix).anyMatch(entry -> file.endsWith(entry)); + } + + private static String[] listLocalFile(String path, String[] suffix) { + File directory = new File(path); + if (directory.exists() && directory.isDirectory()) { + File[] files = directory.listFiles((dir, name) -> (fileHasExtension(name, suffix))); + String[] result = new String[files.length]; + for (int i = 0; i < files.length; i++) { + if (files[i].isFile() && fileHasExtension(files[i].getName(), suffix)) { + result[i] = files[i].getAbsolutePath(); + } + } + return result; + } + return new String[] {}; + } + + private void setupModelTypeSelectorDialog() { + // Convert enum to list + List modelTypesList = new ArrayList<>(); + for (ModelType modelType : ModelType.values()) { + modelTypesList.add(modelType.toString()); + } + // Alert dialog builder takes in arr of string instead of list + String[] modelTypes = modelTypesList.toArray(new String[0]); + AlertDialog.Builder modelTypeBuilder = new AlertDialog.Builder(this); + modelTypeBuilder.setTitle("Select model type"); + modelTypeBuilder.setSingleChoiceItems( + modelTypes, + -1, + (dialog, item) -> { + mModelTypeTextView.setText(modelTypes[item]); + mModelType = ModelType.valueOf(modelTypes[item]); + mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType, false)); + dialog.dismiss(); + }); + + modelTypeBuilder.create().show(); + } + + private void setupTokenizerSelectorDialog() { + String[] tokenizerFiles = + listLocalFile("/data/local/tmp/llama/", new String[] {".bin", ".json", ".model"}); + AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this); + tokenizerPathBuilder.setTitle("Select tokenizer path"); + tokenizerPathBuilder.setSingleChoiceItems( + tokenizerFiles, + -1, + (dialog, item) -> { + mTokenizerFilePath = tokenizerFiles[item]; + mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); + mLoadModelButton.setEnabled(true); + dialog.dismiss(); + }); + + tokenizerPathBuilder.create().show(); + } + + private String getFilenameFromPath(String uriFilePath) { + String[] segments = uriFilePath.split("/"); + if (segments.length > 0) { + return segments[segments.length - 1]; // get last element (aka filename) + } + return ""; + } + + private void setBackendSettingMode() { + if (mBackendType.equals(BackendType.XNNPACK) || mBackendType.equals(BackendType.QUALCOMM)) { + setXNNPACKSettingMode(); + } else if (mBackendType.equals(BackendType.MEDIATEK)) { + setMediaTekSettingMode(); + } + } + + private void setXNNPACKSettingMode() { + requireViewById(R.id.modelLayout).setVisibility(View.VISIBLE); + requireViewById(R.id.tokenizerLayout).setVisibility(View.VISIBLE); + requireViewById(R.id.parametersView).setVisibility(View.VISIBLE); + requireViewById(R.id.temperatureLayout).setVisibility(View.VISIBLE); + mModelFilePath = ""; + mTokenizerFilePath = ""; + } + + private void setMediaTekSettingMode() { + requireViewById(R.id.modelLayout).setVisibility(View.GONE); + requireViewById(R.id.tokenizerLayout).setVisibility(View.GONE); + requireViewById(R.id.parametersView).setVisibility(View.GONE); + requireViewById(R.id.temperatureLayout).setVisibility(View.GONE); + mModelFilePath = "/in/mtk/llama/runner"; + mTokenizerFilePath = "/in/mtk/llama/runner"; + } + + private void loadSettings() { + Gson gson = new Gson(); + String settingsFieldsJSON = mDemoSharedPreferences.getSettings(); + if (!settingsFieldsJSON.isEmpty()) { + mSettingsFields = gson.fromJson(settingsFieldsJSON, SettingsFields.class); + } + } + + private void saveSettings() { + mSettingsFields.saveModelPath(mModelFilePath); + mSettingsFields.saveTokenizerPath(mTokenizerFilePath); + mSettingsFields.saveParameters(mSetTemperature); + mSettingsFields.savePrompts(mSystemPrompt, mUserPrompt); + mSettingsFields.saveModelType(mModelType); + mSettingsFields.saveBackendType(mBackendType); + mDemoSharedPreferences.addSettings(mSettingsFields); + } + + @Override + public void onBackPressed() { + super.onBackPressed(); + saveSettings(); + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java new file mode 100644 index 00000000..94036f43 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java @@ -0,0 +1,148 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo; + +public class SettingsFields { + + public String getModelFilePath() { + return modelFilePath; + } + + public String getTokenizerFilePath() { + return tokenizerFilePath; + } + + public double getTemperature() { + return temperature; + } + + public String getSystemPrompt() { + return systemPrompt; + } + + public ModelType getModelType() { + return modelType; + } + + public BackendType getBackendType() { + return backendType; + } + + public String getUserPrompt() { + return userPrompt; + } + + public String getFormattedSystemAndUserPrompt(String prompt, boolean thinkingMode) { + return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt, thinkingMode); + } + + public String getFormattedSystemPrompt() { + return PromptFormat.getSystemPromptTemplate(modelType) + .replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt); + } + + public String getFormattedUserPrompt(String prompt, boolean thinkingMode) { + return userPrompt + .replace(PromptFormat.USER_PLACEHOLDER, prompt) + .replace( + PromptFormat.THINKING_MODE_PLACEHOLDER, + PromptFormat.getThinkingModeToken(modelType, thinkingMode)); + } + + public boolean getIsClearChatHistory() { + return isClearChatHistory; + } + + public boolean getIsLoadModel() { + return isLoadModel; + } + + private String modelFilePath; + private String tokenizerFilePath; + private double temperature; + private String systemPrompt; + private String userPrompt; + private boolean isClearChatHistory; + private boolean isLoadModel; + private ModelType modelType; + private BackendType backendType; + + public SettingsFields() { + ModelType DEFAULT_MODEL = ModelType.LLAMA_3; + BackendType DEFAULT_BACKEND = BackendType.XNNPACK; + + modelFilePath = ""; + tokenizerFilePath = ""; + temperature = SettingsActivity.TEMPERATURE_MIN_VALUE; + systemPrompt = ""; + userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL, false); + isClearChatHistory = false; + isLoadModel = false; + modelType = DEFAULT_MODEL; + backendType = DEFAULT_BACKEND; + } + + public SettingsFields(SettingsFields settingsFields) { + this.modelFilePath = settingsFields.modelFilePath; + this.tokenizerFilePath = settingsFields.tokenizerFilePath; + this.temperature = settingsFields.temperature; + this.systemPrompt = settingsFields.getSystemPrompt(); + this.userPrompt = settingsFields.getUserPrompt(); + this.isClearChatHistory = settingsFields.getIsClearChatHistory(); + this.isLoadModel = settingsFields.getIsLoadModel(); + this.modelType = settingsFields.modelType; + this.backendType = settingsFields.backendType; + } + + public void saveModelPath(String modelFilePath) { + this.modelFilePath = modelFilePath; + } + + public void saveTokenizerPath(String tokenizerFilePath) { + this.tokenizerFilePath = tokenizerFilePath; + } + + public void saveModelType(ModelType modelType) { + this.modelType = modelType; + } + + public void saveBackendType(BackendType backendType) { + this.backendType = backendType; + } + + public void saveParameters(Double temperature) { + this.temperature = temperature; + } + + public void savePrompts(String systemPrompt, String userPrompt) { + this.systemPrompt = systemPrompt; + this.userPrompt = userPrompt; + } + + public void saveIsClearChatHistory(boolean needToClear) { + this.isClearChatHistory = needToClear; + } + + public void saveLoadModelAction(boolean shouldLoadModel) { + this.isLoadModel = shouldLoadModel; + } + + public boolean equals(SettingsFields anotherSettingsFields) { + if (this == anotherSettingsFields) return true; + return modelFilePath.equals(anotherSettingsFields.modelFilePath) + && tokenizerFilePath.equals(anotherSettingsFields.tokenizerFilePath) + && temperature == anotherSettingsFields.temperature + && systemPrompt.equals(anotherSettingsFields.systemPrompt) + && userPrompt.equals(anotherSettingsFields.userPrompt) + && isClearChatHistory == anotherSettingsFields.isClearChatHistory + && isLoadModel == anotherSettingsFields.isLoadModel + && modelType == anotherSettingsFields.modelType + && backendType == anotherSettingsFields.backendType; + } +} diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml new file mode 100644 index 00000000..0868ffff --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/banner_shape.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml new file mode 100644 index 00000000..2ae27b84 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_add_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml new file mode 100644 index 00000000..7077fedd --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml new file mode 100644 index 00000000..a6837b9c --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_article_24.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml new file mode 100644 index 00000000..fb902d43 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_close_24.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml new file mode 100644 index 00000000..4680bc66 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_delete_forever_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_lightbulb_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_lightbulb_24.xml new file mode 100644 index 00000000..aa045396 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_lightbulb_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml new file mode 100644 index 00000000..860470ab --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_restart_alt_24.xml @@ -0,0 +1,6 @@ + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml new file mode 100644 index 00000000..2de1f642 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_send_24.xml @@ -0,0 +1,6 @@ + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml new file mode 100644 index 00000000..c51d84b9 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_settings_24.xml @@ -0,0 +1,11 @@ + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml new file mode 100644 index 00000000..832e2585 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/baseline_stop_24.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/blue_lightbulb_24.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/blue_lightbulb_24.xml new file mode 100644 index 00000000..585cd3b1 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/blue_lightbulb_24.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/btn.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/btn.xml new file mode 100644 index 00000000..ceb3ac56 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/btn.xml @@ -0,0 +1,8 @@ + + + + + + + \ No newline at end of file diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/chat_background.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/chat_background.xml new file mode 100644 index 00000000..eb8b9d1f --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/chat_background.xml @@ -0,0 +1,21 @@ + + + + + + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml new file mode 100644 index 00000000..87c82d2a --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/custom_button_round.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/expand_circle_down.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/expand_circle_down.xml new file mode 100644 index 00000000..0a7a71f0 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/expand_circle_down.xml @@ -0,0 +1,9 @@ + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_background.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 00000000..07d5da9c --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_foreground.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_foreground.xml new file mode 100644 index 00000000..7706ab9e --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml new file mode 100644 index 00000000..35c778a4 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/input_text_shape.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/logo.png b/llm/android/LlamaDemo/app/src/main/res/drawable/logo.png new file mode 100644 index 00000000..60e3e517 Binary files /dev/null and b/llm/android/LlamaDemo/app/src/main/res/drawable/logo.png differ diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml new file mode 100644 index 00000000..bb45d63d --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/outline_add_box_48.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml new file mode 100644 index 00000000..c7b4b2e4 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/outline_camera_alt_48.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml new file mode 100644 index 00000000..a8bb4b2f --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/outline_image_48.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml new file mode 100644 index 00000000..5f81396e --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/prompt_shape.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/received_message.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/received_message.xml new file mode 100644 index 00000000..c2288b5b --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/received_message.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml new file mode 100644 index 00000000..e8d13ca4 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/sent_message.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml b/llm/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml new file mode 100644 index 00000000..afbe22da --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/drawable/three_dots.xml @@ -0,0 +1,5 @@ + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml b/llm/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml new file mode 100644 index 00000000..6e48b5de --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/layout/activity_benchmarking.xml @@ -0,0 +1,16 @@ + + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml b/llm/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml new file mode 100644 index 00000000..b327a544 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/layout/activity_logs.xml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/llm/android/LlamaDemo/app/src/main/res/layout/activity_main.xml b/llm/android/LlamaDemo/app/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000..52bf5335 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,241 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/llm/android/LlamaDemo/app/src/main/res/layout/activity_settings.xml b/llm/android/LlamaDemo/app/src/main/res/layout/activity_settings.xml new file mode 100644 index 00000000..0ec551ae --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/res/layout/activity_settings.xml @@ -0,0 +1,338 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +