@@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock {
6262 int num_blocks = 3 ;
6363
6464public:
65- TinyEncoder () {
65+ TinyEncoder (int z_channels = 4 )
66+ : z_channels(z_channels) {
6667 int index = 0 ;
6768 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
6869 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock (channels, channels));
@@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock {
106107 int num_blocks = 3 ;
107108
108109public:
109- TinyDecoder (int index = 0 ) {
110+ TinyDecoder (int z_channels = 4 )
111+ : z_channels(z_channels) {
112+ int index = 0 ;
113+
110114 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (z_channels, channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
111115 index++; // nn.ReLU()
112116
@@ -163,12 +167,16 @@ class TAESD : public GGMLBlock {
163167 bool decode_only;
164168
165169public:
166- TAESD (bool decode_only = true )
170+ TAESD (bool decode_only = true , SDVersion version = VERSION_SD1 )
167171 : decode_only(decode_only) {
168- blocks[" decoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyDecoder ());
172+ int z_channels = 4 ;
173+ if (sd_version_is_dit (version)) {
174+ z_channels = 16 ;
175+ }
176+ blocks[" decoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyDecoder (z_channels));
169177
170178 if (!decode_only) {
171- blocks[" encoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyEncoder ());
179+ blocks[" encoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyEncoder (z_channels ));
172180 }
173181 }
174182
@@ -190,9 +198,10 @@ struct TinyAutoEncoder : public GGMLRunner {
190198 TinyAutoEncoder (ggml_backend_t backend,
191199 std::map<std::string, enum ggml_type>& tensor_types,
192200 const std::string prefix,
193- bool decoder_only = true )
201+ bool decoder_only = true ,
202+ SDVersion version = VERSION_SD1)
194203 : decode_only(decoder_only),
195- taesd (decode_only),
204+ taesd (decode_only, version ),
196205 GGMLRunner(backend) {
197206 taesd.init (params_ctx, tensor_types, prefix);
198207 }
0 commit comments