Skip to content

Commit 9688ac6

Browse files
authored
fp16 - brain script add to CNTK.core.bs (microsoft#3617)
1 parent 764c8c4 commit 9688ac6

File tree

3 files changed

+56
-11
lines changed

3 files changed

+56
-11
lines changed

Source/CNTK/BrainScript/CNTKCoreLib/CNTK.core.bs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ CNTK2 = [
412412
// Changes: dims -> shape
413413
DynamicAxis(tag='') = new ComputationNode [ operation = 'DynamicAxis' ; /*plus the function args*/ ]
414414
# TODO: Is it a good idea to default to "feature"?
415-
Input(shape, dynamicAxis='', tag='feature') = new ComputationNode [ operation = 'InputValue' ; shape = new TensorShape [ /*shape*/ ] ; isImage = false /*plus the function args*/ ]
415+
Input(shape, dynamicAxis='', tag='feature', precision=precision) = new ComputationNode [ operation = 'InputValue' ; shape = new TensorShape [ /*shape*/ ] ; isImage = false /*plus the function args*/ ]
416416

417417
// 2. Variables and constants
418418
// Changes: ParameterTensor -> _Parameter; "dims" -> "shape"
@@ -570,14 +570,14 @@ ParameterTensor {dims, learningRateMultiplier = 1.0, init = ''/*|uniform|fixedVa
570570
ConstantFromString(literal, tag='') = ParameterTensor((0)/*dim, will be inferred*/, initFromLiteral = literal, learningRateMultiplier = 0.0)
571571
# TODO: Deprecate ConstantFromString() in favor of Constant(array expression)
572572
DynamicAxis(tag='') = new ComputationNode [ operation = 'DynamicAxis' ; /*plus the function args*/ ]
573-
Input(dims, dynamicAxis='', sparse=false, tag='feature') =
574-
if sparse then SparseInput(dims, dynamicAxis=dynamicAxis, tag=tag)
573+
Input(dims, dynamicAxis='', sparse=false, tag='feature', precision=precision) =
574+
if sparse then SparseInput(dims, dynamicAxis=dynamicAxis, tag=tag, precision=precision)
575575
else new ComputationNode [ operation = 'InputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]
576576
# TODO: change from dynamicAxis by name to dynamicAxis being an actual object
577577
# the following variants of Input() are deprecated
578-
SparseInput(dims, dynamicAxis='', tag='feature') = new ComputationNode [ operation = 'SparseInputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]
579-
ImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', dynamicAxis='', tag='feature') = new ComputationNode [ operation = 'InputValue' ; isImage = true /*plus the function args*/ ]
580-
SparseImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', dynamicAxis='', tag='feature') = new ComputationNode [ operation = 'SparseInputValue' ; isImage = true /*plus the function args*/ ]
578+
SparseInput(dims, dynamicAxis='', tag='feature', precision=precision) = new ComputationNode [ operation = 'SparseInputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]
579+
ImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', dynamicAxis='', tag='feature', precision=precision) = new ComputationNode [ operation = 'InputValue' ; isImage = true /*plus the function args*/ ]
580+
SparseImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', dynamicAxis='', tag='feature', precision=precision) = new ComputationNode [ operation = 'SparseInputValue' ; isImage = true /*plus the function args*/ ]
581581
EnvironmentInput(propertyName, tag='') = new ComputationNode [ operation = 'EnvironmentInput' /*plus the function args*/ ]
582582
# TODO: make 'dims' the first parameter, think ConstantTensor<dims> (val)
583583
ConstantTensor(val, dims, tag='') = ParameterTensor(dims, learningRateMultiplier = 0, initValue = val)
@@ -706,6 +706,8 @@ TransposeTimes(leftMatrix, rightMatrix, tag='') = new ComputationNode [ operatio
706706
QuantizedTimes(leftMatrix, rightMatrix, bitSmoothingA=1, bitSmoothingB=1, outputRank=1, inferInputRankToMap=-1, tag='') = new ComputationNode [ operation = 'QuantizedTimes' ; inputs = _AsNodes (leftMatrix : rightMatrix) /*plus the function args*/ ]
707707
Where(cond, tag='') = new ComputationNode [ operation = 'Where' ; inputs = _AsNodes (cond) /*plus the function args*/ ]
708708
709+
Cast(node, precision='', tag='') = new ComputationNode [ operation = 'Cast' ; inputs = _AsNodes (node) /*plus the function args*/ ]
710+
709711
##############################################################################
710712
# non-neural-network functions
711713
##############################################################################

Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,54 @@ shared_ptr<ComputationNodeBase> NewComputationNodeFromConfig(const Microsoft::MS
208208
wstring precision = configp->Get(L"precision"); // dispatch on ElemType
209209
wstring operationName = configp->Get(L"operation");
210210
ComputationNodeBasePtr node;
211-
if (precision == L"float")
212-
node = CreateNode<float>(operationName, configp);
213-
else if (precision == L"double")
214-
node = CreateNode<double>(operationName, configp);
211+
if (operationName == OperationName2Of(CastNode))
212+
{
213+
auto inputs = ComputationNodeBase::GetInputsFromConfig(configp);
214+
if (inputs.empty())
215+
RuntimeError("NewComputationNodeFromConfig: No inputs found for Cast node.");
216+
217+
if (precision == L"float16" || precision == L"half")
218+
{
219+
if (inputs[0]->Is<ComputationNode<float>>())
220+
node = CreateNode2<half, float>(operationName, configp);
221+
else if (inputs[0]->Is<ComputationNode<double>>())
222+
node = CreateNode2<half, double>(operationName, configp);
223+
else
224+
RuntimeError("NewComputationNodeFromConfig: for CastNode to cast to half, input must be 'float' or 'double'");
225+
}
226+
else if (precision == L"float")
227+
{
228+
if (inputs[0]->Is<ComputationNode<half>>())
229+
node = CreateNode2<float, half>(operationName, configp);
230+
else if (inputs[0]->Is<ComputationNode<double>>())
231+
node = CreateNode2<float, double>(operationName, configp);
232+
else
233+
RuntimeError("NewComputationNodeFromConfig: for CastNode to cast to float, input must be 'float16' or 'double'");
234+
}
235+
else if (precision == L"double")
236+
{
237+
if (inputs[0]->Is<ComputationNode<half>>())
238+
node = CreateNode2<double, half>(operationName, configp);
239+
else if (inputs[0]->Is<ComputationNode<float>>())
240+
node = CreateNode2<double, float>(operationName, configp);
241+
else
242+
RuntimeError("NewComputationNodeFromConfig: for CastNode to cast to double, input must be 'float' or 'float16'");
243+
}
244+
else
245+
RuntimeError("NewComputationNodeFromConfig: CastNode - need to specify 'precision' parameter: 'float', 'double' or 'float16'.");
246+
}
215247
else
216-
RuntimeError("NewStandardNode: Invalid value '%ls' for 'precision' parameter. Must be 'float' or 'double'.", precision.c_str());
248+
{
249+
if (precision == L"float")
250+
node = CreateNode<float>(operationName, configp);
251+
else if (precision == L"float16" || precision == L"half")
252+
node = CreateNode<half>(operationName, configp);
253+
else if (precision == L"double")
254+
node = CreateNode<double>(operationName, configp);
255+
else
256+
RuntimeError("NewComputationNodeFromConfig: Invalid value '%ls' for 'precision' parameter. Must be 'float16', 'float' or 'double'.", precision.c_str());
257+
}
258+
217259
// add a tag
218260
// Tags are used to declare special node types to ComputationNetwork.
219261
// For now we support only a single tag, but we could in the future easily extend this to an array of tags.

Source/ComputationNetworkLib/LinearAlgebraNodes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,6 +2013,7 @@ class CastNode : public UnaryElementWiseNode<ElemType>
20132013
static const std::wstring TypeName() { return L"Cast"; }
20142014

20152015
public:
2016+
DeclareConstructorFromConfig(CastNode);
20162017
CastNode(DEVICEID_TYPE deviceId, const wstring& name)
20172018
: Base(deviceId, name)
20182019
{

0 commit comments

Comments
 (0)