Skip to content

split

torch_to_nnef.op.aten.split

chunk

chunk(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:chunk' (and unsafe_chunk) to NNEF.

unsafe_chunk has identical inference-time semantics to chunk; the only difference is the autograd-graph promise around in-place writes, which doesn't apply on the export path.

dsplit

dsplit(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:dsplit' (tensor_split along axis 2) to NNEF.

hsplit

hsplit(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:hsplit' (tensor_split along axis 1) to NNEF.

split

split(g, node, name_to_tensor, **kwargs)

Map PyTorch: aten::split.Tensor (and unsafe_split) to NNEF.

split(self, split_size: int, dim) produces ceil(dim_size / split_size) chunks, where every chunk has split_size elements along dim except possibly the last (which holds the remainder).

unsafe_split differs only in autograd semantics around in-place writes; inference behaviour is identical so it shares this path.

The chunk count is fixed by len(node.outputs); each output's per-axis size comes from its traced shape, so we walk them and emit a slice per output (same pattern as chunk / tensor_split).

split_with_sizes

split_with_sizes(g, node, name_to_tensor, **kwargs)

Translate aten::split_with_sizes to NNEF.

NNEF spec has a split op (value, axis, ratios -> tensor[]) but tract does not register it, so we re-express each output as a slice.

ratio_node may be a PythonConstant (literal sizes from the trace) or a TensorVariable whose data is shape-derived (e.g. fused-qkv splits like x.shape[-1] // 3); both cases are unwrapped to plain ints.

tensor_split

tensor_split(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:tensor_split' to NNEF.

Generalised split that allows uneven sections (unlike split / chunk). Two overloads are supported:

  • tensor_split(self, sections: int, dim) -- divide into N approximately-equal chunks; the first dim_size % N chunks take one extra element.
  • tensor_split(self, indices: int[], dim) -- split at the given boundary indices; produces len(indices) + 1 chunks.

Each output is a slice of the input along dim. Static-axis only: the boundaries depend on dim_size, which we resolve at trace time.

unbind

unbind(g, node, name_to_tensor, **kwargs)

Unbind is unstack in NNEF.

vsplit

vsplit(g, node, name_to_tensor, **kwargs)

Map PyTorch: 'aten:vsplit' (tensor_split along axis 0) to NNEF.