Skip to content

Model (aka Typed Model)

Classes

Model

Main model object

Central focus point of the model transformation pipeline

The Model is the central point of tract model loading and "model cooking". ONNX and NNEF serialized models are converted to Model (more or less directly) before we can do anything of value with them. Model can be dumped to NNEF (or tract-opl which is NNEF plus tract proprietary extensions).

A Model can be optimize(), substituing the "high level" operators in tract-core operator set by the best implementation available for the current system. From there it can be transformed into a Runnable object that we will use to run.

Model cooking

But some model transformations can be peformed on the Model class:

  • declutter (getting rid of training artefacts)
  • "pulsification" (transforming a batch-oriented model into a streaming model)
  • symbol substitution (make N or Batch a fixed number, unlocking potential optimisation later on)
  • static cost evalation and dynamic profiling
  • ...

In some situation, these operation are done "on-the-fly" when a ONNX or NNEF model is loaded, at start-up time. In other situation, when start-up time becomes an issue, it may be beneficial to "pre-cook" the model: apply the transformations one time, serialize the model as NNEF (with tract-opl extension if needed). At start-up this model can be significantly less expensive to "cook" for inference.

Model and TypedModel

This class is actually a wrapper around the "TypedModel" in Rust codebase. The "Typed" bit means than all shapes and element types in all input, output and temporary values must known. There is support in tract for symbols in dimensions, with some limited computation capabilities on symbolic expression. For instance, it is relatively frequent to work with a Model where all tensors shapes start with the N or Batch.

Source code in tract/model.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
class Model:
    """
    # Main model object

    ## Central focus point of the model transformation pipeline

    The Model is the central point of tract model loading and "model cooking". ONNX and NNEF 
    serialized models are converted to Model (more or less directly) before we can do anything
    of value with them. Model can be dumped to NNEF (or tract-opl which is NNEF plus tract
    proprietary extensions).

    A Model can be `optimize()`, substituing the "high level" operators in tract-core operator set by
    the best implementation available for the current system. From there it can be transformed into a 
    Runnable object that we will use to run.

    ## Model cooking

    But some model transformations can be peformed on the `Model` class:

    * declutter (getting rid of training artefacts)
    * "pulsification" (transforming a batch-oriented model into a streaming model)
    * symbol substitution (make N or Batch a fixed number, unlocking potential optimisation later on)
    * static cost evalation and dynamic profiling
    * ...

    In some situation, these operation are done "on-the-fly" when a ONNX or NNEF model is loaded,
    at start-up time. In other situation, when start-up time becomes an issue, it may be beneficial
    to "pre-cook" the model: apply the transformations one time, serialize the model as NNEF (with
    tract-opl extension if needed). At start-up this model can be significantly less expensive to
    "cook" for inference.

    ## Model and TypedModel

    This class is actually a wrapper around the "TypedModel" in Rust codebase. The "Typed"
    bit means than all shapes and element types in all input, output and temporary values must
    known. There is support in tract for symbols in dimensions, with some limited computation
    capabilities on symbolic expression. For instance, it is relatively frequent to work with
    a Model where all tensors shapes start with the `N` or `Batch`.
    """

    def __init__(self, ptr):
        self.ptr = ptr

    def __del__(self):
        if self.ptr:
            check(lib.tract_model_destroy(byref(self.ptr)))

    def _valid(self):
        if self.ptr == None:
            raise TractError("invalid model (maybe already consumed ?)")

    def input_count(self) -> int:
        """Return the number of inputs of the model"""
        self._valid()
        i = c_size_t()
        check(lib.tract_model_nbio(self.ptr, byref(i), None))
        return i.value

    def output_count(self) -> int:
        """Return the number of outputs of the model"""
        self._valid()
        i = c_size_t()
        check(lib.tract_model_nbio(self.ptr, None, byref(i)))
        return i.value

    def input_name(self, input_id: int) -> str:
        """Return the name of the input_id-th input"""
        self._valid()
        cstring = c_char_p()
        check(lib.tract_model_input_name(self.ptr, input_id, byref(cstring)))
        result = str(cstring.value, "utf-8")
        lib.tract_free_cstring(cstring)
        return result

    def input_fact(self, input_id: int) -> Fact:
        """Return the fact of the input_id-th input"""
        self._valid()
        fact = c_void_p()
        check(lib.tract_model_input_fact(self.ptr, input_id, byref(fact)))
        return Fact(fact)

    def set_output_names(self, names: List[str]):
        """Change the output nodes of the model"""
        self._valid()
        nb = len(names)
        names_str = []
        names_ptr = (c_char_p * nb)()
        for ix, n in enumerate(names):
            names_str.append(str(n).encode("utf-8"))
            names_ptr[ix] = names_str[ix]
        check(lib.tract_model_set_output_names(self.ptr, nb, names_ptr))

    def output_name(self, output_id: int) -> str:
        """Return the name of the output_id-th output"""
        self._valid()
        cstring = c_char_p()
        check(lib.tract_model_output_name(self.ptr, output_id, byref(cstring)))
        result = str(cstring.value, "utf-8")
        lib.tract_free_cstring(cstring)
        return result

    def output_fact(self, input_id: int) -> Fact:
        """Return the fact of the output_id-th output"""
        self._valid()
        fact = c_void_p()
        check(lib.tract_model_output_fact(self.ptr, input_id, byref(fact)))
        return Fact(fact)

    def concretize_symbols(self, values: Dict[str, int]) -> None:
        """Substitute symbols by a value

        Replace all occurencies of the symbols in the dictionary, in all the Model facts shapes.

        While this is not strictly necesary, the optimizing steps may make better choices if the model
        is informed of some specific symbol values.
        """
        self._valid()
        nb = len(values)
        names_str = []
        names = (c_char_p * nb)()
        values_list = (c_int64 * nb)()
        for ix, (k, v) in enumerate(values.items()):
            names_str.append(str(k).encode("utf-8"))
            names[ix] = names_str[ix]
            values_list[ix] = v
        check(lib.tract_model_concretize_symbols(self.ptr, c_size_t(nb), names, values_list))

    def pulse(self, symbol: str, pulse: Union[str, int]) -> None:
        """Pulsify a model.

        `pulse` is typically a one-length dictionary mapping the time dimension symbol to a pulse len.
        """
        self._valid()
        check(lib.tract_model_pulse_simple(byref(self.ptr), symbol.encode("utf-8"), str(pulse).encode("utf-8")))

    def declutter(self) -> None:
        """Declutter the model in place"""
        self._valid()
        check(lib.tract_model_declutter(self.ptr))

    def optimize(self) -> None:
        """Optimize the model in place"""
        self._valid()
        check(lib.tract_model_optimize(self.ptr))

    def into_decluttered(self) -> "Model":
        """Convenience method performing `declutter()` and returning the model"""
        self.declutter();
        return self

    def into_optimized(self) -> "Model":
        """Convenience method performing `optimize()` and returning the model"""
        self.optimize()
        return self

    def into_runnable(self) -> Runnable:
        """Transform the model into a Runnable model ready to be used"""
        self._valid()
        runnable = c_void_p()
        check(lib.tract_model_into_runnable(byref(self.ptr), byref(runnable)))
        return Runnable(runnable)

    def property_keys(self) -> List[str]:
        """Extract the list of properties from a model"""
        self._valid()
        count = c_size_t()
        check(lib.tract_model_property_count(self.ptr, byref(count)))
        count = count.value
        cstrings = (POINTER(c_char) * count)()
        check(lib.tract_model_property_names(self.ptr, cstrings))
        names = []
        for i in range(0, count):
            names.append(str(cast(cstrings[i], c_char_p).value, "utf-8"))
            lib.tract_free_cstring(cstrings[i])
        return names

    def property(self, name: str) -> Value:
        """Query a property by name"""
        self._valid()
        value = c_void_p()
        check(lib.tract_model_property(self.ptr, str(name).encode("utf-8"), byref(value)))
        return Value(value)

    def profile_json(self, inputs: Union[None, List[Union[Value, numpy.ndarray]]]) -> str:
        """Profile the model on the provided input"""
        self._valid()
        cstring = c_char_p()
        input_values = []
        input_ptrs = None
        if inputs != None:
            for v in inputs:
                if isinstance(v, Value):
                    input_values.append(v)
                elif isinstance(v, numpy.ndarray):
                    input_values.append(Value.from_numpy(v))
                else:
                    raise TractError(f"Inputs must be of type tract.Value or numpy.Array, got {v}")
            input_ptrs = (c_void_p * len(inputs))()
            for ix, v in enumerate(input_values):
                input_ptrs[ix] = v.ptr
        check(lib.tract_model_profile_json(self.ptr, input_ptrs, byref(cstring)))
        result = str(cstring.value, "utf-8")
        lib.tract_free_cstring(cstring)
        return result

Functions

input_count() -> int

Return the number of inputs of the model

Source code in tract/model.py
60
61
62
63
64
65
def input_count(self) -> int:
    """Return the number of inputs of the model"""
    self._valid()
    i = c_size_t()
    check(lib.tract_model_nbio(self.ptr, byref(i), None))
    return i.value
output_count() -> int

Return the number of outputs of the model

Source code in tract/model.py
67
68
69
70
71
72
def output_count(self) -> int:
    """Return the number of outputs of the model"""
    self._valid()
    i = c_size_t()
    check(lib.tract_model_nbio(self.ptr, None, byref(i)))
    return i.value
input_name(input_id: int) -> str

Return the name of the input_id-th input

Source code in tract/model.py
74
75
76
77
78
79
80
81
def input_name(self, input_id: int) -> str:
    """Return the name of the input_id-th input"""
    self._valid()
    cstring = c_char_p()
    check(lib.tract_model_input_name(self.ptr, input_id, byref(cstring)))
    result = str(cstring.value, "utf-8")
    lib.tract_free_cstring(cstring)
    return result
input_fact(input_id: int) -> Fact

Return the fact of the input_id-th input

Source code in tract/model.py
83
84
85
86
87
88
def input_fact(self, input_id: int) -> Fact:
    """Return the fact of the input_id-th input"""
    self._valid()
    fact = c_void_p()
    check(lib.tract_model_input_fact(self.ptr, input_id, byref(fact)))
    return Fact(fact)
set_output_names(names: List[str])

Change the output nodes of the model

Source code in tract/model.py
90
91
92
93
94
95
96
97
98
99
def set_output_names(self, names: List[str]):
    """Change the output nodes of the model"""
    self._valid()
    nb = len(names)
    names_str = []
    names_ptr = (c_char_p * nb)()
    for ix, n in enumerate(names):
        names_str.append(str(n).encode("utf-8"))
        names_ptr[ix] = names_str[ix]
    check(lib.tract_model_set_output_names(self.ptr, nb, names_ptr))
output_name(output_id: int) -> str

Return the name of the output_id-th output

Source code in tract/model.py
101
102
103
104
105
106
107
108
def output_name(self, output_id: int) -> str:
    """Return the name of the output_id-th output"""
    self._valid()
    cstring = c_char_p()
    check(lib.tract_model_output_name(self.ptr, output_id, byref(cstring)))
    result = str(cstring.value, "utf-8")
    lib.tract_free_cstring(cstring)
    return result
output_fact(input_id: int) -> Fact

Return the fact of the output_id-th output

Source code in tract/model.py
110
111
112
113
114
115
def output_fact(self, input_id: int) -> Fact:
    """Return the fact of the output_id-th output"""
    self._valid()
    fact = c_void_p()
    check(lib.tract_model_output_fact(self.ptr, input_id, byref(fact)))
    return Fact(fact)
concretize_symbols(values: Dict[str, int]) -> None

Substitute symbols by a value

Replace all occurencies of the symbols in the dictionary, in all the Model facts shapes.

While this is not strictly necesary, the optimizing steps may make better choices if the model is informed of some specific symbol values.

Source code in tract/model.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def concretize_symbols(self, values: Dict[str, int]) -> None:
    """Substitute symbols by a value

    Replace all occurencies of the symbols in the dictionary, in all the Model facts shapes.

    While this is not strictly necesary, the optimizing steps may make better choices if the model
    is informed of some specific symbol values.
    """
    self._valid()
    nb = len(values)
    names_str = []
    names = (c_char_p * nb)()
    values_list = (c_int64 * nb)()
    for ix, (k, v) in enumerate(values.items()):
        names_str.append(str(k).encode("utf-8"))
        names[ix] = names_str[ix]
        values_list[ix] = v
    check(lib.tract_model_concretize_symbols(self.ptr, c_size_t(nb), names, values_list))
pulse(symbol: str, pulse: Union[str, int]) -> None

Pulsify a model.

pulse is typically a one-length dictionary mapping the time dimension symbol to a pulse len.

Source code in tract/model.py
136
137
138
139
140
141
142
def pulse(self, symbol: str, pulse: Union[str, int]) -> None:
    """Pulsify a model.

    `pulse` is typically a one-length dictionary mapping the time dimension symbol to a pulse len.
    """
    self._valid()
    check(lib.tract_model_pulse_simple(byref(self.ptr), symbol.encode("utf-8"), str(pulse).encode("utf-8")))
declutter() -> None

Declutter the model in place

Source code in tract/model.py
144
145
146
147
def declutter(self) -> None:
    """Declutter the model in place"""
    self._valid()
    check(lib.tract_model_declutter(self.ptr))
optimize() -> None

Optimize the model in place

Source code in tract/model.py
149
150
151
152
def optimize(self) -> None:
    """Optimize the model in place"""
    self._valid()
    check(lib.tract_model_optimize(self.ptr))
into_decluttered() -> Model

Convenience method performing declutter() and returning the model

Source code in tract/model.py
154
155
156
157
def into_decluttered(self) -> "Model":
    """Convenience method performing `declutter()` and returning the model"""
    self.declutter();
    return self
into_optimized() -> Model

Convenience method performing optimize() and returning the model

Source code in tract/model.py
159
160
161
162
def into_optimized(self) -> "Model":
    """Convenience method performing `optimize()` and returning the model"""
    self.optimize()
    return self
into_runnable() -> Runnable

Transform the model into a Runnable model ready to be used

Source code in tract/model.py
164
165
166
167
168
169
def into_runnable(self) -> Runnable:
    """Transform the model into a Runnable model ready to be used"""
    self._valid()
    runnable = c_void_p()
    check(lib.tract_model_into_runnable(byref(self.ptr), byref(runnable)))
    return Runnable(runnable)
property_keys() -> List[str]

Extract the list of properties from a model

Source code in tract/model.py
171
172
173
174
175
176
177
178
179
180
181
182
183
def property_keys(self) -> List[str]:
    """Extract the list of properties from a model"""
    self._valid()
    count = c_size_t()
    check(lib.tract_model_property_count(self.ptr, byref(count)))
    count = count.value
    cstrings = (POINTER(c_char) * count)()
    check(lib.tract_model_property_names(self.ptr, cstrings))
    names = []
    for i in range(0, count):
        names.append(str(cast(cstrings[i], c_char_p).value, "utf-8"))
        lib.tract_free_cstring(cstrings[i])
    return names
property(name: str) -> Value

Query a property by name

Source code in tract/model.py
185
186
187
188
189
190
def property(self, name: str) -> Value:
    """Query a property by name"""
    self._valid()
    value = c_void_p()
    check(lib.tract_model_property(self.ptr, str(name).encode("utf-8"), byref(value)))
    return Value(value)
profile_json(inputs: Union[None, List[Union[Value, numpy.ndarray]]]) -> str

Profile the model on the provided input

Source code in tract/model.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def profile_json(self, inputs: Union[None, List[Union[Value, numpy.ndarray]]]) -> str:
    """Profile the model on the provided input"""
    self._valid()
    cstring = c_char_p()
    input_values = []
    input_ptrs = None
    if inputs != None:
        for v in inputs:
            if isinstance(v, Value):
                input_values.append(v)
            elif isinstance(v, numpy.ndarray):
                input_values.append(Value.from_numpy(v))
            else:
                raise TractError(f"Inputs must be of type tract.Value or numpy.Array, got {v}")
        input_ptrs = (c_void_p * len(inputs))()
        for ix, v in enumerate(input_values):
            input_ptrs[ix] = v.ptr
    check(lib.tract_model_profile_json(self.ptr, input_ptrs, byref(cstring)))
    result = str(cstring.value, "utf-8")
    lib.tract_free_cstring(cstring)
    return result