Skip to content

Commit a209037

Browse files
authored
[VAE] fix the downsample block in Encoder. (huggingface#156)
* pass downsample_padding in encoder * update tests
1 parent c4a3b09 commit a209037

File tree

2 files changed

+9
-19
lines changed

2 files changed

+9
-19
lines changed

src/diffusers/models/vae.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
out_channels=output_channel,
4141
add_downsample=not is_final_block,
4242
resnet_eps=1e-6,
43+
downsample_padding=0,
4344
resnet_act_fn=act_fn,
4445
attn_num_head_channels=None,
4546
temb_channels=None,

tests/test_modeling_utils.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,11 @@ def output_shape(self):
555555

556556
def prepare_init_args_and_inputs_for_common(self):
557557
init_dict = {
558-
"block_out_channels": [64],
558+
"block_out_channels": [32, 64],
559559
"in_channels": 3,
560560
"out_channels": 3,
561-
"down_block_types": ["DownEncoderBlock2D"],
562-
"up_block_types": ["UpDecoderBlock2D"],
561+
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
562+
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
563563
"latent_channels": 3,
564564
}
565565
inputs_dict = self.dummy_input
@@ -595,7 +595,7 @@ def test_output_pretrained(self):
595595

596596
output_slice = output[0, -1, -3:, -3:].flatten()
597597
# fmt: off
598-
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
598+
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
599599
# fmt: on
600600
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
601601

@@ -623,22 +623,11 @@ def output_shape(self):
623623

624624
def prepare_init_args_and_inputs_for_common(self):
625625
init_dict = {
626-
"ch": 64,
627-
"ch_mult": (1,),
628-
"embed_dim": 4,
629-
"in_channels": 3,
630-
"attn_resolutions": [],
631-
"num_res_blocks": 1,
632-
"out_ch": 3,
633-
"resolution": 32,
634-
"z_channels": 4,
635-
}
636-
init_dict = {
637-
"block_out_channels": [64],
626+
"block_out_channels": [32, 64],
638627
"in_channels": 3,
639628
"out_channels": 3,
640-
"down_block_types": ["DownEncoderBlock2D"],
641-
"up_block_types": ["UpDecoderBlock2D"],
629+
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
630+
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
642631
"latent_channels": 4,
643632
}
644633
inputs_dict = self.dummy_input
@@ -674,7 +663,7 @@ def test_output_pretrained(self):
674663

675664
output_slice = output[0, -1, -3:, -3:].flatten()
676665
# fmt: off
677-
expected_output_slice = torch.tensor([-0.3900, -0.2800, 0.1281, -0.4449, -0.4890, -0.0207, 0.0784, -0.1258, -0.0409])
666+
expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
678667
# fmt: on
679668
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
680669

0 commit comments

Comments
 (0)