Skip to content

Commit 13bddec

Browse files
Update 2020-08-18-pytorch-1.6-now-includes-stochastic-weight-averaging.md
1 parent 7af9c04 commit 13bddec

File tree

1 file changed

+17
-61
lines changed

1 file changed

+17
-61
lines changed

_posts/2020-08-18-pytorch-1.6-now-includes-stochastic-weight-averaging.md

+17-61
Original file line numberDiff line numberDiff line change
@@ -144,28 +144,12 @@ We expect solutions that are centered in the flat region of the loss to generali
144144
We release a GitHub [repo](https://github.com/izmailovpavel/torch_swa_examples) with examples using the PyTorch implementation of SWA for training DNNs. For example, these examples can be used to achieve the following results on CIFAR-100:
145145

146146

147-
<table width="700" border="1" cellspacing="5" cellpadding="5">
148-
<tbody>
149-
<tr>
150-
<td>&nbsp;</td>
151-
<td>VGG-16</td>
152-
<td>ResNet-164</td>
153-
<td>WideResNet-28x10</td>
154-
</tr>
155-
<tr>
156-
<td>SGD</td>
157-
<td>72.8 ± 0.3</td>
158-
<td>78.4 ± 0.3</td>
159-
<td>81.0 ± 0.3</td>
160-
</tr>
161-
<tr>
162-
<td>SWA</td>
163-
<td>74.4 ± 0.3</td>
164-
<td>79.8 ± 0.4</td>
165-
<td>82.5 ± 0.2</td>
166-
</tr>
167-
</tbody>
168-
</table>
147+
{:.table.table-striped.table-bordered}
148+
| | VGG-16 | ResNet-164 | WideResNet-28x10 |
149+
| ------------- | ------------- | ------------- | ------------- |
150+
| SGD | 72.8 ± 0.3 | 78.4 ± 0.3 | 81.0 ± 0.3 |
151+
| SWA | 74.4 ± 0.3 | 79.8 ± 0.4 | 82.5 ± 0.2 |
152+
169153

170154
## Semi-Supervised Learning
171155

@@ -180,45 +164,17 @@ In a follow-up [paper](https://arxiv.org/abs/1806.05594) SWA was applied to semi
180164

181165
In another follow-up [paper](http://www.gatsby.ucl.ac.uk/~balaji/udl-camera-ready/UDL-24.pdf) SWA was shown to improve the performance of policy gradient methods A2C and DDPG on several Atari games and MuJoCo environments [3]. This application is also an instance of where SWA is used with Adam. Recall that SWA is not specific to SGD and can benefit essentially any optimizer.
182166

183-
<table width="700" border="1" cellspacing="5" cellpadding="5">
184-
<tbody>
185-
<tr>
186-
<td>Environment Name</td>
187-
<td>A2C</td>
188-
<td>A2C + SWA</td>
189-
</tr>
190-
<tr>
191-
<td>Breakout</td>
192-
<td>522 ± 34</td>
193-
<td>703 ± 60</td>
194-
</tr>
195-
<tr>
196-
<td>Qbert</td>
197-
<td>18777 ± 778</td>
198-
<td>21272 ± 655</td>
199-
</tr>
200-
<tr>
201-
<td>SpaceInvaders</td>
202-
<td>7727 ± 1121</td>
203-
<td>21676 ± 8897</td>
204-
</tr>
205-
<tr>
206-
<td>Seaquest</td>
207-
<td>1779 ± 4</td>
208-
<td>1795 ± 4</td>
209-
</tr>
210-
<tr>
211-
<td>BeamRider</td>
212-
<td>9999 ± 402</td>
213-
<td>11321 ± 1065</td>
214-
</tr>
215-
<tr>
216-
<td>CrazyClimber</td>
217-
<td>147030 ± 10239</td>
218-
<td>139752 ± 11618</td>
219-
</tr>
220-
</tbody>
221-
</table>
167+
168+
{:.table.table-striped.table-bordered}
169+
| Environment Name | A2C | A2C + SWA |
170+
| ------------- | ------------- | ------------- |
171+
| Breakout | 522 ± 34 | 703 ± 60 |
172+
| Qbert | 18777 ± 778 | 21272 ± 655 |
173+
| SpaceInvaders | 7727 ± 1121 | 21676 ± 8897 |
174+
| Seaquest | 1779 ± 4 | 1795 ± 4 |
175+
| BeamRider | 9999 ± 402 | 11321 ± 1065 |
176+
| CrazyClimber | 147030 ± 10239 | 139752 ± 11618 |
177+
222178

223179
## Low Precision Training
224180

0 commit comments

Comments
 (0)