|
9 | 9 | import ExecuTorchLLM
|
10 | 10 | import XCTest
|
11 | 11 |
|
| 12 | +extension UIImage { |
| 13 | + func asImage() -> Image { |
| 14 | + let targetWidth = 336 |
| 15 | + let h = Double(targetWidth) * Double(size.height) / Double(size.width) |
| 16 | + let targetHeight = Int(h.rounded()) |
| 17 | + let format = UIGraphicsImageRendererFormat.default() |
| 18 | + format.scale = 1 |
| 19 | + let resized = UIGraphicsImageRenderer(size: CGSize(width: targetWidth, height: targetHeight), format: format).image { _ in |
| 20 | + draw(in: CGRect(origin: .zero, size: CGSize(width: targetWidth, height: targetHeight))) |
| 21 | + } |
| 22 | + let cgImage = resized.cgImage! |
| 23 | + let width = cgImage.width |
| 24 | + let height = cgImage.height |
| 25 | + let pixelCount = width * height |
| 26 | + let bytesPerPixel = 4 |
| 27 | + let bytesPerRow = bytesPerPixel * width |
| 28 | + var pixelBytes = [UInt8](repeating: 0, count: pixelCount * bytesPerPixel) |
| 29 | + let context = CGContext( |
| 30 | + data: &pixelBytes, |
| 31 | + width: width, |
| 32 | + height: height, |
| 33 | + bitsPerComponent: 8, |
| 34 | + bytesPerRow: bytesPerRow, |
| 35 | + space: CGColorSpaceCreateDeviceRGB(), |
| 36 | + bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue |
| 37 | + )! |
| 38 | + context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height)) |
| 39 | + var rgbBytes = [UInt8](repeating: 0, count: pixelCount * 3) |
| 40 | + for i in 0..<pixelCount { |
| 41 | + let p = i * bytesPerPixel |
| 42 | + rgbBytes[i] = pixelBytes[p] |
| 43 | + rgbBytes[i + pixelCount] = pixelBytes[p + 1] |
| 44 | + rgbBytes[i + pixelCount * 2] = pixelBytes[p + 2] |
| 45 | + } |
| 46 | + return Image(data: Data(rgbBytes), width: width, height: height, channels: 3) |
| 47 | + } |
| 48 | +} |
| 49 | + |
12 | 50 | class MultimodalRunnerTest: XCTestCase {
|
13 | 51 | func test() {
|
14 | 52 | let bundle = Bundle(for: type(of: self))
|
15 | 53 | guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"),
|
16 |
| - let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin") else { |
| 54 | + let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin"), |
| 55 | + let imagePath = bundle.path(forResource: "IMG_0005", ofType: "JPG"), |
| 56 | + let image = UIImage(contentsOfFile: imagePath) else { |
17 | 57 | XCTFail("Couldn't find model or tokenizer files")
|
18 | 58 | return
|
19 | 59 | }
|
20 |
| - return |
21 | 60 | let runner = MultimodalRunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
|
22 | 61 | var text = ""
|
23 | 62 |
|
24 | 63 | do {
|
25 |
| - try runner.generate([MultimodalInput("hello")], sequenceLength: 2) { token in |
| 64 | + try runner.generate([ |
| 65 | + MultimodalInput("What's on the picture?"), |
| 66 | + MultimodalInput(image.asImage()), |
| 67 | + ], sequenceLength: 256) { token in |
26 | 68 | text += token
|
27 | 69 | }
|
28 | 70 | } catch {
|
29 | 71 | XCTFail("Failed to generate text with error \(error)")
|
30 | 72 | }
|
31 |
| - XCTAssertEqual("hello,", text.lowercased()) |
| 73 | + XCTAssertTrue(text.lowercased().contains("water")) |
32 | 74 | }
|
33 | 75 | }
|
0 commit comments