Skip to content

Commit b9ea664

Browse files
authored
Improve disposal of tensors during training (tensorflow#1604)
BUG Improve memory management of tensors during training. Op authors now explicitly save intermediate tensors that are needed for backwards mode. This allows the engine to optimally dispose memory during forward pass, and keep only tensors needed for the backward pass. Based on the [layers benchmark](https://github.com/tensorflow/tfjs-layers/tree/v1.0.0/integration_tests/benchmarks), this change along with tensorflow#1621 led to: ## 2-3X memory reduction **Before** ![before-mem](https://user-images.githubusercontent.com/2294279/54299829-13b9db80-4592-11e9-97a4-04a95012b5c4.png) **After** ![after-mem](https://user-images.githubusercontent.com/2294279/54299854-203e3400-4592-11e9-8887-62165d0bfbbb.png) ## 1.5-1.7x improvement in fit() for GRU and LSTM ops. **Before** ![before](https://user-images.githubusercontent.com/2294279/54299715-d5242100-4591-11e9-8a4c-b5944e991f57.png) **After** ![after](https://user-images.githubusercontent.com/2294279/54299724-d9e8d500-4591-11e9-8ea9-e340e10a41ce.png) - When the user writes the forward pass of an op, they are given a `save` function that allows them to save inputs or intermediate tensors to be reused for the backwards pass - Before this change, the `save` function was a no-op, a placeholder for when we decide to optimize disposal of tensors in the training process in the future. - However, `save` being a no-op caused a bug for existing users who rely on it (e.g. Magenta.js). - After this change, the `save` function makes a shallow copy of the tensor, and keeps it until the backwards pass is done. - `save` used to take an array of tensors. Now it takes a `NamedTensorMap` which improves code readability, and reduces chances of off-by-one index bugs. Fixes tensorflow/tfjs#1320 PERF BUG
1 parent 134616d commit b9ea664

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1911
-560
lines changed

karma.conf.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ const karmaTypescriptConfig = {
1919
tsconfig: 'tsconfig.json',
2020
// Disable coverage reports and instrumentation by default for tests
2121
coverageOptions: {instrumentation: false},
22-
reports: {}
22+
reports: {},
23+
bundlerOptions: {sourceMap: true}
2324
};
2425

2526
// Enable coverage reports and instrumentation under KARMA_COVERAGE=1 env
@@ -48,7 +49,7 @@ module.exports = function(config) {
4849
exclude: ['src/test_node.ts'],
4950
preprocessors: {'**/*.ts': ['karma-typescript']},
5051
karmaTypescriptConfig,
51-
reporters: ['progress', 'karma-typescript'],
52+
reporters: ['dots', 'karma-typescript'],
5253
browsers: ['Chrome'],
5354
browserStack: {
5455
username: process.env.BROWSERSTACK_USERNAME,

0 commit comments

Comments
 (0)