diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 18d9d3867e..517aaf2cfc 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -103,8 +103,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTe
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTests.netcoreapp", "test\Microsoft.ML.CpuMath.UnitTests.netcoreapp\Microsoft.ML.CpuMath.UnitTests.netcoreapp.csproj", "{5F81A2A4-73AD-494C-B387-07D605EC8826}"
EndProject
-
-Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
+Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}"
EndProject
@@ -426,11 +425,11 @@ Global
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
+ {B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
+ {3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{7333EDEF-4144-405C-A5EC-6F42201857D8} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{A0E562A9-0E6D-470D-B180-6EB44BA84D60} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{5F81A2A4-73AD-494C-B387-07D605EC8826} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
- {B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
- {3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
diff --git a/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj b/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj
index 8aa272922c..d3c5419cb0 100644
--- a/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj
+++ b/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj
@@ -1,9 +1,10 @@
-
+
netstandard2.0
Microsoft.ML
CORECLR
+ true
@@ -56,6 +57,11 @@
True
Resources.resx
+
+ True
+ True
+ TensorGeneric.tt
+
@@ -65,4 +71,15 @@
+
+
+ TextTemplatingFileGenerator
+ TensorGeneric.cs
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.Transforms/TensorFlow/Buffer.cs b/src/Microsoft.ML.Transforms/TensorFlow/Buffer.cs
new file mode 100644
index 0000000000..eedbf9b16d
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/TensorFlow/Buffer.cs
@@ -0,0 +1,211 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Text;
+using size_t = System.UIntPtr;
+
+#pragma warning disable MSML_GeneralName
+#pragma warning disable MSML_ParameterLocalVarName
+
+namespace Microsoft.ML.Transforms.TensorFlow
+{
+ ///
+ /// This attribute can be applied to callback functions that will be invoked
+ /// from unmanaged code to managed code.
+ ///
+ ///
+ ///
+ /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
+ /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
+ ///
+ ///
+ internal sealed class MonoPInvokeCallbackAttribute : Attribute
+ {
+ ///
+ /// Use this constructor to annotate the type of the callback function that
+ /// will be invoked from unmanaged code.
+ ///
+ /// T.
+ public MonoPInvokeCallbackAttribute (Type t) { }
+ }
+
+ [StructLayout (LayoutKind.Sequential)]
+ internal struct LLBuffer
+ {
+ internal IntPtr data;
+ internal size_t length;
+ internal IntPtr data_deallocator;
+ }
+
+ ///
+ /// Holds a block of data, suitable to pass, or retrieve from TensorFlow.
+ ///
+ ///
+ ///
+ /// Use the TFBuffer to blobs of data into TensorFlow, or to retrieve blocks
+ /// of data out of TensorFlow.
+ ///
+ ///
+ /// There are two constructors to wrap existing data, one to wrap blocks that are
+ /// pointed to by an IntPtr and one that takes a byte array that we want to wrap.
+ ///
+ ///
+ /// The empty constructor can be used to create a new TFBuffer that can be populated
+ /// by the TensorFlow library and returned to user code.
+ ///
+ ///
+ /// Typically, the data consists of a serialized protocol buffer, but other data
+ /// may also be held in a buffer.
+ ///
+ ///
+ // TODO: the string ctor
+ // TODO: perhaps we should have an implicit byte [] conversion that just calls ToArray?
+ internal class TFBuffer : TFDisposable
+ {
+ // extern TF_Buffer * TF_NewBufferFromString (const void *proto, size_t proto_len);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe LLBuffer* TF_NewBufferFromString (IntPtr proto, IntPtr proto_len);
+
+ // extern TF_Buffer * TF_NewBuffer ();
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe LLBuffer* TF_NewBuffer ();
+
+ internal TFBuffer (IntPtr handle) : base (handle) { }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public unsafe TFBuffer () : base ((IntPtr)TF_NewBuffer ())
+ {
+ }
+
+ ///
+ /// Signature of the method that is invoked to release the data.
+ ///
+ ///
+ /// Methods of this signature are invoked with the data pointer and the
+ /// lenght pointer when then TFBuffer no longer needs to hold on to the
+ /// data. If you are using this on platforms with static compilation
+ /// like iOS, you need to annotate your callback with the MonoPInvokeCallbackAttribute,
+ /// like this:
+ ///
+ ///
+ /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
+ /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
+ ///
+ ///
+ public delegate void BufferReleaseFunc (IntPtr data, IntPtr lenght);
+
+ ///
+ /// Initializes a new instance of the by wrapping the unmanaged resource pointed by the buffer.
+ ///
+ /// Pointer to the data that will be wrapped.
+ /// The size of the buffer to wrap.
+ /// Optional, if not null, this method will be invoked to release the block.
+ ///
+ /// This constructor wraps the buffer as a the data to be held by the ,
+ /// if the release parameter is null, then you must ensure that the data is not released before the TFBuffer
+ /// is no longer in use. If the value is not null, the provided method will be invoked to release
+ /// the data when the TFBuffer is disposed, or the contents of the buffer replaced.
+ ///
+ public unsafe TFBuffer (IntPtr buffer, long size, BufferReleaseFunc release) : base ((IntPtr)TF_NewBuffer ())
+ {
+ LLBuffer* buf = (LLBuffer*)handle;
+ buf->data = buffer;
+ buf->length = (size_t)size;
+ if (release == null)
+ buf->data_deallocator = IntPtr.Zero;
+ else
+ buf->data_deallocator = Marshal.GetFunctionPointerForDelegate (release);
+ }
+
+ [MonoPInvokeCallback (typeof (BufferReleaseFunc))]
+ internal static void FreeBlock (IntPtr data, IntPtr length)
+ {
+ Marshal.FreeHGlobal (data);
+ }
+
+ internal static IntPtr FreeBufferFunc;
+ internal static BufferReleaseFunc FreeBlockDelegate;
+
+ static TFBuffer ()
+ {
+ FreeBlockDelegate = FreeBlock;
+ FreeBufferFunc = Marshal.GetFunctionPointerForDelegate (FreeBlockDelegate);
+ }
+
+ ///
+ /// Initializes a new instance of the by making a copy of the provided byte array.
+ ///
+ /// Buffer of data that will be wrapped.
+ ///
+ /// This constructor makes a copy of the data into an unmanaged buffer,
+ /// so the byte array is not pinned.
+ ///
+ public TFBuffer (byte [] buffer) : this (buffer, 0, buffer.Length) { }
+
+ ///
+ /// Initializes a new instance of the by making a copy of the provided byte array.
+ ///
+ /// Buffer of data that will be wrapped.
+ /// Starting offset into the buffer to wrap.
+ /// Number of bytes from the buffer to keep.
+ ///
+ /// This constructor makes a copy of the data into an unmanaged buffer,
+ /// so the byte array is not pinned.
+ ///
+ public TFBuffer (byte [] buffer, int start, int count) : this ()
+ {
+ if (start < 0 || start >= buffer.Length)
+ throw new ArgumentException ("start");
+ if (count < 0 || count > buffer.Length - start)
+ throw new ArgumentException ("count");
+ unsafe
+ {
+ LLBuffer* buf = LLBuffer;
+ buf->data = Marshal.AllocHGlobal (count);
+ Marshal.Copy (buffer, start, buf->data, count);
+ buf->length = (size_t)count;
+ buf->data_deallocator = FreeBufferFunc;
+ }
+ }
+
+ internal unsafe LLBuffer* LLBuffer => (LLBuffer*)handle;
+
+ // extern void TF_DeleteBuffer (TF_Buffer *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_DeleteBuffer (LLBuffer* buffer);
+
+ internal override void NativeDispose (IntPtr handle)
+ {
+ unsafe { TF_DeleteBuffer ((LLBuffer*)handle); }
+ }
+
+ // extern TF_Buffer TF_GetBuffer (TF_Buffer *buffer);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe LLBuffer TF_GetBuffer (LLBuffer* buffer);
+
+ ///
+ /// Returns a byte array representing the data wrapped by this buffer.
+ ///
+ /// The array.
+ public byte [] ToArray ()
+ {
+ if (handle == IntPtr.Zero)
+ return null;
+
+ unsafe
+ {
+ var lb = (LLBuffer*)handle;
+
+ var result = new byte [(int)lb->length];
+ Marshal.Copy (lb->data, result, 0, (int)lb->length);
+
+ return result;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/TensorFlow/Tensor.cs b/src/Microsoft.ML.Transforms/TensorFlow/Tensor.cs
new file mode 100644
index 0000000000..83b861fcd0
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/TensorFlow/Tensor.cs
@@ -0,0 +1,952 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Numerics;
+using System.Runtime.InteropServices;
+using System.Text;
+using size_t = System.UIntPtr;
+using TF_Tensor = System.IntPtr;
+
+#pragma warning disable MSML_ParameterLocalVarName
+
+namespace Microsoft.ML.Transforms.TensorFlow
+{
+
+ ///
+ /// TFTensor holds a multi-dimensional array of elements of a single data type.
+ ///
+ ///
+ ///
+ /// You can create tensors with the various constructors in this class, or using
+ /// the implicit conversions from various data types into a TFTensor, including
+ /// the creation of tensors from simple constants (returning a tensor that reprensets
+ /// a scalar, that is, it is a 0D tensor), arrays (returning a tensor of a single
+ /// dimension, 1D) or arbitrary multidimensional arrays.
+ ///
+ ///
+ /// Given a tensor, you can retrieve the number of dimensions in it via the
+ /// NumDims property, or you can retrieve the shape of a tensor, that is how many
+ /// elements on each dimension the tensor has, by fetching the Shape property.
+ ///
+ ///
+ /// The implicit conversions for basic types produce tensors of one dimesion with
+ /// a single element, while the implicit conversion from an array, expects a multi-dimensional
+ /// array that is converted into a tensor of the right dimensions.
+ ///
+ ///
+ /// The special "String" tensor data type that you will find in TensorFlow documentation
+ /// really represents a byte array. You can create string tensors by using the
+ /// method that takes a byte array buffer as input.
+ ///
+ ///
+ ///
+ /// TFTensor scalar = 1; // Creates a 0D tensor, for the integer value 1
+ /// int d = scalar.NumDims; // d will be equal to zero, as it is a 0D tensor
+ /// long [] shape = scalar.Shape // returns an empty array, as it is a 0D tensor
+ ///
+ /// TFTensor list = new [] {1,2,3} // Creates a 1D tensor, or vector, for the values 1, 2, 3
+ /// d = list.NumDims; // d will be one
+ /// shape = list.Shape; // shape will be an array with a single value 3, representing that the dimension 0 has 3 elements
+ ///
+ /// // Creates a 3D tensor,
+ /// TFTensor cube = new [,,] { {{1,2,3},{4,5,6}}}
+ /// d = cube.NumDims // d will be 3
+ /// shape = list.Shape // shape will be [1,2,3] which is the shape of the above 3D array
+ ///
+ ///
+ ///
+ internal partial class TFTensor : TFDisposableThreadSafe
+ {
+ ///
+ /// Signature that methods must conform to to be used to release memory that was passed to a manually allocated TFTensor
+ ///
+ public delegate void Deallocator (IntPtr data, IntPtr size, IntPtr deallocatorData);
+
+ // extern TF_Tensor * TF_NewTensor (TF_DataType, const int64_t *dims, int num_dims, void *data, size_t len, void (* deallocator)(void *, size_t, void *), void *deallocator_arg);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Tensor TF_NewTensor (TFDataType dataType, long [] dims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg);
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Tensor TF_NewTensor (TFDataType dataType, IntPtr zeroDims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg);
+
+ internal TFTensor (IntPtr handle) : base (handle) { }
+
+ internal static Deallocator FreeTensorDataDelegate = FreeTensorData;
+ internal static Deallocator FreeTensorHandleDelegate = FreeTensorHandle;
+
+ [MonoPInvokeCallback (typeof (Deallocator))]
+ internal static void FreeTensorData (IntPtr data, IntPtr len, IntPtr closure)
+ {
+ Marshal.FreeHGlobal (data);
+ }
+
+ [MonoPInvokeCallback (typeof (Deallocator))]
+ internal static void FreeTensorHandle (IntPtr data, IntPtr len, IntPtr closure)
+ {
+ var gch = GCHandle.FromIntPtr (closure);
+ gch.Free ();
+ }
+
+ // TODO: Other overloads we could add: String, Complex (float), Bool, QInt8, QUInt8, QInt32, Bfloat16,
+ // QInt16, QUint16, Half, Resource
+ // TODO: not clear that this is very useful (the dims versions), perhaps to reduce the surface of
+ // construcors these rarer blobs should be "FromSpec" or something like that
+
+ ///
+ /// Creates a new tensor from a portion of an array of sbytes
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, sbyte [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.Int8, shape, data, start, count, size: 2));
+ }
+
+ ///
+ /// Creates a new tensor from a portion of an array of bytes
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, byte [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.UInt8, shape, data, start, count, size: 1));
+ }
+
+ ///
+ /// Creates a new tensor from a portion of an array of shorts
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, short [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.Int16, shape, data, start, count, size: 2));
+ }
+
+ ///
+ /// Creates a new tensor from a portion of an array of ushorts
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, ushort [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.UInt16, shape, data, start, count, size: 2));
+ }
+
+ ///
+ /// Creates a new tensor from a portion of an array of ints
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, int [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.Int32, shape, data, start, count, size: 4));
+ }
+
+ ///
+ /// Creates a new tensor from a portion of an array of floats
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, float [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.Float, shape, data, start, count, size: 4));
+ }
+
+ ///
+ /// Creates a new tensor from a portion of an array of doubles
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, double [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.Double, shape, data, start, count, size: 8));
+ }
+
+ ///
+ /// Creates a new tensor from a portion of an array of longs
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, long [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.Int64, shape, data, start, count, size: 8));
+ }
+
+ ///
+ /// Creates a new tensor from a portion of an array of Complex numbers
+ ///
+ /// Represents the tensor shape.
+ /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions.
+ /// The offset into the provided data array where the data resides.
+ /// The number of bytes to copy from count into the tensor.
+ ///
+ /// Use the FromBuffer method to create a tensor that has the specified dimensions
+ /// and is initialized with data from the data array. The data is copied starting
+ /// at the start offset, for count bytes and is laid out into the tensor following the
+ /// specified dimensions.
+ ///
+ public static TFTensor FromBuffer (TFShape shape, Complex [] data, int start, int count)
+ {
+ return new TFTensor (SetupTensor (TFDataType.Complex128, shape, data, start, count, size: 16));
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from an integer value.
+ ///
+ public unsafe TFTensor (int value)
+ {
+ var v = (int*)Marshal.AllocHGlobal (sizeof (int));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Int32, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (int), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from a boolean value.
+ ///
+ public unsafe TFTensor (bool value)
+ {
+ var v = (bool*)Marshal.AllocHGlobal (sizeof (bool));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Bool, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (int), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from an sbyte value.
+ ///
+ public unsafe TFTensor (sbyte value)
+ {
+ var v = (sbyte*)Marshal.AllocHGlobal (sizeof (sbyte));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Int8, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (sbyte), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from a short value.
+ ///
+ public unsafe TFTensor (short value)
+ {
+ var v = (short*)Marshal.AllocHGlobal (sizeof (short));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Int16, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (short), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from an ushort value.
+ ///
+ public unsafe TFTensor (ushort value)
+ {
+ var v = (ushort*)Marshal.AllocHGlobal (sizeof (ushort));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Int16, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (ushort), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from an byte value.
+ ///
+ public unsafe TFTensor (byte value)
+ {
+ var v = (int*)Marshal.AllocHGlobal (sizeof (byte));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.UInt8, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (byte), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from a Complex value.
+ ///
+ public unsafe TFTensor (Complex value)
+ {
+ var v = (Complex*)Marshal.AllocHGlobal (sizeof (Complex));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Complex128, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (Complex), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from a float value.
+ ///
+ public unsafe TFTensor (float value)
+ {
+ var v = (float*)Marshal.AllocHGlobal (sizeof (float));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Float, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (float), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from a double value.
+ ///
+ public unsafe TFTensor (double value)
+ {
+ var v = (double*)Marshal.AllocHGlobal (sizeof (double));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Double, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (double), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+
+ ///
+ /// Creates a constant tensor with a single dimension from a long value.
+ ///
+ public unsafe TFTensor (long value)
+ {
+ var v = (long*)Marshal.AllocHGlobal (sizeof (long));
+ *v = value;
+ handle = TF_NewTensor (TFDataType.Int64, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof (long), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
+ }
+ ///
+ /// Creates a 1 dimensional tensor from an array of booleans.
+ ///
+ /// Data.
+ public TFTensor (bool [] data) : base (SetupTensor (TFDataType.Bool, data, size: 1)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of sbytes.
+ ///
+ /// Data.
+ public TFTensor (sbyte [] data) : base (SetupTensor (TFDataType.Int8, data, size: 1)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of bytes.
+ ///
+ /// Data.
+ public TFTensor (byte [] data) : base (SetupTensor (TFDataType.UInt8, data, size: 1)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of shorts.
+ ///
+ /// Data.
+ public TFTensor (short [] data) : base (SetupTensor (TFDataType.Int16, data, size: 2)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of ushorts
+ ///
+ /// Data.
+ public TFTensor (ushort [] data) : base (SetupTensor (TFDataType.UInt16, data, size: 2)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of ints.
+ ///
+ /// Data.
+ public TFTensor (int [] data) : base (SetupTensor (TFDataType.Int32, data, size: 4)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of floats.
+ ///
+ /// Data.
+ public TFTensor (float [] data) : base (SetupTensor (TFDataType.Float, data, size: 4)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of doubles.
+ ///
+ /// Data.
+ public TFTensor (double [] data) : base (SetupTensor (TFDataType.Double, data, size: 8)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of longs.
+ ///
+ /// Data.
+ public TFTensor (long [] data) : base (SetupTensor (TFDataType.Int64, data, size: 8)) { }
+ ///
+ /// Creates a 1 dimensional tensor from an array of complex numbers.
+ ///
+ /// Data.
+ public TFTensor (Complex [] data) : base (SetupTensor (TFDataType.Complex128, data, size: 16)) { }
+
+ // Convenience function to factor out the setup of a new tensor from an array
+ internal static IntPtr SetupTensor (TFDataType dt, long [] dims, Array data, int size)
+ {
+ return SetupTensor (dt, dims, data, start: 0, count: data.Length, size: size);
+ }
+
+ // Convenience function to factor out the setup of a new tensor from an array
+ internal static IntPtr SetupTensor (TFDataType dt, Array data, int size)
+ {
+ long [] dims = new long [data.Rank];
+ for (int i = 0; i < dims.Length; i++)
+ dims [i] = data.GetLength (i);
+
+ return SetupTensor (dt, dims, data, start: 0, count: data.Length, size: size);
+ }
+
+ // Use for single dimension arrays
+ internal static IntPtr SetupTensor (TFDataType dt, TFShape shape, Array data, int start, int count, int size)
+ {
+ if (shape == null)
+ throw new ArgumentNullException (nameof (shape));
+ return SetupTensor (dt, shape.dims, data, start, count, size);
+ }
+
+ // Use for single dimension arrays
+ internal static IntPtr SetupTensor (TFDataType dt, long [] dims, Array data, int start, int count, int size)
+ {
+ if (start < 0 || start > data.Length - count)
+ throw new ArgumentException ("start + count > Array size");
+
+ var dataHandle = GCHandle.Alloc (data, GCHandleType.Pinned);
+
+ if (dims == null)
+ return TF_NewTensor (dt, IntPtr.Zero, 0, dataHandle.AddrOfPinnedObject () + start * size, (UIntPtr)(count * size), FreeTensorHandleDelegate, GCHandle.ToIntPtr (dataHandle));
+ else
+ return TF_NewTensor (dt, dims, dims.Length, dataHandle.AddrOfPinnedObject () + start * size, (UIntPtr)(count * size), FreeTensorHandleDelegate, GCHandle.ToIntPtr (dataHandle));
+ }
+
+ // General purpose constructor, specifies data type and gets pointer to buffer
+ // Is the default good, one where we let the user provide their own deallocator, or should we make a copy in that case?
+ ///
+ /// Low-level tensor constructor that creates a tensor from a buffer pointed to by an IntPtr.
+ ///
+ /// Specifies the data type held by the tensor, as well as how to interpret the provided data.
+ /// Describes the tensor shape, an array that indicates .
+ /// Pointer to the raw data that will be used to initialize the tensor.
+ /// The size of the data being passed in.
+ /// Deallocator method, it is invoked when the tensor is destroyed to release the data pointed to by . On platforms like iOS (or other static compilation platforms), yiou must annotate the method specified in the deallocator with a .
+ /// An optional argument of data that is passed to the deallocator method when the tensor is destroyed, you can use this to pass context information.
+ public TFTensor (TFDataType dataType, long [] dims, IntPtr data, size_t dataSize, Deallocator deallocator, IntPtr deallocatorData) : base (IntPtr.Zero)
+ {
+ if (dims == null)
+ throw new ArgumentNullException ("dims");
+
+ handle = TF_NewTensor (dataType, dims, dims.Length, data, dataSize, deallocator, deallocatorData);
+
+ }
+
+ internal override void NativeDispose (IntPtr handle)
+ {
+ TF_DeleteTensor (handle);
+ }
+
+ // extern TF_Tensor * TF_AllocateTensor (TF_DataType, const int64_t *dims, int num_dims, size_t len);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Tensor TF_AllocateTensor (TFDataType dataType, long [] dims, int num_dims, size_t len);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Tensor TF_AllocateTensor (TFDataType dataType, IntPtr zeroDim, int num_dims, size_t len);
+
+ ///
+ /// Low-level: Creates an empty tensor of the specified type and shape, with the specified number of elements
+ ///
+ /// Data type.
+ /// Tensor shape.
+ /// Size in bytes of the tensor, this will be the actual memory allocated.
+ ///
+ /// It is the responsibility of the caller to ensure that the size is correct given the data type size
+ /// and the tensor dimension specified in dims.
+ ///
+ public TFTensor (TFDataType dataType, long [] dims, int size) : base (IntPtr.Zero)
+ {
+ if (dims == null)
+ throw new ArgumentNullException ("dims");
+ handle = TF_AllocateTensor (dataType, dims, dims.Length, (size_t)size);
+ }
+
+ // extern void TF_DeleteTensor (TF_Tensor *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_DeleteTensor (TF_Tensor tensor);
+
+ // extern TF_DataType TF_TensorType (const TF_Tensor *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TFDataType TF_TensorType (TF_Tensor tensor);
+
+ ///
+ /// Returns the data type for the tensor.
+ ///
+ /// The type of the tensor.
+ public TFDataType TensorType => TF_TensorType (handle);
+
+ // extern int TF_NumDims (const TF_Tensor *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe int TF_NumDims (TF_Tensor tensor);
+
+ ///
+ /// Returns the number of dimensions in the tensor.
+ ///
+ ///
+ /// For single-dimension tensors the return is 1, 2 dimensions is 2 and so on.
+ ///
+ public int NumDims => TF_NumDims (handle);
+
+ // extern int64_t TF_Dim (const TF_Tensor *tensor, int dim_index);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe long TF_Dim (TF_Tensor tensor, int dim_index);
+
+ ///
+ /// Returns the number of elements on a specific dimension in the tensor.
+ ///
+ /// The tensor dimension.
+ /// Dimension that you are querying.
+ ///
+ /// If you have a tensor of 3 elements by 5, represented by [3 5],
+ /// the GetTensorDimension(0) will return 3, the GetTensorDimension(1)
+ /// will return 5.
+ ///
+ public long GetTensorDimension (int dimIndex)
+ {
+ return TF_Dim (handle, dimIndex);
+ }
+
+ // extern size_t TF_TensorByteSize (const TF_Tensor *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe size_t TF_TensorByteSize (TF_Tensor tensor);
+
+ public size_t TensorByteSize => TF_TensorByteSize (handle);
+
+ // extern void * TF_TensorData (const TF_Tensor *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe IntPtr TF_TensorData (TF_Tensor tensor);
+
+ ///
+ /// Returns a pointer to the raw data in the tensor.
+ ///
+ ///
+ /// The contents of the Data must be interpreted according to the type of the
+ /// data as described by the DataType property. The amount of data
+ /// is given by the the TensorByteSize property.
+ ///
+ public IntPtr Data => TF_TensorData (handle);
+
+ ///
+ /// Returns the tensor shape, this is an array whose size determines the number of dimensions on the tensor, and each element is the size of the dimension
+ ///
+ ///
+ /// An array of size 0 is used for constants, an array of size 1 is used
+ /// for single-dimension arrays, where the dimension is the value of the
+ /// first element. And so on.
+ ///
+ public long [] Shape {
+ get {
+ var dims = new long [TF_NumDims (handle)];
+ for (int i = 0; i < dims.Length; i++)
+ dims [i] = (int)TF_Dim (handle, i);
+
+ return dims;
+ }
+ }
+
+ ///
+ /// Converts a to a system type.
+ ///
+ /// The to be converted.
+ /// The system type corresponding to the given .
+ public static Type TypeFromTensorType (TFDataType type)
+ {
+ switch (type) {
+ case TFDataType.Float:
+ return typeof (float);
+ case TFDataType.Double:
+ return typeof (double);
+ case TFDataType.Int32:
+ return typeof (int);
+ case TFDataType.UInt8:
+ return typeof (byte);
+ case TFDataType.Int16:
+ return typeof (short);
+ case TFDataType.Int8:
+ return typeof (sbyte);
+ case TFDataType.String:
+ throw new NotSupportedException();
+ case TFDataType.Int64:
+ return typeof (long);
+ case TFDataType.Bool:
+ return typeof (bool);
+ case TFDataType.UInt16:
+ return typeof (ushort);
+ case TFDataType.Complex128:
+ return typeof (Complex);
+ default:
+ return null;
+ }
+ }
+
+ ///
+ /// Converts a system type to a .
+ ///
+ /// The system type to be converted.
+ /// The corresponding to the given type.
+ public static TFDataType TensorTypeFromType(Type type)
+ {
+ if (type == typeof(float))
+ return TFDataType.Float;
+ if (type == typeof(double))
+ return TFDataType.Double;
+ if (type == typeof(int))
+ return TFDataType.Int32;
+ if (type == typeof(byte))
+ return TFDataType.UInt8;
+ if (type == typeof(short))
+ return TFDataType.Int16;
+ if (type == typeof(sbyte))
+ return TFDataType.Int8;
+ if (type == typeof(string))
+ return TFDataType.String;
+ if (type == typeof(long))
+ return TFDataType.Int64;
+ if (type == typeof(bool))
+ return TFDataType.Bool;
+ if (type == typeof(ushort))
+ return TFDataType.UInt16;
+ if (type == typeof(Complex))
+ return TFDataType.Complex128;
+
+ throw new ArgumentOutOfRangeException(nameof(type), $"The given type could not be mapped to an existing {nameof(TFDataType)}.");
+ }
+
+ private static unsafe object FetchSimple (TFDataType dt, IntPtr data)
+ {
+ switch (dt) {
+ case TFDataType.Float:
+ return *(float*)data;
+ case TFDataType.Double:
+ return *(double*)data;
+ case TFDataType.Int32:
+ return *(int*)data;
+ case TFDataType.UInt8:
+ return *(byte*)data;
+ case TFDataType.Int16:
+ return *(short*)data;
+ case TFDataType.Int8:
+ return *(sbyte*)data;
+ case TFDataType.String:
+ throw new NotImplementedException ();
+ case TFDataType.Int64:
+ return *(long*)data;
+ case TFDataType.Bool:
+ return *(bool*)data;
+ case TFDataType.UInt16:
+ return *(ushort*)data;
+ case TFDataType.Complex128:
+ return *(Complex*)data;
+ default:
+ return null;
+ }
+ }
+
+ internal static unsafe void Copy (IntPtr src, void* target, int size)
+ {
+ Buffer.MemoryCopy ((void*)src, target, size, size);
+ }
+
+ internal static unsafe void FetchFlatArray (Array target, TFDataType dt, IntPtr data)
+ {
+ int len = target.Length;
+ switch (dt) {
+ case TFDataType.Int8:
+ var asbyte = (sbyte [])target;
+ fixed (sbyte* p = &asbyte [0])
+ Copy (data, p, len);
+ return;
+ case TFDataType.Bool:
+ var abool = (bool [])target;
+ fixed (bool* p = &abool [0])
+ Copy (data, p, len);
+ return;
+ case TFDataType.UInt16:
+ var aushort = (ushort [])target;
+ fixed (ushort* p = &aushort [0])
+ Copy (data, p, len * 2);
+ return;
+ case TFDataType.Complex128:
+ var acomplex = (Complex [])target;
+ fixed (Complex* p = &acomplex [0])
+ Copy (data, p, len * sizeof (Complex));
+ return;
+ case TFDataType.Float:
+ var afloat = (float [])target;
+ fixed (float* p = &afloat [0])
+ Copy (data, p, len * sizeof (float));
+ return;
+ case TFDataType.Double:
+ var adouble = (double [])target;
+ fixed (double* p = &adouble [0])
+ Copy (data, p, len * sizeof (double));
+ return;
+ case TFDataType.Int32:
+ var aint = (int [])target;
+ fixed (int* p = &aint [0])
+ Copy (data, p, len * sizeof (double));
+ return;
+ case TFDataType.UInt8:
+ var abyte = (byte [])target;
+ fixed (byte* p = &abyte [0])
+ Copy (data, p, len * sizeof (byte));
+ return;
+ case TFDataType.Int16:
+ var ashort = (short [])target;
+ fixed (short* p = &ashort [0])
+ Copy (data, p, len * sizeof (short));
+ return;
+ case TFDataType.Int64:
+ var along = (long [])target;
+ fixed (long* p = &along [0])
+ Copy (data, p, len * sizeof (long));
+ return;
+ case TFDataType.String:
+ // need to return an array of TFStrings []
+ throw new NotImplementedException ();
+ default:
+ throw new NotImplementedException ();
+ }
+ }
+
+ private static unsafe object FetchJaggedArray (Type t, TFDataType dt, ref IntPtr data, long [] shape, int level = 0)
+ {
+ Array target;
+
+ // If we are at the last node
+ if (level == shape.Length - 1) {
+ target = Array.CreateInstance (t, shape [level]);
+
+ for (long l = 0; l < shape [level]; l++)
+ switch (dt) {
+ case TFDataType.Float:
+ target.SetValue ((*(float*)data), l);
+ data += 4;
+ break;
+ case TFDataType.Double:
+ target.SetValue ((*(double*)data), l);
+ data += 8;
+ break;
+ case TFDataType.Int32:
+ target.SetValue ((*(int*)data), l);
+ data += 4;
+ break;
+ case TFDataType.UInt8:
+ target.SetValue ((*(byte*)data), l);
+ data += 1;
+ break;
+ case TFDataType.Int16:
+ target.SetValue ((*(short*)data), l);
+ data += 2;
+ break;
+ case TFDataType.Int8:
+ target.SetValue ((*(sbyte*)data), l);
+ data += 1;
+ break;
+ case TFDataType.Int64:
+ target.SetValue ((*(long*)data), l);
+ data += 8;
+ break;
+ case TFDataType.Bool:
+ target.SetValue ((*(bool*)data), l);
+ data += 1;
+ break;
+ case TFDataType.Complex128:
+ target.SetValue ((*(Complex*)data), l);
+ data += sizeof (Complex);
+ break;
+ case TFDataType.String:
+ throw new NotImplementedException ("String decoding not implemented for tensor vecotrs yet");
+ default:
+ throw new NotImplementedException ();
+ }
+ } else {
+ target = null;
+
+ long top = shape [level];
+ if (top < Int32.MaxValue) {
+ int itop = (int)top;
+
+ for (int i = 0; i < itop; i++) {
+ var childArray = FetchJaggedArray (t, dt, ref data, shape, level + 1);
+ if (target == null)
+ target = Array.CreateInstance (childArray.GetType (), shape [level]);
+
+ target.SetValue (childArray, i);
+ }
+ } else {
+ for (long l = 0; l < top; l++) {
+
+ var chidArray = FetchJaggedArray (t, dt, ref data, shape, level + 1);
+ if (target == null)
+ target = Array.CreateInstance (chidArray.GetType (), shape [level]);
+
+ target.SetValue (chidArray, l);
+ }
+ }
+ return target;
+ }
+
+ return target;
+ }
+
+ private static void FetchMultiDimensionalArray (Array target, TFDataType dt, IntPtr data, long [] shape)
+ {
+ var idx = new int [shape.Length];
+ for (int i = 0; i < shape.Length; i++) {
+ if (shape [i] > Int32.MaxValue)
+ throw new ArgumentOutOfRangeException ("Shape can not be longer than 32 bits");
+ }
+ Copy (target, dt, shape, idx, 0, ref data);
+ }
+
+ private static unsafe void Copy (Array target, TFDataType dt, long [] shape, int [] idx, int level, ref IntPtr data)
+ {
+ if (level < shape.Length - 1) {
+ for (idx [level] = 0; idx [level] < shape [level]; idx [level]++)
+ Copy (target, dt, shape, idx, level + 1, ref data);
+ } else {
+ for (idx [level] = 0; idx [level] < shape [level]; idx [level]++) {
+ switch (dt) {
+ case TFDataType.Float:
+ target.SetValue ((*(float*)data), idx);
+ data += 4;
+ break;
+ case TFDataType.Double:
+ target.SetValue ((*(double*)data), idx);
+ data += 8;
+ break;
+ case TFDataType.Int32:
+ target.SetValue ((*(int*)data), idx);
+ data += 4;
+ break;
+ case TFDataType.UInt8:
+ target.SetValue ((*(byte*)data), idx);
+ data += 1;
+ break;
+ case TFDataType.Int16:
+ target.SetValue ((*(short*)data), idx);
+ data += 2;
+ break;
+ case TFDataType.Int8:
+ target.SetValue ((*(sbyte*)data), idx);
+ data += 1;
+ break;
+ case TFDataType.Int64:
+ target.SetValue ((*(long*)data), idx);
+ data += 8;
+ break;
+ case TFDataType.Bool:
+ target.SetValue ((*(bool*)data), idx);
+ data += 1;
+ break;
+ case TFDataType.Complex128:
+ target.SetValue ((*(Complex*)data), idx);
+ data += sizeof (Complex);
+ break;
+ case TFDataType.String:
+ throw new NotImplementedException ("String decoding not implemented for tensor vecotrs yet");
+ default:
+ throw new NotImplementedException ();
+ }
+ }
+ }
+ }
+
+ ///
+ /// Returns the value of the Tensor as a C# type if possible, or null if the data type can not be represented in C#
+ ///
+ ///
+ /// The default is set to false, which returns .NET multi-dimensional arrays for multi-dimensional
+ /// tensors. This is useful to feed the data back as a TFTensor created from an array. Set to
+ /// true if you want to get arrays pointing to arrays, which are slightly more convenient to work
+ /// with from C#
+ ///
+ ///
+ /// Jagged arrays create various intermediate arrays, while multi-dimensional arrays are more
+ /// efficient memory-wise.
+ ///
+ /// The value encodes the contents of the tensor, and could include simple values, arrays and multi-dimensional values.
+ public object GetValue (bool jagged = false)
+ {
+ var dims = NumDims;
+ if (dims == 0)
+ return FetchSimple (TensorType, Data);
+
+ var t = TypeFromTensorType (TensorType);
+ if (t == null)
+ return null;
+
+ if (dims == 1) {
+ var result = Array.CreateInstance (t, Shape [0]);
+ FetchFlatArray (result, TensorType, Data);
+ return result;
+ } else {
+ if (jagged) {
+ IntPtr data = Data;
+ return FetchJaggedArray (t, TensorType, ref data, Shape);
+ } else {
+ var result = Array.CreateInstance (t, Shape);
+ FetchMultiDimensionalArray (result, TensorType, Data, Shape);
+ return result;
+ }
+ }
+ }
+
+ ///
+ /// Returns a that represents the current .
+ ///
+ /// A that represents the current .
+ public override string ToString ()
+ {
+ var n = NumDims;
+ if (n == 0)
+ return GetValue ().ToString ();
+
+ StringBuilder sb = new StringBuilder ("[");
+ for (int i = 0; i < n; i++) {
+ sb.Append (TF_Dim (handle, i));
+ if (i + 1 < n)
+ sb.Append ("x");
+ }
+ sb.Append ("]");
+ return sb.ToString ();
+ }
+
+ }
+
+}
diff --git a/src/Microsoft.ML.Transforms/TensorFlow/TensorGeneric.cs b/src/Microsoft.ML.Transforms/TensorFlow/TensorGeneric.cs
new file mode 100644
index 0000000000..12e2bea004
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/TensorFlow/TensorGeneric.cs
@@ -0,0 +1,139 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+
+namespace Microsoft.ML.Transforms.TensorFlow
+{
+ internal partial class TFTensor
+ {
+ ///
+ /// Creates a tensor representing type T.
+ /// The tensor will be backed with a managed-heap-allocated T.
+ ///
+ /// .NET type of tensor to create
+ /// value of tensor
+ public static TFTensor CreateScalar(T data)
+ {
+ if (typeof(T) == typeof(System.Boolean))
+ {
+ return new TFTensor((System.Boolean)(object)data);
+ }
+ else if (typeof(T) == typeof(System.Byte))
+ {
+ return new TFTensor((System.Byte)(object)data);
+ }
+ else if (typeof(T) == typeof(System.Char))
+ {
+ return new TFTensor((System.Char)(object)data);
+ }
+ else if (typeof(T) == typeof(System.Numerics.Complex))
+ {
+ return new TFTensor((System.Numerics.Complex)(object)data);
+ }
+ else if (typeof(T) == typeof(System.Double))
+ {
+ return new TFTensor((System.Double)(object)data);
+ }
+ else if (typeof(T) == typeof(System.Single))
+ {
+ return new TFTensor((System.Single)(object)data);
+ }
+ else if (typeof(T) == typeof(System.Int32))
+ {
+ return new TFTensor((System.Int32)(object)data);
+ }
+ else if (typeof(T) == typeof(System.Int64))
+ {
+ return new TFTensor((System.Int64)(object)data);
+ }
+ else if (typeof(T) == typeof(System.SByte))
+ {
+ return new TFTensor((System.SByte)(object)data);
+ }
+ else if (typeof(T) == typeof(System.Int16))
+ {
+ return new TFTensor((System.Int16)(object)data);
+ }
+ else if (typeof(T) == typeof(System.UInt32))
+ {
+ return new TFTensor((System.UInt32)(object)data);
+ }
+ else if (typeof(T) == typeof(System.UInt64))
+ {
+ return new TFTensor((System.UInt64)(object)data);
+ }
+ else if (typeof(T) == typeof(System.UInt16))
+ {
+ return new TFTensor((System.UInt16)(object)data);
+ }
+ throw new NotSupportedException($"Unsupported type {typeof(T)}");
+ }
+
+ ///
+ /// Creates a tensor representing type T[].
+ /// T[] will be pinned and wrapped in a tensor.
+ ///
+ /// .NET type of tensor to create
+ /// value of tensor
+ /// shape of tensor
+ public static TFTensor Create(T[] data, TFShape shape)
+ {
+ if (typeof(T) == typeof(System.Boolean))
+ {
+ return new TFTensor(SetupTensor(TFDataType.Bool, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4));
+ }
+ else if (typeof(T) == typeof(System.Byte))
+ {
+ return new TFTensor(SetupTensor(TFDataType.UInt8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1));
+ }
+ else if (typeof(T) == typeof(System.Char))
+ {
+ return new TFTensor(SetupTensor(TFDataType.UInt8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1));
+ }
+ else if (typeof(T) == typeof(System.Numerics.Complex))
+ {
+ return new TFTensor(SetupTensor(TFDataType.Complex128, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 16));
+ }
+ else if (typeof(T) == typeof(System.Double))
+ {
+ return new TFTensor(SetupTensor(TFDataType.Double, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8));
+ }
+ else if (typeof(T) == typeof(System.Single))
+ {
+ return new TFTensor(SetupTensor(TFDataType.Float, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4));
+ }
+ else if (typeof(T) == typeof(System.Int32))
+ {
+ return new TFTensor(SetupTensor(TFDataType.Int32, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4));
+ }
+ else if (typeof(T) == typeof(System.Int64))
+ {
+ return new TFTensor(SetupTensor(TFDataType.Int64, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8));
+ }
+ else if (typeof(T) == typeof(System.SByte))
+ {
+ return new TFTensor(SetupTensor(TFDataType.Int8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1));
+ }
+ else if (typeof(T) == typeof(System.Int16))
+ {
+ return new TFTensor(SetupTensor(TFDataType.Int16, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 2));
+ }
+ else if (typeof(T) == typeof(System.UInt32))
+ {
+ return new TFTensor(SetupTensor(TFDataType.UInt32, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4));
+ }
+ else if (typeof(T) == typeof(System.UInt64))
+ {
+ return new TFTensor(SetupTensor(TFDataType.UInt64, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8));
+ }
+ else if (typeof(T) == typeof(System.UInt16))
+ {
+ return new TFTensor(SetupTensor(TFDataType.UInt16, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 2));
+ }
+ // note that we will get here for jagged arrays, which is intententional since we'd need to copy them.
+ throw new NotSupportedException($"Unsupported type {typeof(T)}");
+ }
+ }
+}
+
diff --git a/src/Microsoft.ML.Transforms/TensorFlow/TensorGeneric.tt b/src/Microsoft.ML.Transforms/TensorFlow/TensorGeneric.tt
new file mode 100644
index 0000000000..b2f44697c3
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/TensorFlow/TensorGeneric.tt
@@ -0,0 +1,99 @@
+<#@ template debug="false" hostspecific="false" language="C#" #>
+<#@ assembly name="System.Core" #>
+<#@ assembly name="System.Numerics" #>
+<#@ import namespace="System.Linq" #>
+<#@ import namespace="System.Text" #>
+<#@ import namespace="System.Collections.Generic" #>
+<#@ import namespace="System.Runtime.InteropServices" #>
+<#@ output extension=".cs" #>// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+
+namespace Microsoft.ML.Transforms.TensorFlow
+{
+ internal partial class TFTensor
+ {
+ ///
+ /// Creates a tensor representing type T.
+ /// The tensor will be backed with a managed-heap-allocated T.
+ ///
+ /// .NET type of tensor to create
+ /// value of tensor
+ public static TFTensor CreateScalar(T data)
+ {
+<# foreach (TypeConfiguration type in typeConfiguration) { #>
+ <#=GenerateIfStatementHeader(type)#>
+ {
+ return new TFTensor((<#=type.TypeName#>)(object)data);
+ }
+<# } #>
+ throw new NotSupportedException($"Unsupported type {typeof(T)}");
+ }
+
+ ///
+ /// Creates a tensor representing type T[].
+ /// T[] will be pinned and wrapped in a tensor.
+ ///
+ /// .NET type of tensor to create
+ /// value of tensor
+ /// shape of tensor
+ public static TFTensor Create(T[] data, TFShape shape)
+ {
+<# foreach (TypeConfiguration type in typeConfiguration) { #>
+ <#=GenerateIfStatementHeader(type)#>
+ {
+ return new TFTensor(SetupTensor(TFDataType.<#=type.TFDataType#>, shape, (Array)(object)data, 0, ((Array)(object)data).Length, <#=type.Size#>));
+ }
+<# } #>
+ // note that we will get here for jagged arrays, which is intententional since we'd need to copy them.
+ throw new NotSupportedException($"Unsupported type {typeof(T)}");
+ }
+ }
+}
+
+<#+
+ public class TypeConfiguration
+ {
+ public TypeConfiguration(Type type, string tfDataType)
+ {
+ Type = type;
+ TFDataType = tfDataType;
+ }
+ public string TypeName
+ {
+ get { return Type.ToString(); }
+ }
+ public Type Type { get; }
+ public string TFDataType { get; }
+ public int Size
+ {
+ get { return Marshal.SizeOf(Type); }
+ }
+ }
+
+ public string GenerateIfStatementHeader(TypeConfiguration type, string lhs = "typeof(T)")
+ {
+ string keyword = (type == typeConfiguration[0]) ? "if" : "else if";
+ return $"{keyword} ({lhs} == typeof({type.TypeName}))";
+ }
+
+ public TypeConfiguration[] typeConfiguration = new []
+ {
+ new TypeConfiguration(typeof(bool), "Bool"),
+ new TypeConfiguration(typeof(byte), "UInt8"),
+ new TypeConfiguration(typeof(char), "UInt8"),
+ new TypeConfiguration(typeof(System.Numerics.Complex), "Complex128"),
+ // new TypeConfiguration(typeof(decimal), "unknown"), TF doesn't appear to have 128-bit floating-point
+ new TypeConfiguration(typeof(double),"Double"),
+ new TypeConfiguration(typeof(float), "Float"),
+ new TypeConfiguration(typeof(int), "Int32"),
+ new TypeConfiguration(typeof(long), "Int64"),
+ new TypeConfiguration(typeof(sbyte), "Int8"),
+ new TypeConfiguration(typeof(short), "Int16"),
+ new TypeConfiguration(typeof(uint), "UInt32"),
+ new TypeConfiguration(typeof(ulong), "UInt64"),
+ new TypeConfiguration(typeof(ushort), "UInt16")
+ // TODO, map other types
+ };
+#>
\ No newline at end of file
diff --git a/src/Microsoft.ML.Transforms/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.Transforms/TensorFlow/Tensorflow.cs
new file mode 100644
index 0000000000..76c74a8693
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/TensorFlow/Tensorflow.cs
@@ -0,0 +1,2147 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Text;
+using System.Globalization;
+using System.Linq;
+
+// We use this TF_Xxx as the native "TF_Xxx *" as those are opaque
+using TF_Status = System.IntPtr;
+using TF_SessionOptions = System.IntPtr;
+using TF_Graph = System.IntPtr;
+using TF_OperationDescription = System.IntPtr;
+using TF_Operation = System.IntPtr;
+using TF_Session = System.IntPtr;
+using TF_DeprecatedSession = System.IntPtr;
+using TF_Tensor = System.IntPtr;
+using TF_ImportGraphDefOptions = System.IntPtr;
+using TF_Library = System.IntPtr;
+using TF_BufferPtr = System.IntPtr;
+using TF_Function = System.IntPtr;
+using TF_DeviceList = System.IntPtr;
+
+using size_t = System.UIntPtr;
+using System.Numerics;
+using System.Collections.Generic;
+using System.Linq.Expressions;
+
+#pragma warning disable MSML_GeneralName
+#pragma warning disable MSML_PrivateFieldName
+#pragma warning disable MSML_ParameterLocalVarName
+
+namespace Microsoft.ML.Transforms.TensorFlow
+{
+ internal static partial class NativeBinding
+ {
+ public const string TensorFlowLibrary = "libtensorflow";
+ public const string TensorFlowLibraryGPU = "libtensorflowgpu";
+
+ internal static string GetStr (this IntPtr x) => Marshal.PtrToStringAnsi (x);
+ }
+
+ ///
+ /// Contains TensorFlow fundamental methods and utility functions.
+ ///
+ internal static class TFCore
+ {
+ internal static bool UseCPU = true;
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe IntPtr TF_Version ();
+
+ static TFCore ()
+ {
+ Init ();
+ }
+
+ internal static void Init ()
+ {
+ CheckSize ();
+ }
+
+ ///
+ /// Returns the version of the TensorFlow runtime in use.
+ ///
+ /// The version.
+ public static string Version => TF_Version ().GetStr ();
+
+ // extern size_t TF_DataTypeSize (TF_DataType dt);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern IntPtr TF_DataTypeSize (TFDataType dt);
+
+ ///
+ /// Gets the size in bytes of the specified TensorFlow data type.
+ ///
+ /// The data type size.
+ /// Dt.
+ public static long GetDataTypeSize (TFDataType dt) => (long)TF_DataTypeSize (dt);
+
+ // extern TF_Buffer * TF_GetAllOpList ();
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe IntPtr TF_GetAllOpList ();
+
+ ///
+ /// Retrieves the ProtocolBuffer describing all of the available operations in
+ /// the TensorFlow library in current use.
+ ///
+ /// The buffer contains a ProtocolBuffer encoded payload, you need a ProtocolBuffer reader to process the contents.
+ public static TFBuffer GetAllOpList ()
+ {
+ return new TFBuffer (TF_GetAllOpList ());
+ }
+
+ private static void CheckSize ()
+ {
+ unsafe {
+ if (sizeof (IntPtr) == 4) {
+ Console.Error.WriteLine (
+ "The TensorFlow native libraries were compiled in 64 bit mode, you must run in 64 bit mode\n" +
+ "With Mono, do that with mono --arch=64 executable.exe, if using an IDE like MonoDevelop,\n" +
+ "Xamarin Studio or Visual Studio for Mac, Build/Compiler settings, make sure that " +
+ "\"Platform Target\" has x64 selected.");
+ throw new Exception ();
+
+ }
+ }
+ }
+ }
+
+ ///
+ /// Base class for many TensorFlow data types that provides a common idiom to dispose and
+ /// release resources associated with the native data types. Generally, you do not need to use this.
+ ///
+ ///
+ ///
+ /// This implements the Dispose pattern in a reusable form for TensorFlow types.
+ ///
+ ///
+ /// Subclasses invoke the constructor with the handle that this will wrap, and must
+ /// override the NativeDispose method (internal) to release the associated resource.
+ ///
+ ///
+ internal abstract class TFDisposable : IDisposable
+ {
+ internal IntPtr handle;
+
+ ///
+ /// Returns the opaque handle to the object that this TFDisposable owns.
+ ///
+ /// The handle.
+ public IntPtr Handle => handle;
+
+ static TFDisposable ()
+ {
+ TFCore.Init ();
+ }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public TFDisposable ()
+ { }
+
+ ///
+ /// Initializes a new instance of the class
+ /// from the handle that it will wrap.
+ ///
+ public TFDisposable (IntPtr handle)
+ {
+ this.handle = handle;
+ }
+
+ ///
+ /// Releases all resource used by the object.
+ ///
+ /// Call Dispose when you are finished using the . The
+ /// Dispose method leaves the in an unusable state. After
+ /// calling Dispose, you must release all references to the so
+ /// the garbage collector can reclaim the memory that the was occupying.
+ public void Dispose ()
+ {
+ Dispose (true);
+ GC.SuppressFinalize (this);
+ }
+
+ ~TFDisposable ()
+ {
+ Dispose (false);
+ }
+
+ // Must be implemented in subclasses to dispose the unmanaged object, it does
+ // not need to take care of zeroing out the handle, that is done by the Dispose
+ // method inherited from TFDisposable
+ internal abstract void NativeDispose (IntPtr handle);
+
+ ///
+ /// Dispose the specified object
+ ///
+ /// If set to true it means that this method was called from Dispose, otherwise from the finalizer.
+ public virtual void Dispose (bool disposing)
+ {
+ if (disposing) {
+ if (handle != IntPtr.Zero)
+ NativeDispose (handle);
+ handle = IntPtr.Zero;
+ }
+ }
+
+ internal static void ObjectDisposedException ()
+ {
+ throw new ObjectDisposedException ("The object was disposed");
+ }
+ }
+
+ ///
+ /// ase class for many TensorFlow data types that provides a common idiom to dispose and
+ /// release resources associated with the native data types and whose unmanaged resource
+ /// disposing can be called from a background thread (the finalizer). Users do not
+ /// need to deal with this class.
+ ///
+ ///
+ /// Some object deletion APIs in TensorFlow can be invoked from a background thread,
+ /// so the release methods are suitable to be invoked from the Finalizer thread, in
+ /// those scenarios, subclass from this class rather than the TFDisposable class.
+ ///
+ internal abstract class TFDisposableThreadSafe : TFDisposable {
+ ///
+ /// Initializes a new instance of the class
+ /// from the handle that it will wrap.
+ ///
+ public TFDisposableThreadSafe (IntPtr handle) : base (handle)
+ {
+ }
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public TFDisposableThreadSafe ()
+ { }
+
+ ///
+ /// Dispose the object, unlike the default implementat in TFDisposable,
+ /// this will release the unmanaged resources from a background thread.
+ ///
+ /// If set to true disposing.
+ public override void Dispose (bool disposing)
+ {
+ if (handle != IntPtr.Zero)
+ NativeDispose (handle);
+ handle = IntPtr.Zero;
+ }
+ }
+
+ ///
+ /// TensorFlow Exception
+ ///
+ internal class TFException : Exception {
+ ///
+ /// Initializes a new instance of the class with a message.
+ ///
+ /// Message.
+ public TFException (string message) : base (message) { }
+ }
+
+ ///
+ /// Used to track the result of TensorFlow operations.
+ ///
+ ///
+ ///
+ /// TFStatus is used to track the status of a call to some TensorFlow
+ /// operations. Instances of this object are passed to various
+ /// TensorFlow operations and you can use the
+ /// to quickly check if the operation succeeded, or get more detail from the
+ /// and a human-readable text
+ /// using the property.
+ ///
+ ///
+ /// The convenience can be used
+ /// to raise a if the status of the
+ /// operation did not succeed.
+ ///
+ ///
+ internal class TFStatus : TFDisposable
+ {
+ // extern TF_Status * TF_NewStatus ();
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Status TF_NewStatus ();
+
+ ///
+ /// Per-thread global status that you can use if you do not need to create a new instance of this object.
+ ///
+ ///
+ /// This is provided as a convenience for APIs that take a TFStatus. While the TFStatus is usually an
+ /// optional parameter, when it is made optional, API calls that fail raise an exception. Use this
+ /// property to pass a TFStatus without having to allocate a new one. The problem with this of course
+ /// is that you risk having multiple parts of your code override this thread-global variable.
+ ///
+ [ThreadStatic] public static TFStatus Default = new TFStatus ();
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public TFStatus () : base (TF_NewStatus ())
+ {
+ }
+
+ // extern void TF_DeleteStatus (TF_Status *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_DeleteStatus (TF_Status status);
+
+ internal override void NativeDispose (IntPtr handle)
+ {
+ TF_DeleteStatus (handle);
+ }
+
+ // extern void TF_SetStatus (TF_Status *s, TF_Code code, const char *msg);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_SetStatus (TF_Status s, TFCode code, string msg);
+
+ ///
+ /// Sets the status code on this TFStatus.
+ ///
+ /// Code.
+ /// Message.
+ public void SetStatusCode (TFCode code, string msg)
+ {
+ TF_SetStatus (handle, code, msg);
+ }
+
+ // extern TF_Code TF_GetCode (const TF_Status *s);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TFCode TF_GetCode (TF_Status s);
+
+ ///
+ /// Gets the status code for the status code.
+ ///
+ /// The status code as an enumeration.
+ public TFCode StatusCode {
+ get {
+ if (handle == IntPtr.Zero)
+ throw new ObjectDisposedException ("TFStatus");
+ return TF_GetCode (handle);
+ }
+ }
+
+ // extern const char * TF_Message (const TF_Status *s);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe IntPtr TF_Message (TF_Status s);
+
+ ///
+ /// Gets a human-readable status message.
+ ///
+ /// The status message.
+ public string StatusMessage => TF_Message (handle).GetStr ();
+
+ ///
+ /// Returns a that represents the current .
+ ///
+ /// A that represents the current .
+ public override string ToString ()
+ {
+ if (handle == IntPtr.Zero)
+ throw new ObjectDisposedException ("TFStatus");
+
+ return string.Format ("[TFStatus: StatusCode={0}, StatusMessage={1}]", StatusCode, StatusMessage);
+ }
+
+ ///
+ /// Gets a value indicating whether this state has been set to ok.
+ ///
+ /// true if ok; otherwise, false.
+ public bool Ok => StatusCode == TFCode.Ok;
+
+ ///
+ /// Gets a value indicating whether this state has been set to an error.
+ ///
+ /// true if error; otherwise, false.
+ public bool Error => StatusCode != TFCode.Ok;
+
+ ///
+ /// Convenience method that raises an exception if the current status is an error.
+ ///
+ ///
+ /// You can use this method as a convenience to raise an exception after you
+ /// invoke an operation if the operation did not succeed.
+ ///
+ public void Raise ()
+ {
+ if (TF_GetCode (handle) != TFCode.Ok)
+ throw new TFException (StatusMessage);
+ }
+
+ //
+ // Utility function used to simplify implementing the idiom
+ // where the user optionally provides a TFStatus, if it is provided,
+ // the error is returned there; If it is not provided, then an
+ // exception is raised.
+ //
+
+ internal bool CheckMaybeRaise (TFStatus incomingStatus, bool last = true)
+ {
+ if (incomingStatus == null) {
+ if (handle == IntPtr.Zero)
+ Console.WriteLine ("oops");
+ if (StatusCode != TFCode.Ok) {
+ var e = new TFException (StatusMessage);
+ if (last)
+ Dispose ();
+ throw e;
+ }
+ if (last)
+ Dispose ();
+ return true;
+ }
+ return StatusCode == TFCode.Ok;
+ }
+
+ internal static TFStatus Setup (TFStatus incoming)
+ {
+ return incoming == null ? new TFStatus () : incoming;
+ }
+ }
+
+ ///
+ /// The session options object holds configuration options that you want to use during your session, like the TensorFlow target or the configuration.
+ ///
+ internal class TFSessionOptions : TFDisposable
+ {
+ // extern TF_SessionOptions * TF_NewSessionOptions ();
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_SessionOptions TF_NewSessionOptions ();
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public TFSessionOptions () : base (TF_NewSessionOptions ()) { }
+
+ // extern void TF_DeleteSessionOptions (TF_SessionOptions *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_DeleteSessionOptions (TF_SessionOptions options);
+ internal override void NativeDispose (IntPtr handle)
+ {
+ TF_DeleteSessionOptions (handle);
+ }
+
+ // extern void TF_SetTarget (TF_SessionOptions *options, const char *target);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_SetTarget (TF_SessionOptions options, string target);
+
+ ///
+ /// Sets the target in options.
+ ///
+ /// target can be empty, a single entry, or a comma separated list of entries.
+ /// Each entry is in one of the following formats: "local", ip:port, host:port.
+ ///
+ public void SetTarget (string target)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+
+ TF_SetTarget (handle, target);
+ }
+
+ // extern void TF_SetConfig (TF_SessionOptions *options, const void *proto, size_t proto_len, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_SetConfig (TF_SessionOptions options, IntPtr proto, size_t proto_len, TF_Status status);
+
+ ///
+ /// Sets the configuration information for the session.
+ ///
+ /// Serialized protocol buffer for the tensorflow.ConfigProto message.
+ /// Length of the buffer.
+ /// If config was not parsed successfully as a ConfigProto, the error is recorded here.
+ ///
+ /// The configuration option is a Protocol Buffer representing the tensorflow.ConfigProto
+ ///
+ public void SetConfig (IntPtr protoData, int length, TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+
+ var cstatus = TFStatus.Setup (status);
+
+ TF_SetConfig (handle, protoData, (UIntPtr)length, cstatus.handle);
+ cstatus.CheckMaybeRaise (status);
+ }
+
+ }
+
+ ///
+ /// Represents a computation graph. Graphs may be shared between sessions and are thread safe.
+ ///
+ ///
+ ///
+ /// Graphs consist of operations (represented by TFOperation objects), these can be named, or
+ /// the runtime will automatically assign a name.
+ ///
+ ///
+ /// For debugging purposes, you might want to group operations together, for this, call the
+ /// WithScope method with your new scope, which will create a new namespace for your object names.
+ ///
+ ///
+ /// For example, if you call WithScope ("demo"), and add an operation named "add" inside the
+ /// scope, the full name of the operation will be "demo/add", if you create a new scope inside, say
+ /// "hot", and add a "sub" operation there the result will be "demo/hot/sub".
+ ///
+ ///
+ internal partial class TFGraph : TFDisposableThreadSafe
+ {
+ // extern TF_Graph * TF_NewGraph ();
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Graph TF_NewGraph ();
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public TFGraph () : base (TF_NewGraph ())
+ {
+ }
+
+ // extern void TF_DeleteGraph (TF_Graph *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_DeleteGraph (TF_Graph graph);
+ internal override void NativeDispose (IntPtr handle)
+ {
+ TF_DeleteGraph (handle);
+ }
+
+ // extern int TF_GraphGetTensorNumDims (TF_Graph *graph, TF_Output output, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe int TF_GraphGetTensorNumDims (TF_Graph graph, TFOutput output, TF_Status status);
+
+ // extern void TF_GraphGetTensorShape (TF_Graph *graph, TF_Output output, int64_t *dims, int num_dims, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_GraphGetTensorShape (TF_Graph graph, TFOutput output, long [] dims, int num_dims, TF_Status status);
+
+ ///
+ /// Returns the shape of a tensor specified in .
+ ///
+ ///
+ /// The tensor shape. If the number of dimensions in the shape is unknown or the shape is, a scalar, the values in the array will be zero. Otherwise, each element of will be set corresponding to the size of the dimension. An unknown dimension is represented by -1.
+ /// The tensor that you want to look up.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public TFShape GetTensorShape (TFOutput output, TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ var cstatus = TFStatus.Setup (status);
+ var n = TF_GraphGetTensorNumDims (handle, output, cstatus.handle);
+ if (!cstatus.CheckMaybeRaise (status, last: false))
+ return TFShape.Unknown;
+ if (n == -1)
+ return TFShape.Unknown;
+
+ var dims = new long [n];
+ TF_GraphGetTensorShape (handle, output, dims, dims.Length, cstatus.handle);
+ cstatus.CheckMaybeRaise (status);
+ return new TFShape (dims);
+ }
+
+ // extern void TF_GraphToGraphDef (TF_Graph *graph, TF_Buffer *output_graph_def, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_GraphToGraphDef (TF_Graph graph, LLBuffer* output_graph_def, TF_Status status);
+
+ ///
+ /// Write out a serialized representation of the graph (as a GraphDef protocol buffer message) into .
+ ///
+ /// Target buffer where the graphs is serialized into.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public void ToGraphDef (TFBuffer outputGraphDef, TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ if (outputGraphDef == null)
+ throw new ArgumentNullException (nameof (outputGraphDef));
+
+ var cstatus = TFStatus.Setup (status);
+ unsafe
+ {
+ TF_GraphToGraphDef (handle, outputGraphDef.LLBuffer, cstatus.handle);
+ }
+ cstatus.CheckMaybeRaise (status);
+ }
+
+ // extern void TF_GraphImportGraphDef (TF_Graph *graph, const TF_Buffer *graph_def, const TF_ImportGraphDefOptions *options, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_GraphImportGraphDef (TF_Graph graph, LLBuffer* graph_def, TF_ImportGraphDefOptions options, TF_Status status);
+
+ ///
+ /// Import a serialized graph into this graph, using the specified prefix.
+ ///
+ /// The import.
+ /// A buffer containing the serialized graph.
+ /// A prefix that will be prepended to names of nodes in the when they are imported into the graph.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public void Import (TFBuffer graphDef, string prefix = "", TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ if (graphDef == null)
+ throw new ArgumentNullException (nameof (graphDef));
+ if (prefix == null)
+ throw new ArgumentNullException (nameof (prefix));
+
+ using (var options = new TFImportGraphDefOptions ()) {
+ options.SetPrefix (prefix);
+ Import (graphDef, options, status);
+ }
+ }
+
+ ///
+ /// Import a serialized graph into this graph, using the specified importing options.
+ ///
+ /// The import.
+ /// A buffer containing the serialized graph.
+ /// Importing graph options.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public void Import (TFBuffer graphDef, TFImportGraphDefOptions options, TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ if (graphDef == null)
+ throw new ArgumentNullException (nameof (graphDef));
+ if (options == null)
+ throw new ArgumentNullException (nameof (options));
+
+ var cstatus = TFStatus.Setup (status);
+ unsafe
+ {
+ TF_GraphImportGraphDef (handle, graphDef.LLBuffer, options.handle, cstatus.handle);
+ }
+ cstatus.CheckMaybeRaise (status);
+ }
+
+ ///
+ /// Import a serialized graph held in a byte array into this graph, using the specified prefix.
+ ///
+ /// The import.
+ /// A byte array containing the serialized graph.
+ /// A prefix that will be prepended to names of nodes in the graph when they are imported into the graph.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public void Import (byte [] buffer, string prefix = "", TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ if (buffer == null)
+ throw new ArgumentNullException (nameof (buffer));
+ if (prefix == null)
+ throw new ArgumentNullException (nameof (prefix));
+ using (var options = new TFImportGraphDefOptions ()) {
+ options.SetPrefix (prefix);
+ Import (buffer, options, status);
+ }
+ }
+
+ ///
+ /// Import a serialized graph held in a byte array into this graph, using the specified import options.
+ ///
+ /// The import.
+ /// A byte array containing the serialized graph.
+ /// Importing graph options.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ ///
+ /// If you are tryig to load a file stored using the SavedModel file format, you should use the API instead.
+ ///
+ public void Import (byte [] buffer, TFImportGraphDefOptions options, TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ if (buffer == null)
+ throw new ArgumentNullException (nameof (buffer));
+ if (options == null)
+ throw new ArgumentNullException (nameof (options));
+ var cstatus = TFStatus.Setup (status);
+ using (var tb = new TFBuffer (buffer, 0, buffer.Length))
+ Import (tb, options, status);
+
+ cstatus.CheckMaybeRaise (cstatus);
+ }
+
+ // extern TF_Operation * TF_GraphOperationByName (TF_Graph *graph, const char *oper_name);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Operation TF_GraphOperationByName (TF_Graph graph, string oper_name);
+
+ ///
+ /// Gets the with the specified name, or null if the named operation does not exist in the graph.
+ ///
+ /// Name to lookup.
+ public TFOperation this [string name] {
+ get {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ var h = TF_GraphOperationByName (handle, name);
+ if (h == IntPtr.Zero)
+ return null;
+ return new TFOperation (this, h);
+ }
+ }
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern string TF_GraphDebugString (TF_Graph graph, out IntPtr len);
+
+ public override string ToString ()
+ {
+ IntPtr len;
+ return TF_GraphDebugString (Handle, out len);
+ }
+ }
+
+ ///
+ /// Represents a computation node in the graph. Tensorflow operations are attached to a .
+ ///
+ ///
+ /// TFOperations are usually created by invoking one of the methods in
+ /// , but they can also be constructed
+ /// manually using the low-level API.
+ ///
+ internal partial class TFOperation
+ {
+ internal IntPtr handle;
+
+ ///
+ /// Gets the handle to the unmanaged TF_Operation object.
+ ///
+ /// The handle.
+ public IntPtr Handle => handle;
+
+ // Pointer to the graph, to keep it from collecting if there are TFOperations alive.
+ internal TFGraph graph;
+
+ internal TFOperation (TFGraph graph, IntPtr handle)
+ {
+ this.handle = handle;
+ this.graph = graph;
+ }
+
+ ///
+ /// Returns the handle to the idx-th output of the operation.
+ ///
+ /// Index of the output in the operation.
+ public TFOutput this [int idx] {
+ get {
+ return new TFOutput (this, idx);
+ }
+ }
+ }
+
+ ///
+ /// Device type
+ ///
+ internal enum DeviceType
+ {
+ ///
+ /// The device is the Central Processing Unit (CPU)
+ ///
+ CPU,
+
+ ///
+ /// The device is a Graphics Processing Unit (GPU)
+ ///
+ GPU,
+
+ ///
+ /// The device is a Tensor Processing Unit (TPU)
+ ///
+ TPU
+ }
+
+ ///
+ /// Describes the device attributes
+ ///
+ internal class DeviceAttributes
+ {
+ internal DeviceAttributes (string name, DeviceType deviceType, long memoryLimitBytes)
+ {
+ Name = name;
+ DeviceType = deviceType;
+ MemoryLimitBytes = memoryLimitBytes;
+ }
+
+ ///
+ /// The full name of the device (e.g. /job:worker/replica:0/...)
+ ///
+ public string Name { get; private set; }
+
+ ///
+ /// Gets the type of the device.
+ ///
+ /// The type of the device.
+ public DeviceType DeviceType { get; private set; }
+
+ ///
+ /// The amount of memory associated with a given device.
+ ///
+ /// The memory limit bytes.
+ public long MemoryLimitBytes { get; private set; }
+ }
+
+ ///
+ /// Contains options that are used to control how graph importing works.
+ ///
+ internal class TFImportGraphDefOptions : TFDisposable
+ {
+ // extern TF_ImportGraphDefOptions * TF_NewImportGraphDefOptions ();
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_ImportGraphDefOptions TF_NewImportGraphDefOptions ();
+
+ public TFImportGraphDefOptions () : base (TF_NewImportGraphDefOptions ())
+ {
+ }
+
+ // extern void TF_DeleteImportGraphDefOptions (TF_ImportGraphDefOptions *opts);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_DeleteImportGraphDefOptions (TF_ImportGraphDefOptions opts);
+
+ internal override void NativeDispose (IntPtr handle)
+ {
+ TF_DeleteImportGraphDefOptions (handle);
+ }
+
+ // extern void TF_ImportGraphDefOptionsSetPrefix (TF_ImportGraphDefOptions *opts, const char *prefix);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_ImportGraphDefOptionsSetPrefix (TF_ImportGraphDefOptions opts, string prefix);
+
+ public void SetPrefix (string prefix)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ TF_ImportGraphDefOptionsSetPrefix (handle, prefix);
+ }
+
+ // extern void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions *opts, const char* src_name, int src_index, TF_Output dst);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions opts, string src_name, int src_index, TFOutput dst);
+
+ ///
+ /// Adds an input mapping from a source name and index to a destination output
+ ///
+ /// Source name.
+ /// Source index (in the source).
+ /// Replacement value for the srcName:srcIndex.
+ ///
+ /// Set any imported nodes with input `src_name:src_index` to have that input
+ /// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
+ /// `dst` references a node already existing in the graph being imported into.
+ ///
+ public void AddInputMapping (string srcName, int srcIndex, TFOutput dst)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ TF_ImportGraphDefOptionsAddInputMapping (handle, srcName, srcIndex, dst);
+ }
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern void TF_ImportGraphDefOptionsAddControlDependency (TF_ImportGraphDefOptions opts, TF_Operation oper);
+
+ ///
+ /// Cause the imported graph to have a control dependency on the provided operation.
+ ///
+ /// This operation should exist in the graph being imported to.
+ public void AddControlDependency (TFOperation operation)
+ {
+ if (operation == null)
+ throw new ArgumentNullException (nameof (operation));
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+
+ TF_ImportGraphDefOptionsAddControlDependency (handle, operation.handle);
+ }
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern void TF_ImportGraphDefOptionsAddReturnOutput (TF_ImportGraphDefOptions opts, string oper_name, int index);
+
+ ///
+ /// Add an output in the graph definition to be returned via the return outputs parameter.
+ ///
+ /// Operation name.
+ /// Operation index.
+ ///
+ /// If the output is remapped via an input
+ /// mapping, the corresponding existing tensor in graph will be returned.
+ ///
+ public void AddReturnOutput (string operName, int index)
+ {
+ if (operName == null)
+ throw new ArgumentNullException (nameof (operName));
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ TF_ImportGraphDefOptionsAddReturnOutput (handle, operName, index);
+ }
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern int TF_ImportGraphDefOptionsNumReturnOutputs (TF_ImportGraphDefOptions opts);
+
+ ///
+ /// Gets the number return outputs added via AddReturnOutput.
+ ///
+ /// The number return outputs.
+ public int NumReturnOutputs {
+ get {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ return TF_ImportGraphDefOptionsNumReturnOutputs (handle);
+ }
+ }
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern void TF_ImportGraphDefOptionsRemapControlDependency (TF_ImportGraphDefOptions opts, string srcName, TF_Operation dst);
+
+ ///
+ /// Sets any imported nodes with a given control input to have it replaced with an operation
+ ///
+ /// Node in the graph to be imported.
+ /// References an operation that already exists in the graph being imported.
+ ///
+ /// Set any imported nodes with control input to have that input
+ /// replaced with .
+ ///
+ public void RemapControlDependency (string srcName, TFOperation destination)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ if (srcName == null)
+ throw new ArgumentNullException (nameof (srcName));
+ if (destination == null)
+ throw new ArgumentNullException (nameof (destination));
+ if (destination.Handle == IntPtr.Zero)
+ throw new ObjectDisposedException (nameof (destination));
+ TF_ImportGraphDefOptionsRemapControlDependency (handle, srcName, destination.Handle);
+ }
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern void TF_ImportGraphDefOptionsSetUniquifyNames (TF_ImportGraphDefOptions opts, byte uniquify);
+
+ ///
+ /// Set whether to uniquify imported operation names.
+ ///
+ /// If set to true imported operation names will be modified if their name already exists in the graph.
+ /// If set to false conflicting names will be treated as an error.
+ ///
+ ///
+ /// Note that this option has no effect if a prefix is set, since the prefix will guarantee all names are
+ /// Defaults to false.
+ ///
+ public void SetUniquifyNames (bool uniquifyNames)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+
+ TF_ImportGraphDefOptionsSetUniquifyNames (handle, uniquifyNames ? (byte) 1 : (byte) 0);
+ }
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern void TF_ImportGraphDefOptionsSetUniquifyPrefix (TF_ImportGraphDefOptions opts, byte uniquify_prefix);
+
+ ///
+ /// Sets the uniquify prefix. This option has no effect if no prefix is specified.
+ ///
+ /// If set to true the specified prefix will be modified if it already exists as an
+ /// operation name or prefix in the graph.
+ /// If set to false a conflicting prefix will be treated as an error.
+ ///
+ public void SetUniquifyPrefix (bool uniquifyPrefix)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ TF_ImportGraphDefOptionsSetUniquifyPrefix (handle, uniquifyPrefix ? (byte)1 : (byte)0);
+ }
+ }
+
+ ///
+ /// Drives the execution of a graph
+ ///
+ ///
+ ///
+ /// This creates a new context to execute a TFGraph. You can use the
+ /// constructor to create an empty session, or you can load an existing
+ /// model using the static method in this class.
+ ///
+ ///
+ /// To execute operations with the graph, call the method
+ /// which returns an object that you can use to build the operation by providing
+ /// the inputs, requesting the operations that you want to execute and the desired outputs.
+ ///
+ ///
+ /// The method is a high-level helper function that wraps a
+ /// call to the method which just takes too many parameters that must
+ /// be kept in sync.
+ ///
+ ///
+ internal class TFSession : TFDisposableThreadSafe
+ {
+ // extern TF_Session * TF_NewSession (TF_Graph *graph, const TF_SessionOptions *opts, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Session TF_NewSession (TF_Graph graph, TF_SessionOptions opts, TF_Status status);
+
+ ///
+ /// Gets the graph associated with this TensorFlow session.
+ ///
+ /// The graph.
+ public TFGraph Graph { get; private set; }
+
+ private TFSession (IntPtr handle, TFGraph graph) : base (handle)
+ {
+ Graph = graph;
+ }
+
+ ///
+ /// Creates a new execution session associated with the specified session graph with some configuration options.
+ ///
+ /// The Graph to which this session is associated.
+ /// Session options.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public TFSession (TFGraph graph, TFSessionOptions sessionOptions, TFStatus status = null) : base (IntPtr.Zero)
+ {
+ Graph = graph;
+ var cstatus = TFStatus.Setup (status);
+ var h = TF_NewSession (graph.handle, sessionOptions.handle, cstatus.handle);
+ cstatus.CheckMaybeRaise (status);
+ handle = h;
+ }
+
+ ///
+ /// Creates a new execution session associated with the specified session graph.
+ ///
+ /// The Graph to which this session is associated.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public TFSession (TFGraph graph, TFStatus status = null) : base (IntPtr.Zero)
+ {
+ Graph = graph;
+ var cstatus = TFStatus.Setup (status);
+ TF_Status h;
+ using (var empty = new TFSessionOptions())
+ {
+ h = TF_NewSession(graph.handle, empty.Handle, cstatus.handle);
+ }
+ cstatus.CheckMaybeRaise (status);
+ handle = h;
+ }
+
+ ///
+ /// Creates a new execution session with an empty graph
+ ///
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ ///
+ /// The created graph can be retrieved using the Graph property on the session.
+ ///
+ public TFSession (TFStatus status = null) : this (new TFGraph (), status)
+ {
+ }
+
+ // extern TF_Session * TF_LoadSessionFromSavedModel (const TF_SessionOptions *session_options, const TF_Buffer *run_options, const char *export_dir, const char *const *tags, int tags_len, TF_Graph *graph, TF_Buffer *meta_graph_def, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_Session TF_LoadSessionFromSavedModel (TF_SessionOptions session_options, LLBuffer* run_options, string export_dir, string [] tags, int tags_len, TF_Graph graph, LLBuffer* meta_graph_def, TF_Status status);
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe TF_DeviceList TF_SessionListDevices (TF_Session session, TF_Status status);
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe int TF_DeviceListCount (TF_DeviceList list);
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe IntPtr TF_DeviceListName (TF_DeviceList list, int index, TF_Status status);
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe IntPtr TF_DeviceListType (TF_DeviceList list, int index, TF_Status status);
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe long TF_DeviceListMemoryBytes (TF_DeviceList list, int index, TF_Status status);
+
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_DeleteDeviceList (TF_DeviceList list);
+
+ ///
+ /// Lists available devices in this session.
+ ///
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public IEnumerable ListDevices(TFStatus status = null)
+ {
+ var cstatus = TFStatus.Setup (status);
+ var rawDeviceList = TF_SessionListDevices (Handle, cstatus.handle);
+ var size = TF_DeviceListCount (rawDeviceList);
+
+ var list = new List ();
+ for (var i = 0; i < size; i++) {
+ var name = Marshal.PtrToStringAnsi (TF_DeviceListName (rawDeviceList, i, cstatus.handle));
+ var deviceType = (DeviceType) Enum.Parse (typeof(DeviceType), Marshal.PtrToStringAnsi (TF_DeviceListType (rawDeviceList, i, cstatus.handle)));
+ var memory = TF_DeviceListMemoryBytes (rawDeviceList, i, cstatus.handle);
+
+ list.Add (new DeviceAttributes (name, deviceType, memory));
+ }
+
+ TF_DeleteDeviceList (rawDeviceList);
+
+ return list;
+ }
+
+ ///
+ /// Creates a session and graph from a model stored in the SavedModel file format.
+ ///
+ /// On success, this populates the provided with the contents of the graph stored in the specified model and with the MetaGraphDef of the loaded model.
+ /// Session options to use for the new session.
+ /// Options to use to initialize the state (can be null).
+ /// must be set to the path of the exported SavedModel.
+ /// must include the set of tags used to identify one MetaGraphDef in the SavedModel.
+ /// This must be a newly created graph.
+ /// On success, this will be populated on return with the contents of the MetaGraphDef (can be null).
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ ///
+ ///
+ /// This function creates a new session using the specified and then initializes
+ /// the state (restoring tensors and other assets) using .
+ ///
+ ///
+ /// This function loads the data that was saved using the SavedModel file format, as described
+ /// here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
+ ///
+ ///
+ public TFSession FromSavedModel (TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string [] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
+ {
+ if (graph == null)
+ throw new ArgumentNullException (nameof (graph));
+ if (tags == null)
+ throw new ArgumentNullException (nameof (tags));
+ if (exportDir == null)
+ throw new ArgumentNullException (nameof (exportDir));
+ if (metaGraphDef == null)
+ throw new ArgumentNullException (nameof (metaGraphDef));
+ var cstatus = TFStatus.Setup (status);
+ unsafe
+ {
+ var h = TF_LoadSessionFromSavedModel (sessionOptions.handle, runOptions == null ? null : runOptions.LLBuffer, exportDir, tags, tags.Length, graph.handle, metaGraphDef == null ? null : metaGraphDef.LLBuffer, cstatus.handle);
+
+ if (cstatus.CheckMaybeRaise (status)) {
+ return new TFSession (h, graph);
+ }
+ }
+ return null;
+ }
+
+ // extern void TF_CloseSession (TF_Session *, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_CloseSession (TF_Session session, TF_Status status);
+
+ ///
+ /// Closes the session. Contacts any other processes associated with the session, if applicable.
+ ///
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ ///
+ /// Can not be called after calling DeleteSession.
+ ///
+ public void CloseSession (TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ var cstatus = TFStatus.Setup (status);
+ TF_CloseSession (handle, cstatus.handle);
+ cstatus.CheckMaybeRaise (status);
+ }
+
+ // extern void TF_DeleteSession (TF_Session *, TF_Status *status);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_DeleteSession (TF_Session session, TF_Status status);
+
+ ///
+ /// Deletes the session.
+ ///
+ /// Status.
+ public void DeleteSession (TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ var cstatus = TFStatus.Setup (status);
+ TF_DeleteSession (handle, cstatus.handle);
+ cstatus.CheckMaybeRaise (status);
+ }
+
+ internal override void NativeDispose (IntPtr handle)
+ {
+ using (var s = new TFStatus ()) {
+ TF_DeleteSession (handle, s.handle);
+ }
+ }
+
+ // extern void TF_SessionRun (TF_Session *session, const TF_Buffer *run_options, const TF_Output *inputs, TF_Tensor *const *input_values, int ninputs, const TF_Output *outputs, TF_Tensor **output_values, int noutputs, const TF_Operation *const *target_opers, int ntargets, TF_Buffer *run_metadata, TF_Status *);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe void TF_SessionRun (TF_Session session, LLBuffer* run_options, TFOutput [] inputs, TF_Tensor [] input_values, int ninputs, TFOutput [] outputs, TF_Tensor [] output_values, int noutputs, TF_Operation [] target_opers, int ntargets, LLBuffer* run_metadata, TF_Status status);
+
+ ///
+ /// Use the runner class to easily configure inputs, outputs and targets to be passed to the session runner.
+ ///
+ ///
+ ///
+ /// The runner has a simple API that allows developers to call the AddTarget, AddInput, AddOutput and Fetch
+ /// to construct the parameters that will be passed to the TFSession.Run method.
+ ///
+ ///
+ /// Instances of this class are created by calling the GetRunner method on the TFSession.
+ ///
+ ///
+ /// The various methods in this class return an instance to the Runner itsel, to allow
+ /// to easily construct chains of execution like this:
+ ///
+ ///
+ /// var result = session.GetRunner ().AddINput (myInput).Fetch (MyOutput).Run ();
+ ///
+ ///
+ /// You do not need to chain the operations, this works just the same:
+ ///
+ ///
+ /// runner = session.GetRunner ();
+ /// runner.AddInput(myInput);
+ /// runner.Fetch(myOutput);
+ /// var results = runner.Run();
+ ///
+ ///
+ public class Runner
+ {
+ private List inputs;
+ private List outputs;
+ private List inputValues;
+ private List targets;
+ private TFSession session;
+
+ internal Runner (TFSession session)
+ {
+ inputs = new List();
+ outputs = new List();
+ inputValues = new List();
+ targets = new List();
+ this.session = session;
+ RunMetadata = null;
+ RunOptions = null;
+ }
+
+ ///
+ /// Adds an input to the session
+ ///
+ /// An instance to the runner, so you can easily chain the operations together.
+ /// Incoming port.
+ /// Value to assing to the incoming port.
+ public Runner AddInput (TFOutput input, TFTensor value)
+ {
+ if (value == null)
+ throw new ArgumentNullException (nameof (value));
+ inputs.Add (input);
+ inputValues.Add (value);
+ return this;
+ }
+
+ ///
+ /// Adds an input to the session specified by name, with an optional index in the operation (separated by a colon).
+ ///
+ /// An instance to the runner, so you can easily chain the operations together.
+ /// Incoming port, with an optional index separated by a colon.
+ /// Value to assing to the incoming port.
+ public Runner AddInput (string input, TFTensor value)
+ {
+ if (value == null)
+ throw new ArgumentNullException (nameof (value));
+ inputs.Add (ParseOutput (input));
+ inputValues.Add (value);
+ return this;
+ }
+
+ ///
+ /// Adds the specified operations as the ones to be retrieved.
+ ///
+ /// An instance to the runner, so you can easily chain the operations together.
+ /// One or more targets.
+ public Runner AddTarget (params TFOperation [] targets)
+ {
+ foreach (var t in targets)
+ this.targets.Add (t);
+ return this;
+ }
+
+ // Parses user strings that contain both the operation name and an index.
+ private TFOutput ParseOutput (string operation)
+ {
+ var p = operation.IndexOf (':');
+ if (p != -1 && p != operation.Length - 1){
+ var op = operation.Substring (0, p);
+ if (int.TryParse (operation.Substring (p + 1), out var idx)){
+ return session.Graph [op] [idx];
+ }
+ }
+ return session.Graph [operation] [0];
+ }
+
+ ///
+ /// Adds the specified operation names as the ones to be retrieved.
+ ///
+ /// An instance to the runner, so you can easily chain the operations together.
+ /// One or more target names.
+ public Runner AddTarget (params string [] targetNames)
+ {
+ foreach (var tn in targetNames)
+ targets.Add (session.Graph [tn]);
+ return this;
+ }
+
+ ///
+ /// Makes the Run method return the index-th output of the tensor referenced by operation.
+ ///
+ /// The instance of runner, to allow chaining operations.
+ /// The name of the operation in the graph.
+ /// The index of the output in the operation.
+ public Runner Fetch (string operation, int index)
+ {
+ var op = session.Graph [operation];
+ outputs.Add (op [index]);
+ return this;
+ }
+
+ ///
+ /// Makes the Run method return the output of the tensor referenced by operation, the operation string can contain the output index.
+ ///
+ /// The instance of runner, to allow chaining operations.
+ /// The name of the operation in the graph, which might be a simple name, or it might be name:index,
+ /// where the index is the .
+ public Runner Fetch (string operation)
+ {
+ var op = ParseOutput (operation);
+ outputs.Add (op);
+ return this;
+ }
+
+ ///
+ /// Makes the Run method return the output of the tensor referenced by output
+ ///
+ /// The instance of runner, to allow chaining operations.
+ /// The output referencing a specified tensor.
+ public Runner Fetch (TFOutput output)
+ {
+ outputs.Add (output);
+ return this;
+ }
+
+ ///
+ /// Makes the Run method return the output of all the tensor referenced by outputs.
+ ///
+ /// The instance of runner, to allow chaining operations.
+ /// The outputs referencing a specified tensor.
+ public Runner Fetch (params TFOutput [] outputs)
+ {
+ foreach (var output in outputs)
+ this.outputs.Add (output);
+ return this;
+ }
+
+ ///
+ /// Makes the Run method return the output of all the tensor referenced by outputs.
+ ///
+ /// The instance of runner, to allow chaining operations.
+ /// The output sreferencing a specified tensor.
+ public Runner Fetch (params string [] outputs)
+ {
+ foreach (var output in outputs)
+ this.outputs.Add (ParseOutput (output));
+ return this;
+ }
+
+ ///
+ /// Protocol buffer encoded block containing the metadata passed to the method.
+ ///
+ public TFBuffer RunMetadata;
+
+ ///
+ /// Protocol buffer encoded block containing the run options passed to the method.
+ ///
+ public TFBuffer RunOptions;
+
+ ///
+ /// Execute the graph fragments necessary to compute all requested fetches.
+ ///
+ /// One TFTensor for each call to Fetch that you made, in the order that you made them.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public TFTensor [] Run (TFStatus status = null)
+ {
+ return session.Run (inputs.ToArray (), inputValues.ToArray (), outputs.ToArray (), targets.ToArray (), RunMetadata, RunOptions, status);
+ }
+
+ ///
+ /// Run the specified operation, by adding it implicity to the output, single return value
+ ///
+ /// The output of the operation.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ ///
+ /// This method is a convenience method, and when you call it, it will clear any
+ /// calls that you might have done to Fetch() and use the specified operation to Fetch
+ /// instead.
+ ///
+ public TFTensor Run (TFOutput operation, TFStatus status = null)
+ {
+ outputs.Clear ();
+ Fetch (operation);
+ return Run (status) [0];
+ }
+
+ }
+
+ ///
+ /// Gets a new runner, this provides a simpler API to prepare the inputs to run on a session
+ ///
+ /// The runner.
+ ///
+ /// The runner has a simple API that allows developers to call the AddTarget, AddInput, AddOutput and Fetch
+ /// to construct the parameters that will be passed to the TFSession.Run method.
+ ///
+ /// The Run method will return an array of TFTensor values, one for each invocation to the Fetch method.
+ ///
+ public Runner GetRunner ()
+ {
+ return new Runner (this);
+ }
+
+ ///
+ /// Executes a pipeline given the specified inputs, inputValues, outputs, targetOpers, runMetadata and runOptions.
+ /// A simpler API is available by calling the method which performs all the bookkeeping
+ /// necessary.
+ ///
+ /// An array of tensors fetched from the requested outputs.
+ /// Inputs nodes.
+ /// Input values.
+ /// Output nodes.
+ /// Target operations to execute.
+ /// Run metadata, a buffer containing the protocol buffer encoded value for https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/core/protobuf/config.proto.
+ /// Run options, a buffer containing the protocol buffer encoded value for https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/core/protobuf/config.proto.
+ /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error.
+ public TFTensor [] Run (TFOutput [] inputs, TFTensor [] inputValues, TFOutput [] outputs, TFOperation [] targetOpers = null, TFBuffer runMetadata = null, TFBuffer runOptions = null, TFStatus status = null)
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException ();
+ if (inputs == null)
+ throw new ArgumentNullException (nameof (inputs));
+ if (inputValues == null)
+ throw new ArgumentNullException (nameof (inputValues));
+ if (outputs == null)
+ throw new ArgumentNullException (nameof (outputs));
+ int iLen = inputs.Length;
+ if (iLen != inputValues.Length)
+ throw new ArgumentException ("inputs and inputValues have different lengths", "inputs");
+ int oLen = outputs.Length;
+
+ // runOptions and runMetadata might be null
+ var cstatus = TFStatus.Setup (status);
+
+ // Create arrays for the unmanaged versions
+ var ivals = new IntPtr [iLen];
+ for (int i = 0; i < iLen; i++)
+ ivals [i] = inputValues [i].handle;
+
+ // I believe this might not be necessary, the output values in TF_SessionRun looks like a write-only result
+ var ovals = new IntPtr [outputs.Length];
+ IntPtr [] topers = null;
+ int tLen = 0;
+ if (targetOpers != null) {
+ tLen = targetOpers.Length;
+ topers = new IntPtr [tLen];
+ for (int i = 0; i < tLen; i++)
+ topers [i] = targetOpers [i].Handle;
+ }
+
+ unsafe
+ {
+ TF_SessionRun (handle, runOptions == null ? null : runOptions.LLBuffer, inputs, ivals, iLen, outputs, ovals, oLen, topers, tLen, runMetadata == null ? null : runMetadata.LLBuffer, cstatus.handle);
+ }
+ cstatus.CheckMaybeRaise (status);
+ var result = new TFTensor [oLen];
+ for (int i = 0; i < oLen; i++) {
+ result [i] = new TFTensor (ovals [i]);
+ }
+ return result;
+ }
+ }
+
+ ///
+ /// The data type for a specific tensor.
+ ///
+ ///
+ /// Tensors have uniform data types, all the elements of the tensor are of this
+ /// type and they dictate how TensorFlow will treat the data stored.
+ ///
+ internal enum TFDataType : uint
+ {
+ ///
+ /// The TFDataType has not been set
+ ///
+ Unknown = 0,
+
+ ///
+ /// Single precission floatint point, 32-bits (C# float)
+ ///
+ Float = 1,
+ ///
+ /// Double precission floatint point, 64-bits (C# double)
+ ///
+ Double = 2,
+ ///
+ /// 32-bit signed integers (C# int)
+ ///
+ Int32 = 3,
+ ///
+ /// 8 bit unsigned integers (C# byte)
+ ///
+ UInt8 = 4,
+ ///
+ /// 16-bit signed integers (C# short)
+ ///
+ Int16 = 5,
+ ///
+ /// 8-bit signed integers (C# sbyte)
+ ///
+ Int8 = 6,
+ ///
+ /// Binary blob
+ ///
+ String = 7,
+ ///
+ /// Single precission complex numbers (32-bit floats)
+ ///
+ Complex64 = 8,
+ ///
+ /// 32-bit float based complex numbers
+ ///
+ Complex = 8,
+ ///
+ /// 64-bit signed integers (C# long)
+ ///
+ Int64 = 9,
+ ///
+ /// 8-bit boolean (C# bool)
+ ///
+ Bool = 10,
+ ///
+ /// Quantized 8-bit signed integer
+ ///
+ QInt8 = 11,
+ ///
+ /// Quantized 8-bit unsigned integer
+ ///
+ QUInt8 = 12,
+ ///
+ /// Quantized 32-bit signed integer
+ ///
+ QInt32 = 13,
+ ///
+ /// Float32 truncated to 16 bits. Only for cast operations.
+ ///
+ BFloat16 = 14,
+ ///
+ /// Quantized 16-bit signed integer
+ ///
+ QInt16 = 15,
+ ///
+ /// Quantized 16-bit unsigned integer
+ ///
+ QUInt16 = 16,
+ ///
+ /// 16-bit unsigned integers (C# long)
+ ///
+ UInt16 = 17,
+ ///
+ /// Double precission complex numbers (32-bit floats)
+ ///
+ Complex128 = 18,
+
+ ///
+ /// Half floats - 16-bit half precision floating point.
+ ///
+ Half = 19,
+
+ ///
+ /// Handle to a mutable resource.
+ ///
+ Resource = 20,
+
+ ///
+ /// Variant data type
+ ///
+ Variant = 21,
+
+ ///
+ /// 32-bit unsigned integers
+ ///
+ UInt32 = 22,
+
+ ///
+ /// 64-bit unsigned integers
+ ///
+ UInt64 = 23
+ }
+
+ ///
+ /// Status code for invoking a tensorflow operation.
+ ///
+ internal enum TFCode : uint
+ {
+ ///
+ /// Not an error; returned on success
+ ///
+ Ok = 0,
+ ///
+ /// The operation was cancelled (typically by the caller).
+ ///
+ Cancelled = 1,
+ ///
+ /// Unknown error. An example of where this error may be returned is
+ /// if a Status value received from another address space belongs to
+ /// an error-space that is not known in this address space. Also
+ /// errors raised by APIs that do not return enough error information
+ /// may be converted to this error.
+ ///
+ Unknown = 2,
+
+ ///
+ /// Client specified an invalid argument. Note that this differs
+ /// from FailedPrecondition. InvalidArgumentindicates arguments
+ /// that are problematic regardless of the state of the system
+ /// (e.g., a malformed file name).
+ ///
+ InvalidArgument = 3,
+
+ ///
+ /// Deadline expired before operation could complete. For operations
+ /// that change the state of the system, this error may be returned
+ /// even if the operation has completed successfully. For example, a
+ /// successful response from a server could have been delayed long
+ /// enough for the deadline to expire.
+ ///
+ DeadlineExceeded = 4,
+
+ ///
+ /// Some requested entity (e.g., file or directory) was not found.
+ /// For privacy reasons, this code may be returned when the client
+ /// does not have the access right to the entity.
+ ///
+ NotFound = 5,
+
+ ///
+ /// Some entity that we attempted to create (e.g., file or directory) already exists.
+ ///
+ AlreadyExists = 6,
+
+ ///
+ /// The caller does not have permission to execute the specified
+ /// operation. PermissionDenied must not be used for rejections
+ /// caused by exhausting some resource (use ResourceExhausted
+ /// instead for those errors). PermissionDeniedmust not be
+ /// used if the caller can not be identified (use Unauthenticated
+ /// instead for those errors).
+ ///
+ PermissionDenied = 7,
+
+ ///
+ /// The request does not have valid authentication credentials for the
+ /// operation.
+ ///
+ Unauthenticated = 16,
+
+ ///
+ /// Some resource has been exhausted, perhaps a per-user quota, or
+ /// perhaps the entire file system is out of space.
+ ///
+ ResourceExhausted = 8,
+
+ ///
+ /// Operation was rejected because the system is not in a state
+ /// required for the operation's execution. For example, directory
+ /// to be deleted may be non-empty, an rmdir operation is applied to
+ /// a non-directory, etc.
+ ///
+ /// A litmus test that may help a service implementor in deciding
+ /// between FailedPrecondition, Aborted, and Unavailable:
+ ///
+ /// (a) Use Unavailableif the client can retry just the failing call.
+ /// (b) Use Aborted if the client should retry at a higher-level
+ /// (e.g., restarting a read-modify-write sequence).
+ /// (c) Use FailedPrecondition if the client should not retry until
+ /// the system state has been explicitly fixed. E.g., if an "rmdir"
+ /// fails because the directory is non-empty, FailedPrecondition
+ /// should be returned since the client should not retry unless
+ /// they have first fixed up the directory by deleting files from it.
+ /// (d) Use FailedPrecondition if the client performs conditional
+ /// REST Get/Update/Delete on a resource and the resource on the
+ /// server does not match the condition. E.g., conflicting
+ /// read-modify-write on the same resource.
+ ///
+ FailedPrecondition = 9,
+
+ ///
+ /// The operation was aborted, typically due to a concurrency issue
+ /// like sequencer check failures, transaction aborts, etc.
+ ///
+ /// See litmus test above for deciding between FailedPrecondition,
+ /// Aborted and Unavailable
+ ///
+ Aborted = 10,
+
+ ///
+ /// Operation tried to iterate past the valid input range. E.g., seeking or
+ /// reading past end of file.
+ ///
+ /// Unlike InvalidArgument, this error indicates a problem that may
+ /// be fixed if the system state changes. For example, a 32-bit file
+ /// system will generate InvalidArgument if asked to read at an
+ /// offset that is not in the range [0,2^32-1], but it will generate
+ /// OutOfRange if asked to read from an offset past the current
+ /// file size.
+ ///
+ /// There is a fair bit of overlap between FailedPrecondition and
+ /// OutOfRange. We recommend using OutOfRane (the more specific
+ /// error) when it applies so that callers who are iterating through
+ /// a space can easily look for an OutOfRange error to detect when
+ /// they are done.
+ ///
+ OutOfRange = 11,
+
+ ///
+ /// Operation is not implemented or not supported/enabled in this service.
+ ///
+ Unimplemented = 12,
+
+ ///
+ /// Internal errors. Means some invariants expected by underlying
+ /// system has been broken. If you see one of these errors,
+ /// something is very broken.
+ ///
+ Internal = 13,
+
+ ///
+ /// The service is currently unavailable. This is a most likely a
+ /// transient condition and may be corrected by retrying with
+ /// a backoff.
+ ///
+ /// See litmus test above for deciding between FailedPrecondition,
+ /// Aborted, and Unavailable.
+ ///
+ Unavailable = 14,
+
+ ///
+ /// Unrecoverable data loss or corruption.
+ ///
+ DataLoss = 15
+ }
+
+ ///
+ /// Represents a specific input of an operation.
+ ///
+ [StructLayout (LayoutKind.Sequential)]
+ internal struct TFInput
+ {
+ ///
+ /// The operation that this input is for
+ ///
+ public unsafe TF_Operation Operation;
+
+ ///
+ /// The index of the output within the Operation
+ ///
+ public int Index;
+
+ // extern TF_Output TF_OperationInput (TF_Input oper_in);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern TFOutput TF_OperationInput (TFInput oper_in);
+
+ public TFOutput GetOutput (TFInput operIn)
+ {
+ return TF_OperationInput (operIn);
+ }
+
+ // extern TF_DataType TF_OperationInputType (TF_Input oper_in);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern TFDataType TF_OperationInputType (TFInput oper_in);
+
+ public TFDataType InputType => TF_OperationInputType (this);
+
+ }
+
+ ///
+ /// Represents a specific output of an operation on a tensor.
+ ///
+ ///
+ ///
+ /// TFOutput objects represent one of the outputs of an operation in the graph
+ /// (TFGraph). Outputs have a data type, and eventually a shape that you can
+ /// retrieve by calling the method.
+ ///
+ ///
+ /// These can be passed as an input argument to a function for adding operations
+ /// to a graph, or to the TFSession's Run and GetRunner method as values to be
+ /// fetched.
+ ///
+ ///
+ [StructLayout (LayoutKind.Sequential)]
+ internal struct TFOutput
+ {
+ private unsafe TF_Operation LLOperation;
+
+ ///
+ /// The index of the output within the operation.
+ ///
+ public int Index;
+
+ // extern int TF_OperationOutputNumConsumers (TF_Output oper_out);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern int TF_OperationOutputNumConsumers (TFOutput oper_out);
+
+ ///
+ /// Gets the number consumers.
+ ///
+ /// The number consumers.
+ ///
+ /// This number can change when new operations are added to the graph.
+ ///
+ public int NumConsumers => TF_OperationOutputNumConsumers (this);
+
+ // extern TF_DataType TF_OperationOutputType (TF_Output oper_out);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern TFDataType TF_OperationOutputType (TFOutput oper_out);
+
+ ///
+ /// Gets the type of the output.
+ ///
+ /// The type of the output.
+ public TFDataType OutputType => LLOperation == IntPtr.Zero ? TFDataType.Unknown : TF_OperationOutputType (this);
+
+ ///
+ /// Initializes a new TFOutput instance.
+ ///
+ /// The operation to which to attach the output.
+ /// The index of the output within the operation, if not specified, it defaults to zero.
+ public TFOutput (TFOperation operation, int index = 0)
+ {
+ if (operation == null)
+ throw new ArgumentNullException (nameof (operation));
+ LLOperation = operation.Handle;
+ Index = index;
+ }
+
+ ///
+ /// Initializes a new TFOutput instance from another TFOutput
+ ///
+ /// The other TFOutput that is having its operation attached.
+ /// The index of the output within the operation, if not specified, it defaults to zero.
+ public TFOutput (TFOutput output, int index = 0)
+ {
+ if (output.LLOperation == null)
+ throw new ArgumentNullException ("Outputs does not have a valid operation pointer");
+ LLOperation = output.LLOperation;
+ Index = index;
+ }
+
+ // extern int TF_OperationOutputConsumers (TF_Output oper_out, TF_Input *consumers, int max_consumers);
+ [DllImport (NativeBinding.TensorFlowLibrary)]
+ private static extern unsafe int TF_OperationOutputConsumers (TFOutput oper_out, TFInput* consumers, int max_consumers);
+
+ ///
+ /// Get list of all current consumers of a specific output of an operation
+ ///
+ /// The output consumers.
+ ///
+ /// A concurrent modification of the graph can increase the number of consumers of
+ /// an operation.
+ /// This can return null if the TFOutput does not point to a valid object.
+ ///
+ public TFInput [] OutputConsumers {
+ get {
+ var result = new TFInput [NumConsumers];
+ unsafe
+ {
+ fixed (TFInput* first = &result [0])
+ TF_OperationOutputConsumers (this, first, result.Length);
+ }
+ return result;
+ }
+ }
+
+ ///
+ /// The associated operation.
+ ///
+ /// The operation.
+ public TFOperation Operation => new TFOperation (null, LLOperation);
+
+ ///
+ /// Returns a that represents the current .
+ ///
+ /// A that represents the current .
+ public override string ToString ()
+ {
+ return string.Format ("[{3} Index={1} Operation={2} (0x{0:X})]", (long) LLOperation, Index, Operation, OutputType);
+ }
+ }
+
+ ///
+ /// Low-level: Enumeration describing the types of a metadata attribute
+ ///
+ internal enum TFAttributeType : uint
+ {
+ ///
+ /// The type of the attribute is a string
+ ///
+ String = 0,
+
+ ///
+ /// The type of the attribute is an int.
+ ///
+ Int = 1,
+
+ ///
+ /// The type of the attribute is a float
+ ///
+ Float = 2,
+
+ ///
+ /// The type of the attribute is a bool.
+ ///
+ Bool = 3,
+
+ ///
+ /// The type of the attribute is a type.
+ ///
+ Type = 4,
+
+ ///
+ /// The type of the attribute is a tensor shape
+ ///
+ Shape = 5,
+
+ ///
+ /// The type of the attribute is a tensor
+ ///
+ Tensor = 6,
+
+ ///
+ /// The type of the attribute is a placeholder
+ ///
+ Placeholder = 7,
+
+ ///
+ /// The type of the attribute is a function
+ ///
+ Func = 8
+ }
+
+ ///
+ /// Low-level: this describes the tensorflow type information for an attribute in the low-level attributes used by operations.
+ ///
+ ///
+ /// This is a low-level operation returned by the .
+ /// This is included for completeness, but is not generally used from C#, as you have access to the high-level
+ /// bindings in the type.
+ ///
+ [StructLayout (LayoutKind.Sequential)]
+ internal struct TFAttributeMetadata
+ {
+ private byte isList;
+ public bool IsList => isList != 0;
+ public long ListSize;
+ public TFAttributeType Type;
+ public long TotalSize;
+
+ ///
+ /// Returns a that represents the current .
+ ///
+ /// A that represents the current .
+ public override string ToString ()
+ {
+ return string.Format ($"[TFAttributeMetadata IsList={IsList} ListSize={ListSize} Type={Type} TotalSize={TotalSize}]");
+ }
+ }
+
+ ///
+ /// Represents the shape of a tensor, it describes how many dimensions the tensor has in a given axis
+ ///
+ ///
+ ///
+ /// The shapes can be created by calling the constructor with the number of dimensions
+ /// in the shape. The null value is used to specify that the shape is unknown,
+ /// an empty array is used to create a scalar, and other values are used to specify
+ /// the number of dimensions.
+ ///
+ ///
+ /// For the Unknown case, you can use , for
+ /// scalars, you can use the shape.
+ ///
+ ///
+ /// To create a 2-element vector, use:
+ /// new TFShape (2)
+ ///
+ ///
+ /// To create a 2x3 matrix, use:
+ /// new TFShape (2, 3)
+ ///
+ ///
+ /// To create a shape with an unknown number of elements, you can pass the value
+ /// -1. This is typically used to indicate the shape of tensors that represent a
+ /// variable-sized batch of values.
+ ///
+ ///
+ /// To create a matrix with 4 columns and an unknown number of rows:
+ /// var batch = new TFShape (-1, 4)
+ ///
+ ///
+ internal class TFShape
+ {
+ ///
+ /// Represents an unknown number of dimensions in the tensor.
+ ///
+ /// The unknown.
+ public static TFShape Unknown => new TFShape (null);
+
+ ///
+ /// This shape is used to represent scalar values.
+ ///
+ /// The scalar.
+ public static TFShape Scalar => new TFShape (new long [0]);
+
+ internal long [] dims;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// This is a params argument, so you can provide multiple values to it.
+ /// A null value means that this is an unknown shape, a single value is used to create a vector,
+ /// two values are used to create a 2-D matrix and so on.
+ ///
+ ///
+ ///
+ ///
+ public TFShape (params long [] args)
+ {
+ dims = args;
+ }
+
+ ///
+ /// Gets the length of the specified dimension in the tensor
+ ///
+ /// The length, -1 for shapes that have an unknown dimension.
+ /// Dimension.
+ public int GetLength (int dimension) => dims == null ? -1 : dims.GetLength (dimension);
+
+ ///
+ /// Number of dimensions represented by this shape.
+ ///
+ /// The number dimensions, -1 if the number of dimensions is unknown, 0 if the shape represent a scalar, 1 for a vector, 2 for a matrix and so on..
+ public int NumDimensions => dims == null ? -1 : dims.Length;
+
+ ///
+ /// Gets a value indicating whether all the dimensions in the are fully specified.
+ ///
+ /// true if is fully specified; otherwise, false.
+ public bool IsFullySpecified {
+ get {
+ if (dims == null)
+ return false;
+ foreach (var j in dims)
+ if (j == -1)
+ return false;
+ return true;
+ }
+ }
+
+ ///
+ /// Returns the shape as an array
+ ///
+ /// null if the shape represents an unknown shape, otherwise an array with N elements, one per dimension, and each element can be either -1 (if the dimension size is unspecified) or the size of the dimension.
+ public long [] ToArray ()
+ {
+ if (dims == null)
+ return null;
+
+ var ret = (long [])dims.Clone ();
+ return ret;
+ }
+
+ ///
+ /// Returns the shape as an array
+ ///
+ /// null if the shape represents an unknown shape, otherwise an array with N elements, one per dimension, and each element can be either -1 (if the dimension size is unspecified) or the size of the dimension.
+ public int [] ToIntArray ()
+ {
+ if (dims == null)
+ return null;
+
+ var ret = new int [dims.Length];
+ for (int i = 0; i < dims.Length; i++) {
+ checked {
+ ret [i] = (int) dims [i];
+ }
+ }
+ return ret;
+ }
+
+ ///
+ /// Gets a value indicating whether one of the dimensions in the shape is larger than Int32.MaxValue.
+ ///
+ /// true if is long array; otherwise, false.
+ public bool IsLongArray {
+ get {
+ foreach (var l in dims)
+ if (l > Int32.MaxValue)
+ return true;
+
+ return false;
+ }
+ }
+
+ ///
+ /// Returns a that represents the current .
+ ///
+ /// A that represents the current .
+ public override string ToString ()
+ {
+ if (dims == null)
+ return "unknown";
+ return "[" + String.Join (", ", dims.Select (x => x == -1 ? "?" : x.ToString ())) + "]";
+ }
+
+ ///
+ /// Gets the dimensions for the specified index.
+ ///
+ /// Index.
+ public long this [int idx] => dims [idx];
+
+ ///
+ /// Returns the shape as a 1-dimensional tensor with each element corresponding to the specified shape dimension.
+ ///
+ /// The tensor.
+ public TFTensor AsTensor ()
+ {
+ return new TFTensor (ToIntArray ());
+ }
+
+ ///
+ /// Adds a to a , yielding a shape made up of the concatenation of the first and the second shapes.
+ ///
+ /// The first to add.
+ /// The second to add.
+ /// The that is the sum of the values of left and right.
+ public static TFShape operator + (TFShape left, TFShape right)
+ {
+ if (left == null)
+ return right;
+ if (right == null)
+ return left;
+
+ var full = new long [left.dims.Length + right.dims.Length];
+ Array.Copy (left.dims, full, left.dims.Length);
+ Array.Copy (right.dims, 0, full, left.dims.Length, right.dims.Length);
+ return new TFShape (full);
+ }
+
+ ///
+ /// Performs an implicit conversion from to .
+ ///
+ /// The shape.
+ /// The result of the conversion.
+ public static implicit operator TFTensor (TFShape shape)
+ {
+ return shape.AsTensor ();
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.Transforms/TensorFlow/TensorflowUtils.cs
new file mode 100644
index 0000000000..c52b2dc266
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/TensorFlow/TensorflowUtils.cs
@@ -0,0 +1,93 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Reflection;
+using System.Runtime.InteropServices;
+using System.Text;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Transforms;
+using Microsoft.ML.Transforms.TensorFlow;
+
+namespace Microsoft.ML.Transforms.TensorFlow
+{
+ internal partial class TensorflowUtils
+ {
+ internal static DataKind Tf2MlNetType(TFDataType type)
+ {
+ switch (type)
+ {
+ case TFDataType.Float:
+ return DataKind.R4;
+ case TFDataType.Double:
+ return DataKind.R8;
+ case TFDataType.Int32:
+ return DataKind.I4;
+ case TFDataType.Int64:
+ return DataKind.I8;
+ case TFDataType.UInt32:
+ return DataKind.U4;
+ case TFDataType.UInt64:
+ return DataKind.U8;
+ case TFDataType.Bool:
+ return DataKind.Bool;
+ case TFDataType.String:
+ return DataKind.TX;
+ default:
+ throw new NotSupportedException("Tensorflow type not supported.");
+ }
+ }
+
+ internal static bool IsTypeSupportedInTf(ColumnType type)
+ {
+ try
+ {
+ if (type.IsVector)
+ {
+ TFTensor.TensorTypeFromType(type.ItemType.RawType);
+ return true;
+ }
+
+ TFTensor.TensorTypeFromType(type.RawType);
+ return true;
+ }
+ catch (ArgumentOutOfRangeException)
+ {
+ return false;
+ }
+ }
+
+ public static unsafe T[] FetchData(IntPtr data, int size)
+ {
+ var result = new T[size];
+
+ GCHandle handle = GCHandle.Alloc(result, GCHandleType.Pinned);
+ IntPtr target = handle.AddrOfPinnedObject();
+
+ Int64 sizeInBytes = size * Marshal.SizeOf((typeof(T)));
+ Buffer.MemoryCopy(data.ToPointer(), target.ToPointer(), sizeInBytes, sizeInBytes);
+ handle.Free();
+ return result;
+ }
+
+ public static unsafe void FetchData(IntPtr data, T[] result)
+ {
+ var size = result.Length;
+
+ GCHandle handle = GCHandle.Alloc(result, GCHandleType.Pinned);
+ IntPtr target = handle.AddrOfPinnedObject();
+
+ Int64 sizeInBytes = size * Marshal.SizeOf((typeof(T)));
+ Buffer.MemoryCopy(data.ToPointer(), target.ToPointer(), sizeInBytes, sizeInBytes);
+ handle.Free();
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Transforms/TensorflowTransform.cs b/src/Microsoft.ML.Transforms/TensorflowTransform.cs
new file mode 100644
index 0000000000..9369e4efbf
--- /dev/null
+++ b/src/Microsoft.ML.Transforms/TensorflowTransform.cs
@@ -0,0 +1,601 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Reflection;
+using System.Runtime.InteropServices;
+using System.Text;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Transforms;
+using Microsoft.ML.Transforms.TensorFlow;
+
+[assembly: LoadableClass(TensorflowTransform.Summary, typeof(TensorflowTransform), typeof(TensorflowTransform.Arguments), typeof(SignatureDataTransform),
+ TensorflowTransform.UserName, TensorflowTransform.ShortName)]
+
+[assembly: LoadableClass(TensorflowTransform.Summary, typeof(TensorflowTransform), null, typeof(SignatureLoadDataTransform),
+ TensorflowTransform.UserName, TensorflowTransform.LoaderSignature)]
+
+namespace Microsoft.ML.Transforms
+{
+ public sealed class TensorflowTransform : RowToRowMapperTransformBase
+ {
+ public sealed class Column : ManyToOneColumn
+ {
+ public static Column Parse(string str)
+ {
+ Contracts.AssertNonEmpty(str);
+
+ var res = new Column();
+ if (res.TryParse(str))
+ return res;
+ return null;
+ }
+
+ public bool TryUnparse(StringBuilder sb)
+ {
+ Contracts.AssertValue(sb);
+ return TryUnparseCore(sb);
+ }
+ }
+
+ public sealed class Arguments : TransformInputBase
+ {
+ [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
+ public Column[] Column;
+
+ [Argument(ArgumentType.Required, HelpText = "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more detail(s).", ShortName = "ModelDir", SortOrder = 2)]
+ public string ModelFile;
+ }
+
+ private sealed class Bindings : ManyToOneColumnBindingsBase
+ {
+ public sealed class TFColInfo
+ {
+ public readonly string[] InputColNames;
+ public readonly TFShape[] TfShapes;
+ public readonly TFDataType[] TfTypes;
+
+ public TFColInfo(string[] inputColNames, TFShape[] tfShapes, TFDataType[] tfType)
+ {
+ Contracts.AssertNonEmpty(tfShapes);
+ Contracts.AssertNonEmpty(tfType);
+ Contracts.Assert(tfType.Length == tfType.Length);
+
+ InputColNames = inputColNames;
+ TfShapes = tfShapes;
+ TfTypes = tfType;
+ }
+ }
+
+ public readonly TFColInfo[] TfColInfo;
+ public readonly string[] OutputColNames;
+ public readonly ColumnType[] OutputCols;
+ public readonly TFDataType[] OutputTFTypes;
+
+ public Bindings(Column[] columns, ISchema schemaInput, TensorflowTransform parent)
+ : base(columns, schemaInput, TestTypes)
+ {
+ OutputCols = new ColumnType[columns.Length];
+ OutputTFTypes = new TFDataType[columns.Length];
+ OutputColNames = new string[columns.Length];
+ TfColInfo = new TFColInfo[columns.Length];
+ for (int i=0; i t.IsVector))
+ return "All source columns must be of vector type";
+
+ if (!types.All(t => TensorflowUtils.IsTypeSupportedInTf(t)))
+ return "One of the input types is not supported in Tensorflow";
+
+ return null;
+ }
+ }
+
+ public const string Summary = "Transforms the data using the tenorflow model.";
+ public const string UserName = "TensorflowTransform";
+ public const string ShortName = "TFTransform";
+
+ public const string LoaderSignature = "TFTransform";
+ private const string RegistrationName = "Tensorflow";
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "TENSFLOW",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ private readonly Bindings _bindings;
+
+ ///
+ /// Tensorflow session object
+ ///
+ private readonly TFSession _session;
+
+ public override ISchema Schema => _bindings;
+
+ ///
+ /// Any missing dimension can be a batch dimension.
+ /// Currently setting it to 1.
+ ///
+ private readonly int _batchSize;
+
+ ///
+ /// Convenience constructor for public facing API.
+ ///
+ /// Host Environment.
+ /// Input . This is the output from previous transform or loader.
+ /// This is the frozen tensorflow model file. https://www.tensorflow.org/mobile/prepare_models
+ /// Name of the output column. Keep it same as in the Tensorflow model.
+ /// Name of the input column(s). Keep it same as in the Tensorflow model.
+ public TensorflowTransform(IHostEnvironment env, IDataView input, string modelFile, string name, params string[] source)
+ : this(env, new Arguments() { Column = new[] { new Column() { Source = source, Name = name } }, ModelFile = modelFile }, input)
+ {
+ }
+
+ public TensorflowTransform(IHostEnvironment env, Arguments args, IDataView input)
+ : base(env, RegistrationName, input)
+ {
+ Host.CheckValue(args, nameof(args));
+ Host.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));
+ for (int i = 0; i < args.Column.Length; i++)
+ Host.CheckUserArg(Utils.Size(args.Column[i].Source) > 0, nameof(args.Column));
+ Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
+ Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
+
+ _batchSize = 1; // Currently setting it to 1.
+ _session = LoadTFSession(args.ModelFile);
+ _bindings = new Bindings(args.Column, Source.Schema, this);
+ }
+
+ private TensorflowTransform(IHost host, ModelLoadContext ctx, IDataView input)
+ : base(host, input)
+ {
+ Host.AssertValue(ctx);
+
+ _batchSize = ctx.Reader.ReadInt32();
+#pragma warning disable MSML_NoMessagesForLoadContext
+ Host.CheckDecode(_batchSize > 0, "BatchSize must be positive.");
+#pragma warning restore MSML_NoMessagesForLoadContext
+
+ byte[] data = null;
+ if (!ctx.TryLoadBinaryStream("TFModel", r => data = r.ReadByteArray()))
+ throw Host.ExceptDecode();
+
+ var graph = new TFGraph();
+ try
+ {
+ graph.Import(data);
+ _session = new TFSession(graph);
+ }
+ catch (Exception ex)
+ {
+#pragma warning disable MSML_NoMessagesForLoadContext
+ throw Host.ExceptDecode(ex, "Tensorflow exception triggered while loading model.");
+#pragma warning restore MSML_NoMessagesForLoadContext
+
+ }
+ _bindings = new Bindings(ctx, Source.Schema, this);
+ }
+
+ ~TensorflowTransform()
+ {
+ _session.CloseSession();
+ _session.Dispose();
+ }
+
+ public static TensorflowTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var h = env.Register(RegistrationName);
+ h.CheckValue(ctx, nameof(ctx));
+ h.CheckValue(input, nameof(input));
+ ctx.CheckAtModel(GetVersionInfo());
+ return h.Apply("Loading Model", ch => new TensorflowTransform(h, ctx, input));
+ }
+
+ private TFSession LoadTFSession(string modelFile)
+ {
+ var graph = new TFGraph();
+ try
+ {
+ graph.Import(File.ReadAllBytes(modelFile), "");
+ }
+ catch (Exception ex)
+ {
+#pragma warning disable MSML_NoMessagesForLoadContext
+ throw Host.ExceptDecode(ex, "Tensorflow exception triggered while loading model.");
+#pragma warning restore MSML_NoMessagesForLoadContext
+
+ }
+ return new TFSession(graph);
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ Host.AssertValue(ctx);
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ ctx.Writer.Write(_batchSize);
+
+ var buffer = new TFBuffer();
+ _session.Graph.ToGraphDef(buffer);
+
+ ctx.SaveBinaryStream("TFModel", w =>
+ {
+ w.WriteByteArray(buffer.ToArray());
+ });
+ _bindings.Save(ctx);
+ }
+
+ public void DisposeTFSession()
+ {
+ _session.CloseSession();
+ _session.DeleteSession();
+ }
+
+ private ValueGetter GetSrcGetter(IRow input, int iinfo, int isrc)
+ {
+ return input.GetGetter(_bindings.Infos[iinfo].SrcIndices[isrc]);
+ }
+
+ private ITensorValueGetter CreateTensorValueGetter(IRow input, ColumnType type, int colIndex, TFShape tfShape)
+ {
+ if (type.IsVector)
+ return new TensorValueGetterVec(input, colIndex, tfShape);
+ else
+ return new TensorValueGetter(input, colIndex);
+ }
+
+ private ITensorValueGetter CreateTensorValueGetterVec(IRow input, TFDataType tfType, ColumnType columnType, int colIndex, TFShape tfShape)
+ {
+ var type = TFTensor.TypeFromTensorType(tfType);
+ if (type != null)
+ {
+ return Utils.MarshalInvoke(CreateTensorValueGetter, type, input, columnType, colIndex, tfShape);
+ }
+
+ throw Host.ExceptNotSupp("Tensorflow type not supported");
+ }
+
+ private ITensorValueGetter[] GetTensorValueGetters(IRow input, int iinfo)
+ {
+ var info = _bindings.Infos[iinfo];
+ var tfInfo = _bindings.TfColInfo[iinfo];
+ var srcTensorGetters = new ITensorValueGetter[info.SrcIndices.Length];
+ for (int j = 0; j < info.SrcIndices.Length; j++)
+ {
+ int colIndex = _bindings.Infos[iinfo].SrcIndices[j];
+ srcTensorGetters[j] = CreateTensorValueGetterVec(input, tfInfo.TfTypes[j], info.SrcTypes[j], colIndex, tfInfo.TfShapes[j]);
+ }
+ return srcTensorGetters;
+ }
+
+ private Delegate MakeGetter(IRow input, int iinfo)
+ {
+ var info = _bindings.Infos[iinfo];
+ var outInfo = _bindings.OutputCols[iinfo];
+ var tfType = _bindings.OutputTFTypes[iinfo];
+ var type = TFTensor.TypeFromTensorType(tfType);
+ if (type != null)
+ {
+ return Utils.MarshalInvoke(MakeGetter, outInfo.ItemType.RawType, input, iinfo, outInfo);
+ }
+
+ throw Host.ExceptNotSupp("Tensorflow type not supported");
+ }
+
+ private Delegate MakeGetter(IRow input, int iinfo, ColumnType columnType)
+ {
+ Host.AssertValue(input);
+ Host.Assert(typeof(T) == columnType.ItemType.RawType);
+
+ var info = _bindings.Infos[iinfo];
+ var tfInfo = _bindings.TfColInfo[iinfo];
+ var srcTensorGetters = GetTensorValueGetters(input, iinfo);
+
+ ValueGetter> valuegetter = (ref VBuffer dst) =>
+ {
+ var runner = _session.GetRunner();
+ for (int i = 0; i < info.SrcIndices.Length; i++)
+ {
+ var inputName = tfInfo.InputColNames[i];
+ var type = info.SrcTypes[i];
+ runner.AddInput(inputName, srcTensorGetters[i].GetTensor());
+ }
+
+ var tensors = runner.Fetch(_bindings.OutputColNames[iinfo]).Run();
+
+ Contracts.Assert(tensors.Length > 0);
+
+ var values = dst.Values;
+ if (Utils.Size(values) != _bindings.OutputCols[iinfo].VectorSize)
+ values = new T[_bindings.OutputCols[iinfo].VectorSize];
+
+ TensorflowUtils.FetchData(tensors[0].Data, values);
+ dst = new VBuffer(values.Length, values);
+ };
+ return valuegetter;
+ }
+
+ protected override Func GetDependenciesCore(Func predicate)
+ {
+ return _bindings.GetDependencies(predicate);
+ }
+
+ protected override Delegate[] CreateGetters(IRow input, Func active, out Action disp)
+ {
+ Func activeInfos =
+ iinfo =>
+ {
+ int col = _bindings.MapIinfoToCol(iinfo);
+ return active(col);
+ };
+
+ var getters = new Delegate[_bindings.InfoCount];
+ disp = null;
+ using (var ch = Host.Start("CreateGetters"))
+ {
+ for (int iinfo = 0; iinfo < _bindings.InfoCount; iinfo++)
+ {
+ if (!activeInfos(iinfo))
+ continue;
+ getters[iinfo] = MakeGetter(input, iinfo);
+ }
+ ch.Done();
+ return getters;
+ }
+ }
+
+ protected override int MapColumnIndex(out bool isSrc, int col)
+ {
+ return _bindings.MapColumnIndex(out isSrc, col);
+ }
+
+ protected override bool? ShouldUseParallelCursors(Func predicate)
+ {
+ return true;
+ }
+
+ protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null)
+ {
+ Host.AssertValue(predicate, "predicate");
+ Host.AssertValueOrNull(rand);
+
+ var inputPred = _bindings.GetDependencies(predicate);
+ var active = _bindings.GetActive(predicate);
+ var input = Source.GetRowCursor(inputPred, rand);
+ return new RowCursor(Host, this, input, active);
+ }
+
+ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null)
+ {
+ Host.CheckValue(predicate, nameof(predicate));
+ Host.CheckValueOrNull(rand);
+
+ var inputPred = _bindings.GetDependencies(predicate);
+ var active = _bindings.GetActive(predicate);
+ var inputs = Source.GetRowCursorSet(out consolidator, inputPred, n, rand);
+ Host.AssertNonEmpty(inputs);
+
+ if (inputs.Length == 1 && n > 1 && _bindings.AnyNewColumnsActive(predicate))
+ inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n);
+ Host.AssertNonEmpty(inputs);
+
+ var cursors = new IRowCursor[inputs.Length];
+ for (int i = 0; i < inputs.Length; i++)
+ cursors[i] = new RowCursor(Host, this, inputs[i], active);
+ return cursors;
+ }
+
+ private interface ITensorValueGetter
+ {
+ TFTensor GetTensor();
+ }
+
+ private class TensorValueGetter : ITensorValueGetter
+ {
+ private readonly ValueGetter _srcgetter;
+
+ public TensorValueGetter(IRow input, int colIndex)
+ {
+ _srcgetter = input.GetGetter(colIndex);
+ }
+ public TFTensor GetTensor()
+ {
+ var scalar = default(T);
+ _srcgetter(ref scalar);
+ return TFTensor.CreateScalar(scalar);
+ }
+ }
+
+ private class TensorValueGetterVec : ITensorValueGetter
+ {
+ private readonly ValueGetter> _srcgetter;
+ private readonly TFShape _tfShape;
+ private VBuffer _vBuffer;
+ private VBuffer _vBufferDense;
+ public TensorValueGetterVec(IRow input, int colIndex, TFShape tfShape)
+ {
+ _srcgetter = input.GetGetter>(colIndex);
+ _tfShape = tfShape;
+ _vBuffer = default;
+ _vBufferDense = default;
+ }
+ public TFTensor GetTensor()
+ {
+ _srcgetter(ref _vBuffer);
+ _vBuffer.CopyToDense(ref _vBufferDense);
+ return TFTensor.Create(_vBufferDense.Values, _tfShape);
+ }
+ }
+
+ private sealed class RowCursor : SynchronizedCursorBase, IRowCursor
+ {
+ private readonly Bindings _bindings;
+ private readonly bool[] _active;
+ private readonly Delegate[] _getters;
+
+ public RowCursor(IChannelProvider provider, TensorflowTransform parent, IRowCursor input, bool[] active)
+ : base(provider, input)
+ {
+ Ch.AssertValue(parent);
+ Ch.Assert(active == null || active.Length == parent._bindings.ColumnCount);
+
+ _bindings = parent._bindings;
+ _active = active;
+
+ _getters = new Delegate[_bindings.Infos.Length];
+ for (int i = 0; i < _bindings.Infos.Length; i++)
+ {
+ if (IsIndexActive(i))
+ _getters[i] = parent.MakeGetter(Input, i);
+ }
+ }
+
+ public ISchema Schema { get { return _bindings; } }
+
+ private bool IsIndexActive(int iinfo)
+ {
+ Ch.Assert(0 <= iinfo & iinfo < _bindings.Infos.Length);
+ return _active == null || _active[_bindings.MapIinfoToCol(iinfo)];
+ }
+
+ public bool IsColumnActive(int col)
+ {
+ Ch.Check(0 <= col && col < _bindings.ColumnCount);
+ return _active == null || _active[col];
+ }
+
+ public ValueGetter GetGetter(int col)
+ {
+ Ch.Check(IsColumnActive(col));
+
+ bool isSrc;
+ int index = _bindings.MapColumnIndex(out isSrc, col);
+ if (isSrc)
+ return Input.GetGetter(index);
+
+ Ch.Assert(_getters[index] != null);
+ var fn = _getters[index] as ValueGetter;
+ if (fn == null)
+ throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
+ return fn;
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj
index 888a983e51..e66b88399d 100644
--- a/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj
+++ b/test/Microsoft.ML.FSharp.Tests/Microsoft.ML.FSharp.Tests.fsproj
@@ -48,5 +48,8 @@
+
+
+
diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
index ed1b948384..1a1a2d355e 100644
--- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
+++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
@@ -26,4 +26,9 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
new file mode 100644
index 0000000000..47e9e85de8
--- /dev/null
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -0,0 +1,173 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using Microsoft.ML.Runtime.LightGBM;
+using Microsoft.ML.Transforms;
+using System;
+using System.Collections.Generic;
+using Xunit;
+
+namespace Microsoft.ML.Scenarios
+{
+ public partial class ScenariosTests
+ {
+ private class TestData
+ {
+ public float[] a;
+ public float[] b;
+ }
+ [Fact]
+ public void TensorflowTransformMatrixMultiplicationTest()
+ {
+ var model_location = GetDataPath("model_matmul/frozen_saved_model.pb");
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ // Pipeline
+ var loader = ComponentCreation.CreateDataView(env,
+ new List(new TestData[] { new TestData() { a = new[] { 1.0f, 2.0f,
+ 3.0f, 4.0f },
+ b = new[] { 1.0f, 2.0f,
+ 3.0f, 4.0f } },
+ new TestData() { a = new[] { 2.0f, 2.0f,
+ 2.0f, 2.0f },
+ b = new[] { 3.0f, 3.0f,
+ 3.0f, 3.0f } } }));
+
+ var trans = new TensorflowTransform(env, loader, model_location, "c", "a", "b");
+
+ using (var cursor = trans.GetRowCursor(a => true))
+ {
+ var cgetter = cursor.GetGetter>(2);
+ Assert.True(cursor.MoveNext());
+ VBuffer c = default;
+ cgetter(ref c);
+
+ Assert.Equal(1.0 * 1.0 + 2.0 * 3.0, c.Values[0]);
+ Assert.Equal(1.0 * 2.0 + 2.0 * 4.0, c.Values[1]);
+ Assert.Equal(3.0 * 1.0 + 4.0 * 3.0, c.Values[2]);
+ Assert.Equal(3.0 * 2.0 + 4.0 * 4.0, c.Values[3]);
+
+ Assert.True(cursor.MoveNext());
+ c = default;
+ cgetter(ref c);
+
+ Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[0]);
+ Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[1]);
+ Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[2]);
+ Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[3]);
+
+ Assert.False(cursor.MoveNext());
+
+ }
+ }
+ }
+
+ [Fact]
+ public void TensorflowTransformMNISTConvTest()
+ {
+ var model_location = GetDataPath("mnist_model/frozen_saved_model.pb");
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ var dataPath = GetDataPath("mnist_train.1K.tsv");
+ var testDataPath = GetDataPath("mnist_test.1K.tsv");
+
+ // Pipeline
+ var loader = new TextLoader(env,
+ new TextLoader.Arguments()
+ {
+ Separator = "tab",
+ HasHeader = true,
+ Column = new[]
+ {
+ new TextLoader.Column()
+ {
+ Name = "Label",
+ Source = new [] { new TextLoader.Range() { Min=0, Max=0} },
+ Type = DataKind.Num
+ },
+
+ new TextLoader.Column()
+ {
+ Name = "Placeholder",
+ Source = new [] { new TextLoader.Range() { Min=1, Max=784} },
+ Type = DataKind.Num
+ }
+ }
+ }, new MultiFileSource(dataPath));
+
+ IDataView trans = new TensorflowTransform(env, loader, model_location, "Softmax", "Placeholder");
+ trans = new ConcatTransform(env, trans, "reshape_input", "Placeholder");
+ trans = new TensorflowTransform(env, trans, model_location, "dense/Relu", "reshape_input");
+ trans = new ConcatTransform(env, trans, "Features", "Softmax", "dense/Relu");
+
+ var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments());
+
+ var cached = new CacheDataView(env, trans, prefetch: null);
+ var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
+ var pred = trainer.Train(trainRoles);
+
+ // Get scorer and evaluate the predictions from test data
+ IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
+ var metrics = Evaluate(env, testDataScorer);
+
+ Assert.Equal(0.99, metrics.AccuracyMicro, 2);
+ Assert.Equal(0.99, metrics.AccuracyMicro, 2);
+
+ // Create prediction engine and test predictions
+ var model = env.CreatePredictionEngine(testDataScorer);
+
+ var sample1 = new MNISTData()
+ {
+ Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 18, 18, 126, 136, 175, 26,
+ 166, 255, 247, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253,
+ 225, 172, 253, 242, 195, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253, 253, 253,
+ 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253, 198,
+ 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 11, 0,
+ 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }
+ };
+
+ var prediction = model.Predict(sample1);
+
+ float max = -1;
+ int maxIndex = -1;
+ for(int i=0;i max)
+ {
+ max = prediction.PredictedLabels[i];
+ maxIndex = i;
+ }
+ }
+
+ Assert.Equal(5, maxIndex);
+ }
+ }
+
+ public class MNISTData
+ {
+ [Column("1")]
+ public float Label;
+
+ [VectorType(784)]
+ public float[] Placeholder;
+ }
+
+ public class MNISTPrediction
+ {
+ [ColumnName("Score")]
+ public float[] PredictedLabels;
+ }
+ }
+}
diff --git a/test/data/mnist_model/frozen_saved_model.pb b/test/data/mnist_model/frozen_saved_model.pb
new file mode 100644
index 0000000000..e58f7f2b95
Binary files /dev/null and b/test/data/mnist_model/frozen_saved_model.pb differ
diff --git a/test/data/mnist_model/saved_model.pb b/test/data/mnist_model/saved_model.pb
new file mode 100644
index 0000000000..652a46e046
Binary files /dev/null and b/test/data/mnist_model/saved_model.pb differ
diff --git a/test/data/mnist_model/variables/variables.data-00000-of-00001 b/test/data/mnist_model/variables/variables.data-00000-of-00001
new file mode 100644
index 0000000000..0c213724e0
Binary files /dev/null and b/test/data/mnist_model/variables/variables.data-00000-of-00001 differ
diff --git a/test/data/mnist_model/variables/variables.index b/test/data/mnist_model/variables/variables.index
new file mode 100644
index 0000000000..98852e4acc
Binary files /dev/null and b/test/data/mnist_model/variables/variables.index differ
diff --git a/test/data/model_matmul/frozen_saved_model.pb b/test/data/model_matmul/frozen_saved_model.pb
new file mode 100644
index 0000000000..b5196bae47
Binary files /dev/null and b/test/data/model_matmul/frozen_saved_model.pb differ
diff --git a/test/data/model_matmul/saved_model.pb b/test/data/model_matmul/saved_model.pb
new file mode 100644
index 0000000000..6ba20ea8e7
Binary files /dev/null and b/test/data/model_matmul/saved_model.pb differ