|
17 | 17 | #define STB_IMAGE_WRITE_STATIC |
18 | 18 | #include "stb_image_write.h" |
19 | 19 |
|
| 20 | +#include "stb_image_resize.h" |
| 21 | + |
20 | 22 | const char* rng_type_to_str[] = { |
21 | 23 | "std_default", |
22 | 24 | "cuda", |
@@ -663,21 +665,47 @@ int main(int argc, const char* argv[]) { |
663 | 665 | fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str()); |
664 | 666 | return 1; |
665 | 667 | } |
666 | | - if (c != 3) { |
667 | | - fprintf(stderr, "input image must be a 3 channels RGB image, but got %d channels\n", c); |
| 668 | + if (c < 3) { |
| 669 | + fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); |
668 | 670 | free(input_image_buffer); |
669 | 671 | return 1; |
670 | 672 | } |
671 | | - if (params.width <= 0 || params.width % 64 != 0) { |
672 | | - fprintf(stderr, "error: the width of image must be a multiple of 64\n"); |
| 673 | + if (params.width <= 0) { |
| 674 | + fprintf(stderr, "error: the width of image must be greater than 0\n"); |
673 | 675 | free(input_image_buffer); |
674 | 676 | return 1; |
675 | 677 | } |
676 | | - if (params.height <= 0 || params.height % 64 != 0) { |
677 | | - fprintf(stderr, "error: the height of image must be a multiple of 64\n"); |
| 678 | + if (params.height <= 0) { |
| 679 | + fprintf(stderr, "error: the height of image must be greater than 0\n"); |
678 | 680 | free(input_image_buffer); |
679 | 681 | return 1; |
680 | 682 | } |
| 683 | + |
| 684 | + // Resize input image ... |
| 685 | + if (params.height % 64 != 0 || params.width % 64 != 0) { |
| 686 | + int resized_height = params.height + (64 - params.height % 64); |
| 687 | + int resized_width = params.width + (64 - params.width % 64); |
| 688 | + |
| 689 | + uint8_t *resized_image_buffer = (uint8_t *)malloc(resized_height * resized_width * 3); |
| 690 | + if (resized_image_buffer == NULL) { |
| 691 | + fprintf(stderr, "error: allocate memory for resize input image\n"); |
| 692 | + free(input_image_buffer); |
| 693 | + return 1; |
| 694 | + } |
| 695 | + stbir_resize(input_image_buffer, params.width, params.height, 0, |
| 696 | + resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, |
| 697 | + 3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0, |
| 698 | + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, |
| 699 | + STBIR_FILTER_BOX, STBIR_FILTER_BOX, |
| 700 | + STBIR_COLORSPACE_SRGB, nullptr |
| 701 | + ); |
| 702 | + |
| 703 | + // Save resized result |
| 704 | + free(input_image_buffer); |
| 705 | + input_image_buffer = resized_image_buffer; |
| 706 | + params.height = resized_height; |
| 707 | + params.width = resized_width; |
| 708 | + } |
681 | 709 | } |
682 | 710 |
|
683 | 711 | sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), |
|
0 commit comments