@@ -555,11 +555,11 @@ def output_shape(self):
555555
556556 def prepare_init_args_and_inputs_for_common (self ):
557557 init_dict = {
558- "block_out_channels" : [64 ],
558+ "block_out_channels" : [32 , 64 ],
559559 "in_channels" : 3 ,
560560 "out_channels" : 3 ,
561- "down_block_types" : ["DownEncoderBlock2D" ],
562- "up_block_types" : ["UpDecoderBlock2D" ],
561+ "down_block_types" : ["DownEncoderBlock2D" , "DownEncoderBlock2D" ],
562+ "up_block_types" : ["UpDecoderBlock2D" , "UpDecoderBlock2D" ],
563563 "latent_channels" : 3 ,
564564 }
565565 inputs_dict = self .dummy_input
@@ -595,7 +595,7 @@ def test_output_pretrained(self):
595595
596596 output_slice = output [0 , - 1 , - 3 :, - 3 :].flatten ()
597597 # fmt: off
598- expected_output_slice = torch .tensor ([- 1.1321 , 0.1056 , 0.3505 , - 0.6461 , - 0.2014 , 0.0419 , - 0.5763 , - 0.8462 , - 0.4218 ])
598+ expected_output_slice = torch .tensor ([- 0.0153 , - 0.4044 , - 0.1880 , - 0.5161 , - 0.2418 , - 0.4072 , - 0.1612 , - 0.0633 , - 0.0143 ])
599599 # fmt: on
600600 self .assertTrue (torch .allclose (output_slice , expected_output_slice , rtol = 1e-2 ))
601601
@@ -623,22 +623,11 @@ def output_shape(self):
623623
624624 def prepare_init_args_and_inputs_for_common (self ):
625625 init_dict = {
626- "ch" : 64 ,
627- "ch_mult" : (1 ,),
628- "embed_dim" : 4 ,
629- "in_channels" : 3 ,
630- "attn_resolutions" : [],
631- "num_res_blocks" : 1 ,
632- "out_ch" : 3 ,
633- "resolution" : 32 ,
634- "z_channels" : 4 ,
635- }
636- init_dict = {
637- "block_out_channels" : [64 ],
626+ "block_out_channels" : [32 , 64 ],
638627 "in_channels" : 3 ,
639628 "out_channels" : 3 ,
640- "down_block_types" : ["DownEncoderBlock2D" ],
641- "up_block_types" : ["UpDecoderBlock2D" ],
629+ "down_block_types" : ["DownEncoderBlock2D" , "DownEncoderBlock2D" ],
630+ "up_block_types" : ["UpDecoderBlock2D" , "UpDecoderBlock2D" ],
642631 "latent_channels" : 4 ,
643632 }
644633 inputs_dict = self .dummy_input
@@ -674,7 +663,7 @@ def test_output_pretrained(self):
674663
675664 output_slice = output [0 , - 1 , - 3 :, - 3 :].flatten ()
676665 # fmt: off
677- expected_output_slice = torch .tensor ([- 0.3900 , - 0.2800 , 0.1281 , - 0.4449 , - 0.4890 , - 0.0207 , 0.0784 , - 0.1258 , - 0.0409 ])
666+ expected_output_slice = torch .tensor ([- 4.0078e-01 , - 3.8304e-04 , - 1.2681e-01 , - 1.1462e-01 , 2.0095e-01 , 1.0893e-01 , - 8.8248e-02 , - 3.0361e-01 , - 9.8646e-03 ])
678667 # fmt: on
679668 self .assertTrue (torch .allclose (output_slice , expected_output_slice , rtol = 1e-2 ))
680669
0 commit comments