Skip to content

Commit 41505c2

Browse files
authored
Merge pull request victoresque#55 from christopherbate/master
Add pytorch 1.1 utils.tensorboard support.
2 parents 83d9038 + c04075a commit 41505c2

File tree

4 files changed

+73
-27
lines changed

4 files changed

+73
-27
lines changed

README.md

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ PyTorch deep learning project made easy.
2424
* [Additional logging](#additional-logging)
2525
* [Validation data](#validation-data)
2626
* [Checkpoints](#checkpoints)
27-
* [TensorboardX Visualization](#tensorboardx-visualization)
27+
* [Tensorboard Visualization](#tensorboard-visualization)
2828
* [Contributing](#contributing)
2929
* [TODOs](#todos)
3030
* [License](#license)
@@ -36,8 +36,8 @@ PyTorch deep learning project made easy.
3636
* Python >= 3.5 (3.6 recommended)
3737
* PyTorch >= 0.4
3838
* tqdm (Optional for `test.py`)
39-
* tensorboard >= 1.7.0 (Optional for TensorboardX)
40-
* tensorboardX >= 1.2 (Optional for TensorboardX)
39+
* tensorboard >= 1.7.0 (Optional for TensorboardX) or tensorboard >= 1.14 (Optional for pytorch.utils.tensorboard)
40+
* tensorboardX >= 1.2 (Optional for TensorboardX), see [Tensorboard Visualization][#tensorboardx-visualization]
4141

4242
## Features
4343
* Clear folder structure which is suitable for many deep learning projects.
@@ -329,8 +329,11 @@ A copy of config file will be saved in the same folder.
329329
}
330330
```
331331

332-
### TensorboardX Visualization
333-
This template supports [TensorboardX](https://github.com/lanpa/tensorboardX) visualization.
332+
### Tensorboard Visualization
333+
This template supports Tensorboard visualization using either Pytorch 1.1's `torch.utils.tensorboard` capabilities or [TensorboardX](https://github.com/lanpa/tensorboardX).
334+
335+
The template attempts to choose a writing module from a list of modules specified in the config file under "tensorboard.modules". It load the modules in the order specified, only moving on to the next one if the previous one failed.
336+
334337
* **TensorboardX Usage**
335338

336339
1. **Install**
@@ -339,17 +342,44 @@ This template supports [TensorboardX](https://github.com/lanpa/tensorboardX) vis
339342

340343
2. **Run training**
341344

342-
Set `tensorboardX` option in config file true.
345+
Set `tensorboard` option in config file to:
346+
Set the "tensorboard" entry in the config to:
347+
```
348+
"tensorboard" :{
349+
"enabled": true,
350+
"modules": ["tensorboardX", "torch.utils.tensorboard"]
351+
}
352+
```
343353
344-
3. **Open tensorboard server**
354+
3. **Open Tensorboard server**
345355
346356
Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006`
347357
358+
* **Pytorch 1.1 torch.utils.tensorboard Usage**
359+
360+
1. **Install**
361+
362+
Must have Pytorch 1.1 installed and `tensorboard >= 1.14` (`pip install tb-nightly`).
363+
364+
2. **Run training**
365+
366+
Set the "tensorboard" entry in the config to:
367+
```
368+
"tensorboard" :{
369+
"enabled": true,
370+
"modules": ["torch.utils.tensorboard", "tensorboardX"]
371+
}
372+
```
373+
374+
3. **Open Tensorboard server**
375+
376+
Same as above.
377+
348378
By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged.
349379
If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method.
350-
`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` module.
380+
`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules.
351381
352-
**Note**: You don't have to specify current steps, since `WriterTensorboardX` class defined at `logger/visualization.py` will track current steps.
382+
**Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps.
353383
354384
## Contributing
355385
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8

base/base_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from abc import abstractmethod
33
from numpy import inf
4-
from logger import WriterTensorboardX
4+
from logger import TensorboardWriter
55

66

77
class BaseTrainer:
@@ -41,8 +41,9 @@ def __init__(self, model, loss, metrics, optimizer, config):
4141
self.start_epoch = 1
4242

4343
self.checkpoint_dir = config.save_dir
44-
# setup visualization writer instance
45-
self.writer = WriterTensorboardX(config.log_dir, self.logger, cfg_trainer['tensorboardX'])
44+
45+
# setup visualization writer instance
46+
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
4647

4748
if config.resume is not None:
4849
self._resume_checkpoint(config.resume)

config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
"monitor": "min val_loss",
4646
"early_stop": 10,
47-
48-
"tensorboardX": true
47+
48+
"tensorboard": true
4949
}
5050
}

logger/visualization.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,41 @@
22
from utils import Timer
33

44

5-
class WriterTensorboardX():
6-
def __init__(self, log_dir, logger, enable):
5+
class TensorboardWriter():
6+
def __init__(self, log_dir, logger, enabled):
77
self.writer = None
8-
if enable:
8+
self.selected_module = ""
9+
10+
if enabled:
911
log_dir = str(log_dir)
10-
try:
11-
self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_dir)
12-
except ImportError:
13-
message = "Warning: TensorboardX visualization is configured to use, but currently not installed on " \
14-
"this machine. Please install the package by 'pip install tensorboardx' command or turn " \
15-
"off the option in the 'config.json' file."
12+
13+
# Retrieve vizualization writer.
14+
succeeded = False
15+
for module in ["torch.utils.tensorboard", "tensorboardX"]:
16+
try:
17+
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
18+
succeeded = True
19+
break
20+
except ImportError:
21+
succeeded = False
22+
self.selected_module = module
23+
24+
if not succeeded:
25+
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
26+
"this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
27+
"PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
28+
"the 'config.json' file."
1629
logger.warning(message)
30+
1731
self.step = 0
1832
self.mode = ''
1933

20-
self.tb_writer_ftns = [
34+
self.tb_writer_ftns = {
2135
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
2236
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
23-
]
24-
self.tag_mode_exceptions = ['add_histogram', 'add_embedding']
37+
}
38+
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
39+
2540
self.timer = Timer()
2641

2742
def set_step(self, step, mode='train'):
@@ -55,5 +70,5 @@ def wrapper(tag, data, *args, **kwargs):
5570
try:
5671
attr = object.__getattr__(name)
5772
except AttributeError:
58-
raise AttributeError("type object 'WriterTensorboardX' has no attribute '{}'".format(name))
73+
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
5974
return attr

0 commit comments

Comments
 (0)