Skip to content

axes_change

torch_to_nnef.op.aten.axes_change

atleast_1d

atleast_1d(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:atleast_1d' to NNEF.

atleast_2d

atleast_2d(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:atleast_2d' to NNEF.

atleast_3d

atleast_3d(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:atleast_3d' to NNEF.

Note: torch's atleast_3d differs from atleast_1d/2d for rank 1 inputs: [N] is reshaped to [1, N, 1] (not [1, 1, N]). Match that explicitly. The rank>=3 passthrough aliases via remap_node so symbolic dims under dynamic-axes survive.

broadcast_tensors

broadcast_tensors(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:broadcast_tensors' to NNEF.

broadcast_tensors([t0, t1, ...]) returns each input expanded to the common broadcast shape. Each output is a separate tract_core_broadcast(t_i, shape=common) call -- the common shape is whatever torch traced into node.outputs[i].shape (all outputs share it).

channel_shuffle

channel_shuffle(node, op_helper, **kwargs)

Map PyTorch: aten::channel_shuffle(self, groups).

Reshape (N, C, *spatial) -> (N, g, C/g, *spatial), transpose axes 1 and 2, then reshape back to (N, C, *spatial). Used by ShuffleNet-family architectures.

col2im

col2im(node, op_helper, **kwargs)

Map PyTorch aten::col2im (a.k.a. F.fold) to NNEF.

Signature: col2im(self, output_size, kernel_size, dilation, padding, stride) for a rank-3 input (N, C * kH * kW, L). Inverse of im2col: places each of the kH * kW "kernel offsets" of the input at strided positions inside a (N, C, output_H + 2*pH, output_W + 2*pW) canvas, summing overlaps, then crops the canvas to (N, C, output_H, output_W).

Tract has no NNEF-level col2im / scatter_add-with-reduction on the version we target, so we decompose per kernel offset:

  1. reshape input (N, C*kH*kW, n_h*n_w) -> (N, C, kH, kW, n_h, n_w);
  2. for every (di, dj): a. slice the per-offset feature map (N, C, n_h, n_w); b. spread it to (N, C, n_h*sH, n_w*sW) -- reshape with two size-1 axes then pad axis 3 by (0, sH-1) and axis 5 by (0, sW-1) to insert zeros between elements -- then reshape to flatten the spread axes; trim trailing zeros by slicing to ((n_h-1)*sH + 1, (n_w-1)*sW + 1); c. pad to the canvas size (padded_H, padded_W) with left/top offsets (di*dH, dj*dW);
  3. sum the kH * kW placed contributions (add chain);
  4. crop the leading / trailing (pH, pW) rows / cols of the canvas to land on (N, C, output_H, output_W).

A future tract release that exposes a native col2im (or scatter-add with sum reduction) lets us replace this chain with a single op behind a version gate.

expand_as

expand_as(g, node, name_to_tensor, op_helper, **kwargs)

Map PyTorch: 'aten:expand_as' to NNEF.

x.expand_as(y) is x.expand(y.size()): broadcast (tile) along size-1 axes to match y's shape.

Static-shape only for now: when y carries non-int (TDim) shape entries, raises with a hint to use aten::expand directly. The dynamic path would need the runtime per-axis repeat machinery that already lives in aten::expand (see op/aten/expand.py::_append_repeats_on_existing_dims); a follow-up should refactor that helper so this op can share it.

flatten

flatten(g, node, name_to_tensor, inference_target, **kwargs)

Translate operator: aten::flatten to NNEF.

PyTorch flatten(start_dim, end_dim) flattens dims in [start_dim, end_dim] inclusive; NNEF reshape uses axis_count (number of axes to replace), so convert as axis_count = end_dim - start_dim + 1 after normalizing negative indices via :func:pick_axis.

fragment reshape<?>( input: tensor<?>, shape: integer[], axis_start: integer = 0, axis_count: integer = -1 ) -> ( output: tensor<?> );

flip

flip(g, node, name_to_tensor, inference_target, **kwargs)

Map PyTorch aten::flip(input, dims) to NNEF.

fliplr

fliplr(g, node, name_to_tensor, inference_target, **kwargs)

Map aten::fliplr (torch.fliplr) to NNEF.

Reverses elements along axis 1 (per torch's rank>=2 convention).

flipud

flipud(g, node, name_to_tensor, inference_target, **kwargs)

Map aten::flipud (torch.flipud) to NNEF.

Reverses elements along axis 0 (per torch's rank>=1 convention).

im2col

im2col(node, op_helper, **kwargs)

Map PyTorch aten::im2col (a.k.a. F.unfold) to NNEF.

Signature: im2col(self, kernel_size, dilation, padding, stride) for a rank-4 input (N, C, H, W). Output is (N, C * kH * kW, L) where L = oH * oW and:

  • oH = (H + 2*pH - dH*(kH - 1) - 1) // sH + 1
  • oW = (W + 2*pW - dW*(kW - 1) - 1) // sW + 1

No tract / NNEF op exposes this directly (we probed tract_core_im2col / im2col -- both unknown), so we decompose:

  1. zero-pad the input along H and W if padding > 0;
  2. for every kernel position (di, dj), take a strided 2-axis slice -- begin=[di*dH, dj*dW], stride=[sH, sW], length oH x oW;
  3. stack the kH * kW slices along a new axis at position 2, then reshape (N, C, kH*kW, oH, oW) -> (N, C*kH*kW, oH*oW).

Iteration order is di outer, dj inner, matching torch's flat output-channel index c*kH*kW + di*kW + dj.

matrix_transpose

matrix_transpose(g, node, name_to_tensor, **kwargs)

Map aten::mT / aten::mH / aten::matrix_H to NNEF.

matrix_H is the native-functions schema name for the Hermitian transpose property (Tensor.H); aliased here since for real dtypes it has the same semantics as mT / mH (axis swap on the last two dims).

Both ops swap the last two axes of a rank->= 2 tensor. mH is the conjugate-transpose; for real-valued tensors (the only ones NNEF / tract carry without the complex feature flag) it is identical to mT, so a single emitter handles both.

meshgrid

meshgrid(node, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:meshgrid' to NNEF.

meshgrid([t0, .., tN-1], indexing) returns N rank-N tensors. Each output is the corresponding input reshaped to put its size on the proper axis (then broadcast to the full N-dim shape). With indexing='ij' axis i is input i; with indexing='xy' the first two axes are swapped (xy is matrix-style, ij is index-style).

movedim

movedim(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:movedim' to NNEF as transpose.

movedim(x, src, dst) repositions the source axis so it ends up at dst, sliding the others left/right; the result is a permutation of the input's axes. Builds the explicit [axes] list and emits a single transpose.

numpy_t

numpy_t(g, node, name_to_tensor, **kwargs)

Map PyTorch aten::numpy_T (Tensor.T) to NNEF.

Tensor.T reverses every axis -- it is the rank-N generalisation of matrix transpose. Equivalent to permute([N-1, N-2, ..., 0]).

permute

permute(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:permute' to NNEF.

pixel_shuffle

pixel_shuffle(node, op_helper, **kwargs)

Map PyTorch: 'aten:pixel_shuffle' to NNEF.

Standard sub-pixel rearrangement: pulls every r*r channel block out as an r*r spatial tile, multiplying H/W by r and dividing C by r*r. Lowered to reshape + transpose + reshape (no fragment needed: stdlib only).

pixel_unshuffle

pixel_unshuffle(node, op_helper, **kwargs)

Map PyTorch: 'aten:pixel_unshuffle' (inverse of pixel_shuffle).

reshape

reshape(g, node, name_to_tensor, torch_graph, inference_target, **kwargs)

Map PyTorch: 'aten:reshape' to NNEF.

reshape_as

reshape_as(g, node, name_to_tensor, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:reshape_as' to NNEF.

Equivalent to reshape with shape borrowed from the second input. Supports dynamic-axes via runtime tract_core_shape_of(other).

rot90

rot90(g, node, name_to_tensor, inference_target, op_helper, **kwargs)

Map aten::rot90(input, k, dims) to NNEF.

Rotates by 90 * k degrees in the plane (dims[0], dims[1]). The rotation direction is from dims[0] toward dims[1], matching torch's convention. Decomposed per the standard flip + transpose identity:

  • k % 4 == 0: identity (single reshape with the same shape so the named output tensor is materialised).
  • k % 4 == 1: flip(dims[1]) -> transpose(dims[0], dims[1]).
  • k % 4 == 2: flip([dims[0], dims[1]]).
  • k % 4 == 3: transpose(dims[0], dims[1]) -> flip(dims[1]).

squeeze

squeeze(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:squeeze' to NNEF.

t

t(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:t' to NNEF.

Tensor.t() is a 2D-only transpose: rank 0 / 1 inputs pass through, rank 2 swaps axes (0, 1). Higher ranks are a torch error and we don't try to be friendlier than the source.

The rank<2 passthrough uses remap_node instead of emitting a no-op reshape so the input flows through untouched -- correct for dynamic-axes graphs (a literal shape= attr would lose symbolic dims).

transpose

transpose(g, node, name_to_tensor, inference_target, **kwargs)

Map PyTorch: 'aten:transpose' to NNEF.

unflatten

unflatten(g, node, name_to_tensor, torch_graph, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:unflatten' to NNEF.

unfold

unfold(node, op_helper, **kwargs)

Map PyTorch aten::unfold (Tensor.unfold) to NNEF.

Signature: unfold(self, dimension, size, step). Extracts overlapping windows of length size along dimension, advancing by step. The result has rank R + 1: the dimension axis becomes n_windows = (D - size) // step + 1, and a new trailing axis of length size is appended.

Decomposed as n_windows slice ops along dimension followed by a stack along that same axis; if dimension is not the last axis of the input, an extra transpose moves the size axis to the end (matching torch's "appended-at-back" layout).

unsqueeze

unsqueeze(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:unsqueeze' to NNEF.

view

view(g, node, name_to_tensor, torch_graph, inference_target, **kwargs)

Map PyTorch: 'aten:view' to NNEF.

view_as

view_as(g, node, name_to_tensor, op_helper, inference_target, **kwargs)

Map PyTorch: 'aten:view_as' to NNEF.

Equivalent to reshape with shape borrowed from the second input; NNEF reshape covers torch's view semantics for contiguous inputs. Supports dynamic-axes via runtime tract_core_shape_of(other).