loss
torch_to_nnef.op.aten.loss
ATen loss-family op emitters (mse_loss, nll_loss, cross_entropy_loss, ...).
Each loss is decomposed via a pointwise NNEF fragment (where pointwise
makes sense -- mse, bce-with-logits, kl_div) plus a full-tensor
mean_reduce / sum_reduce + squeeze chain governed by torch's
reduction enum (0 = none, 1 = mean, 2 = sum).
binary_cross_entropy
Map aten::binary_cross_entropy(input, target, weight, reduction).
Pointwise -(t*log(x) + (1-t)*log(1-x)) via the
binary_cross_entropy fragment. weight is not currently
supported (training-side per-sample reweighting).
binary_cross_entropy_with_logits
Map aten::binary_cross_entropy_with_logits to NNEF.
Signature: (input, target, weight, pos_weight, reduction).
Pointwise BCE via the numerically-stable softplus formulation lives
in the binary_cross_entropy_with_logits fragment; weight /
pos_weight modulators are not currently supported.
cosine_embedding_loss
Map aten::cosine_embedding_loss to NNEF.
Signature: (input1, input2, target, margin, reduction). Per-sample
cosine similarity along the feature axis (last axis of rank-2
inputs), then the +1 / -1 target split picks 1 - cos_sim or
max(0, cos_sim - margin).
cross_entropy_loss
Map aten::cross_entropy_loss to NNEF.
Lowers to nll_loss(log_softmax(input, dim=1), target, ...).
weight / ignore_index / label_smoothing are not currently
supported (raise on non-default values).
hinge_embedding_loss
Map aten::hinge_embedding_loss(input, target, margin, reduction).
Pointwise input if target==1 else max(0, margin - input).
huber_loss
Map PyTorch aten::huber_loss(input, target, reduction, delta).
Pointwise piecewise: quadratic when |input - target| < delta,
linear otherwise. Reduction applied by the emitter.
kl_div
Map aten::kl_div(input, target, reduction, log_target) to NNEF.
Two pointwise fragments, picked by log_target:
- kl_div (default): target * (log(target) - input)
- kl_div_log_target: exp(target) * (target - input)
input is assumed to be log-probabilities (caller normally feeds
log_softmax(...)). Torch's reduction='batchmean' is lowered to
sum plus an external division upstream of the aten op, so the
aten reduction enum here is only 0 / 1 / 2.
l1_loss
Map PyTorch aten::l1_loss(input, target, reduction) to NNEF.
Pointwise |input - target| via the l1_loss fragment, then
reduced. Like mse_loss, torch broadcasts upstream via
aten::broadcast_tensors, so the fragment assumes matching shapes.
margin_ranking_loss
Map aten::margin_ranking_loss to NNEF.
Signature: (input1, input2, target, margin, reduction). Pointwise
max(0, -target * (input1 - input2) + margin). Inputs are
same-shape (broadcast handled upstream); reduction is applied by
the emitter.
mse_loss
Map PyTorch aten::mse_loss(input, target, reduction) to NNEF.
Pointwise (input - target) ** 2 is delegated to the mse_loss
fragment, then reduced if reduction != none. Torch broadcasts
input / target upstream of the aten op (we see a separate
aten::broadcast_tensors in the trace), so the fragment can assume
matching shapes.
nll_loss
Map PyTorch's nll_loss family to NNEF.
Signature (all three variants):
nll_loss(input, target, weight, reduction, ignore_index).
The per-sample loss is -input[n, target[n], ...] along the class
axis (=1). Class-weighting and ignore-index masking are common
training-side knobs; we raise T2NErrorNotImplemented for both
until a real need shows up.
smooth_l1_loss
Map aten::smooth_l1_loss(input, target, reduction, beta).
Same piecewise shape as huber_loss with a different scaling: the
quadratic branch is 0.5 * diff^2 / beta and the linear branch is
|diff| - 0.5 * beta (vs huber's delta * (|diff| - 0.5 * delta)).
soft_margin_loss
Map aten::soft_margin_loss(input, target, reduction).
Pointwise log(1 + exp(-target * input)) via the
numerically-stable softplus reformulation in the fragment.
triplet_margin_loss
Map aten::triplet_margin_loss to NNEF.
Signature: (anchor, positive, negative, margin, p, eps, swap,
reduction). Per-sample distance along the trailing feature axis;
swap=True picks min(||a-n||, ||p-n||) and is emitted via a
second fragment call + min (lets the main fragment stay focused).