Skip to content

Commit 91b78a7

Browse files
committed
Merge branch 'yanchen/load-from-memory'
Signed-off-by: Yang Chen <[email protected]>
2 parents f843f91 + cfe8f6c commit 91b78a7

File tree

9 files changed

+92
-11
lines changed

9 files changed

+92
-11
lines changed

Source/CNTKv2LibraryDll/API/CNTKLibrary.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3627,7 +3627,8 @@ namespace CNTK
36273627
/// Load a Function from a memory buffer
36283628
///
36293629
CNTK_API static FunctionPtr Load(const char* buffer, size_t length,
3630-
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
3630+
const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice(),
3631+
ModelFormat format = ModelFormat::CNTKv2);
36313632

36323633
///
36333634
/// Load a Function from an istream. The legacy V1 model is not supported.

Source/CNTKv2LibraryDll/API/CNTKLibraryC.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,22 @@ CNTK_API CNTK_StatusCode CNTK_LoadModel(
109109
/*[in]*/ const CNTK_DeviceDescriptor* device,
110110
/*[out]*/ CNTK_ModelHandle* model);
111111

112+
//
113+
// Loads a model from the specified buffer and returns an opaque handle to the model
114+
// that should be passed to further operations.
115+
//
116+
// Parameters:
117+
// modelData [in]: a buffer that holds the CNTK model
118+
// modelDataLen [in]: the length of the buffer
119+
// device [in]: device descriptor
120+
// model [out]: the resulting loaded model
121+
//
122+
CNTK_StatusCode CNTK_LoadModel_FromArray(
123+
/*[in]*/ const void* modelData,
124+
/*[in]*/ int modelDataLen,
125+
/*[in]*/ const CNTK_DeviceDescriptor* device,
126+
/*[out]*/ CNTK_ModelHandle* model);
127+
112128
enum CNTK_ParameterCloningMethod
113129
{
114130
///

Source/CNTKv2LibraryDll/API/Internals/EvaluatorWrapper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ namespace CNTK
152152
public:
153153
CNTKEvaluatorWrapper(const char* modelFilePath, const CNTK_DeviceDescriptor* device);
154154
CNTKEvaluatorWrapper(const char* modelFilePath, DeviceDescriptor device);
155+
CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, const CNTK_DeviceDescriptor* device);
156+
CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, DeviceDescriptor device);
155157
CNTKEvaluatorWrapper(FunctionPtr model, DeviceDescriptor device);
156158

157159
void GetModelArgumentsInfo(CNTK_Variable** inputs, uint32_t* numInputs) override;

Source/CNTKv2LibraryDll/CNTKLibraryC.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@ CNTK_StatusCode CNTK_LoadModel(const char* modelFilePath, const CNTK_DeviceDescr
104104
return ExceptionCatcher::Call([&]() { *handle = new CNTKEvaluatorWrapper(modelFilePath, device); });
105105
}
106106

107+
CNTK_StatusCode CNTK_LoadModel_FromArray(const void* modelData, int modelDataLen,
108+
const CNTK_DeviceDescriptor* device, CNTK_ModelHandle* handle)
109+
{
110+
if (!handle)
111+
return StatusCode(CNTK_ERROR_NULL_POINTER, "'handle' parameter is not allowed to be null");
112+
113+
if (!modelData)
114+
return StatusCode(CNTK_ERROR_NULL_POINTER, "'modelData' parameter is not allowed to be null");
115+
116+
if (modelDataLen <= 0)
117+
return StatusCode(CNTK_ERROR_INVALID_INPUT, "'modelDataLen' parameter must be greater than zero");
118+
119+
*handle = nullptr;
120+
return ExceptionCatcher::Call([&]() { *handle = new CNTKEvaluatorWrapper(modelData, modelDataLen, device); });
121+
}
122+
107123
CNTK_StatusCode CNTK_CloneModel(CNTK_ModelHandle model, CNTK_ParameterCloningMethod method, bool flatten, CNTK_ModelHandle* cloned)
108124
{
109125
if (model == CNTK_INVALID_MODEL_HANDLE)

Source/CNTKv2LibraryDll/EvaluatorWrapper.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ namespace CNTK
3232
CNTKEvaluatorWrapper(modelFilePath, GetDeviceDescriptor(device))
3333
{}
3434

35+
CNTKEvaluatorWrapper::CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, DeviceDescriptor device) :
36+
CNTKEvaluatorWrapper(Function::Load(static_cast<const char*>(modelData), modelDataLen, device), device)
37+
{}
38+
39+
CNTKEvaluatorWrapper::CNTKEvaluatorWrapper(const void* modelData, int modelDataLen, const CNTK_DeviceDescriptor* device) :
40+
CNTKEvaluatorWrapper(modelData, modelDataLen, GetDeviceDescriptor(device))
41+
{}
42+
3543
void CNTKEvaluatorWrapper::GetModelArgumentsInfo(CNTK_Variable** inputs, uint32_t* numInputs)
3644
{
3745
assert(inputs != nullptr);

Source/CNTKv2LibraryDll/Function.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ namespace CNTK
533533
return nullptr;
534534
}
535535

536-
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice)
536+
/*static*/ FunctionPtr Function::Load(const char *buffer, size_t length, const DeviceDescriptor& computeDevice, ModelFormat format)
537537
{
538538
if ((buffer == nullptr) || (length <= 0))
539539
InvalidArgument("The model buffer should not be null and its length should be greater than 0");
@@ -547,15 +547,32 @@ namespace CNTK
547547
}
548548
};
549549

550-
if (Internal::IsLegacyModel(buffer, length))
551-
InvalidArgument("Loading a legacy model from byte array is not supported.");
552-
else
550+
switch (format)
551+
{
552+
case ModelFormat::CNTKv2:
553553
{
554-
modelStreamBuffer buf(buffer, length);
555-
std::istream modelStream(&buf);
554+
if (Internal::IsLegacyModel(buffer, length)) {
555+
InvalidArgument("Loading a legacy model from byte array is not supported.");
556+
}
557+
else
558+
{
559+
modelStreamBuffer buf(buffer, length);
560+
std::istream modelStream(&buf);
561+
562+
return Load(modelStream, computeDevice);
563+
}
564+
break;
565+
}
566+
567+
case ModelFormat::ONNX:
568+
return ONNXFormat::Load(static_cast<const void*>(buffer), length, computeDevice);
569+
break;
556570

557-
return Load(modelStream, computeDevice);
571+
default:
572+
InvalidArgument("unsupported ModelFormat.");
558573
}
574+
575+
return nullptr;
559576
}
560577

561578
/*static*/ FunctionPtr Function::Load(std::istream& inputStream, const DeviceDescriptor& computeDevice)

Source/CNTKv2LibraryDll/proto/onnx/ONNX.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,23 @@ FunctionPtr ONNXFormat::Load(const std::wstring& filepath, const DeviceDescripto
121121
FunctionPtr cntkFunction = ONNXToCNTK::CreateGraph(&model->MainGraph(), computeDevice, ToLegacyString(ToUTF8(filepath)));
122122
return cntkFunction;
123123
}
124+
125+
FunctionPtr ONNXFormat::Load(const void* model_data, int model_data_len, const DeviceDescriptor& computeDevice)
126+
{
127+
InitializeLotusIR();
128+
129+
onnx::ModelProto model_proto;
130+
const bool result = model_proto.ParseFromArray(model_data, model_data_len);
131+
if (!result) {
132+
LogicError("protobuf failed to parse model");
133+
}
134+
135+
std::shared_ptr<onnxruntime::Model> model;
136+
onnxruntime::common::Status loadStatus = onnxruntime::Model::Load(model_proto, model);
137+
138+
if (!loadStatus.IsOK())
139+
LogicError("Failed to load model: '%s'", loadStatus.ErrorMessage().c_str());
140+
141+
FunctionPtr cntkFunction = ONNXToCNTK::CreateGraph(&model->MainGraph(), computeDevice);
142+
return cntkFunction;
143+
}

Source/CNTKv2LibraryDll/proto/onnx/ONNX.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ namespace CNTK
1616
public:
1717
static void Save(const FunctionPtr& src, const std::wstring& filepath, bool useExternalFilesToStoreParameters = false);
1818
static FunctionPtr Load(const std::wstring& filepath, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
19+
static FunctionPtr Load(const void* model_data, int model_data_len, const DeviceDescriptor& computeDevice = DeviceDescriptor::UseDefaultDevice());
1920
private:
2021
static void InitializeLotusIR();
2122
static std::once_flag op_schema_initializer_flag_;
2223
};
23-
}
24+
}

bindings/csharp/CNTKLibraryManagedDll/ShimApiClasses/FunctionShim.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,9 @@ public static Function Load(string filepath, DeviceDescriptor computeDevice, Mod
273273
/// <param name="modelBuffer"></param>
274274
/// <param name="computeDevice"></param>
275275
/// <returns></returns>
276-
public static Function Load(byte[] modelBuffer, DeviceDescriptor computeDevice)
276+
public static Function Load(byte[] modelBuffer, DeviceDescriptor computeDevice, ModelFormat format = ModelFormat.CNTKv2)
277277
{
278-
return _Load(modelBuffer, (uint)modelBuffer.Length, computeDevice);
278+
return _Load(modelBuffer, (uint)modelBuffer.Length, computeDevice, format);
279279
}
280280

281281
/// <summary>

0 commit comments

Comments
 (0)