Skip to content

concat

torch_to_nnef.op.aten.concat

cat

cat(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:cat' to NNEF.

dstack

dstack(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:dstack' to NNEF (concat(axis=2)).

dstack stacks tensors along the third axis (depth). Torch promotes 1-D / 2-D inputs to 3-D before reaching the aten op, so we only see rank>=3 inputs here.

hstack

hstack(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:hstack' to NNEF (concat(axis=1)).

roll

roll(g, node, name_to_tensor, torch_graph, inference_target, **kwargs)

Map PyTorch: 'aten:roll' to NNEF.

PyTorch normalizes shifts modulo the dim size; tract does not, and the slice/concat decomposition we emit produces an empty slice for shift=0 or |shift|>=dim_size, which tract misorders into a doubled-shape output. We reproduce PyTorch's normalization here:

  • Drop any (shift, dim) pair where the normalized shift is 0 (no-op).
  • Replace each remaining shift with shift % dim_size so the slice indices stay in (0, dim_size).

If every pair normalizes away, the entire op is a graph identity. we remap the output node to the input.

stack

stack(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:stack' to NNEF.

vstack

vstack(g, node, name_to_tensor, torch_graph, **kwargs)

Map PyTorch: 'aten:vstack' to NNEF (concat(axis=0)).