Skip to content

Commit 3b7e3b9

Browse files
authored
merge fp16_bs fix (microsoft#3609)
fp16 brain script fix accumulator
1 parent 4003c08 commit 3b7e3b9

File tree

4 files changed

+45
-9
lines changed

4 files changed

+45
-9
lines changed

Source/SGDLib/AccumulatorAggregation.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ template <typename ElemType>
9595
void UpdateEpochEvaluationForAccumulatedResult(
9696
std::vector<EpochCriterion>& epochEvalErrors,
9797
const std::vector<ComputationNodeBasePtr>& evaluationNodes,
98-
CriterionAccumulator<ElemType> localEpochEvalErrors,
98+
CriterionAccumulatorBase& localEpochEvalErrors,
9999
std::function<bool(ComputationNodeBasePtr)> containsAccumulatedResult
100100
)
101101
{
@@ -120,7 +120,7 @@ void AggregateAccumulatorValuesAndUpdateEpochEvaluation(
120120
std::shared_ptr<MPIWrapper> mpi,
121121
std::vector<EpochCriterion>& epochEvalErrors,
122122
const std::vector<ComputationNodeBasePtr>& evaluationNodes,
123-
CriterionAccumulator<ElemType> localEpochEvalErrors,
123+
CriterionAccumulatorBase& localEpochEvalErrors,
124124
std::function<bool(ComputationNodeBasePtr)> containsAccumulatedResult,
125125
size_t packThresholdSizeInBytes = DEFAULT_PACK_THRESHOLD_SIZE_IN_BYTES)
126126
{

Source/SGDLib/Criterion.h

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,17 @@ struct EpochCriterion : public std::pair<double, size_t>
7474

7575
// We accumulate criteria in this struct.
7676
// Criteria are accumulated together with their counts (counts depend on sequence lengths, and different criteria may have different sequence lengths).
77+
struct CriterionAccumulatorBase
78+
{
79+
CriterionAccumulatorBase() {};
80+
virtual ~CriterionAccumulatorBase() {};
81+
virtual const CriterionAccumulatorBase& Add(size_t i, size_t numSamplesInMinibatch) = 0;
82+
virtual const CriterionAccumulatorBase& Assign(size_t i, size_t numSamplesInMinibatch) = 0;
83+
virtual EpochCriterion GetCriterion(size_t i) const = 0;
84+
};
85+
7786
template <class ElemType>
78-
struct CriterionAccumulator
87+
struct CriterionAccumulator : CriterionAccumulatorBase
7988
{
8089
// constructor params:
8190
// criterionNodes - list of criterion nodes
@@ -91,16 +100,16 @@ struct CriterionAccumulator
91100
}
92101
// 'i' is the index of the element we add into (multiple eval criteria share the same matrix object)
93102
// Use 'reset=true' to not accumulate but overwrite.
94-
const CriterionAccumulator& Add(size_t i, size_t numSamplesInMinibatch)
103+
virtual const CriterionAccumulator& Add(size_t i, size_t numSamplesInMinibatch) override
95104
{
96105
return Accumulate(i, numSamplesInMinibatch, /*reset=*/false);
97106
}
98-
const CriterionAccumulator& Assign(size_t i, size_t numSamplesInMinibatch)
107+
virtual const CriterionAccumulator& Assign(size_t i, size_t numSamplesInMinibatch) override
99108
{
100109
return Accumulate(i, numSamplesInMinibatch, /*reset=*/true);
101110
}
102111
// retrieve an accumulated result as a pair (numerator, denominator)
103-
EpochCriterion GetCriterion(size_t i) const
112+
virtual EpochCriterion GetCriterion(size_t i) const override
104113
{
105114
// BUGBUG: For unknown reasons, this (or the other below) check makes a difference for MPI configs.
106115
// If it is left out, then training and test configs end up being scaled by the same factor close to 1.
@@ -194,4 +203,26 @@ struct CriterionAccumulator
194203
const std::vector<ComputationNodeBasePtr> m_accumulatorCriterionNodes;
195204
};
196205

206+
class CriterionAccumulatorFactory
207+
{
208+
public:
209+
template <class ElemType>
210+
static shared_ptr<CriterionAccumulatorBase> CreateCriterionAccumulator(
211+
const std::vector<ComputationNodeBasePtr>& criterionNodes, DEVICEID_TYPE deviceId,
212+
const std::vector<ComputationNodeBasePtr>& accumulatorCriterionNodesNodes = {})
213+
{
214+
// Both half and float use float as accumulator
215+
if (std::is_same<ElemType, float>() || std::is_same<ElemType, half>())
216+
{
217+
return make_shared<CriterionAccumulator<float>>(criterionNodes, deviceId, accumulatorCriterionNodesNodes);
218+
}
219+
else if (std::is_same<ElemType, double>())
220+
{
221+
return make_shared<CriterionAccumulator<double>>(criterionNodes, deviceId, accumulatorCriterionNodesNodes);
222+
}
223+
RuntimeError("CreateCriterionAccumulator: unsupported node element type!");
224+
}
225+
226+
};
227+
197228
}}}

Source/SGDLib/SGD.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,10 +1148,14 @@ size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
11481148
// NOTE: the following two local matrices are not used in distGradAgg path
11491149
// assume only one training criterion node for each epoch.
11501150
// The criterion values are accumulated here over the minibatches (without having to pull them off the GPU).
1151-
CriterionAccumulator<ElemType> localEpochCriterion(criterionNodes, net->GetDeviceId());
1152-
CriterionAccumulator<ElemType> localEpochEvalErrors(
1151+
// For half, the cr and error nodes should be float nodes
1152+
shared_ptr<CriterionAccumulatorBase> localEpochCriterionPtr = CriterionAccumulatorFactory::CreateCriterionAccumulator<ElemType>(
1153+
criterionNodes, net->GetDeviceId());
1154+
shared_ptr<CriterionAccumulatorBase> localEpochEvalErrorsPtr = CriterionAccumulatorFactory::CreateCriterionAccumulator<ElemType>(
11531155
evaluationNodes, net->GetDeviceId(),
11541156
{evaluationNodesWhichAccumulateResult.begin(), evaluationNodesWhichAccumulateResult.end()});
1157+
CriterionAccumulatorBase& localEpochCriterion = *localEpochCriterionPtr;
1158+
CriterionAccumulatorBase& localEpochEvalErrors = *localEpochEvalErrorsPtr;
11551159

11561160
// --- MAIN MINIBATCH LOOP
11571161

Source/SGDLib/SimpleEvaluator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,10 @@ class SimpleEvaluator
108108
if (numSubminibatchesNeeded > 1)
109109
smbDispatcher.Init(m_net, learnableNodes, criterionNodes, evalNodes);
110110

111-
CriterionAccumulator<ElemType> localEpochEvalErrors(
111+
shared_ptr<CriterionAccumulatorBase> localEpochEvalErrorsPtr = CriterionAccumulatorFactory::CreateCriterionAccumulator<ElemType>(
112112
evalNodes, m_net->GetDeviceId(),
113113
{evalNodesWhichAccumulateResult.begin(), evalNodesWhichAccumulateResult.end()});
114+
CriterionAccumulatorBase& localEpochEvalErrors = *localEpochEvalErrorsPtr;
114115

115116
const size_t numIterationsBeforePrintingProgress = 100;
116117
size_t numItersSinceLastPrintOfProgress = 0;

0 commit comments

Comments
 (0)