Skip to content

Commit 99ef32e

Browse files
authored
make sure shallowSlice use the tensor dtype instead of dataId dtype (tensorflow#1526)
* make sure slice use the tensor dtype instead of dataId dtype * added test
1 parent 0dbd779 commit 99ef32e

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/kernels/backend_webgl.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ export class MathBackendWebGL implements KernelBackend {
655655

656656
private shallowSlice(x: Tensor, begin: number[], size: number[]): Tensor {
657657
const xTexData = this.texData.get(x.dataId);
658-
const t = Tensor.make(size, {}, xTexData.dtype);
658+
const t = Tensor.make(size, {}, x.dtype);
659659
const newTexData = this.texData.get(t.dataId);
660660
// Copy texture data from the original tensor.
661661
Object.assign(newTexData, xTexData);

src/ops/slice_test.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,4 +497,11 @@ describeWithFlags('slice ergonomics', ALL_ENVS, () => {
497497
expect(result.shape).toEqual([2, 1, 1]);
498498
expectArraysClose(result, [4, 8]);
499499
});
500+
501+
it('should match source tensor dtype', () => {
502+
const a = tf.tensor1d([1, 2, 3, 4, 5], 'int32');
503+
const b = a.asType('float32');
504+
505+
expect(tf.slice(b, 0).dtype).toEqual('float32');
506+
});
500507
});

0 commit comments

Comments
 (0)