Skip to content

Commit

Permalink
Optimize memory usage when calculating cross entory loss
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 22, 2023
1 parent 71e9372 commit 8f2af2d
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2600,14 +2600,11 @@ void backward()

private (float, IWeightTensor) CalculateEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float smooth, float gamma)
{
IWeightTensor loss = null;
float lossValue = 0.0f;

var scatterIdxTensor = View(truthTgtSeqs, new long[] { -1, 1 });
var scatterTrue = Scatter(scatterIdxTensor, 1.0f, 1, needGradient: false, shape: probs.Sizes);
var scatterFalse = Sub(1.0f, scatterTrue);
var probsFalse = Sub(1.0f, probs);
loss = EltMulMulAdd(scatterTrue, probs, scatterFalse, probsFalse);
var loss = EltMulMulAdd(scatterTrue, probs, scatterFalse, probsFalse);
if (smooth > 0.0f)
{
loss = Add(loss, smooth);
Expand All @@ -2621,14 +2618,14 @@ void backward()
}

loss = Log(loss);
loss = Mul(loss, -1.0f);
loss = Mul(loss, -1.0f, inPlace: true);

if (focalFactor != null)
{
loss = EltMul(loss, focalFactor);
}
var lossTrue = Gather(loss, scatterIdxTensor, 1, runGradients: false);
lossValue = lossTrue.ToWeightArray().Sum() / loss.ElementCount;
var lossValue = lossTrue.ToWeightArray().Sum() / loss.ElementCount;

return (lossValue, loss);
}
Expand Down

0 comments on commit 8f2af2d

Please sign in to comment.