Skip to content

Commit 5c0b5a7

Browse files
Kriyszigcaisq
authored andcommitted
Allow squeeze to use negative index (tensorflow#1430)
Issue in focus: tensorflow/tfjs#940 #### Description `squeeze()` didn't support the use of negative index which is supported in tensorflow for Python. Added tests for the same in `util_test` and `tensor_test`. The previous implementation didn't allow for random arrangement of values of axis. For example if axes is set to `[2, 3]`, it would remove them both but if you pass it as `[3, 2]` it would only remove 3 and ignore 2 entirely. I tried the same with tensorflow for python in this [Google Colab Notebook](https://colab.research.google.com/drive/1jzyuRFXfIVrK50jI7vDrVuW5dYpiYw0V) and this too was inconsistent with tfjs as tensorflow python allows for removal of axis irrespective of order of input if it's allowed. This inconsistency was removed when adding the support for negative index ( it was totally unintentional but resulted from my implementation to allow negative index ). Clarification required for following topics: * I am not sure whether the test messages are satisfactory for the new tests added - Please let me know if this can be improved and I'll implement the same as soon as possible. * The random input order support was a result of my judgement where i thought `[-1, -2]` was intuitively much better for a programmer but if we just replaced it with the positive index of same, it would have been `[5, 4]` which due to it's lack of ascending order would have ignored 4 completely. If this needs to changed, please let me know and I'll come up with another implementation to change this behavior. Tests Ran: * [x] yarn test * [x] yarn lint BUG
1 parent 08e92ed commit 5c0b5a7

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

src/tensor_test.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1356,16 +1356,43 @@ describeWithFlags('tensor', ALL_ENVS, () => {
13561356
expect(b.shape).toEqual([3, 1]);
13571357
});
13581358

1359+
it('squeeze with negative axis', () => {
1360+
const a = tf.tensor3d([4, 2, 1], [3, 1, 1], 'bool');
1361+
const b = a.squeeze([-1]);
1362+
expect(b.shape).toEqual([3, 1]);
1363+
});
1364+
1365+
it('squeeze with multiple negative axis', () => {
1366+
const a = tf.tensor3d([4, 2, 1], [3, 1, 1], 'bool');
1367+
const b = a.squeeze([-1, -2]);
1368+
expect(b.shape).toEqual([3]);
1369+
});
1370+
13591371
it('squeeze wrong axis', () => {
13601372
const a = tf.tensor3d([4, 2, 1], [3, 1, 1], 'bool');
13611373
expect(() => a.squeeze([0, 1])).toThrowError();
13621374
});
13631375

1376+
it('squeeze wrong negative axis', () => {
1377+
const a = tf.tensor3d([4, 2, 1], [3, 1, 1], 'bool');
1378+
expect(() => a.squeeze([-3, -2])).toThrowError();
1379+
});
1380+
1381+
it('squeeze axis out of range', () => {
1382+
const a = tf.tensor3d([4, 2, 1], [3, 1, 1], 'bool');
1383+
expect(() => a.squeeze([10, 11])).toThrowError();
1384+
});
1385+
1386+
it('squeeze negative axis out of range', () => {
1387+
const a = tf.tensor3d([4, 2, 1], [3, 1, 1], 'bool');
1388+
expect(() => a.squeeze([-13, -12])).toThrowError();
1389+
});
1390+
13641391
it('squeeze throws when passed a non-tensor', () => {
13651392
expect(() => tf.squeeze({} as tf.Tensor))
13661393
.toThrowError(/Argument 'x' passed to 'squeeze' must be a Tensor/);
13671394
});
1368-
1395+
13691396
it('squeeze accepts a tensor-like object', () => {
13701397
const res = tf.squeeze([[[4]], [[2]], [[1]]] /* shape is [3, 1, 1] */);
13711398
expect(res.shape).toEqual([3]);

src/util.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,19 @@ export function squeezeShape(shape: number[], axis?: number[]):
275275
{newShape: number[], keptDims: number[]} {
276276
const newShape: number[] = [];
277277
const keptDims: number[] = [];
278+
if (axis != null) {
279+
for (let i = 0; i < axis.length; ++i) {
280+
if (axis[i] < -shape.length || axis[i] >= shape.length) {
281+
throw new Error(
282+
`Can't squeeze axis ${axis[i]} since its not in ` +
283+
`[-${shape.length}, ${shape.length}) for shape ${shape}`);
284+
}
285+
if (axis[i] < 0) {
286+
axis[i] = shape.length + axis[i];
287+
}
288+
}
289+
axis.sort();
290+
}
278291
let j = 0;
279292
for (let i = 0; i < shape.length; ++i) {
280293
if (axis != null) {

src/util_test.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,25 @@ describe('util.squeezeShape', () => {
262262
expect(newShape).toEqual([1, 1, 4]);
263263
expect(keptDims).toEqual([0, 3, 4]);
264264
});
265+
it('should only reduce dimensions specified by negative axis', () => {
266+
const {newShape, keptDims} = util.squeezeShape([1, 1, 1, 1, 4], [-2, -3]);
267+
expect(newShape).toEqual([1, 1, 4]);
268+
expect(keptDims).toEqual([0, 1, 4]);
269+
});
265270
it('throws error when specified axis is not squeezable', () => {
266271
expect(() => util.squeezeShape([1, 1, 2, 1, 4], [1, 2])).toThrowError();
267272
});
273+
it('throws error when specified negative axis is not squeezable', () => {
274+
expect(() => util.squeezeShape([1, 1, 2, 1, 4], [-1, -2])).toThrowError();
275+
});
276+
it('throws error when specified axis is out of range', () => {
277+
expect(
278+
() => util.squeezeShape([1, 1, 2, 1, 4], [11, 22])).toThrowError();
279+
});
280+
it('throws error when specified negative axis is out of range', () => {
281+
expect(
282+
() => util.squeezeShape([1, 1, 2, 1, 4], [-11, -22])).toThrowError();
283+
});
268284
});
269285
});
270286

0 commit comments

Comments
 (0)