llama_core/
images.rs

1//! Define APIs for image generation and edit.
2
3use crate::{error::LlamaCoreError, ARCHIVES_DIR, SD_IMAGE_TO_IMAGE, SD_TEXT_TO_IMAGE};
4use base64::{engine::general_purpose, Engine as _};
5use endpoints::images::{
6    ImageCreateRequest, ImageEditRequest, ImageObject, ImageVariationRequest, ListImagesResponse,
7    ResponseFormat, SamplingMethod,
8};
9use std::{
10    fs::{self, File},
11    io::{self, Read},
12    path::Path,
13};
14use wasmedge_stable_diffusion::{stable_diffusion_interface::ImageType, BaseFunction};
15
16/// Create an image given a prompt.
17pub async fn image_generation(
18    req: &mut ImageCreateRequest,
19) -> Result<ListImagesResponse, LlamaCoreError> {
20    #[cfg(feature = "logging")]
21    info!(target: "stdout", "Processing the image generation request.");
22
23    let text_to_image_ctx = match SD_TEXT_TO_IMAGE.get() {
24        Some(sd) => sd,
25        None => {
26            let err_msg = "Fail to get the underlying value of `SD_TEXT_TO_IMAGE`.";
27
28            #[cfg(feature = "logging")]
29            error!(target: "stdout", "{}", &err_msg);
30
31            return Err(LlamaCoreError::Operation(err_msg.into()));
32        }
33    };
34
35    let mut context = text_to_image_ctx.lock().map_err(|e| {
36        let err_msg = format!("Fail to acquire the lock of `SD_TEXT_TO_IMAGE`. {}", e);
37
38        #[cfg(feature = "logging")]
39        error!(target: "stdout", "{}", &err_msg);
40
41        LlamaCoreError::Operation(err_msg)
42    })?;
43
44    let mut ctx = &mut *context;
45
46    // create a unique file id
47    let id = format!("file_{}", uuid::Uuid::new_v4());
48
49    // save the file
50    let archives_path = Path::new("archives");
51    if !archives_path.exists() {
52        fs::create_dir(archives_path).unwrap();
53    }
54    let file_path = archives_path.join(&id);
55    if !file_path.exists() {
56        fs::create_dir(&file_path).unwrap();
57    }
58    let filename = "output.png";
59    let output_image_file = file_path.join(filename);
60    let output_image_file = output_image_file.to_str().unwrap();
61
62    // log
63    #[cfg(feature = "logging")]
64    info!(target: "stdout", "prompt: {}", &req.prompt);
65
66    // negative prompt
67    let negative_prompt = req.negative_prompt.clone().unwrap_or_default();
68    #[cfg(feature = "logging")]
69    info!(target: "stdout", "negative prompt: {}", &negative_prompt);
70
71    // n
72    let n = req.n.unwrap_or(1);
73    #[cfg(feature = "logging")]
74    info!(target: "stdout", "number of images to generate: {}", n);
75
76    // cfg_scale
77    let cfg_scale = req.cfg_scale.unwrap_or(7.0);
78    #[cfg(feature = "logging")]
79    info!(target: "stdout", "cfg_scale: {}", cfg_scale);
80
81    // sampling method
82    let sample_method = req.sample_method.unwrap_or(SamplingMethod::EulerA);
83    #[cfg(feature = "logging")]
84    info!(target: "stdout", "sample_method: {}", sample_method);
85
86    // convert sample method to value of `SampleMethodT` type
87    let sample_method = match sample_method {
88        SamplingMethod::Euler => {
89            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::EULER
90        }
91        SamplingMethod::EulerA => {
92            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::EULERA
93        }
94        SamplingMethod::Heun => {
95            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::HEUN
96        }
97        SamplingMethod::Dpm2 => {
98            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPM2
99        }
100        SamplingMethod::DpmPlusPlus2sA => {
101            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2SA
102        }
103        SamplingMethod::DpmPlusPlus2m => {
104            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2M
105        }
106        SamplingMethod::DpmPlusPlus2mv2 => {
107            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2Mv2
108        }
109        SamplingMethod::Ipndm => {
110            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::IPNDM
111        }
112        SamplingMethod::IpndmV => {
113            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::IPNDMV
114        }
115        SamplingMethod::Lcm => {
116            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::LCM
117        }
118    };
119
120    // steps
121    let steps = req.steps.unwrap_or(20);
122    #[cfg(feature = "logging")]
123    info!(target: "stdout", "steps: {}", steps);
124
125    // size
126    let height = req.height.unwrap_or(512);
127    let width = req.width.unwrap_or(512);
128    #[cfg(feature = "logging")]
129    info!(target: "stdout", "height: {}, width: {}", height, width);
130
131    // control_strength
132    let control_strength = req.control_strength.unwrap_or(0.9);
133    #[cfg(feature = "logging")]
134    info!(target: "stdout", "control_strength: {}", control_strength);
135
136    // seed
137    let seed = req.seed.unwrap_or(42);
138    #[cfg(feature = "logging")]
139    info!(target: "stdout", "seed: {}", seed);
140
141    // apply canny preprocessor
142    let apply_canny_preprocessor = req.apply_canny_preprocessor.unwrap_or(false);
143    #[cfg(feature = "logging")]
144    info!(target: "stdout", "apply_canny_preprocessor: {}", apply_canny_preprocessor);
145
146    // style ratio
147    let style_ratio = req.style_ratio.unwrap_or(0.2);
148    #[cfg(feature = "logging")]
149    info!(target: "stdout", "style_ratio: {}", style_ratio);
150
151    ctx = ctx
152        .set_prompt(&req.prompt)
153        .set_negative_prompt(negative_prompt)
154        .set_output_path(output_image_file)
155        .set_cfg_scale(cfg_scale)
156        .set_sample_method(sample_method)
157        .set_sample_steps(steps as i32)
158        .set_height(height as i32)
159        .set_width(width as i32)
160        .set_batch_count(n as i32)
161        .set_cfg_scale(cfg_scale)
162        .set_sample_method(sample_method)
163        .set_height(height as i32)
164        .set_width(width as i32)
165        .set_control_strength(control_strength)
166        .set_seed(seed)
167        .enable_canny_preprocess(apply_canny_preprocessor)
168        .set_style_ratio(style_ratio)
169        .set_output_path(output_image_file);
170
171    // control_image
172    if let Some(control_image) = &req.control_image {
173        #[cfg(feature = "logging")]
174        info!(target: "stdout", "control_image: {:?}", control_image);
175
176        let control_image_file = Path::new("archives")
177            .join(&control_image.id)
178            .join(&control_image.filename);
179        if !control_image_file.exists() {
180            let err_msg = format!(
181                "The control image file does not exist: {:?}",
182                &control_image_file
183            );
184
185            #[cfg(feature = "logging")]
186            error!(target: "stdout", "{}", &err_msg);
187
188            return Err(LlamaCoreError::Operation(err_msg));
189        }
190
191        let path_control_image = match control_image_file.to_str() {
192            Some(path) => path,
193            None => {
194                let err_msg = "Fail to get the path of the control image.";
195
196                #[cfg(feature = "logging")]
197                error!(target: "stdout", "{}", &err_msg);
198
199                return Err(LlamaCoreError::Operation(err_msg.into()));
200            }
201        };
202
203        ctx = ctx.set_control_image(ImageType::Path(path_control_image.into()));
204    }
205
206    #[cfg(feature = "logging")]
207    info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
208
209    // log
210    #[cfg(feature = "logging")]
211    info!(target: "stdout", "generate image");
212
213    ctx.generate().map_err(|e| {
214        let err_msg = format!("Fail to dump the image. {}", e);
215
216        #[cfg(feature = "logging")]
217        error!(target: "stdout", "{}", &err_msg);
218
219        LlamaCoreError::Operation(err_msg)
220    })?;
221
222    // log
223    #[cfg(feature = "logging")]
224    info!(target: "stdout", "file_id: {}, file_name: {}", &id, &filename);
225
226    let image = match req.response_format {
227        Some(ResponseFormat::B64Json) => {
228            // convert the image to base64 string
229            let base64_string = match image_to_base64(output_image_file) {
230                Ok(base64_string) => base64_string,
231                Err(e) => {
232                    let err_msg = format!("Fail to convert the image to base64 string. {}", e);
233
234                    #[cfg(feature = "logging")]
235                    error!(target: "stdout", "{}", &err_msg);
236
237                    return Err(LlamaCoreError::Operation(err_msg));
238                }
239            };
240
241            // log
242            #[cfg(feature = "logging")]
243            info!(target: "stdout", "base64 string: {}", &base64_string.chars().take(10).collect::<String>());
244
245            // create an image object
246            ImageObject {
247                b64_json: Some(base64_string),
248                url: None,
249                prompt: Some(req.prompt.clone()),
250            }
251        }
252        Some(ResponseFormat::Url) | None => {
253            // create an image object
254            ImageObject {
255                b64_json: None,
256                url: Some(format!("/{}/{}/{}", ARCHIVES_DIR, &id, &filename)),
257                prompt: Some(req.prompt.clone()),
258            }
259        }
260    };
261
262    let created: u64 = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
263        Ok(n) => n.as_secs(),
264        Err(_) => {
265            let err_msg = "Failed to get the current time.";
266
267            // log
268            #[cfg(feature = "logging")]
269            error!(target: "stdout", "{}", &err_msg);
270
271            return Err(LlamaCoreError::Operation(err_msg.into()));
272        }
273    };
274
275    let res = ListImagesResponse {
276        created,
277        data: vec![image],
278    };
279
280    #[cfg(feature = "logging")]
281    info!(target: "stdout", "End of the image generation.");
282
283    Ok(res)
284}
285
286/// Create an edited or extended image given an original image and a prompt.
287pub async fn image_edit(req: &mut ImageEditRequest) -> Result<ListImagesResponse, LlamaCoreError> {
288    let image_to_image_ctx = match SD_IMAGE_TO_IMAGE.get() {
289        Some(sd) => sd,
290        None => {
291            let err_msg = "Fail to get the underlying value of `SD_IMAGE_TO_IMAGE`.";
292
293            #[cfg(feature = "logging")]
294            error!(target: "stdout", "{}", &err_msg);
295
296            return Err(LlamaCoreError::Operation(err_msg.into()));
297        }
298    };
299
300    let mut context = image_to_image_ctx.lock().map_err(|e| {
301        let err_msg = format!("Fail to acquire the lock of `SD_IMAGE_TO_IMAGE`. {}", e);
302
303        #[cfg(feature = "logging")]
304        error!(target: "stdout", "{}", &err_msg);
305
306        LlamaCoreError::Operation(err_msg)
307    })?;
308
309    let mut ctx = &mut *context;
310
311    // create a unique file id
312    let id = format!("file_{}", uuid::Uuid::new_v4());
313
314    // save the file
315    let path = Path::new("archives");
316    if !path.exists() {
317        fs::create_dir(path).unwrap();
318    }
319    let file_path = path.join(&id);
320    if !file_path.exists() {
321        fs::create_dir(&file_path).unwrap();
322    }
323    let filename = "output.png";
324    let output_image_file = file_path.join(filename);
325    let output_image_file = output_image_file.to_str().unwrap();
326
327    // get the path of the original image
328    let origin_image_file = Path::new("archives")
329        .join(&req.image.id)
330        .join(&req.image.filename);
331    let path_origin_image = origin_image_file.to_str().ok_or(LlamaCoreError::Operation(
332        "Fail to get the path of the original image.".into(),
333    ))?;
334
335    // n
336    let n = req.n.unwrap_or(1);
337    #[cfg(feature = "logging")]
338    info!(target: "stdout", "number of images to generate: {}", n);
339
340    // cfg scale
341    let cfg_scale = req.cfg_scale.unwrap_or(7.0);
342    #[cfg(feature = "logging")]
343    info!(target: "stdout", "cfg_scale: {}", cfg_scale);
344
345    // sample method
346    let sample_method = req.sample_method.unwrap_or(SamplingMethod::EulerA);
347    #[cfg(feature = "logging")]
348    info!(target: "stdout", "sample_method: {:?}", sample_method);
349    // convert sample method to value of `SampleMethodT` type
350    let sample_method = match sample_method {
351        SamplingMethod::Euler => {
352            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::EULER
353        }
354        SamplingMethod::EulerA => {
355            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::EULERA
356        }
357        SamplingMethod::Heun => {
358            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::HEUN
359        }
360        SamplingMethod::Dpm2 => {
361            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPM2
362        }
363        SamplingMethod::DpmPlusPlus2sA => {
364            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2SA
365        }
366        SamplingMethod::DpmPlusPlus2m => {
367            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2M
368        }
369        SamplingMethod::DpmPlusPlus2mv2 => {
370            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2Mv2
371        }
372        SamplingMethod::Ipndm => {
373            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::IPNDM
374        }
375        SamplingMethod::IpndmV => {
376            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::IPNDMV
377        }
378        SamplingMethod::Lcm => {
379            wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::LCM
380        }
381    };
382
383    // steps
384    let steps = req.steps.unwrap_or(20);
385    #[cfg(feature = "logging")]
386    info!(target: "stdout", "steps: {}", steps);
387
388    // size
389    let height = req.height.unwrap_or(512);
390    let width = req.width.unwrap_or(512);
391    #[cfg(feature = "logging")]
392    info!(target: "stdout", "height: {}, width: {}", height, width);
393
394    // control_strength
395    let control_strength = req.control_strength.unwrap_or(0.9);
396    #[cfg(feature = "logging")]
397    info!(target: "stdout", "control_strength: {}", control_strength);
398
399    // seed
400    let seed = req.seed.unwrap_or(42);
401    #[cfg(feature = "logging")]
402    info!(target: "stdout", "seed: {}", seed);
403
404    // strength
405    let strength = req.strength.unwrap_or(0.75);
406    #[cfg(feature = "logging")]
407    info!(target: "stdout", "strength: {}", strength);
408
409    // apply canny preprocessor
410    let apply_canny_preprocessor = req.apply_canny_preprocessor.unwrap_or(false);
411    #[cfg(feature = "logging")]
412    info!(target: "stdout", "apply_canny_preprocessor: {}", apply_canny_preprocessor);
413
414    // style ratio
415    let style_ratio = req.style_ratio.unwrap_or(0.2);
416    #[cfg(feature = "logging")]
417    info!(target: "stdout", "style_ratio: {}", style_ratio);
418
419    // create and dump the generated image
420    ctx = ctx
421        .set_prompt(&req.prompt)
422        .set_image(ImageType::Path(path_origin_image.into()))
423        .set_batch_count(n as i32)
424        .set_cfg_scale(cfg_scale)
425        .set_sample_method(sample_method)
426        .set_sample_steps(steps as i32)
427        .set_height(height as i32)
428        .set_width(width as i32)
429        .set_control_strength(control_strength)
430        .set_seed(seed)
431        .set_strength(strength)
432        .enable_canny_preprocess(apply_canny_preprocessor)
433        .set_style_ratio(style_ratio)
434        .set_output_path(output_image_file);
435
436    #[cfg(feature = "logging")]
437    info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
438
439    // log
440    #[cfg(feature = "logging")]
441    info!(target: "stdout", "generate image");
442
443    ctx.generate().map_err(|e| {
444        let err_msg = format!("Fail to dump the image. {}", e);
445
446        #[cfg(feature = "logging")]
447        error!(target: "stdout", "{}", &err_msg);
448
449        LlamaCoreError::Operation(err_msg)
450    })?;
451
452    // log
453    #[cfg(feature = "logging")]
454    info!(target: "stdout", "file_id: {}, file_name: {}", &id, &filename);
455
456    let image = match req.response_format {
457        Some(ResponseFormat::B64Json) => {
458            // convert the image to base64 string
459            let base64_string = match image_to_base64(output_image_file) {
460                Ok(base64_string) => base64_string,
461                Err(e) => {
462                    let err_msg = format!("Fail to convert the image to base64 string. {}", e);
463
464                    #[cfg(feature = "logging")]
465                    error!(target: "stdout", "{}", &err_msg);
466
467                    return Err(LlamaCoreError::Operation(err_msg));
468                }
469            };
470
471            // log
472            #[cfg(feature = "logging")]
473            info!(target: "stdout", "base64 string: {}", &base64_string.chars().take(10).collect::<String>());
474
475            // create an image object
476            ImageObject {
477                b64_json: Some(base64_string),
478                url: None,
479                prompt: Some(req.prompt.clone()),
480            }
481        }
482        Some(ResponseFormat::Url) | None => {
483            // create an image object
484            ImageObject {
485                b64_json: None,
486                url: Some(format!("/{}/{}/{}", ARCHIVES_DIR, &id, &filename)),
487                prompt: Some(req.prompt.clone()),
488            }
489        }
490    };
491
492    let created: u64 = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
493        Ok(n) => n.as_secs(),
494        Err(_) => {
495            let err_msg = "Failed to get the current time.";
496
497            // log
498            #[cfg(feature = "logging")]
499            error!(target: "stdout", "{}", &err_msg);
500
501            return Err(LlamaCoreError::Operation(err_msg.into()));
502        }
503    };
504
505    Ok(ListImagesResponse {
506        created,
507        data: vec![image],
508    })
509}
510
511/// Create a variation of a given image.
512pub async fn image_variation(
513    _req: &mut ImageVariationRequest,
514) -> Result<ListImagesResponse, LlamaCoreError> {
515    unimplemented!("image_variation")
516}
517
518// convert an image file to a base64 string
519fn image_to_base64(image_path: &str) -> io::Result<String> {
520    // Open the file
521    let mut image_file = File::open(image_path)?;
522
523    // Read the file into a byte array
524    let mut buffer = Vec::new();
525    image_file.read_to_end(&mut buffer)?;
526
527    Ok(general_purpose::STANDARD.encode(&buffer))
528}