@@ -952,11 +952,7 @@ <h2 id="model-definition" class="anchor">Model definition </h2>
952952        < span  style ="color: #008000 "> self</ span > < span  style ="color: #666666 "> .</ span > dec2 < span  style ="color: #666666 "> =</ span >  nn< span  style ="color: #666666 "> .</ span > ConvTranspose2d(< span  style ="color: #666666 "> 64</ span > , c, < span  style ="color: #666666 "> 3</ span > , padding< span  style ="color: #666666 "> =1</ span > )
953953        < span  style ="color: #008000 "> self</ span > < span  style ="color: #666666 "> .</ span > act  < span  style ="color: #666666 "> =</ span >  nn< span  style ="color: #666666 "> .</ span > ReLU()
954954        < span  style ="color: #408080; font-style: italic "> # timestep embedding to condition on t</ span > 
955-         < span  style ="color: #008000 "> self</ span > < span  style ="color: #666666 "> .</ span > time_mlp < span  style ="color: #666666 "> =</ span >  nn< span  style ="color: #666666 "> .</ span > Sequential(
956-             nn< span  style ="color: #666666 "> .</ span > Linear(< span  style ="color: #666666 "> 1</ span > , < span  style ="color: #666666 "> 128</ span > ), < span  style ="color: #408080; font-style: italic "> # Changed from 64 to 128</ span > 
957-             nn< span  style ="color: #666666 "> .</ span > ReLU(),
958-             nn< span  style ="color: #666666 "> .</ span > Linear(< span  style ="color: #666666 "> 128</ span > , < span  style ="color: #666666 "> 128</ span > ), < span  style ="color: #408080; font-style: italic "> # Changed from 64 to 128</ span > 
959-         )
955+         < span  style ="color: #008000 "> self</ span > < span  style ="color: #666666 "> .</ span > time_mlp < span  style ="color: #666666 "> =</ span >  nn< span  style ="color: #666666 "> .</ span > Sequential(nn< span  style ="color: #666666 "> .</ span > Linear(< span  style ="color: #666666 "> 1</ span > , < span  style ="color: #666666 "> 128</ span > ), nn< span  style ="color: #666666 "> .</ span > ReLU(),nn< span  style ="color: #666666 "> .</ span > Linear(< span  style ="color: #666666 "> 128</ span > , < span  style ="color: #666666 "> 128</ span > ))
960956
961957    < span  style ="color: #008000; font-weight: bold "> def</ span >  < span  style ="color: #0000FF "> forward</ span > (< span  style ="color: #008000 "> self</ span > , x, t):
962958        < span  style ="color: #408080; font-style: italic "> # x: [B, C, H, W], t: [B]</ span > 
@@ -965,7 +961,7 @@ <h2 id="model-definition" class="anchor">Model definition </h2>
965961        < span  style ="color: #408080; font-style: italic "> # add time embedding</ span > 
966962        t < span  style ="color: #666666 "> =</ span >  t< span  style ="color: #666666 "> .</ span > unsqueeze(< span  style ="color: #666666 "> -1</ span > )                             
967963        temb < span  style ="color: #666666 "> =</ span >  < span  style ="color: #008000 "> self</ span > < span  style ="color: #666666 "> .</ span > time_mlp(t)
968-         temb < span  style ="color: #666666 "> =</ span >  temb< span  style ="color: #666666 "> .</ span > view(< span  style ="color: #666666 "> -1</ span > , < span  style ="color: #666666 "> 128</ span > , < span  style ="color: #666666 "> 1</ span > , < span  style ="color: #666666 "> 1</ span > )  < span   style =" color: #408080; font-style: italic " > # Changed from 64 to 128 </ span > 
964+         temb < span  style ="color: #666666 "> =</ span >  temb< span  style ="color: #666666 "> .</ span > view(< span  style ="color: #666666 "> -1</ span > , < span  style ="color: #666666 "> 128</ span > , < span  style ="color: #666666 "> 1</ span > , < span  style ="color: #666666 "> 1</ span > )
969965        h < span  style ="color: #666666 "> =</ span >  h < span  style ="color: #666666 "> +</ span >  temb
970966        h < span  style ="color: #666666 "> =</ span >  < span  style ="color: #008000 "> self</ span > < span  style ="color: #666666 "> .</ span > act(< span  style ="color: #008000 "> self</ span > < span  style ="color: #666666 "> .</ span > dec1(h))
971967        < span  style ="color: #008000; font-weight: bold "> return</ span >  < span  style ="color: #008000 "> self</ span > < span  style ="color: #666666 "> .</ span > dec2(h)
0 commit comments