Skip to content

selector

torch_to_nnef.op.aten.selector

argsort

argsort(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:argsort' to NNEF.

bucketize

bucketize(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:bucketize' to NNEF.

bucketize(input, boundaries, out_int32, right) returns the index in boundaries (1-D, sorted) for each value in input. Decomposed via broadcast-compare + sum_reduce of the comparison mask.

diagonal

diagonal(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:diagonal' to NNEF (tract path).

Strategy: bring (dim1, dim2) to the last two axes via transpose, slice each axis to the diagonal window, then evaluate <leading>ii-><leading>i with tract_core_einsum. The slice begin on each axis encodes the offset:

begin_a1 = max(0, -offset)
begin_a2 = max(0,  offset)
L        = min(s1 - begin_a1, s2 - begin_a2)

offset is interpreted in the user's (dim1, dim2) order; when we sort axes to a1 < a2 the sign flips. Empty diagonals (L <= 0) are left as T2NErrorNotImplemented since static zero-extent axes are awkward to represent.

embedding

embedding(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:embedding' to NNEF.

embedding_bag

embedding_bag(g, node, name_to_tensor, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:embedding_bag' to NNEF.

embedding_bag(weight, indices, offsets, scale_grad, mode, sparse, per_sample_weights, include_last_offset, padding_idx) returns a 4-tuple in torch (output, offset2bag, bag_size, max_indices); we only emit the first output (the bag-reduced embeddings). The other three are gradient bookkeeping and typically aren't consumed in inference traces.

Decomposition (statically-known offsets):

emb = tract_core_gather(weight, indices, axis=0) # (K, D) for each bag b: bag_out_b = reduce(emb[offsets[b]:end_b], axis=0) output = concat(bag_outs, axis=0) # (B, D)

Equal bag sizes collapse to a single reshape + reduce + squeeze.

gather

gather(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:gather' to NNEF.

index_

index_(node, op_helper, inference_target, **kwargs)

Translate aten::index to NNEF.

Fragment gather<?>(. input: tensor<?>, # the tensor to gather from indices: tensor, # the indices to gather at axis: integer = 0 ) # the axis to gather at -> ( output: tensor<?> )

torch ir, in this case structure indexes_node with: a list of n values where n <= input_node rank each value is either a constant or a tensor. if the constant is None this means the full dimension

index_add

index_add(g, node, name_to_tensor, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:index_add' to NNEF.

index_add(self, dim, index, source, alpha) adds alpha * source into the input at the index slabs. We pre-multiply source by alpha (when not 1) and reuse the scatter backbone with reduction='add'.

index_copy

index_copy(g, node, name_to_tensor, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:index_copy' to NNEF.

index_copy(self, dim, index, source) overwrites slabs along dim: out[..., index[k], ...] = source[..., k, ...]. Same scatter backbone as index_fill, just with the user's source directly.

index_fill

index_fill(g, node, name_to_tensor, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:index_fill' to NNEF.

index_fill(self, dim, index, value) writes the scalar value at every (..., index[k], ...) position along dim. Lowered to tract_core_scatter_elements with reduction='none' against an all-value constant of the broadcast shape.

index_put

index_put(g, node, name_to_tensor, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:index_put' (and _) to NNEF.

index_put(self, indices: Tensor?[], values, accumulate) writes values into self at the positions indexed by indices. Only the len(indices) == 1 case with a single 1-D int index along axis 0 is supported -- that's the out[idx] = values pattern that covers most realistic usage. Lowered to tract_core_scatter_elements with reduction 'add' (when accumulate=True) or 'none' (overwrite).

index_select

index_select(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:index_select' to NNEF.

masked_fill

masked_fill(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:masked_fill' to NNEF.

narrow

narrow(node, op_helper, **kwargs)

Fancy slice made in PyTorch.

torch.narrow(input, dim, start, length)

Example:

import torch x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) torch.narrow(x, 0, 0, 2) tensor([[1, 2, 3], [4, 5, 6]])

scatter

scatter(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:scatter' to NNEF.

scatter_add

scatter_add(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:scatter_add' to NNEF.

scatter_add(input, dim, index, src) accumulates src values into input at positions selected by index along dim. Equivalent to tract_core_scatter_elements with reduction="add".

scatter_reduce

scatter_reduce(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:scatter_reduce' to NNEF.

Maps torch's reduce mode to tract's ScatterReduction. mean is not in tract's set ({add, mul, min, max}) so we raise; the same goes for include_self=False, since tract always reduces against the pre-existing destination value.

searchsorted

searchsorted(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:searchsorted' to NNEF.

searchsorted(sorted_seq, values, out_int32, right, side, sorter) is bucketize with the args swapped (sorted_seq plays the role of boundaries, values plays the role of input). The side string overload supersedes right when present.

select

select(node, op_helper, **kwargs)

Map PyTorch: 'aten:select' to NNEF.

select_scatter

select_scatter(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:select_scatter' to NNEF.

out = input.clone(); out.select(dim, index).copy_(src) -- the functional select-write. Decomposes to slice + unsqueeze + concat: replace the (size-1) slab at position index along dim with src (which has rank input.rank - 1). Static-shape only.

slice_

slice_(node, torch_graph, inference_target, op_helper, **kwargs)

Map PyTorch: 'aten:slice' to NNEF.

slice_scatter

slice_scatter(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:slice_scatter' to NNEF.

out = input.clone(); out[..., start:end:step, ...] = src -- the functional slice-write. Decomposes to slice + concat. step != 1 is rejected (would need an interleave path). Static-shape only.

sort

sort(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:sort' to NNEF.

take

take(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:take' to NNEF.

take(self, index) flattens self to 1-D and gathers along axis 0. Lowered to a reshape(input, shape=[-1]) followed by tract_core_gather on axis 0. The -1 lets NNEF / tract derive the flat size at runtime from the input shape, so dynamic input axes work without any special-casing.

take_along_dim

take_along_dim(node, op_helper, inference_target, **kwargs)

Map PyTorch: aten::take_along_dim(self, indices, dim?) to NNEF.

With an explicit dim, the op is gather along that axis. The dim=None form (flatten both tensors and do a 1-D take) needs an extra reshape pair; left for follow-up.

topk

topk(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:topk' to NNEF.

tract_pre_0_21_7_slice

tract_pre_0_21_7_slice(node, torch_graph, nnef_spec_strict, has_dynamic_axes, op_helper, **kwargs)

Old version of slice for tract version prior to 0.21.7.

where

where(node, op_helper, **kwargs)

Map PyTorch: 'aten:where' to NNEF.