Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

backend: Add ops onesLike and zerosLike #1585

Merged
merged 6 commits into from
Feb 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/kernels/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,14 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
throw new Error('Not yet implemented.');
}

onesLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
throw new Error('Not yet implemented');
}

zerosLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
throw new Error('Not yet implemented');
}

/**
* Sets the data mover for this backend. Backends should use the mover to
* move data from other backends to this backend.
Expand Down
14 changes: 14 additions & 0 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3338,6 +3338,20 @@ export class MathBackendCPU implements KernelBackend {
return Tensor.make(shape, {values}, dtype);
}

onesLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
if (x.dtype === 'string') {
throw new Error('onesLike is not supported for string tensors');
} else {
return this.fill(x.shape, 1, x.dtype);
}
}

zerosLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
const values =
getArrayFromDType(x.dtype, sizeFromShape(x.shape)) as TypedArray;
return Tensor.make(x.shape, {values}, x.dtype);
}

private scatter<R extends Rank>(
indices: Tensor, updates: Tensor, shape: ShapeMap[R], outputSize: number,
sliceSize: number, numUpdates: number, sliceRank: number,
Expand Down
14 changes: 14 additions & 0 deletions src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2149,6 +2149,20 @@ export class MathBackendWebGL implements KernelBackend {
}
}

onesLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
if (x.dtype === 'string') {
throw new Error('onesLike is not supported under string dtype');
} else {
// TODO(cais, smilkov): Add WebGL shader for onesLike:
// https://github.com/tensorflow/tfjs/issues/1293
return this.fill(x.shape, 1, x.dtype);
}
}

zerosLike<R extends Rank>(x: Tensor<R>): Tensor<R> {
return this.fill(x.shape, x.dtype === 'string' ? '' : 0, x.dtype);
}

private makeOutputArray<T extends Tensor>(shape: number[], dtype: DataType):
T {
return Tensor.make(shape, {}, dtype, this) as T;
Expand Down
5 changes: 3 additions & 2 deletions src/ops/tensor_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ function fill<R extends Rank>(
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function onesLike_<T extends Tensor>(x: T|TensorLike): T {
const $x = convertToTensor(x, 'x', 'onesLike');
return ones($x.shape, $x.dtype) as T;
return ENV.engine.runKernel(backend => backend.onesLike($x), {$x}, null) as T;
}

/**
Expand All @@ -464,7 +464,8 @@ function onesLike_<T extends Tensor>(x: T|TensorLike): T {
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function zerosLike_<T extends Tensor>(x: T|TensorLike): T {
const $x = convertToTensor(x, 'x', 'zerosLike');
return zeros($x.shape, $x.dtype) as T;
return ENV.engine.runKernel(backend => backend.zerosLike($x), {$x}, null) as
T;
}

/**
Expand Down