1use crate::{error::LlamaCoreError, ARCHIVES_DIR, SD_IMAGE_TO_IMAGE, SD_TEXT_TO_IMAGE};
4use base64::{engine::general_purpose, Engine as _};
5use endpoints::images::{
6 ImageCreateRequest, ImageEditRequest, ImageObject, ImageVariationRequest, ListImagesResponse,
7 ResponseFormat, SamplingMethod,
8};
9use std::{
10 fs::{self, File},
11 io::{self, Read},
12 path::Path,
13};
14use wasmedge_stable_diffusion::{stable_diffusion_interface::ImageType, BaseFunction};
15
16pub async fn image_generation(
18 req: &mut ImageCreateRequest,
19) -> Result<ListImagesResponse, LlamaCoreError> {
20 #[cfg(feature = "logging")]
21 info!(target: "stdout", "Processing the image generation request.");
22
23 let text_to_image_ctx = match SD_TEXT_TO_IMAGE.get() {
24 Some(sd) => sd,
25 None => {
26 let err_msg = "Fail to get the underlying value of `SD_TEXT_TO_IMAGE`.";
27
28 #[cfg(feature = "logging")]
29 error!(target: "stdout", "{}", &err_msg);
30
31 return Err(LlamaCoreError::Operation(err_msg.into()));
32 }
33 };
34
35 let mut context = text_to_image_ctx.lock().map_err(|e| {
36 let err_msg = format!("Fail to acquire the lock of `SD_TEXT_TO_IMAGE`. {}", e);
37
38 #[cfg(feature = "logging")]
39 error!(target: "stdout", "{}", &err_msg);
40
41 LlamaCoreError::Operation(err_msg)
42 })?;
43
44 let mut ctx = &mut *context;
45
46 let id = format!("file_{}", uuid::Uuid::new_v4());
48
49 let archives_path = Path::new("archives");
51 if !archives_path.exists() {
52 fs::create_dir(archives_path).unwrap();
53 }
54 let file_path = archives_path.join(&id);
55 if !file_path.exists() {
56 fs::create_dir(&file_path).unwrap();
57 }
58 let filename = "output.png";
59 let output_image_file = file_path.join(filename);
60 let output_image_file = output_image_file.to_str().unwrap();
61
62 #[cfg(feature = "logging")]
64 info!(target: "stdout", "prompt: {}", &req.prompt);
65
66 let negative_prompt = req.negative_prompt.clone().unwrap_or_default();
68 #[cfg(feature = "logging")]
69 info!(target: "stdout", "negative prompt: {}", &negative_prompt);
70
71 let n = req.n.unwrap_or(1);
73 #[cfg(feature = "logging")]
74 info!(target: "stdout", "number of images to generate: {}", n);
75
76 let cfg_scale = req.cfg_scale.unwrap_or(7.0);
78 #[cfg(feature = "logging")]
79 info!(target: "stdout", "cfg_scale: {}", cfg_scale);
80
81 let sample_method = req.sample_method.unwrap_or(SamplingMethod::EulerA);
83 #[cfg(feature = "logging")]
84 info!(target: "stdout", "sample_method: {}", sample_method);
85
86 let sample_method = match sample_method {
88 SamplingMethod::Euler => {
89 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::EULER
90 }
91 SamplingMethod::EulerA => {
92 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::EULERA
93 }
94 SamplingMethod::Heun => {
95 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::HEUN
96 }
97 SamplingMethod::Dpm2 => {
98 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPM2
99 }
100 SamplingMethod::DpmPlusPlus2sA => {
101 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2SA
102 }
103 SamplingMethod::DpmPlusPlus2m => {
104 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2M
105 }
106 SamplingMethod::DpmPlusPlus2mv2 => {
107 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2Mv2
108 }
109 SamplingMethod::Ipndm => {
110 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::IPNDM
111 }
112 SamplingMethod::IpndmV => {
113 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::IPNDMV
114 }
115 SamplingMethod::Lcm => {
116 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::LCM
117 }
118 };
119
120 let steps = req.steps.unwrap_or(20);
122 #[cfg(feature = "logging")]
123 info!(target: "stdout", "steps: {}", steps);
124
125 let height = req.height.unwrap_or(512);
127 let width = req.width.unwrap_or(512);
128 #[cfg(feature = "logging")]
129 info!(target: "stdout", "height: {}, width: {}", height, width);
130
131 let control_strength = req.control_strength.unwrap_or(0.9);
133 #[cfg(feature = "logging")]
134 info!(target: "stdout", "control_strength: {}", control_strength);
135
136 let seed = req.seed.unwrap_or(42);
138 #[cfg(feature = "logging")]
139 info!(target: "stdout", "seed: {}", seed);
140
141 let apply_canny_preprocessor = req.apply_canny_preprocessor.unwrap_or(false);
143 #[cfg(feature = "logging")]
144 info!(target: "stdout", "apply_canny_preprocessor: {}", apply_canny_preprocessor);
145
146 let style_ratio = req.style_ratio.unwrap_or(0.2);
148 #[cfg(feature = "logging")]
149 info!(target: "stdout", "style_ratio: {}", style_ratio);
150
151 ctx = ctx
152 .set_prompt(&req.prompt)
153 .set_negative_prompt(negative_prompt)
154 .set_output_path(output_image_file)
155 .set_cfg_scale(cfg_scale)
156 .set_sample_method(sample_method)
157 .set_sample_steps(steps as i32)
158 .set_height(height as i32)
159 .set_width(width as i32)
160 .set_batch_count(n as i32)
161 .set_cfg_scale(cfg_scale)
162 .set_sample_method(sample_method)
163 .set_height(height as i32)
164 .set_width(width as i32)
165 .set_control_strength(control_strength)
166 .set_seed(seed)
167 .enable_canny_preprocess(apply_canny_preprocessor)
168 .set_style_ratio(style_ratio)
169 .set_output_path(output_image_file);
170
171 if let Some(control_image) = &req.control_image {
173 #[cfg(feature = "logging")]
174 info!(target: "stdout", "control_image: {:?}", control_image);
175
176 let control_image_file = Path::new("archives")
177 .join(&control_image.id)
178 .join(&control_image.filename);
179 if !control_image_file.exists() {
180 let err_msg = format!(
181 "The control image file does not exist: {:?}",
182 &control_image_file
183 );
184
185 #[cfg(feature = "logging")]
186 error!(target: "stdout", "{}", &err_msg);
187
188 return Err(LlamaCoreError::Operation(err_msg));
189 }
190
191 let path_control_image = match control_image_file.to_str() {
192 Some(path) => path,
193 None => {
194 let err_msg = "Fail to get the path of the control image.";
195
196 #[cfg(feature = "logging")]
197 error!(target: "stdout", "{}", &err_msg);
198
199 return Err(LlamaCoreError::Operation(err_msg.into()));
200 }
201 };
202
203 ctx = ctx.set_control_image(ImageType::Path(path_control_image.into()));
204 }
205
206 #[cfg(feature = "logging")]
207 info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
208
209 #[cfg(feature = "logging")]
211 info!(target: "stdout", "generate image");
212
213 ctx.generate().map_err(|e| {
214 let err_msg = format!("Fail to dump the image. {}", e);
215
216 #[cfg(feature = "logging")]
217 error!(target: "stdout", "{}", &err_msg);
218
219 LlamaCoreError::Operation(err_msg)
220 })?;
221
222 #[cfg(feature = "logging")]
224 info!(target: "stdout", "file_id: {}, file_name: {}", &id, &filename);
225
226 let image = match req.response_format {
227 Some(ResponseFormat::B64Json) => {
228 let base64_string = match image_to_base64(output_image_file) {
230 Ok(base64_string) => base64_string,
231 Err(e) => {
232 let err_msg = format!("Fail to convert the image to base64 string. {}", e);
233
234 #[cfg(feature = "logging")]
235 error!(target: "stdout", "{}", &err_msg);
236
237 return Err(LlamaCoreError::Operation(err_msg));
238 }
239 };
240
241 #[cfg(feature = "logging")]
243 info!(target: "stdout", "base64 string: {}", &base64_string.chars().take(10).collect::<String>());
244
245 ImageObject {
247 b64_json: Some(base64_string),
248 url: None,
249 prompt: Some(req.prompt.clone()),
250 }
251 }
252 Some(ResponseFormat::Url) | None => {
253 ImageObject {
255 b64_json: None,
256 url: Some(format!("/{}/{}/{}", ARCHIVES_DIR, &id, &filename)),
257 prompt: Some(req.prompt.clone()),
258 }
259 }
260 };
261
262 let created: u64 = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
263 Ok(n) => n.as_secs(),
264 Err(_) => {
265 let err_msg = "Failed to get the current time.";
266
267 #[cfg(feature = "logging")]
269 error!(target: "stdout", "{}", &err_msg);
270
271 return Err(LlamaCoreError::Operation(err_msg.into()));
272 }
273 };
274
275 let res = ListImagesResponse {
276 created,
277 data: vec![image],
278 };
279
280 #[cfg(feature = "logging")]
281 info!(target: "stdout", "End of the image generation.");
282
283 Ok(res)
284}
285
286pub async fn image_edit(req: &mut ImageEditRequest) -> Result<ListImagesResponse, LlamaCoreError> {
288 let image_to_image_ctx = match SD_IMAGE_TO_IMAGE.get() {
289 Some(sd) => sd,
290 None => {
291 let err_msg = "Fail to get the underlying value of `SD_IMAGE_TO_IMAGE`.";
292
293 #[cfg(feature = "logging")]
294 error!(target: "stdout", "{}", &err_msg);
295
296 return Err(LlamaCoreError::Operation(err_msg.into()));
297 }
298 };
299
300 let mut context = image_to_image_ctx.lock().map_err(|e| {
301 let err_msg = format!("Fail to acquire the lock of `SD_IMAGE_TO_IMAGE`. {}", e);
302
303 #[cfg(feature = "logging")]
304 error!(target: "stdout", "{}", &err_msg);
305
306 LlamaCoreError::Operation(err_msg)
307 })?;
308
309 let mut ctx = &mut *context;
310
311 let id = format!("file_{}", uuid::Uuid::new_v4());
313
314 let path = Path::new("archives");
316 if !path.exists() {
317 fs::create_dir(path).unwrap();
318 }
319 let file_path = path.join(&id);
320 if !file_path.exists() {
321 fs::create_dir(&file_path).unwrap();
322 }
323 let filename = "output.png";
324 let output_image_file = file_path.join(filename);
325 let output_image_file = output_image_file.to_str().unwrap();
326
327 let origin_image_file = Path::new("archives")
329 .join(&req.image.id)
330 .join(&req.image.filename);
331 let path_origin_image = origin_image_file.to_str().ok_or(LlamaCoreError::Operation(
332 "Fail to get the path of the original image.".into(),
333 ))?;
334
335 let n = req.n.unwrap_or(1);
337 #[cfg(feature = "logging")]
338 info!(target: "stdout", "number of images to generate: {}", n);
339
340 let cfg_scale = req.cfg_scale.unwrap_or(7.0);
342 #[cfg(feature = "logging")]
343 info!(target: "stdout", "cfg_scale: {}", cfg_scale);
344
345 let sample_method = req.sample_method.unwrap_or(SamplingMethod::EulerA);
347 #[cfg(feature = "logging")]
348 info!(target: "stdout", "sample_method: {:?}", sample_method);
349 let sample_method = match sample_method {
351 SamplingMethod::Euler => {
352 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::EULER
353 }
354 SamplingMethod::EulerA => {
355 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::EULERA
356 }
357 SamplingMethod::Heun => {
358 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::HEUN
359 }
360 SamplingMethod::Dpm2 => {
361 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPM2
362 }
363 SamplingMethod::DpmPlusPlus2sA => {
364 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2SA
365 }
366 SamplingMethod::DpmPlusPlus2m => {
367 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2M
368 }
369 SamplingMethod::DpmPlusPlus2mv2 => {
370 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::DPMPP2Mv2
371 }
372 SamplingMethod::Ipndm => {
373 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::IPNDM
374 }
375 SamplingMethod::IpndmV => {
376 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::IPNDMV
377 }
378 SamplingMethod::Lcm => {
379 wasmedge_stable_diffusion::stable_diffusion_interface::SampleMethodT::LCM
380 }
381 };
382
383 let steps = req.steps.unwrap_or(20);
385 #[cfg(feature = "logging")]
386 info!(target: "stdout", "steps: {}", steps);
387
388 let height = req.height.unwrap_or(512);
390 let width = req.width.unwrap_or(512);
391 #[cfg(feature = "logging")]
392 info!(target: "stdout", "height: {}, width: {}", height, width);
393
394 let control_strength = req.control_strength.unwrap_or(0.9);
396 #[cfg(feature = "logging")]
397 info!(target: "stdout", "control_strength: {}", control_strength);
398
399 let seed = req.seed.unwrap_or(42);
401 #[cfg(feature = "logging")]
402 info!(target: "stdout", "seed: {}", seed);
403
404 let strength = req.strength.unwrap_or(0.75);
406 #[cfg(feature = "logging")]
407 info!(target: "stdout", "strength: {}", strength);
408
409 let apply_canny_preprocessor = req.apply_canny_preprocessor.unwrap_or(false);
411 #[cfg(feature = "logging")]
412 info!(target: "stdout", "apply_canny_preprocessor: {}", apply_canny_preprocessor);
413
414 let style_ratio = req.style_ratio.unwrap_or(0.2);
416 #[cfg(feature = "logging")]
417 info!(target: "stdout", "style_ratio: {}", style_ratio);
418
419 ctx = ctx
421 .set_prompt(&req.prompt)
422 .set_image(ImageType::Path(path_origin_image.into()))
423 .set_batch_count(n as i32)
424 .set_cfg_scale(cfg_scale)
425 .set_sample_method(sample_method)
426 .set_sample_steps(steps as i32)
427 .set_height(height as i32)
428 .set_width(width as i32)
429 .set_control_strength(control_strength)
430 .set_seed(seed)
431 .set_strength(strength)
432 .enable_canny_preprocess(apply_canny_preprocessor)
433 .set_style_ratio(style_ratio)
434 .set_output_path(output_image_file);
435
436 #[cfg(feature = "logging")]
437 info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
438
439 #[cfg(feature = "logging")]
441 info!(target: "stdout", "generate image");
442
443 ctx.generate().map_err(|e| {
444 let err_msg = format!("Fail to dump the image. {}", e);
445
446 #[cfg(feature = "logging")]
447 error!(target: "stdout", "{}", &err_msg);
448
449 LlamaCoreError::Operation(err_msg)
450 })?;
451
452 #[cfg(feature = "logging")]
454 info!(target: "stdout", "file_id: {}, file_name: {}", &id, &filename);
455
456 let image = match req.response_format {
457 Some(ResponseFormat::B64Json) => {
458 let base64_string = match image_to_base64(output_image_file) {
460 Ok(base64_string) => base64_string,
461 Err(e) => {
462 let err_msg = format!("Fail to convert the image to base64 string. {}", e);
463
464 #[cfg(feature = "logging")]
465 error!(target: "stdout", "{}", &err_msg);
466
467 return Err(LlamaCoreError::Operation(err_msg));
468 }
469 };
470
471 #[cfg(feature = "logging")]
473 info!(target: "stdout", "base64 string: {}", &base64_string.chars().take(10).collect::<String>());
474
475 ImageObject {
477 b64_json: Some(base64_string),
478 url: None,
479 prompt: Some(req.prompt.clone()),
480 }
481 }
482 Some(ResponseFormat::Url) | None => {
483 ImageObject {
485 b64_json: None,
486 url: Some(format!("/{}/{}/{}", ARCHIVES_DIR, &id, &filename)),
487 prompt: Some(req.prompt.clone()),
488 }
489 }
490 };
491
492 let created: u64 = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
493 Ok(n) => n.as_secs(),
494 Err(_) => {
495 let err_msg = "Failed to get the current time.";
496
497 #[cfg(feature = "logging")]
499 error!(target: "stdout", "{}", &err_msg);
500
501 return Err(LlamaCoreError::Operation(err_msg.into()));
502 }
503 };
504
505 Ok(ListImagesResponse {
506 created,
507 data: vec![image],
508 })
509}
510
511pub async fn image_variation(
513 _req: &mut ImageVariationRequest,
514) -> Result<ListImagesResponse, LlamaCoreError> {
515 unimplemented!("image_variation")
516}
517
518fn image_to_base64(image_path: &str) -> io::Result<String> {
520 let mut image_file = File::open(image_path)?;
522
523 let mut buffer = Vec::new();
525 image_file.read_to_end(&mut buffer)?;
526
527 Ok(general_purpose::STANDARD.encode(&buffer))
528}