|
9 | 9 | import ExecuTorchLLM
|
10 | 10 | import XCTest
|
11 | 11 |
|
| 12 | +extension UIImage { |
| 13 | + func asImage() -> Image { |
| 14 | + let targetWidth = 336 |
| 15 | + let scaledHeight = Int((Double(targetWidth) * Double(size.height) / Double(size.width)).rounded()) |
| 16 | + let format = UIGraphicsImageRendererFormat.default() |
| 17 | + format.scale = 1 |
| 18 | + let resizedImage = UIGraphicsImageRenderer(size: CGSize(width: targetWidth, height: scaledHeight), format: format).image { _ in |
| 19 | + draw(in: CGRect(origin: .zero, size: CGSize(width: targetWidth, height: scaledHeight))) |
| 20 | + } |
| 21 | + let resizedCGImage = resizedImage.cgImage! |
| 22 | + let imageWidth = resizedCGImage.width |
| 23 | + let imageHeight = resizedCGImage.height |
| 24 | + let pixelCount = imageWidth * imageHeight |
| 25 | + var rgbaBuffer = [UInt8](repeating: 0, count: pixelCount * 4) |
| 26 | + let context = CGContext( |
| 27 | + data: &rgbaBuffer, |
| 28 | + width: imageWidth, |
| 29 | + height: imageHeight, |
| 30 | + bitsPerComponent: 8, |
| 31 | + bytesPerRow: imageWidth * 4, |
| 32 | + space: CGColorSpaceCreateDeviceRGB(), |
| 33 | + bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue |
| 34 | + )! |
| 35 | + context.draw(resizedCGImage, in: CGRect(x: 0, y: 0, width: imageWidth, height: imageHeight)) |
| 36 | + var planarRGB = [UInt8](repeating: 0, count: pixelCount * 3) |
| 37 | + for pixelIndex in 0..<pixelCount { |
| 38 | + let sourceOffset = pixelIndex * 4 |
| 39 | + planarRGB[pixelIndex] = rgbaBuffer[sourceOffset] |
| 40 | + planarRGB[pixelIndex + pixelCount] = rgbaBuffer[sourceOffset + 1] |
| 41 | + planarRGB[pixelIndex + pixelCount * 2] = rgbaBuffer[sourceOffset + 2] |
| 42 | + } |
| 43 | + return Image(data: Data(planarRGB), width: targetWidth, height: scaledHeight, channels: 3) |
| 44 | + } |
| 45 | +} |
| 46 | + |
12 | 47 | class MultimodalRunnerTest: XCTestCase {
|
13 | 48 | func test() {
|
14 | 49 | let bundle = Bundle(for: type(of: self))
|
15 | 50 | guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"),
|
16 |
| - let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin") else { |
| 51 | + let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin"), |
| 52 | + let imagePath = bundle.path(forResource: "IMG_0005", ofType: "jpg"), |
| 53 | + let image = UIImage(contentsOfFile: imagePath) else { |
17 | 54 | XCTFail("Couldn't find model or tokenizer files")
|
18 | 55 | return
|
19 | 56 | }
|
20 |
| - return |
21 | 57 | let runner = MultimodalRunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
|
22 | 58 | var text = ""
|
23 | 59 |
|
24 | 60 | do {
|
25 |
| - try runner.generate([MultimodalInput("hello")], sequenceLength: 2) { token in |
| 61 | + try runner.generate([ |
| 62 | + MultimodalInput("A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "), |
| 63 | + MultimodalInput(image.asImage()), |
| 64 | + MultimodalInput("What's on the picture? ASSISTANT: "), |
| 65 | + ], sequenceLength: 768) { token in |
26 | 66 | text += token
|
27 | 67 | }
|
28 | 68 | } catch {
|
29 | 69 | XCTFail("Failed to generate text with error \(error)")
|
30 | 70 | }
|
31 |
| - XCTAssertEqual("hello,", text.lowercased()) |
| 71 | + XCTAssertTrue(text.lowercased().contains("waterfall")) |
32 | 72 | }
|
33 | 73 | }
|
0 commit comments