Skip to content

Commit cb0bcc4

Browse files
authored
backend: Add ops onesLike and zerosLike (tensorflow#1585)
- Also add CPU and WebGL backend implementations of these based on fill() - This will be used in tfjs-node to bind to the libtensorflow implementations of onesLike and zerosLike to save memory copying between CPU and GPU, which will lead to performance improvements. PERF
1 parent e07d50d commit cb0bcc4

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

src/kernels/backend.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,14 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
589589
throw new Error('Not yet implemented.');
590590
}
591591

592+
onesLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
593+
throw new Error('Not yet implemented');
594+
}
595+
596+
zerosLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
597+
throw new Error('Not yet implemented');
598+
}
599+
592600
/**
593601
* Sets the data mover for this backend. Backends should use the mover to
594602
* move data from other backends to this backend.

src/kernels/backend_cpu.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3338,6 +3338,20 @@ export class MathBackendCPU implements KernelBackend {
33383338
return Tensor.make(shape, {values}, dtype);
33393339
}
33403340

3341+
onesLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
3342+
if (x.dtype === 'string') {
3343+
throw new Error('onesLike is not supported for string tensors');
3344+
} else {
3345+
return this.fill(x.shape, 1, x.dtype);
3346+
}
3347+
}
3348+
3349+
zerosLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
3350+
const values =
3351+
getArrayFromDType(x.dtype, sizeFromShape(x.shape)) as TypedArray;
3352+
return Tensor.make(x.shape, {values}, x.dtype);
3353+
}
3354+
33413355
private scatter<R extends Rank>(
33423356
indices: Tensor, updates: Tensor, shape: ShapeMap[R], outputSize: number,
33433357
sliceSize: number, numUpdates: number, sliceRank: number,

src/kernels/backend_webgl.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,6 +2149,20 @@ export class MathBackendWebGL implements KernelBackend {
21492149
}
21502150
}
21512151

2152+
onesLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
2153+
if (x.dtype === 'string') {
2154+
throw new Error('onesLike is not supported under string dtype');
2155+
} else {
2156+
// TODO(cais, smilkov): Add WebGL shader for onesLike:
2157+
// https://github.com/tensorflow/tfjs/issues/1293
2158+
return this.fill(x.shape, 1, x.dtype);
2159+
}
2160+
}
2161+
2162+
zerosLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
2163+
return this.fill(x.shape, x.dtype === 'string' ? '' : 0, x.dtype);
2164+
}
2165+
21522166
private makeOutputArray<T extends Tensor>(shape: number[], dtype: DataType):
21532167
T {
21542168
return Tensor.make(shape, {}, dtype, this) as T;

src/ops/tensor_ops.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ function fill<R extends Rank>(
447447
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
448448
function onesLike_<T extends Tensor>(x: T|TensorLike): T {
449449
const $x = convertToTensor(x, 'x', 'onesLike');
450-
return ones($x.shape, $x.dtype) as T;
450+
return ENV.engine.runKernel(backend => backend.onesLike($x), {$x}, null) as T;
451451
}
452452

453453
/**
@@ -464,7 +464,8 @@ function onesLike_<T extends Tensor>(x: T|TensorLike): T {
464464
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
465465
function zerosLike_<T extends Tensor>(x: T|TensorLike): T {
466466
const $x = convertToTensor(x, 'x', 'zerosLike');
467-
return zeros($x.shape, $x.dtype) as T;
467+
return ENV.engine.runKernel(backend => backend.zerosLike($x), {$x}, null) as
468+
T;
468469
}
469470

470471
/**

0 commit comments

Comments
 (0)