@@ -108,6 +108,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
108108 ///
109109 /// - Parameters:
110110 /// - prompt: Text prompt to guide sampling
111+ /// - negativePrompt: Negative text prompt to guide sampling
111112 /// - stepCount: Number of inference steps to perform
112113 /// - imageCount: Number of samples/images to generate for the input prompt
113114 /// - seed: Random seed which
@@ -117,6 +118,7 @@ public struct StableDiffusionPipeline: ResourceManaging {
117118 /// The images will be nil if safety checks were performed and found the result to be un-safe
118119 public func generateImages(
119120 prompt: String ,
121+ negativePrompt: String = " " ,
120122 imageCount: Int = 1 ,
121123 stepCount: Int = 50 ,
122124 seed: UInt32 = 0 ,
@@ -125,17 +127,18 @@ public struct StableDiffusionPipeline: ResourceManaging {
125127 progressHandler: ( Progress ) -> Bool = { _ in true }
126128 ) throws -> [ CGImage ? ] {
127129
128- // Encode the input prompt as well as a blank unconditioned input
130+ // Encode the input prompt and negative prompt
129131 let promptEmbedding = try textEncoder. encode ( prompt)
130- let blankEmbedding = try textEncoder. encode ( " " )
132+ let negativePromptEmbedding = try textEncoder. encode ( negativePrompt )
131133
132134 if reduceMemory {
133135 textEncoder. unloadResources ( )
134136 }
135137
136138 // Convert to Unet hidden state representation
139+ // Concatenate the prompt and negative prompt embeddings
137140 let concatEmbedding = MLShapedArray < Float32 > (
138- concatenating: [ blankEmbedding , promptEmbedding] ,
141+ concatenating: [ negativePromptEmbedding , promptEmbedding] ,
139142 alongAxis: 0
140143 )
141144
0 commit comments