Skip to content

Inference model

Classes

InferenceModel

ONNX model are loaded as an InferenceModels instead of Models: many ONNX models come with partial shape and element type information, while tract's Model assume full shape and element type knownledge. In this case, it is generally sufficient to inform tract about the input shape and type, then let tract infer the rest of the missing shape information before converting the InferenceModel to a regular Model.

# load the model as an InferenceModel
model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")

# set the shape and type of its first and only input
model.set_input_fact(0, "1,3,224,224,f32")

# get ready to run the model
model = model.into_optimized().into_runnable()
Source code in tract/inference_model.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class InferenceModel:
    """
    ONNX model are loaded as an
    `InferenceModel`s instead of `Model`s: many ONNX models come with partial shape and
    element type information, while tract's `Model` assume full shape and element type
    knownledge. In this case, it is generally sufficient to inform tract about the input
    shape and type, then let tract *infer* the rest of the missing shape information
    before converting the `InferenceModel` to a regular `Model`.

    ```python
    # load the model as an InferenceModel
    model = tract.onnx().model_for_path("./mobilenetv2-7.onnx")

    # set the shape and type of its first and only input
    model.set_input_fact(0, "1,3,224,224,f32")

    # get ready to run the model
    model = model.into_optimized().into_runnable()
    ```
    """
    def __init__(self, ptr):
        self.ptr = ptr

    def __del__(self):
        if self.ptr:
            check(lib.tract_inference_model_destroy(byref(self.ptr)))

    def _valid(self):
        if self.ptr == None:
            raise TractError("invalid inference model (maybe already consumed ?)")

    def into_optimized(self) -> Model:
        """
        Run the InferenceModel through the full tract optimisation pipeline to get an
        optimised Model.
        """
        self._valid()
        model = c_void_p()
        check(lib.tract_inference_model_into_optimized(byref(self.ptr), byref(model)))
        return Model(model)

    def into_typed(self) -> Model:
        """
        Convert an InferenceModel to a regular typed `Model`.

        This will leave the opportunity to run more transformation on the intermediary form of the
        model, before optimisint it all the way.
        """
        self._valid()
        model = c_void_p()
        check(lib.tract_inference_model_into_typed(byref(self.ptr), byref(model)))
        return Model(model)

    def input_count(self) -> int:
        """Return the number of inputs of the model"""
        self._valid()
        i = c_size_t()
        check(lib.tract_inference_model_nbio(self.ptr, byref(i), None))
        return i.value

    def output_count(self) -> int:
        """Return the number of outputs of the model"""
        self._valid()
        i = c_size_t()
        check(lib.tract_inference_model_nbio(self.ptr, None, byref(i)))
        return i.value

    def input_name(self, input_id: int) -> str:
        """Return the name of the `input_id`th input."""
        self._valid()
        cstring = c_char_p()
        check(lib.tract_inference_model_input_name(self.ptr, input_id, byref(cstring)))
        result = str(cstring.value, "utf-8")
        lib.tract_free_cstring(cstring)
        return result

    def input_fact(self, input_id: int) -> InferenceFact:
        """Extract the InferenceFact of the `input_id`th input."""
        self._valid()
        fact = c_void_p()
        check(lib.tract_inference_model_input_fact(self.ptr, input_id, byref(fact)))
        return InferenceFact(fact)

    def set_input_fact(self, input_id: int, fact: Union[InferenceFact, str, None]) -> None:
        """Change the InferenceFact of the `input_id`th input."""
        self._valid()
        if isinstance(fact, str):
            fact = self.fact(fact)
        if fact == None:
            check(lib.tract_inference_model_set_input_fact(self.ptr, input_id, None))
        else:
            check(lib.tract_inference_model_set_input_fact(self.ptr, input_id, fact.ptr))

    def set_output_names(self, names: List[str]):
        """Change the output nodes of the model"""
        self._valid()
        nb = len(names)
        names_str = []
        names_ptr = (c_char_p * nb)()
        for ix, n in enumerate(names):
            names_str.append(str(n).encode("utf-8"))
            names_ptr[ix] = names_str[ix]
        check(lib.tract_inference_model_set_output_names(self.ptr, nb, names_ptr))

    def output_name(self, output_id: int) -> str:
        """Return the name of the `output_id`th output."""
        self._valid()
        cstring = c_char_p()
        check(lib.tract_inference_model_output_name(self.ptr, output_id, byref(cstring)))
        result = str(cstring.value, "utf-8")
        lib.tract_free_cstring(cstring)
        return result

    def output_fact(self, output_id: int) -> InferenceFact:
        """Extract the InferenceFact of the `output_id`th output."""
        self._valid()
        fact = c_void_p()
        check(lib.tract_inference_model_output_fact(self.ptr, output_id, byref(fact)))
        return InferenceFact(fact)

    def set_output_fact(self, output_id: int, fact: Union[InferenceFact, str, None]) -> None:
        """Change the InferenceFact of the `output_id`th output."""
        self._valid()
        if isinstance(fact, str):
            fact = self.fact(fact)
        if fact == None:
            check(lib.tract_inference_model_set_output_fact(self.ptr, output_id, None))
        else:
            check(lib.tract_inference_model_set_output_fact(self.ptr, output_id, fact.ptr))

    def fact(self, spec:str) -> InferenceFact:
        """
        Parse an fact specification as an `InferenceFact`

        Typical `InferenceFact` specification is in the form "1,224,224,3,f32". Comma-separated
        list of dimension, one for each axis, plus an mnemonic for the element type. f32 is 
        single precision "float", i16 is a 16-bit signed integer, and u8 a 8-bit unsigned integer.
        """
        self._valid()
        spec = str(spec).encode("utf-8")
        fact = c_void_p();
        check(lib.tract_inference_fact_parse(self.ptr, spec, byref(fact)))
        return InferenceFact(fact)

    def analyse(self) -> None:
        """
        Perform shape and element type inference on the model.
        """
        self._valid()
        check(lib.tract_inference_model_analyse(self.ptr, False))

    def into_analysed(self) -> "InferenceModel":
        """
        Perform shape and element type inference on the model.
        """
        self.analyse()
        return self

Functions

into_optimized() -> Model

Run the InferenceModel through the full tract optimisation pipeline to get an optimised Model.

Source code in tract/inference_model.py
38
39
40
41
42
43
44
45
46
def into_optimized(self) -> Model:
    """
    Run the InferenceModel through the full tract optimisation pipeline to get an
    optimised Model.
    """
    self._valid()
    model = c_void_p()
    check(lib.tract_inference_model_into_optimized(byref(self.ptr), byref(model)))
    return Model(model)
into_typed() -> Model

Convert an InferenceModel to a regular typed Model.

This will leave the opportunity to run more transformation on the intermediary form of the model, before optimisint it all the way.

Source code in tract/inference_model.py
48
49
50
51
52
53
54
55
56
57
58
def into_typed(self) -> Model:
    """
    Convert an InferenceModel to a regular typed `Model`.

    This will leave the opportunity to run more transformation on the intermediary form of the
    model, before optimisint it all the way.
    """
    self._valid()
    model = c_void_p()
    check(lib.tract_inference_model_into_typed(byref(self.ptr), byref(model)))
    return Model(model)
input_count() -> int

Return the number of inputs of the model

Source code in tract/inference_model.py
60
61
62
63
64
65
def input_count(self) -> int:
    """Return the number of inputs of the model"""
    self._valid()
    i = c_size_t()
    check(lib.tract_inference_model_nbio(self.ptr, byref(i), None))
    return i.value
output_count() -> int

Return the number of outputs of the model

Source code in tract/inference_model.py
67
68
69
70
71
72
def output_count(self) -> int:
    """Return the number of outputs of the model"""
    self._valid()
    i = c_size_t()
    check(lib.tract_inference_model_nbio(self.ptr, None, byref(i)))
    return i.value
input_name(input_id: int) -> str

Return the name of the input_idth input.

Source code in tract/inference_model.py
74
75
76
77
78
79
80
81
def input_name(self, input_id: int) -> str:
    """Return the name of the `input_id`th input."""
    self._valid()
    cstring = c_char_p()
    check(lib.tract_inference_model_input_name(self.ptr, input_id, byref(cstring)))
    result = str(cstring.value, "utf-8")
    lib.tract_free_cstring(cstring)
    return result
input_fact(input_id: int) -> InferenceFact

Extract the InferenceFact of the input_idth input.

Source code in tract/inference_model.py
83
84
85
86
87
88
def input_fact(self, input_id: int) -> InferenceFact:
    """Extract the InferenceFact of the `input_id`th input."""
    self._valid()
    fact = c_void_p()
    check(lib.tract_inference_model_input_fact(self.ptr, input_id, byref(fact)))
    return InferenceFact(fact)
set_input_fact(input_id: int, fact: Union[InferenceFact, str, None]) -> None

Change the InferenceFact of the input_idth input.

Source code in tract/inference_model.py
90
91
92
93
94
95
96
97
98
def set_input_fact(self, input_id: int, fact: Union[InferenceFact, str, None]) -> None:
    """Change the InferenceFact of the `input_id`th input."""
    self._valid()
    if isinstance(fact, str):
        fact = self.fact(fact)
    if fact == None:
        check(lib.tract_inference_model_set_input_fact(self.ptr, input_id, None))
    else:
        check(lib.tract_inference_model_set_input_fact(self.ptr, input_id, fact.ptr))
set_output_names(names: List[str])

Change the output nodes of the model

Source code in tract/inference_model.py
100
101
102
103
104
105
106
107
108
109
def set_output_names(self, names: List[str]):
    """Change the output nodes of the model"""
    self._valid()
    nb = len(names)
    names_str = []
    names_ptr = (c_char_p * nb)()
    for ix, n in enumerate(names):
        names_str.append(str(n).encode("utf-8"))
        names_ptr[ix] = names_str[ix]
    check(lib.tract_inference_model_set_output_names(self.ptr, nb, names_ptr))
output_name(output_id: int) -> str

Return the name of the output_idth output.

Source code in tract/inference_model.py
111
112
113
114
115
116
117
118
def output_name(self, output_id: int) -> str:
    """Return the name of the `output_id`th output."""
    self._valid()
    cstring = c_char_p()
    check(lib.tract_inference_model_output_name(self.ptr, output_id, byref(cstring)))
    result = str(cstring.value, "utf-8")
    lib.tract_free_cstring(cstring)
    return result
output_fact(output_id: int) -> InferenceFact

Extract the InferenceFact of the output_idth output.

Source code in tract/inference_model.py
120
121
122
123
124
125
def output_fact(self, output_id: int) -> InferenceFact:
    """Extract the InferenceFact of the `output_id`th output."""
    self._valid()
    fact = c_void_p()
    check(lib.tract_inference_model_output_fact(self.ptr, output_id, byref(fact)))
    return InferenceFact(fact)
set_output_fact(output_id: int, fact: Union[InferenceFact, str, None]) -> None

Change the InferenceFact of the output_idth output.

Source code in tract/inference_model.py
127
128
129
130
131
132
133
134
135
def set_output_fact(self, output_id: int, fact: Union[InferenceFact, str, None]) -> None:
    """Change the InferenceFact of the `output_id`th output."""
    self._valid()
    if isinstance(fact, str):
        fact = self.fact(fact)
    if fact == None:
        check(lib.tract_inference_model_set_output_fact(self.ptr, output_id, None))
    else:
        check(lib.tract_inference_model_set_output_fact(self.ptr, output_id, fact.ptr))
fact(spec: str) -> InferenceFact

Parse an fact specification as an InferenceFact

Typical InferenceFact specification is in the form "1,224,224,3,f32". Comma-separated list of dimension, one for each axis, plus an mnemonic for the element type. f32 is single precision "float", i16 is a 16-bit signed integer, and u8 a 8-bit unsigned integer.

Source code in tract/inference_model.py
137
138
139
140
141
142
143
144
145
146
147
148
149
def fact(self, spec:str) -> InferenceFact:
    """
    Parse an fact specification as an `InferenceFact`

    Typical `InferenceFact` specification is in the form "1,224,224,3,f32". Comma-separated
    list of dimension, one for each axis, plus an mnemonic for the element type. f32 is 
    single precision "float", i16 is a 16-bit signed integer, and u8 a 8-bit unsigned integer.
    """
    self._valid()
    spec = str(spec).encode("utf-8")
    fact = c_void_p();
    check(lib.tract_inference_fact_parse(self.ptr, spec, byref(fact)))
    return InferenceFact(fact)
analyse() -> None

Perform shape and element type inference on the model.

Source code in tract/inference_model.py
151
152
153
154
155
156
def analyse(self) -> None:
    """
    Perform shape and element type inference on the model.
    """
    self._valid()
    check(lib.tract_inference_model_analyse(self.ptr, False))
into_analysed() -> InferenceModel

Perform shape and element type inference on the model.

Source code in tract/inference_model.py
158
159
160
161
162
163
def into_analysed(self) -> "InferenceModel":
    """
    Perform shape and element type inference on the model.
    """
    self.analyse()
    return self